Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,165 changes: 1,165 additions & 0 deletions server/router/api/v1/auth_passkey.go

Large diffs are not rendered by default.

46 changes: 25 additions & 21 deletions server/router/api/v1/connect_interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"reflect"
"runtime/debug"

Expand All @@ -30,27 +31,7 @@ func NewMetadataInterceptor() *MetadataInterceptor {

func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// Convert HTTP headers to gRPC metadata
header := req.Header()
md := metadata.MD{}

// Copy important headers for client info extraction
if ua := header.Get("User-Agent"); ua != "" {
md.Set("user-agent", ua)
}
if xff := header.Get("X-Forwarded-For"); xff != "" {
md.Set("x-forwarded-for", xff)
}
if xri := header.Get("X-Real-Ip"); xri != "" {
md.Set("x-real-ip", xri)
}
// Forward Cookie header for authentication methods that need it (e.g., RefreshToken)
if cookie := header.Get("Cookie"); cookie != "" {
md.Set("cookie", cookie)
}

// Set metadata in context so services can use metadata.FromIncomingContext()
ctx = metadata.NewIncomingContext(ctx, md)
ctx = metadata.NewIncomingContext(ctx, metadataFromHeaders(req.Header(), ""))

// Execute the request
resp, err := next(ctx, req)
Expand All @@ -67,6 +48,29 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc
}
}

func metadataFromHeaders(header http.Header, host string) metadata.MD {
md := metadata.MD{}

setMetadataHeader(md, header, "User-Agent", "user-agent")
setMetadataHeader(md, header, "X-Forwarded-For", "x-forwarded-for")
setMetadataHeader(md, header, "X-Real-Ip", "x-real-ip")
setMetadataHeader(md, header, "Cookie", "cookie")
setMetadataHeader(md, header, "Origin", "origin")
setMetadataHeader(md, header, "X-Forwarded-Host", "x-forwarded-host")
setMetadataHeader(md, header, "X-Forwarded-Proto", "x-forwarded-proto")
if host != "" {
md.Set("host", host)
}

return md
}

func setMetadataHeader(md metadata.MD, header http.Header, httpHeader, metadataKey string) {
if value := header.Get(httpHeader); value != "" {
md.Set(metadataKey, value)
}
}

func isNilAnyResponse(resp connect.AnyResponse) bool {
if resp == nil {
return true
Expand Down
1 change: 1 addition & 0 deletions server/router/api/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
s.registerPasskeyRoutes(gwGroup)
// Register SSE endpoint with same CORS as rest of /api/v1.
RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret)
handler := echo.WrapHandler(http.MaxBytesHandler(gwMux, maxAPIRequestBytes))
Expand Down
25 changes: 21 additions & 4 deletions server/router/frontend/dist/index.html
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
<!DOCTYPE html>
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate" />
<meta http-equiv="Pragma" content="no-cache" />
<meta http-equiv="Expires" content="0" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
<link rel="icon" type="image/webp" href="/logo.webp" />
<link rel="manifest" href="/site.webmanifest" />
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
<meta name="theme-color" content="#faf9f5" />
<meta name="mobile-web-app-capable" content="yes" />
<meta name="apple-mobile-web-app-capable" content="yes" />
<meta name="apple-mobile-web-app-status-bar-style" content="default" />
<!-- memos.metadata.head -->
<title>Memos</title>
<script type="module" crossorigin src="/assets/index-BwnUwF1C.js"></script>
<link rel="modulepreload" crossorigin href="/assets/utils-vendor-CUDuyvje.js">
<link rel="modulepreload" crossorigin href="/assets/leaflet-vendor-BT6BJd6h.js">
<link rel="modulepreload" crossorigin href="/assets/mermaid-vendor-DWNEOl4-.js">
<link rel="stylesheet" crossorigin href="/assets/index-D0jMfv2K.css">
</head>
<body>
No embeddable frontend found.
<body class="text-base w-full min-h-svh">
<div id="root" class="relative w-full min-h-full"></div>
<!-- memos.metadata.body -->
</body>
</html>
6 changes: 3 additions & 3 deletions store/db/mysql/user_setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
stmt := "INSERT INTO `user_setting` (`user_id`, `key`, `value`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?"
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value, upsert.Value); err != nil {
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, store.UserSettingKeyString(upsert.Key), upsert.Value, upsert.Value); err != nil {
return nil, err
}
return upsert, nil
Expand All @@ -22,7 +22,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
where, args := []string{"1 = 1"}, []any{}

if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "`key` = ?"), append(args, v.String())
where, args = append(where, "`key` = ?"), append(args, store.UserSettingKeyString(v))
}
if v := find.UserID; v != nil {
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
Expand All @@ -46,7 +46,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
userSetting.Key = store.ParseUserSettingKey(keyString)
userSettingList = append(userSettingList, userSetting)
}

Expand Down
6 changes: 3 additions & 3 deletions store/db/postgres/user_setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, store.UserSettingKeyString(upsert.Key), upsert.Value); err != nil {
return nil, err
}
return upsert, nil
Expand All @@ -29,7 +29,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
where, args := []string{"1 = 1"}, []any{}

if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, store.UserSettingKeyString(v))
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
Expand Down Expand Up @@ -59,7 +59,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
userSetting.Key = store.ParseUserSettingKey(keyString)
userSettingList = append(userSettingList, userSetting)
}

Expand Down
6 changes: 3 additions & 3 deletions store/db/sqlite/user_setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, store.UserSettingKeyString(upsert.Key), upsert.Value); err != nil {
return nil, err
}
return upsert, nil
Expand All @@ -29,7 +29,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
where, args := []string{"1 = 1"}, []any{}

if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, v.String())
where, args = append(where, "key = ?"), append(args, store.UserSettingKeyString(v))
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
Expand Down Expand Up @@ -59,7 +59,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
userSetting.Key = store.ParseUserSettingKey(keyString)
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
Expand Down
194 changes: 194 additions & 0 deletions store/passkey_setting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package store

import (
"context"
"encoding/json"
"time"

"github.com/pkg/errors"

Check failure on line 8 in store/passkey_setting.go

View workflow job for this annotation

GitHub Actions / Static Checks

File is not properly formatted (goimports)
storepb "github.com/usememos/memos/proto/gen/store"
)

const (
// UserSettingKeyPasskeys stores WebAuthn passkeys for a user.
UserSettingKeyPasskeys storepb.UserSetting_Key = 100

userSettingKeyPasskeysString = "PASSKEYS"
)

// Passkey stores a user's WebAuthn credential metadata.
type Passkey struct {
ID string `json:"id"`
Label string `json:"label"`
CredentialID string `json:"credentialId"`
PublicKey string `json:"publicKey"`
Algorithm int32 `json:"algorithm"`
SignCount uint32 `json:"signCount"`
Transports []string `json:"transports,omitempty"`
AddedTs int64 `json:"addedTs"`
LastUsedTs int64 `json:"lastUsedTs,omitempty"`
}

type passkeysUserSetting struct {
Passkeys []*Passkey `json:"passkeys"`
}

// UserSettingKeyString converts a user setting key to the persisted database value.
func UserSettingKeyString(key storepb.UserSetting_Key) string {
if key == UserSettingKeyPasskeys {
return userSettingKeyPasskeysString
}
return key.String()
}

// ParseUserSettingKey converts a persisted database value back to a user setting key.
func ParseUserSettingKey(key string) storepb.UserSetting_Key {
if key == userSettingKeyPasskeysString {
return UserSettingKeyPasskeys
}
return storepb.UserSetting_Key(storepb.UserSetting_Key_value[key])
}

// GetUserPasskeys returns the passkeys registered for the user.
func (s *Store) GetUserPasskeys(ctx context.Context, userID int32) ([]*Passkey, error) {
cacheKey := getUserSettingCacheKey(userID, userSettingKeyPasskeysString)
if cache, ok := s.userSettingCache.Get(ctx, cacheKey); ok {
if passkeys, ok := cache.([]*Passkey); ok {
return clonePasskeys(passkeys), nil
}
}

settings, err := s.driver.ListUserSettings(ctx, &FindUserSetting{
UserID: &userID,
Key: UserSettingKeyPasskeys,
})
if err != nil {
return nil, err
}
if len(settings) == 0 {
s.userSettingCache.Set(ctx, cacheKey, []*Passkey{})
return []*Passkey{}, nil
}
if len(settings) > 1 {
return nil, errors.Errorf("expected 1 passkey setting, got %d", len(settings))
}

passkeys, err := unmarshalPasskeys(settings[0].Value)
if err != nil {
return nil, err
}
s.userSettingCache.Set(ctx, cacheKey, clonePasskeys(passkeys))
return passkeys, nil
}

// AddUserPasskey stores a new passkey for the user.
func (s *Store) AddUserPasskey(ctx context.Context, userID int32, passkey *Passkey) error {
passkeys, err := s.GetUserPasskeys(ctx, userID)
if err != nil {
return err
}

passkeys = append(passkeys, clonePasskey(passkey))
return s.upsertUserPasskeys(ctx, userID, passkeys)
}

// UpdateUserPasskey updates an existing passkey for the user.
func (s *Store) UpdateUserPasskey(ctx context.Context, userID int32, passkey *Passkey) error {
passkeys, err := s.GetUserPasskeys(ctx, userID)
if err != nil {
return err
}

updated := false
for i, existing := range passkeys {
if existing.ID == passkey.ID {
passkeys[i] = clonePasskey(passkey)
updated = true
break
}
}
if !updated {
return errors.Errorf("passkey %s not found", passkey.ID)
}

return s.upsertUserPasskeys(ctx, userID, passkeys)
}

// DeleteUserPasskey deletes a passkey for the user.
func (s *Store) DeleteUserPasskey(ctx context.Context, userID int32, passkeyID string) error {
passkeys, err := s.GetUserPasskeys(ctx, userID)
if err != nil {
return err
}

filtered := make([]*Passkey, 0, len(passkeys))
deleted := false
for _, existing := range passkeys {
if existing.ID == passkeyID {
deleted = true
continue
}
filtered = append(filtered, clonePasskey(existing))
}
if !deleted {
return errors.Errorf("passkey %s not found", passkeyID)
}

return s.upsertUserPasskeys(ctx, userID, filtered)
}

func (s *Store) upsertUserPasskeys(ctx context.Context, userID int32, passkeys []*Passkey) error {
value, err := json.Marshal(&passkeysUserSetting{
Passkeys: clonePasskeys(passkeys),
})
if err != nil {
return errors.Wrap(err, "failed to marshal passkeys")
}

if _, err := s.driver.UpsertUserSetting(ctx, &UserSetting{
UserID: userID,
Key: UserSettingKeyPasskeys,
Value: string(value),
}); err != nil {
return err
}

s.userSettingCache.Set(ctx, getUserSettingCacheKey(userID, userSettingKeyPasskeysString), clonePasskeys(passkeys))
return nil
}

func unmarshalPasskeys(value string) ([]*Passkey, error) {
setting := &passkeysUserSetting{}
if value == "" {
return []*Passkey{}, nil
}
if err := json.Unmarshal([]byte(value), setting); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal passkeys")
}
if setting.Passkeys == nil {
return []*Passkey{}, nil
}
return clonePasskeys(setting.Passkeys), nil
}

func clonePasskeys(passkeys []*Passkey) []*Passkey {
cloned := make([]*Passkey, 0, len(passkeys))
for _, passkey := range passkeys {
cloned = append(cloned, clonePasskey(passkey))
}
return cloned
}

func clonePasskey(passkey *Passkey) *Passkey {
if passkey == nil {
return nil
}
cloned := *passkey
cloned.Transports = append([]string(nil), passkey.Transports...)
return &cloned
}

// NewDefaultPasskeyLabel returns the default label for a newly created passkey.
func NewDefaultPasskeyLabel(now time.Time) string {
return "Passkey " + now.Format("2006-01-02 15:04")
}
Loading
Loading