diff --git a/server/router/api/v1/auth_passkey.go b/server/router/api/v1/auth_passkey.go new file mode 100644 index 0000000000000..01f95d9dfef67 --- /dev/null +++ b/server/router/api/v1/auth_passkey.go @@ -0,0 +1,1165 @@ +package v1 + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "math/big" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v5" + "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/usememos/memos/internal/util" + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" +) + +const ( + passkeySessionDuration = 5 * time.Minute + + passkeyRegistrationFlow = "registration" + passkeyAuthenticationFlow = "authentication" + passkeySessionAudienceBase = "passkey:" + passkeyCredentialType = "public-key" + + passkeyAlgES256 = -7 + passkeyAlgEdDSA = -8 + passkeyAlgRS256 = -257 +) + +type passkeySessionClaims struct { + Flow string `json:"flow"` + Challenge string `json:"challenge"` + RPID string `json:"rpId"` + Origin string `json:"origin"` + Username string `json:"username"` + jwt.RegisteredClaims +} + +type passkeyRPJSON struct { + Name string `json:"name"` + ID string `json:"id"` +} + +type passkeyUserJSON struct { + ID string `json:"id"` + Name string `json:"name"` + DisplayName string `json:"displayName"` +} + +type passkeyPubKeyCredentialParamJSON struct { + Type string `json:"type"` + Alg int32 `json:"alg"` +} + +type passkeyCredentialDescriptorJSON struct { + Type string `json:"type"` + ID string `json:"id"` + Transports []string `json:"transports,omitempty"` +} + +type passkeyAuthenticatorSelectionJSON struct { + ResidentKey string `json:"residentKey,omitempty"` + UserVerification string `json:"userVerification,omitempty"` +} + +type beginPasskeyRegistrationResponse struct { + State string `json:"state"` + PublicKey passkeyCreationOptionsJSON `json:"publicKey"` +} + +type passkeyCreationOptionsJSON struct { + Challenge string `json:"challenge"` + RP passkeyRPJSON `json:"rp"` + User passkeyUserJSON `json:"user"` + PubKeyCredParams []passkeyPubKeyCredentialParamJSON `json:"pubKeyCredParams"` + Timeout int `json:"timeout"` + Attestation string `json:"attestation"` + ExcludeCredentials []passkeyCredentialDescriptorJSON `json:"excludeCredentials,omitempty"` + AuthenticatorSelection *passkeyAuthenticatorSelectionJSON `json:"authenticatorSelection,omitempty"` +} + +type beginPasskeyAuthenticationRequest struct { + Username string `json:"username"` +} + +type beginPasskeyAuthenticationResponse struct { + State string `json:"state"` + PublicKey passkeyRequestOptionsJSON `json:"publicKey"` +} + +type passkeyRequestOptionsJSON struct { + Challenge string `json:"challenge"` + RPID string `json:"rpId"` + Timeout int `json:"timeout"` + UserVerification string `json:"userVerification,omitempty"` + AllowCredentials []passkeyCredentialDescriptorJSON `json:"allowCredentials,omitempty"` +} + +type finishPasskeyRegistrationRequest struct { + State string `json:"state"` + Credential passkeyRegistrationCredentialJSON `json:"credential"` +} + +type passkeyRegistrationCredentialJSON struct { + ID string `json:"id"` + RawID string `json:"rawId"` + Type string `json:"type"` + Response passkeyRegistrationResponseJSON `json:"response"` +} + +type passkeyRegistrationResponseJSON struct { + ClientDataJSON string `json:"clientDataJSON"` + AttestationObject string `json:"attestationObject"` + Transports []string `json:"transports,omitempty"` +} + +type finishPasskeyAuthenticationRequest struct { + State string `json:"state"` + Credential passkeyAuthenticationCredentialJSON `json:"credential"` +} + +type passkeyAuthenticationCredentialJSON struct { + ID string `json:"id"` + RawID string `json:"rawId"` + Type string `json:"type"` + Response passkeyAuthenticationResponseJSON `json:"response"` +} + +type passkeyAuthenticationResponseJSON struct { + ClientDataJSON string `json:"clientDataJSON"` + AuthenticatorData string `json:"authenticatorData"` + Signature string `json:"signature"` + UserHandle string `json:"userHandle,omitempty"` +} + +type finishPasskeyAuthenticationResponse struct { + AccessToken string `json:"accessToken"` + AccessTokenExpiresAt string `json:"accessTokenExpiresAt"` +} + +type listPasskeysResponse struct { + Passkeys []passkeyJSON `json:"passkeys"` +} + +type passkeyJSON struct { + ID string `json:"id"` + Label string `json:"label"` + Transports []string `json:"transports,omitempty"` + AddedTs int64 `json:"addedTs"` + LastUsedTs int64 `json:"lastUsedTs,omitempty"` +} + +type passkeyRelyingParty struct { + ID string + Name string + Origin string +} + +type passkeyClientData struct { + Type string `json:"type"` + Challenge string `json:"challenge"` + Origin string `json:"origin"` +} + +type parsedPasskeyAuthData struct { + RPIDHash []byte + Flags byte + SignCount uint32 + CredentialID []byte + CredentialPublicKey []byte +} + +func (s *APIV1Service) registerPasskeyRoutes(group *echo.Group) { + group.GET("/api/v1/auth/passkeys", s.listPasskeysHandler) + group.DELETE("/api/v1/auth/passkeys/:passkeyID", s.deletePasskeyHandler) + group.POST("/api/v1/auth/passkeys/registration/begin", s.beginPasskeyRegistrationHandler) + group.POST("/api/v1/auth/passkeys/registration/finish", s.finishPasskeyRegistrationHandler) + group.POST("/api/v1/auth/passkeys/authentication/begin", s.beginPasskeyAuthenticationHandler) + group.POST("/api/v1/auth/passkeys/authentication/finish", s.finishPasskeyAuthenticationHandler) +} + +func (s *APIV1Service) listPasskeysHandler(c *echo.Context) error { + ctx, currentUser, err := s.authenticateNativeRequest(c, true) + if err != nil { + return s.writeNativeError(c, err) + } + + passkeys, err := s.Store.GetUserPasskeys(ctx, currentUser.ID) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to load passkeys")) + } + + response := listPasskeysResponse{ + Passkeys: make([]passkeyJSON, 0, len(passkeys)), + } + for _, passkey := range passkeys { + response.Passkeys = append(response.Passkeys, passkeyJSON{ + ID: passkey.ID, + Label: passkey.Label, + Transports: append([]string(nil), passkey.Transports...), + AddedTs: passkey.AddedTs, + LastUsedTs: passkey.LastUsedTs, + }) + } + + s.applyNativeResponseHeaders(ctx, c) + return c.JSON(http.StatusOK, response) +} + +func (s *APIV1Service) deletePasskeyHandler(c *echo.Context) error { + ctx, currentUser, err := s.authenticateNativeRequest(c, true) + if err != nil { + return s.writeNativeError(c, err) + } + + passkeyID := strings.TrimSpace(c.Param("passkeyID")) + if passkeyID == "" { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "passkey id is required")) + } + + if err := s.Store.DeleteUserPasskey(ctx, currentUser.ID, passkeyID); err != nil { + if strings.Contains(err.Error(), "not found") { + return s.writeNativeError(c, status.Errorf(codes.NotFound, "passkey not found")) + } + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to delete passkey")) + } + + s.applyNativeResponseHeaders(ctx, c) + return c.NoContent(http.StatusNoContent) +} + +func (s *APIV1Service) beginPasskeyRegistrationHandler(c *echo.Context) error { + ctx, currentUser, err := s.authenticateNativeRequest(c, true) + if err != nil { + return s.writeNativeError(c, err) + } + + rp, err := s.resolvePasskeyRelyingParty(ctx) + if err != nil { + return s.writeNativeError(c, err) + } + + challenge, err := randomBase64URL(32) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to generate challenge")) + } + + passkeys, err := s.Store.GetUserPasskeys(ctx, currentUser.ID) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to load passkeys")) + } + + state, err := s.signPasskeySessionToken(currentUser, passkeyRegistrationFlow, challenge, rp) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to create passkey session")) + } + + response := beginPasskeyRegistrationResponse{ + State: state, + PublicKey: passkeyCreationOptionsJSON{ + Challenge: challenge, + RP: passkeyRPJSON{ + Name: rp.Name, + ID: rp.ID, + }, + User: passkeyUserJSON{ + ID: base64.RawURLEncoding.EncodeToString([]byte(strconv.Itoa(int(currentUser.ID)))), + Name: currentUser.Username, + DisplayName: currentUser.Nickname, + }, + PubKeyCredParams: []passkeyPubKeyCredentialParamJSON{ + {Type: passkeyCredentialType, Alg: passkeyAlgES256}, + {Type: passkeyCredentialType, Alg: passkeyAlgEdDSA}, + {Type: passkeyCredentialType, Alg: passkeyAlgRS256}, + }, + Timeout: int((60 * time.Second) / time.Millisecond), + Attestation: "none", + ExcludeCredentials: buildPasskeyCredentialDescriptors(passkeys), + AuthenticatorSelection: &passkeyAuthenticatorSelectionJSON{ + ResidentKey: "preferred", + UserVerification: "preferred", + }, + }, + } + + s.applyNativeResponseHeaders(ctx, c) + return c.JSON(http.StatusOK, response) +} + +func (s *APIV1Service) finishPasskeyRegistrationHandler(c *echo.Context) error { + ctx, currentUser, err := s.authenticateNativeRequest(c, true) + if err != nil { + return s.writeNativeError(c, err) + } + + request := &finishPasskeyRegistrationRequest{} + if err := decodeNativeJSONBody(c, request); err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid request body")) + } + + claims, err := s.parsePasskeySessionToken(request.State, passkeyRegistrationFlow) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid passkey session")) + } + if claims.Subject != strconv.Itoa(int(currentUser.ID)) { + return s.writeNativeError(c, status.Errorf(codes.PermissionDenied, "passkey session does not belong to current user")) + } + + rp, err := s.resolvePasskeyRelyingParty(ctx) + if err != nil { + return s.writeNativeError(c, err) + } + if rp.ID != claims.RPID || rp.Origin != claims.Origin { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "relying party changed during registration")) + } + + if request.Credential.Type != passkeyCredentialType { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid credential type")) + } + + rawCredentialID, err := decodeBase64URL(request.Credential.RawID) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid credential id")) + } + if request.Credential.ID != request.Credential.RawID { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "credential id mismatch")) + } + + clientDataJSON, err := decodeBase64URL(request.Credential.Response.ClientDataJSON) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid client data")) + } + if err := validatePasskeyClientData(clientDataJSON, "webauthn.create", claims.Challenge, claims.Origin); err != nil { + return s.writeNativeError(c, err) + } + + attestationObject, err := decodeBase64URL(request.Credential.Response.AttestationObject) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid attestation object")) + } + authData, err := parseAttestationAuthData(attestationObject) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid attestation object")) + } + if err := validatePasskeyAuthData(authData, rp.ID, true); err != nil { + return s.writeNativeError(c, err) + } + if !bytes.Equal(authData.CredentialID, rawCredentialID) { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "credential id mismatch")) + } + + algorithm, err := extractCOSEAlgorithm(authData.CredentialPublicKey) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "unsupported public key")) + } + + passkeys, err := s.Store.GetUserPasskeys(ctx, currentUser.ID) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to load passkeys")) + } + for _, existing := range passkeys { + if existing.CredentialID == request.Credential.RawID { + return s.writeNativeError(c, status.Errorf(codes.AlreadyExists, "passkey already exists")) + } + } + + passkey := &store.Passkey{ + ID: util.GenUUID(), + Label: buildPasskeyLabel(s.extractClientInfo(ctx), time.Now()), + CredentialID: request.Credential.RawID, + PublicKey: base64.RawURLEncoding.EncodeToString(authData.CredentialPublicKey), + Algorithm: algorithm, + SignCount: authData.SignCount, + Transports: normalizePasskeyTransports(request.Credential.Response.Transports), + AddedTs: time.Now().Unix(), + } + if err := s.Store.AddUserPasskey(ctx, currentUser.ID, passkey); err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to save passkey")) + } + + s.applyNativeResponseHeaders(ctx, c) + return c.NoContent(http.StatusNoContent) +} + +func (s *APIV1Service) beginPasskeyAuthenticationHandler(c *echo.Context) error { + ctx, _, err := s.authenticateNativeRequest(c, false) + if err != nil { + return s.writeNativeError(c, err) + } + + request := &beginPasskeyAuthenticationRequest{} + if err := decodeNativeJSONBody(c, request); err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid request body")) + } + request.Username = strings.TrimSpace(request.Username) + if request.Username == "" { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "username is required")) + } + + user, err := s.Store.GetUser(ctx, &store.FindUser{ + Username: &request.Username, + }) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to get user")) + } + if user == nil || user.RowStatus == store.Archived { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "passkey sign in is not available")) + } + + passkeys, err := s.Store.GetUserPasskeys(ctx, user.ID) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to load passkeys")) + } + if len(passkeys) == 0 { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "passkey sign in is not available")) + } + + rp, err := s.resolvePasskeyRelyingParty(ctx) + if err != nil { + return s.writeNativeError(c, err) + } + + challenge, err := randomBase64URL(32) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to generate challenge")) + } + + state, err := s.signPasskeySessionToken(user, passkeyAuthenticationFlow, challenge, rp) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to create passkey session")) + } + + response := beginPasskeyAuthenticationResponse{ + State: state, + PublicKey: passkeyRequestOptionsJSON{ + Challenge: challenge, + RPID: rp.ID, + Timeout: int((60 * time.Second) / time.Millisecond), + UserVerification: "preferred", + AllowCredentials: buildPasskeyCredentialDescriptors(passkeys), + }, + } + + s.applyNativeResponseHeaders(ctx, c) + return c.JSON(http.StatusOK, response) +} + +func (s *APIV1Service) finishPasskeyAuthenticationHandler(c *echo.Context) error { + ctx, _, err := s.authenticateNativeRequest(c, false) + if err != nil { + return s.writeNativeError(c, err) + } + + request := &finishPasskeyAuthenticationRequest{} + if err := decodeNativeJSONBody(c, request); err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid request body")) + } + + claims, err := s.parsePasskeySessionToken(request.State, passkeyAuthenticationFlow) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid passkey session")) + } + + userID, err := util.ConvertStringToInt32(claims.Subject) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid passkey session")) + } + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to get user")) + } + if user == nil || user.RowStatus == store.Archived { + return s.writeNativeError(c, status.Errorf(codes.PermissionDenied, "user is unavailable")) + } + + rp, err := s.resolvePasskeyRelyingParty(ctx) + if err != nil { + return s.writeNativeError(c, err) + } + if rp.ID != claims.RPID || rp.Origin != claims.Origin { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "relying party changed during authentication")) + } + + if request.Credential.Type != passkeyCredentialType { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid credential type")) + } + if request.Credential.ID != request.Credential.RawID { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "credential id mismatch")) + } + + clientDataJSON, err := decodeBase64URL(request.Credential.Response.ClientDataJSON) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid client data")) + } + if err := validatePasskeyClientData(clientDataJSON, "webauthn.get", claims.Challenge, claims.Origin); err != nil { + return s.writeNativeError(c, err) + } + + authenticatorData, err := decodeBase64URL(request.Credential.Response.AuthenticatorData) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid authenticator data")) + } + authData, err := parseAssertionAuthData(authenticatorData) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid authenticator data")) + } + if err := validatePasskeyAuthData(authData, rp.ID, false); err != nil { + return s.writeNativeError(c, err) + } + + passkeys, err := s.Store.GetUserPasskeys(ctx, user.ID) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to load passkeys")) + } + var matched *store.Passkey + for _, passkey := range passkeys { + if passkey.CredentialID == request.Credential.RawID { + matched = passkey + break + } + } + if matched == nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "passkey not found")) + } + + signature, err := decodeBase64URL(request.Credential.Response.Signature) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.InvalidArgument, "invalid signature")) + } + if err := verifyPasskeySignature(matched, authenticatorData, clientDataJSON, signature); err != nil { + return s.writeNativeError(c, status.Errorf(codes.PermissionDenied, "passkey verification failed")) + } + if matched.SignCount > 0 && authData.SignCount > 0 && authData.SignCount <= matched.SignCount { + return s.writeNativeError(c, status.Errorf(codes.PermissionDenied, "passkey sign count is invalid")) + } + + updatedPasskey := *matched + updatedPasskey.SignCount = authData.SignCount + updatedPasskey.LastUsedTs = time.Now().Unix() + if err := s.Store.UpdateUserPasskey(ctx, user.ID, &updatedPasskey); err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to update passkey")) + } + + accessToken, accessExpiresAt, err := s.doSignIn(ctx, user) + if err != nil { + return s.writeNativeError(c, status.Errorf(codes.Internal, "failed to sign in")) + } + + s.applyNativeResponseHeaders(ctx, c) + return c.JSON(http.StatusOK, finishPasskeyAuthenticationResponse{ + AccessToken: accessToken, + AccessTokenExpiresAt: accessExpiresAt.Format(time.RFC3339), + }) +} + +func (s *APIV1Service) authenticateNativeRequest(c *echo.Context, requireAuth bool) (context.Context, *store.User, error) { + ctx := WithHeaderCarrier(c.Request().Context()) + ctx = metadata.NewIncomingContext(ctx, metadataFromHeaders(c.Request().Header, c.Request().Host)) + + authenticator := auth.NewAuthenticator(s.Store, s.Secret) + result := authenticator.Authenticate(ctx, c.Request().Header.Get("Authorization")) + if result == nil { + if requireAuth { + return ctx, nil, status.Errorf(codes.Unauthenticated, "authentication required") + } + return auth.ApplyToContext(ctx, nil), nil, nil + } + + ctx = auth.ApplyToContext(ctx, result) + currentUser, err := s.fetchCurrentUser(ctx) + if err != nil { + return ctx, nil, status.Errorf(codes.Internal, "failed to get current user") + } + if currentUser == nil { + return ctx, nil, status.Errorf(codes.Unauthenticated, "user not found") + } + return ctx, currentUser, nil +} + +func (s *APIV1Service) applyNativeResponseHeaders(ctx context.Context, c *echo.Context) { + if carrier := GetHeaderCarrier(ctx); carrier != nil { + for key, value := range carrier.All() { + c.Response().Header().Add(key, value) + } + } +} + +func (s *APIV1Service) writeNativeError(c *echo.Context, err error) error { + httpStatus := http.StatusInternalServerError + message := "internal server error" + if st, ok := status.FromError(err); ok { + message = st.Message() + switch st.Code() { + case codes.InvalidArgument, codes.FailedPrecondition: + httpStatus = http.StatusBadRequest + case codes.Unauthenticated: + httpStatus = http.StatusUnauthorized + case codes.PermissionDenied: + httpStatus = http.StatusForbidden + case codes.NotFound: + httpStatus = http.StatusNotFound + case codes.AlreadyExists: + httpStatus = http.StatusConflict + default: + httpStatus = http.StatusInternalServerError + } + } + return c.JSON(httpStatus, map[string]string{"message": message}) +} + +func decodeNativeJSONBody(c *echo.Context, target any) error { + defer c.Request().Body.Close() + return json.NewDecoder(c.Request().Body).Decode(target) +} + +func (s *APIV1Service) resolvePasskeyRelyingParty(ctx context.Context) (*passkeyRelyingParty, error) { + instanceTitle := "Memos" + if instanceSetting, err := s.Store.GetInstanceGeneralSetting(ctx); err == nil { + if title := strings.TrimSpace(instanceSetting.CustomProfile.GetTitle()); title != "" { + instanceTitle = title + } + } + + if md, ok := metadata.FromIncomingContext(ctx); ok { + if origin := firstMetadataValue(md, "origin"); origin != "" { + parsed, err := url.Parse(origin) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return nil, status.Errorf(codes.InvalidArgument, "invalid origin") + } + return &passkeyRelyingParty{ + ID: parsed.Hostname(), + Name: instanceTitle, + Origin: parsed.Scheme + "://" + parsed.Host, + }, nil + } + + host := firstMetadataValue(md, "x-forwarded-host", "host") + proto := firstMetadataValue(md, "x-forwarded-proto") + if host != "" { + if proto == "" { + proto = "https" + } + return &passkeyRelyingParty{ + ID: stripPort(host), + Name: instanceTitle, + Origin: proto + "://" + host, + }, nil + } + } + + if s.Profile != nil && s.Profile.InstanceURL != "" { + parsed, err := url.Parse(s.Profile.InstanceURL) + if err == nil && parsed.Scheme != "" && parsed.Host != "" { + return &passkeyRelyingParty{ + ID: parsed.Hostname(), + Name: instanceTitle, + Origin: parsed.Scheme + "://" + parsed.Host, + }, nil + } + } + + return nil, status.Errorf(codes.FailedPrecondition, "unable to determine relying party") +} + +func firstMetadataValue(md metadata.MD, keys ...string) string { + for _, key := range keys { + values := md.Get(key) + if len(values) > 0 && values[0] != "" { + return values[0] + } + } + return "" +} + +func stripPort(host string) string { + if strings.HasPrefix(host, "[") { + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + return strings.Trim(parsedHost, "[]") + } + } + if strings.Count(host, ":") == 1 { + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + return parsedHost + } + } + return host +} + +func (s *APIV1Service) signPasskeySessionToken(user *store.User, flow, challenge string, rp *passkeyRelyingParty) (string, error) { + claims := &passkeySessionClaims{ + Flow: flow, + Challenge: challenge, + RPID: rp.ID, + Origin: rp.Origin, + Username: user.Username, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(int(user.ID)), + Issuer: auth.Issuer, + Audience: jwt.ClaimStrings{passkeySessionAudienceBase + flow}, + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(passkeySessionDuration)), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = auth.KeyID + return token.SignedString([]byte(s.Secret)) +} + +func (s *APIV1Service) parsePasskeySessionToken(tokenString, expectedFlow string) (*passkeySessionClaims, error) { + claims := &passkeySessionClaims{} + _, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { + if token.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.New("unexpected signing method") + } + return []byte(s.Secret), nil + }, jwt.WithIssuer(auth.Issuer), jwt.WithAudience(passkeySessionAudienceBase+expectedFlow)) + if err != nil { + return nil, err + } + if claims.Flow != expectedFlow { + return nil, errors.New("unexpected passkey flow") + } + return claims, nil +} + +func buildPasskeyCredentialDescriptors(passkeys []*store.Passkey) []passkeyCredentialDescriptorJSON { + descriptors := make([]passkeyCredentialDescriptorJSON, 0, len(passkeys)) + for _, passkey := range passkeys { + descriptors = append(descriptors, passkeyCredentialDescriptorJSON{ + Type: passkeyCredentialType, + ID: passkey.CredentialID, + Transports: append([]string(nil), passkey.Transports...), + }) + } + return descriptors +} + +func buildPasskeyLabel(clientInfo *storepb.RefreshTokensUserSetting_ClientInfo, now time.Time) string { + if clientInfo != nil { + parts := []string{} + if clientInfo.Browser != "" { + parts = append(parts, clientInfo.Browser) + } + if clientInfo.Os != "" { + parts = append(parts, clientInfo.Os) + } + if len(parts) > 0 { + return strings.Join(parts, " / ") + } + } + return store.NewDefaultPasskeyLabel(now) +} + +func normalizePasskeyTransports(transports []string) []string { + seen := map[string]struct{}{} + normalized := make([]string, 0, len(transports)) + for _, transport := range transports { + transport = strings.TrimSpace(strings.ToLower(transport)) + if transport == "" { + continue + } + if _, exists := seen[transport]; exists { + continue + } + seen[transport] = struct{}{} + normalized = append(normalized, transport) + } + return normalized +} + +func randomBase64URL(size int) (string, error) { + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func decodeBase64URL(value string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(value) +} + +func validatePasskeyClientData(clientDataJSON []byte, expectedType, expectedChallenge, expectedOrigin string) error { + payload := &passkeyClientData{} + if err := json.Unmarshal(clientDataJSON, payload); err != nil { + return status.Errorf(codes.InvalidArgument, "invalid client data") + } + if payload.Type != expectedType { + return status.Errorf(codes.InvalidArgument, "unexpected client data type") + } + if payload.Challenge != expectedChallenge { + return status.Errorf(codes.InvalidArgument, "unexpected challenge") + } + if payload.Origin != expectedOrigin { + return status.Errorf(codes.InvalidArgument, "unexpected origin") + } + return nil +} + +func parseAttestationAuthData(attestationObject []byte) (*parsedPasskeyAuthData, error) { + decoder := newCBORDecoder(attestationObject) + value, err := decoder.Decode() + if err != nil { + return nil, err + } + attestation, ok := value.(map[any]any) + if !ok { + return nil, errors.New("invalid attestation object") + } + authDataValue, ok := attestation["authData"].([]byte) + if !ok { + return nil, errors.New("attestation authData missing") + } + return parseAuthenticatorData(authDataValue, true) +} + +func parseAssertionAuthData(authenticatorData []byte) (*parsedPasskeyAuthData, error) { + return parseAuthenticatorData(authenticatorData, false) +} + +func validatePasskeyAuthData(authData *parsedPasskeyAuthData, rpID string, requireAttestedCredential bool) error { + if len(authData.RPIDHash) != sha256.Size { + return status.Errorf(codes.InvalidArgument, "invalid rp id hash") + } + expectedHash := sha256.Sum256([]byte(rpID)) + if !bytes.Equal(authData.RPIDHash, expectedHash[:]) { + return status.Errorf(codes.InvalidArgument, "rp id hash mismatch") + } + if authData.Flags&0x01 == 0 { + return status.Errorf(codes.InvalidArgument, "user presence is required") + } + if requireAttestedCredential && len(authData.CredentialID) == 0 { + return status.Errorf(codes.InvalidArgument, "attested credential data missing") + } + return nil +} + +func parseAuthenticatorData(data []byte, requireAttestedCredential bool) (*parsedPasskeyAuthData, error) { + if len(data) < 37 { + return nil, errors.New("authenticator data too short") + } + result := &parsedPasskeyAuthData{ + RPIDHash: append([]byte(nil), data[:32]...), + Flags: data[32], + SignCount: binary.BigEndian.Uint32(data[33:37]), + } + if !requireAttestedCredential { + return result, nil + } + if result.Flags&0x40 == 0 { + return nil, errors.New("attested credential flag missing") + } + offset := 37 + if len(data) < offset+16+2 { + return nil, errors.New("attested credential data too short") + } + offset += 16 // Skip AAGUID. + credentialIDLength := int(binary.BigEndian.Uint16(data[offset : offset+2])) + offset += 2 + if len(data) < offset+credentialIDLength { + return nil, errors.New("credential id is truncated") + } + result.CredentialID = append([]byte(nil), data[offset:offset+credentialIDLength]...) + offset += credentialIDLength + + keyDecoder := newCBORDecoder(data[offset:]) + if _, err := keyDecoder.Decode(); err != nil { + return nil, err + } + result.CredentialPublicKey = append([]byte(nil), data[offset:offset+keyDecoder.Offset()]...) + return result, nil +} + +func extractCOSEAlgorithm(publicKey []byte) (int32, error) { + key, err := parseCOSEPublicKey(publicKey) + if err != nil { + return 0, err + } + return key.Algorithm, nil +} + +func verifyPasskeySignature(passkey *store.Passkey, authenticatorData, clientDataJSON, signature []byte) error { + publicKeyBytes, err := decodeBase64URL(passkey.PublicKey) + if err != nil { + return err + } + publicKey, err := parseCOSEPublicKey(publicKeyBytes) + if err != nil { + return err + } + + clientDataHash := sha256.Sum256(clientDataJSON) + signedData := append(append([]byte{}, authenticatorData...), clientDataHash[:]...) + + switch key := publicKey.PublicKey.(type) { + case *ecdsa.PublicKey: + sum := sha256.Sum256(signedData) + if !ecdsa.VerifyASN1(key, sum[:], signature) { + return errors.New("ecdsa verification failed") + } + case ed25519.PublicKey: + if !ed25519.Verify(key, signedData, signature) { + return errors.New("ed25519 verification failed") + } + case *rsa.PublicKey: + sum := sha256.Sum256(signedData) + if err := rsa.VerifyPKCS1v15(key, crypto.SHA256, sum[:], signature); err != nil { + return err + } + default: + return errors.New("unsupported passkey algorithm") + } + return nil +} + +type parsedCOSEPublicKey struct { + Algorithm int32 + PublicKey any +} + +func parseCOSEPublicKey(raw []byte) (*parsedCOSEPublicKey, error) { + decoder := newCBORDecoder(raw) + value, err := decoder.Decode() + if err != nil { + return nil, err + } + keyMap, ok := value.(map[any]any) + if !ok { + return nil, errors.New("invalid cose key") + } + + kty, err := cborInt(keyMap[int64(1)]) + if err != nil { + return nil, err + } + alg, err := cborInt(keyMap[int64(3)]) + if err != nil { + return nil, err + } + + switch kty { + case 2: // EC2 + crv, err := cborInt(keyMap[int64(-1)]) + if err != nil { + return nil, err + } + if crv != 1 { + return nil, errors.New("unsupported elliptic curve") + } + x, ok := keyMap[int64(-2)].([]byte) + if !ok { + return nil, errors.New("invalid ec x coordinate") + } + y, ok := keyMap[int64(-3)].([]byte) + if !ok { + return nil, errors.New("invalid ec y coordinate") + } + return &parsedCOSEPublicKey{ + Algorithm: int32(alg), + PublicKey: &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + }, + }, nil + case 1: // OKP + crv, err := cborInt(keyMap[int64(-1)]) + if err != nil { + return nil, err + } + if crv != 6 { + return nil, errors.New("unsupported okp curve") + } + x, ok := keyMap[int64(-2)].([]byte) + if !ok { + return nil, errors.New("invalid okp key") + } + return &parsedCOSEPublicKey{ + Algorithm: int32(alg), + PublicKey: ed25519.PublicKey(x), + }, nil + case 3: // RSA + n, ok := keyMap[int64(-1)].([]byte) + if !ok { + return nil, errors.New("invalid rsa modulus") + } + e, ok := keyMap[int64(-2)].([]byte) + if !ok { + return nil, errors.New("invalid rsa exponent") + } + return &parsedCOSEPublicKey{ + Algorithm: int32(alg), + PublicKey: &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(new(big.Int).SetBytes(e).Int64()), + }, + }, nil + default: + return nil, errors.New("unsupported key type") + } +} + +func cborInt(value any) (int64, error) { + switch v := value.(type) { + case int64: + return v, nil + case uint64: + return int64(v), nil + case int: + return int64(v), nil + default: + return 0, errors.New("unexpected cbor integer") + } +} + +type cborDecoder struct { + data []byte + offset int +} + +func newCBORDecoder(data []byte) *cborDecoder { + return &cborDecoder{data: data} +} + +func (d *cborDecoder) Offset() int { + return d.offset +} + +func (d *cborDecoder) Decode() (any, error) { + if d.offset >= len(d.data) { + return nil, errors.New("unexpected end of cbor data") + } + initial := d.data[d.offset] + d.offset++ + + majorType := initial >> 5 + additionalInfo := initial & 0x1f + + length, err := d.readArgument(additionalInfo) + if err != nil { + return nil, err + } + + switch majorType { + case 0: + return int64(length), nil + case 1: + return -1 - int64(length), nil + case 2: + if !d.hasBytes(int(length)) { + return nil, errors.New("invalid cbor byte string") + } + value := append([]byte(nil), d.data[d.offset:d.offset+int(length)]...) + d.offset += int(length) + return value, nil + case 3: + if !d.hasBytes(int(length)) { + return nil, errors.New("invalid cbor text string") + } + value := string(d.data[d.offset : d.offset+int(length)]) + d.offset += int(length) + return value, nil + case 4: + values := make([]any, 0, int(length)) + for i := uint64(0); i < length; i++ { + value, err := d.Decode() + if err != nil { + return nil, err + } + values = append(values, value) + } + return values, nil + case 5: + values := make(map[any]any, int(length)) + for i := uint64(0); i < length; i++ { + key, err := d.Decode() + if err != nil { + return nil, err + } + value, err := d.Decode() + if err != nil { + return nil, err + } + values[key] = value + } + return values, nil + case 7: + switch additionalInfo { + case 20: + return false, nil + case 21: + return true, nil + case 22: + return nil, nil + default: + return nil, errors.New("unsupported cbor simple value") + } + default: + return nil, errors.New("unsupported cbor major type") + } +} + +func (d *cborDecoder) readArgument(additionalInfo byte) (uint64, error) { + switch { + case additionalInfo < 24: + return uint64(additionalInfo), nil + case additionalInfo == 24: + if !d.hasBytes(1) { + return 0, errors.New("invalid cbor uint8") + } + value := uint64(d.data[d.offset]) + d.offset++ + return value, nil + case additionalInfo == 25: + if !d.hasBytes(2) { + return 0, errors.New("invalid cbor uint16") + } + value := uint64(binary.BigEndian.Uint16(d.data[d.offset : d.offset+2])) + d.offset += 2 + return value, nil + case additionalInfo == 26: + if !d.hasBytes(4) { + return 0, errors.New("invalid cbor uint32") + } + value := uint64(binary.BigEndian.Uint32(d.data[d.offset : d.offset+4])) + d.offset += 4 + return value, nil + case additionalInfo == 27: + if !d.hasBytes(8) { + return 0, errors.New("invalid cbor uint64") + } + value := binary.BigEndian.Uint64(d.data[d.offset : d.offset+8]) + d.offset += 8 + return value, nil + default: + return 0, errors.New("unsupported cbor additional info") + } +} + +func (d *cborDecoder) hasBytes(size int) bool { + return d.offset+size <= len(d.data) +} diff --git a/server/router/api/v1/connect_interceptors.go b/server/router/api/v1/connect_interceptors.go index 9ea26f3b09d03..2543a351aa2ec 100644 --- a/server/router/api/v1/connect_interceptors.go +++ b/server/router/api/v1/connect_interceptors.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "net/http" "reflect" "runtime/debug" @@ -30,27 +31,7 @@ func NewMetadataInterceptor() *MetadataInterceptor { func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { - // Convert HTTP headers to gRPC metadata - header := req.Header() - md := metadata.MD{} - - // Copy important headers for client info extraction - if ua := header.Get("User-Agent"); ua != "" { - md.Set("user-agent", ua) - } - if xff := header.Get("X-Forwarded-For"); xff != "" { - md.Set("x-forwarded-for", xff) - } - if xri := header.Get("X-Real-Ip"); xri != "" { - md.Set("x-real-ip", xri) - } - // Forward Cookie header for authentication methods that need it (e.g., RefreshToken) - if cookie := header.Get("Cookie"); cookie != "" { - md.Set("cookie", cookie) - } - - // Set metadata in context so services can use metadata.FromIncomingContext() - ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.NewIncomingContext(ctx, metadataFromHeaders(req.Header(), "")) // Execute the request resp, err := next(ctx, req) @@ -67,6 +48,29 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc } } +func metadataFromHeaders(header http.Header, host string) metadata.MD { + md := metadata.MD{} + + setMetadataHeader(md, header, "User-Agent", "user-agent") + setMetadataHeader(md, header, "X-Forwarded-For", "x-forwarded-for") + setMetadataHeader(md, header, "X-Real-Ip", "x-real-ip") + setMetadataHeader(md, header, "Cookie", "cookie") + setMetadataHeader(md, header, "Origin", "origin") + setMetadataHeader(md, header, "X-Forwarded-Host", "x-forwarded-host") + setMetadataHeader(md, header, "X-Forwarded-Proto", "x-forwarded-proto") + if host != "" { + md.Set("host", host) + } + + return md +} + +func setMetadataHeader(md metadata.MD, header http.Header, httpHeader, metadataKey string) { + if value := header.Get(httpHeader); value != "" { + md.Set(metadataKey, value) + } +} + func isNilAnyResponse(resp connect.AnyResponse) bool { if resp == nil { return true diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index ad974b4a5a187..ee8efbf603e5b 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -122,6 +122,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{"*"}, })) + s.registerPasskeyRoutes(gwGroup) // Register SSE endpoint with same CORS as rest of /api/v1. RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret) handler := echo.WrapHandler(http.MaxBytesHandler(gwMux, maxAPIRequestBytes)) diff --git a/server/router/frontend/dist/index.html b/server/router/frontend/dist/index.html index a612ed1f7de3a..5009ae7a546ad 100644 --- a/server/router/frontend/dist/index.html +++ b/server/router/frontend/dist/index.html @@ -1,11 +1,28 @@ - + - + + + + + + + + + + + + Memos + + + + + - - No embeddable frontend found. + +
+ diff --git a/store/db/mysql/user_setting.go b/store/db/mysql/user_setting.go index 7fb075913c8a1..30daf6328834b 100644 --- a/store/db/mysql/user_setting.go +++ b/store/db/mysql/user_setting.go @@ -12,7 +12,7 @@ import ( func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) { stmt := "INSERT INTO `user_setting` (`user_id`, `key`, `value`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?" - if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value, upsert.Value); err != nil { + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, store.UserSettingKeyString(upsert.Key), upsert.Value, upsert.Value); err != nil { return nil, err } return upsert, nil @@ -22,7 +22,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED { - where, args = append(where, "`key` = ?"), append(args, v.String()) + where, args = append(where, "`key` = ?"), append(args, store.UserSettingKeyString(v)) } if v := find.UserID; v != nil { where, args = append(where, "`user_id` = ?"), append(args, *find.UserID) @@ -46,7 +46,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ); err != nil { return nil, err } - userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString]) + userSetting.Key = store.ParseUserSettingKey(keyString) userSettingList = append(userSettingList, userSetting) } diff --git a/store/db/postgres/user_setting.go b/store/db/postgres/user_setting.go index a66a47a678eff..1d35c07e0c593 100644 --- a/store/db/postgres/user_setting.go +++ b/store/db/postgres/user_setting.go @@ -19,7 +19,7 @@ func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) ( ON CONFLICT(user_id, key) DO UPDATE SET value = EXCLUDED.value ` - if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil { + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, store.UserSettingKeyString(upsert.Key), upsert.Value); err != nil { return nil, err } return upsert, nil @@ -29,7 +29,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED { - where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String()) + where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, store.UserSettingKeyString(v)) } if v := find.UserID; v != nil { where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID) @@ -59,7 +59,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ); err != nil { return nil, err } - userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString]) + userSetting.Key = store.ParseUserSettingKey(keyString) userSettingList = append(userSettingList, userSetting) } diff --git a/store/db/sqlite/user_setting.go b/store/db/sqlite/user_setting.go index f5809492831b2..4c3824413053a 100644 --- a/store/db/sqlite/user_setting.go +++ b/store/db/sqlite/user_setting.go @@ -19,7 +19,7 @@ func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) ( ON CONFLICT(user_id, key) DO UPDATE SET value = EXCLUDED.value ` - if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil { + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, store.UserSettingKeyString(upsert.Key), upsert.Value); err != nil { return nil, err } return upsert, nil @@ -29,7 +29,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED { - where, args = append(where, "key = ?"), append(args, v.String()) + where, args = append(where, "key = ?"), append(args, store.UserSettingKeyString(v)) } if v := find.UserID; v != nil { where, args = append(where, "user_id = ?"), append(args, *find.UserID) @@ -59,7 +59,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ); err != nil { return nil, err } - userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString]) + userSetting.Key = store.ParseUserSettingKey(keyString) userSettingList = append(userSettingList, userSetting) } if err := rows.Err(); err != nil { diff --git a/store/passkey_setting.go b/store/passkey_setting.go new file mode 100644 index 0000000000000..023ba038d5b62 --- /dev/null +++ b/store/passkey_setting.go @@ -0,0 +1,194 @@ +package store + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + storepb "github.com/usememos/memos/proto/gen/store" +) + +const ( + // UserSettingKeyPasskeys stores WebAuthn passkeys for a user. + UserSettingKeyPasskeys storepb.UserSetting_Key = 100 + + userSettingKeyPasskeysString = "PASSKEYS" +) + +// Passkey stores a user's WebAuthn credential metadata. +type Passkey struct { + ID string `json:"id"` + Label string `json:"label"` + CredentialID string `json:"credentialId"` + PublicKey string `json:"publicKey"` + Algorithm int32 `json:"algorithm"` + SignCount uint32 `json:"signCount"` + Transports []string `json:"transports,omitempty"` + AddedTs int64 `json:"addedTs"` + LastUsedTs int64 `json:"lastUsedTs,omitempty"` +} + +type passkeysUserSetting struct { + Passkeys []*Passkey `json:"passkeys"` +} + +// UserSettingKeyString converts a user setting key to the persisted database value. +func UserSettingKeyString(key storepb.UserSetting_Key) string { + if key == UserSettingKeyPasskeys { + return userSettingKeyPasskeysString + } + return key.String() +} + +// ParseUserSettingKey converts a persisted database value back to a user setting key. +func ParseUserSettingKey(key string) storepb.UserSetting_Key { + if key == userSettingKeyPasskeysString { + return UserSettingKeyPasskeys + } + return storepb.UserSetting_Key(storepb.UserSetting_Key_value[key]) +} + +// GetUserPasskeys returns the passkeys registered for the user. +func (s *Store) GetUserPasskeys(ctx context.Context, userID int32) ([]*Passkey, error) { + cacheKey := getUserSettingCacheKey(userID, userSettingKeyPasskeysString) + if cache, ok := s.userSettingCache.Get(ctx, cacheKey); ok { + if passkeys, ok := cache.([]*Passkey); ok { + return clonePasskeys(passkeys), nil + } + } + + settings, err := s.driver.ListUserSettings(ctx, &FindUserSetting{ + UserID: &userID, + Key: UserSettingKeyPasskeys, + }) + if err != nil { + return nil, err + } + if len(settings) == 0 { + s.userSettingCache.Set(ctx, cacheKey, []*Passkey{}) + return []*Passkey{}, nil + } + if len(settings) > 1 { + return nil, errors.Errorf("expected 1 passkey setting, got %d", len(settings)) + } + + passkeys, err := unmarshalPasskeys(settings[0].Value) + if err != nil { + return nil, err + } + s.userSettingCache.Set(ctx, cacheKey, clonePasskeys(passkeys)) + return passkeys, nil +} + +// AddUserPasskey stores a new passkey for the user. +func (s *Store) AddUserPasskey(ctx context.Context, userID int32, passkey *Passkey) error { + passkeys, err := s.GetUserPasskeys(ctx, userID) + if err != nil { + return err + } + + passkeys = append(passkeys, clonePasskey(passkey)) + return s.upsertUserPasskeys(ctx, userID, passkeys) +} + +// UpdateUserPasskey updates an existing passkey for the user. +func (s *Store) UpdateUserPasskey(ctx context.Context, userID int32, passkey *Passkey) error { + passkeys, err := s.GetUserPasskeys(ctx, userID) + if err != nil { + return err + } + + updated := false + for i, existing := range passkeys { + if existing.ID == passkey.ID { + passkeys[i] = clonePasskey(passkey) + updated = true + break + } + } + if !updated { + return errors.Errorf("passkey %s not found", passkey.ID) + } + + return s.upsertUserPasskeys(ctx, userID, passkeys) +} + +// DeleteUserPasskey deletes a passkey for the user. +func (s *Store) DeleteUserPasskey(ctx context.Context, userID int32, passkeyID string) error { + passkeys, err := s.GetUserPasskeys(ctx, userID) + if err != nil { + return err + } + + filtered := make([]*Passkey, 0, len(passkeys)) + deleted := false + for _, existing := range passkeys { + if existing.ID == passkeyID { + deleted = true + continue + } + filtered = append(filtered, clonePasskey(existing)) + } + if !deleted { + return errors.Errorf("passkey %s not found", passkeyID) + } + + return s.upsertUserPasskeys(ctx, userID, filtered) +} + +func (s *Store) upsertUserPasskeys(ctx context.Context, userID int32, passkeys []*Passkey) error { + value, err := json.Marshal(&passkeysUserSetting{ + Passkeys: clonePasskeys(passkeys), + }) + if err != nil { + return errors.Wrap(err, "failed to marshal passkeys") + } + + if _, err := s.driver.UpsertUserSetting(ctx, &UserSetting{ + UserID: userID, + Key: UserSettingKeyPasskeys, + Value: string(value), + }); err != nil { + return err + } + + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userID, userSettingKeyPasskeysString), clonePasskeys(passkeys)) + return nil +} + +func unmarshalPasskeys(value string) ([]*Passkey, error) { + setting := &passkeysUserSetting{} + if value == "" { + return []*Passkey{}, nil + } + if err := json.Unmarshal([]byte(value), setting); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal passkeys") + } + if setting.Passkeys == nil { + return []*Passkey{}, nil + } + return clonePasskeys(setting.Passkeys), nil +} + +func clonePasskeys(passkeys []*Passkey) []*Passkey { + cloned := make([]*Passkey, 0, len(passkeys)) + for _, passkey := range passkeys { + cloned = append(cloned, clonePasskey(passkey)) + } + return cloned +} + +func clonePasskey(passkey *Passkey) *Passkey { + if passkey == nil { + return nil + } + cloned := *passkey + cloned.Transports = append([]string(nil), passkey.Transports...) + return &cloned +} + +// NewDefaultPasskeyLabel returns the default label for a newly created passkey. +func NewDefaultPasskeyLabel(now time.Time) string { + return "Passkey " + now.Format("2006-01-02 15:04") +} diff --git a/store/user_setting.go b/store/user_setting.go index 40a1c04806cfa..fd282c2a009e0 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -51,7 +51,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetti if userSetting == nil { return nil, errors.New("unexpected nil user setting") } - s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, UserSettingKeyString(userSetting.Key)), userSetting) return userSetting, nil } @@ -70,7 +70,7 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([] if userSetting == nil { continue } - s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, UserSettingKeyString(userSetting.Key)), userSetting) userSettings = append(userSettings, userSetting) } return userSettings, nil @@ -78,7 +78,7 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([] func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*storepb.UserSetting, error) { if find.UserID != nil { - if cache, ok := s.userSettingCache.Get(ctx, getUserSettingCacheKey(*find.UserID, find.Key.String())); ok { + if cache, ok := s.userSettingCache.Get(ctx, getUserSettingCacheKey(*find.UserID, UserSettingKeyString(find.Key))); ok { userSetting, ok := cache.(*storepb.UserSetting) if ok { return userSetting, nil @@ -98,7 +98,7 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto } userSetting := list[0] - s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, UserSettingKeyString(userSetting.Key)), userSetting) return userSetting, nil } diff --git a/web/src/components/PasskeyDialog.tsx b/web/src/components/PasskeyDialog.tsx new file mode 100644 index 0000000000000..d0ae54615e7e2 --- /dev/null +++ b/web/src/components/PasskeyDialog.tsx @@ -0,0 +1,171 @@ +import { LoaderIcon, PlusIcon, TrashIcon } from "lucide-react"; +import { useEffect, useState } from "react"; +import { toast } from "react-hot-toast"; +import ConfirmDialog from "@/components/ConfirmDialog"; +import { Button } from "@/components/ui/button"; +import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/components/ui/dialog"; +import { handleError } from "@/lib/error"; +import { useTranslate } from "@/utils/i18n"; +import { createPasskey, deletePasskey, getPasskeyErrorKey, listPasskeys, type Passkey, supportsPasskeys } from "@/utils/passkey"; + +interface Props { + open: boolean; + onOpenChange: (open: boolean) => void; +} + +function PasskeyDialog({ open, onOpenChange }: Props) { + const t = useTranslate(); + const [passkeys, setPasskeys] = useState([]); + const [isLoading, setIsLoading] = useState(false); + const [isCreating, setIsCreating] = useState(false); + const [deletingPasskey, setDeletingPasskey] = useState(undefined); + + const loadPasskeys = async () => { + setIsLoading(true); + try { + const passkeys = await listPasskeys(); + setPasskeys(passkeys); + } catch (error: unknown) { + await handleError(error, toast.error, { + context: "List passkeys", + }); + } finally { + setIsLoading(false); + } + }; + + useEffect(() => { + if (!open) { + return; + } + + void loadPasskeys(); + }, [open]); + + const handleCreatePasskey = async () => { + if (isCreating) { + return; + } + + try { + setIsCreating(true); + await createPasskey(); + toast.success(t("message.passkey-created")); + await loadPasskeys(); + } catch (error: unknown) { + const errorKey = getPasskeyErrorKey(error, "create"); + if (errorKey) { + console.error(error); + toast.error(t(errorKey)); + return; + } + await handleError(error, toast.error, { + fallbackMessage: "Failed to create passkey.", + }); + } finally { + setIsCreating(false); + } + }; + + const formatUnixTime = (value?: number) => { + if (!value) { + return t("setting.account.passkey-never-used"); + } + return new Date(value * 1000).toLocaleString(); + }; + + const confirmDeletePasskey = async () => { + if (!deletingPasskey) { + return; + } + + try { + await deletePasskey(deletingPasskey.id); + setPasskeys((prev) => prev.filter((passkey) => passkey.id !== deletingPasskey.id)); + toast.success(t("message.passkey-deleted")); + } catch (error: unknown) { + await handleError(error, toast.error, { + context: "Delete passkey", + }); + throw error; + } + }; + + return ( + <> + + + + {t("setting.account.passkey-title")} + {t("setting.account.passkey-dialog-description")} + +
+
+
+

{t("setting.account.passkey-list-title")}

+

+ {supportsPasskeys() ? t("setting.account.passkey-description") : t("auth.passkey-unsupported")} +

+
+ +
+ +
+ {isLoading ? ( +
+ + {t("setting.account.passkey-loading")} +
+ ) : passkeys.length === 0 ? ( +
+ {t("setting.account.no-passkeys-found")} +
+ ) : ( + passkeys.map((passkey) => ( +
+
+
+

{passkey.label}

+ {passkey.transports && passkey.transports.length > 0 && ( +

{passkey.transports.join(", ")}

+ )} +
+ +
+
+

+ {t("setting.account.passkey-added-at")}: {formatUnixTime(passkey.addedTs)} +

+

+ {t("setting.account.passkey-last-used-at")}: {formatUnixTime(passkey.lastUsedTs)} +

+
+
+ )) + )} +
+
+
+
+ + !open && setDeletingPasskey(undefined)} + title={deletingPasskey ? t("setting.account.passkey-deletion", { label: deletingPasskey.label }) : ""} + description={t("setting.account.passkey-deletion-description")} + confirmLabel={t("common.delete")} + cancelLabel={t("common.cancel")} + onConfirm={confirmDeletePasskey} + confirmVariant="destructive" + /> + + ); +} + +export default PasskeyDialog; diff --git a/web/src/components/PasswordSignInForm.tsx b/web/src/components/PasswordSignInForm.tsx index d9c0198582d40..0bd7d4a1878e4 100644 --- a/web/src/components/PasswordSignInForm.tsx +++ b/web/src/components/PasswordSignInForm.tsx @@ -13,17 +13,20 @@ import useNavigateTo from "@/hooks/useNavigateTo"; import { handleError } from "@/lib/error"; import { ROUTES } from "@/router/routes"; import { useTranslate } from "@/utils/i18n"; +import { getPasskeyErrorKey, signInWithPasskey, supportsPasskeys } from "@/utils/passkey"; interface PasswordSignInFormProps { + allowPasswordAuth?: boolean; redirectPath?: string; } -function PasswordSignInForm({ redirectPath }: PasswordSignInFormProps) { +function PasswordSignInForm({ allowPasswordAuth = true, redirectPath }: PasswordSignInFormProps) { const t = useTranslate(); const navigateTo = useNavigateTo(); const { profile } = useInstance(); const { initialize } = useAuth(); const actionBtnLoadingState = useLoading(false); + const passkeyBtnLoadingState = useLoading(false); const [username, setUsername] = useState(profile.demo ? "demo" : ""); const [password, setPassword] = useState(profile.demo ? "secret" : ""); @@ -39,10 +42,17 @@ function PasswordSignInForm({ redirectPath }: PasswordSignInFormProps) { const handleFormSubmit = (e: React.FormEvent) => { e.preventDefault(); - handleSignInButtonClick(); + if (allowPasswordAuth) { + handleSignInButtonClick(); + } else { + handlePasskeySignInButtonClick(); + } }; const handleSignInButtonClick = async () => { + if (!allowPasswordAuth) { + return; + } if (username === "" || password === "") { return; } @@ -73,6 +83,38 @@ function PasswordSignInForm({ redirectPath }: PasswordSignInFormProps) { actionBtnLoadingState.setFinish(); }; + const handlePasskeySignInButtonClick = async () => { + if (username === "") { + return; + } + + if (passkeyBtnLoadingState.isLoading) { + return; + } + + try { + passkeyBtnLoadingState.setLoading(); + const response = await signInWithPasskey(username); + if (response.accessToken) { + setAccessToken(response.accessToken, response.accessTokenExpiresAt ? new Date(response.accessTokenExpiresAt) : undefined); + } + await initialize(); + navigateTo(redirectPath || ROUTES.ROOT, { replace: true }); + } catch (error: unknown) { + const errorKey = getPasskeyErrorKey(error, "sign-in"); + if (errorKey) { + console.error(error); + toast.error(t(errorKey)); + return; + } + handleError(error, toast.error, { + fallbackMessage: "Failed to sign in with passkey.", + }); + } finally { + passkeyBtnLoadingState.setFinish(); + } + }; + return (
@@ -91,27 +133,42 @@ function PasswordSignInForm({ redirectPath }: PasswordSignInFormProps) { required />
-
- {t("common.password")} - -
+ {allowPasswordAuth && ( +
+ {t("common.password")} + +
+ )} -
- + )} + + {!supportsPasskeys() &&

{t("auth.passkey-unsupported")}

}
); diff --git a/web/src/components/Settings/MyAccountSection.tsx b/web/src/components/Settings/MyAccountSection.tsx index d8fb97a755266..3c8c0012121af 100644 --- a/web/src/components/Settings/MyAccountSection.tsx +++ b/web/src/components/Settings/MyAccountSection.tsx @@ -4,6 +4,7 @@ import useCurrentUser from "@/hooks/useCurrentUser"; import { useDialog } from "@/hooks/useDialog"; import { useTranslate } from "@/utils/i18n"; import ChangeMemberPasswordDialog from "../ChangeMemberPasswordDialog"; +import PasskeyDialog from "../PasskeyDialog"; import UpdateAccountDialog from "../UpdateAccountDialog"; import UserAvatar from "../UserAvatar"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "../ui/dropdown-menu"; @@ -16,6 +17,7 @@ const MyAccountSection = () => { const user = useCurrentUser(); const accountDialog = useDialog(); const passwordDialog = useDialog(); + const passkeyDialog = useDialog(); return ( @@ -41,6 +43,7 @@ const MyAccountSection = () => { + {t("setting.account.passkey-title")} {t("setting.account.change-password")} @@ -57,6 +60,9 @@ const MyAccountSection = () => { {/* Change Password Dialog */} + + {/* Passkey Dialog */} + ); }; diff --git a/web/src/locales/en-GB.json b/web/src/locales/en-GB.json index fc3b0afd2852c..3c85a75e03d8e 100644 --- a/web/src/locales/en-GB.json +++ b/web/src/locales/en-GB.json @@ -96,12 +96,24 @@ }, "account": { "change-password": "Change password", + "create-passkey": "Create Passkey", "email-note": "Optional", "export-memos": "Export Memos", "nickname-note": "Displayed in the banner", "openapi-reset": "Reset OpenAPI Key", "openapi-sample-post": "Hello #memos from {{url}}", "openapi-title": "OpenAPI", + "no-passkeys-found": "No passkeys yet.", + "passkey-added-at": "Added at", + "passkey-deletion": "Delete passkey `{{label}}`?", + "passkey-deletion-description": "This removes the selected passkey from your account. You will not be able to sign in with it afterwards.", + "passkey-description": "Create a passkey to sign in without a password on this browser or your synced devices.", + "passkey-dialog-description": "Use passkeys to sign in without a password on this browser or your synced devices.", + "passkey-last-used-at": "Last used at", + "passkey-list-title": "Your passkeys", + "passkey-loading": "Loading passkeys...", + "passkey-never-used": "Never used", + "passkey-title": "Passkeys", "reset-api": "Reset API", "title": "Account Information", "update-information": "Update Information", @@ -213,7 +225,14 @@ "create-your-account": "Create your account", "host-tip": "You are registering as the Site Host.", "new-password": "New password", + "passkey-already-exists": "A passkey for this account is already available on this device.", + "passkey-create-cancelled": "Passkey creation was cancelled or timed out.", + "passkey-security-error": "Passkeys are only available in a secure browser context for this site.", + "passkey-sign-in-cancelled": "Passkey sign-in was cancelled or timed out.", + "passkey-sign-in-unavailable": "Passkey sign-in is not available for this account.", + "passkey-unsupported": "This browser does not support passkeys.", "repeat-new-password": "Repeat the new password", + "sign-in-with-passkey": "Sign in with passkey", "sign-in-tip": "Already have an account?", "sign-up-tip": "Don't have an account yet?" }, @@ -463,6 +482,8 @@ "new-password-not-match": "New passwords do not match.", "no-data": "No data found.", "password-changed": "Password Changed", + "passkey-created": "Passkey created", + "passkey-deleted": "Passkey deleted", "password-not-match": "Passwords do not match.", "restored-successfully": "Restored successfully", "succeed-copy-content": "Content copied successfully.", diff --git a/web/src/locales/en.json b/web/src/locales/en.json index 9b24142a7e5f3..f3e45858ec593 100644 --- a/web/src/locales/en.json +++ b/web/src/locales/en.json @@ -11,8 +11,15 @@ "create-your-account": "Create your account", "host-tip": "You are registering as the Site Host.", "new-password": "New password", + "passkey-already-exists": "A passkey for this account is already available on this device.", + "passkey-create-cancelled": "Passkey creation was cancelled or timed out.", + "passkey-security-error": "Passkeys are only available in a secure browser context for this site.", + "passkey-sign-in-cancelled": "Passkey sign-in was cancelled or timed out.", + "passkey-sign-in-unavailable": "Passkey sign-in is not available for this account.", + "passkey-unsupported": "This browser does not support passkeys.", "protected-memo-notice": "This memo is not public. Sign in to continue.", "repeat-new-password": "Repeat the new password", + "sign-in-with-passkey": "Sign in with passkey", "sign-in-tip": "Already have an account?", "sign-up-tip": "Don't have an account yet?" }, @@ -307,6 +314,8 @@ "new-password-not-match": "New passwords do not match.", "no-data": "No data found.", "password-changed": "Password Changed", + "passkey-created": "Passkey created", + "passkey-deleted": "Passkey deleted", "password-not-match": "Passwords do not match.", "restored-successfully": "Restored successfully", "succeed-copy-content": "Content copied successfully.", @@ -386,12 +395,24 @@ }, "account": { "change-password": "Change password", + "create-passkey": "Create Passkey", "email-note": "Optional", "export-memos": "Export Memos", "nickname-note": "Displayed in the banner", "openapi-reset": "Reset OpenAPI Key", "openapi-sample-post": "Hello #memos from {{url}}", "openapi-title": "OpenAPI", + "no-passkeys-found": "No passkeys yet.", + "passkey-added-at": "Added at", + "passkey-deletion": "Delete passkey `{{label}}`?", + "passkey-deletion-description": "This removes the selected passkey from your account. You will not be able to sign in with it afterwards.", + "passkey-description": "Create a passkey to sign in without a password on this browser or your synced devices.", + "passkey-dialog-description": "Use passkeys to sign in without a password on this browser or your synced devices.", + "passkey-last-used-at": "Last used at", + "passkey-list-title": "Your passkeys", + "passkey-loading": "Loading passkeys...", + "passkey-never-used": "Never used", + "passkey-title": "Passkeys", "reset-api": "Reset API", "title": "Account Information", "update-information": "Update Information", diff --git a/web/src/locales/zh-Hans.json b/web/src/locales/zh-Hans.json index 01d36a193e0dd..0fd55053c6746 100644 --- a/web/src/locales/zh-Hans.json +++ b/web/src/locales/zh-Hans.json @@ -10,8 +10,15 @@ "create-your-account": "创建您的账户", "host-tip": "您正在注册为站点管理员。", "new-password": "新密码", + "passkey-already-exists": "该账号的通行密钥已存在于当前设备上。", + "passkey-create-cancelled": "通行密钥创建已取消或已超时。", + "passkey-security-error": "当前站点不是受支持的安全环境,无法使用通行密钥。", + "passkey-sign-in-cancelled": "通行密钥登录已取消或已超时。", + "passkey-sign-in-unavailable": "当前账号暂不可使用通行密钥登录。", + "passkey-unsupported": "当前浏览器不支持通行密钥。", "protected-memo-notice": "此备忘录不是公开的。请先登录后继续。", "repeat-new-password": "重复新密码", + "sign-in-with-passkey": "使用通行密钥登录", "sign-in-tip": "已有账户?", "sign-up-tip": "还没有账户?" }, @@ -269,6 +276,8 @@ "new-password-not-match": "新密码不一致。", "no-data": "未找到任何数据。", "password-changed": "密码已修改", + "passkey-created": "通行密钥已创建", + "passkey-deleted": "通行密钥已删除", "password-not-match": "密码不一致。", "restored-successfully": "恢复成功", "succeed-copy-content": "复制内容到剪贴板成功。", @@ -493,12 +502,24 @@ }, "account": { "change-password": "修改密码", + "create-passkey": "创建通行密钥", "email-note": "可选", "export-memos": "导出备忘录", "nickname-note": "显示在横幅中", "openapi-reset": "重置 OpenAPI 密钥(Key)", "openapi-sample-post": "您好 #memos 来自 {{url}}", "openapi-title": "OpenAPI 接口", + "no-passkeys-found": "暂无通行密钥。", + "passkey-added-at": "添加时间", + "passkey-deletion": "删除通行密钥 `{{label}}`?", + "passkey-deletion-description": "删除后将无法再使用这个通行密钥登录当前账号。", + "passkey-description": "创建一个通行密钥,以便在当前浏览器或已同步的设备上免密码登录。", + "passkey-dialog-description": "你可以在这里查看已绑定的通行密钥,并继续添加新的通行密钥。", + "passkey-last-used-at": "最后使用时间", + "passkey-list-title": "已绑定的通行密钥", + "passkey-loading": "正在加载通行密钥...", + "passkey-never-used": "从未使用", + "passkey-title": "通行密钥", "reset-api": "重置 API", "title": "账号信息", "update-information": "更新个人信息", diff --git a/web/src/pages/SignIn.tsx b/web/src/pages/SignIn.tsx index 8e33bf67dc279..61ccb2a7e9bc9 100644 --- a/web/src/pages/SignIn.tsx +++ b/web/src/pages/SignIn.tsx @@ -77,11 +77,7 @@ const SignIn = () => {

{instanceGeneralSetting.customProfile?.title || "Memos"}

- {!instanceGeneralSetting.disallowPasswordAuth ? ( - - ) : ( - identityProviderList.length === 0 &&

Password auth is not allowed.

- )} + {!instanceGeneralSetting.disallowUserRegistration && !instanceGeneralSetting.disallowPasswordAuth && (

{t("auth.sign-up-tip")} @@ -92,14 +88,12 @@ const SignIn = () => { )} {identityProviderList.length > 0 && ( <> - {!instanceGeneralSetting.disallowPasswordAuth && ( -

- -
- {t("common.or")} -
+
+ +
+ {t("common.or")}
- )} +
{identityProviderList.map((identityProvider) => (