diff --git a/docs/docs/admin.md b/docs/docs/admin.md index d096afa..5cbfd97 100644 --- a/docs/docs/admin.md +++ b/docs/docs/admin.md @@ -218,6 +218,10 @@ Errors: - 401 `invalid token` or `missing authorization` or `invalid authorization` - 403 `forbidden` +Validation notes: + +- `flag` must be at most 72 bytes (bcrypt input limit). + --- ## List Registration Keys @@ -613,6 +617,10 @@ Errors: - 403 `forbidden` - 404 `challenge not found` +Validation notes: + +- When provided, `flag` must be at most 72 bytes (bcrypt input limit). + --- ## Get Challenge Detail (Admin) diff --git a/docs/docs/auth.md b/docs/docs/auth.md index c329ae0..6ac8626 100644 --- a/docs/docs/auth.md +++ b/docs/docs/auth.md @@ -33,6 +33,10 @@ Errors: - 400 `invalid input` - 409 `user already exists` +Validation notes: + +- `password` must be at most 72 bytes (bcrypt input limit). + `registration_key` must be an admin-created alphanumeric code. Keys can be reused up to their configured `max_uses` and assign the user to the key's team. diff --git a/docs/docs/users.md b/docs/docs/users.md index 3633720..6e1e10d 100644 --- a/docs/docs/users.md +++ b/docs/docs/users.md @@ -84,6 +84,7 @@ Errors: - 400 `invalid input` - 401 `invalid token` or `missing authorization` or `invalid authorization` - 403 `user blocked` +- 409 `user already exists` (username already in use) Notes: diff --git a/internal/bootstrap/testenv_test.go b/internal/bootstrap/testenv_test.go index 88052b4..6630827 100644 --- a/internal/bootstrap/testenv_test.go +++ b/internal/bootstrap/testenv_test.go @@ -65,7 +65,10 @@ func startPostgres(ctx context.Context) (testcontainers.Container, config.DBConf "POSTGRES_PASSWORD": "smctf", "POSTGRES_DB": "smctf_test", }, - WaitingFor: wait.ForListeningPort("5432/tcp"), + WaitingFor: wait.ForAll( + wait.ForListeningPort("5432/tcp").SkipExternalCheck(), + wait.ForLog("database system is ready to accept connections"), + ), } container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ diff --git a/internal/db/testenv_test.go b/internal/db/testenv_test.go index db72034..54eeadd 100644 --- a/internal/db/testenv_test.go +++ b/internal/db/testenv_test.go @@ -63,7 +63,10 @@ func startPostgres(ctx context.Context) (testcontainers.Container, config.DBConf "POSTGRES_PASSWORD": "smctf", "POSTGRES_DB": "smctf_test", }, - WaitingFor: wait.ForListeningPort("5432/tcp"), + WaitingFor: wait.ForAll( + wait.ForListeningPort("5432/tcp").SkipExternalCheck(), + wait.ForLog("database system is ready to accept connections"), + ), } container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ diff --git a/internal/http/handlers/testenv_test.go b/internal/http/handlers/testenv_test.go index 143d2ac..3de96c6 100644 --- a/internal/http/handlers/testenv_test.go +++ b/internal/http/handlers/testenv_test.go @@ -182,7 +182,10 @@ func startHandlerPostgres(ctx context.Context) (testcontainers.Container, config "POSTGRES_PASSWORD": "smctf", "POSTGRES_DB": "smctf_test", }, - WaitingFor: wait.ForListeningPort("5432/tcp"), + WaitingFor: wait.ForAll( + wait.ForListeningPort("5432/tcp").SkipExternalCheck(), + wait.ForLog("database system is ready to accept connections"), + ), } container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ diff --git a/internal/http/integration/admin_test.go b/internal/http/integration/admin_test.go index 3728132..e240ac9 100644 --- a/internal/http/integration/admin_test.go +++ b/internal/http/integration/admin_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "testing" "time" @@ -74,6 +75,20 @@ func TestAdminCreateChallenge(t *testing.T) { decodeJSON(t, rec, &resp) assertFieldErrors(t, resp.Details, map[string]string{"category": "invalid"}) + + rec = doRequest(t, env.router, http.MethodPost, "/api/admin/challenges", map[string]any{ + "title": "Ch4", + "description": "desc", + "category": "Web", + "points": 100, + "flag": strings.Repeat("a", 73), + "is_active": true, + }, authHeader(adminAccess)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status %d: %s", rec.Code, rec.Body.String()) + } + decodeJSON(t, rec, &resp) + assertFieldErrors(t, resp.Details, map[string]string{"flag": "max bytes is 72"}) } func TestAdminUpdateChallenge(t *testing.T) { @@ -163,6 +178,15 @@ func TestAdminUpdateChallenge(t *testing.T) { t.Fatalf("expected flag hash to be updated") } + rec = doRequest(t, env.router, http.MethodPut, "/api/admin/challenges/"+itoa(created.ID), map[string]any{ + "flag": strings.Repeat("a", 73), + }, authHeader(adminAccess)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status %d: %s", rec.Code, rec.Body.String()) + } + decodeJSON(t, rec, &errResp) + assertFieldErrors(t, errResp.Details, map[string]string{"flag": "max bytes is 72"}) + nullCases := []struct { name string body map[string]any diff --git a/internal/http/integration/auth_test.go b/internal/http/integration/auth_test.go index 2de5ce9..5380235 100644 --- a/internal/http/integration/auth_test.go +++ b/internal/http/integration/auth_test.go @@ -4,6 +4,7 @@ import ( "net/http" "smctf/internal/models" "smctf/internal/service" + "strings" "testing" ) @@ -168,6 +169,29 @@ func TestRegister(t *testing.T) { t.Fatalf("unexpected error: %s", resp.Error) } }) + + t.Run("password too long", func(t *testing.T) { + env := setupTest(t, testCfg) + admin := ensureAdminUser(t, env) + key := createRegistrationKey(t, env, admin.ID) + body := map[string]string{ + "email": "user@example.com", + "username": "user1", + "password": strings.Repeat("a", 73), + "registration_key": key.Code, + } + + rec := doRequest(t, env.router, http.MethodPost, "/api/auth/register", body, nil) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status %d: %s", rec.Code, rec.Body.String()) + } + + var resp errorResp + decodeJSON(t, rec, &resp) + assertFieldErrors(t, resp.Details, map[string]string{ + "password": "max bytes is 72", + }) + }) } func TestLogin(t *testing.T) { @@ -315,6 +339,12 @@ func TestUpdateMe(t *testing.T) { if resp.ID != userID || resp.Email != "user@example.com" || resp.Username != "newuser" || resp.Role != models.UserRole { t.Fatalf("unexpected response: %+v", resp) } + + _, _, _ = registerAndLogin(t, env, "user2@example.com", "user2", "strong-password") + rec = doRequest(t, env.router, http.MethodPut, "/api/me", map[string]string{"username": "user2"}, authHeader(access)) + if rec.Code != http.StatusConflict { + t.Fatalf("status %d: %s", rec.Code, rec.Body.String()) + } } func TestMeSolved(t *testing.T) { diff --git a/internal/http/integration/testenv_test.go b/internal/http/integration/testenv_test.go index dd87005..1881750 100644 --- a/internal/http/integration/testenv_test.go +++ b/internal/http/integration/testenv_test.go @@ -222,7 +222,10 @@ func startPostgres(ctx context.Context) (testcontainers.Container, config.DBConf "POSTGRES_PASSWORD": "smctf", "POSTGRES_DB": "smctf_test", }, - WaitingFor: wait.ForListeningPort("5432/tcp"), + WaitingFor: wait.ForAll( + wait.ForListeningPort("5432/tcp").SkipExternalCheck(), + wait.ForLog("database system is ready to accept connections"), + ), } container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ diff --git a/internal/repo/testenv_test.go b/internal/repo/testenv_test.go index 5d10322..90a6b99 100644 --- a/internal/repo/testenv_test.go +++ b/internal/repo/testenv_test.go @@ -96,7 +96,10 @@ func startPostgres(ctx context.Context) (testcontainers.Container, config.DBConf "POSTGRES_PASSWORD": "smctf", "POSTGRES_DB": "smctf_test", }, - WaitingFor: wait.ForListeningPort("5432/tcp"), + WaitingFor: wait.ForAll( + wait.ForListeningPort("5432/tcp").SkipExternalCheck(), + wait.ForLog("database system is ready to accept connections"), + ), } container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ diff --git a/internal/repo/user_repo.go b/internal/repo/user_repo.go index 222c50f..133ac02 100644 --- a/internal/repo/user_repo.go +++ b/internal/repo/user_repo.go @@ -2,6 +2,7 @@ package repo import ( "context" + "strings" "smctf/internal/models" @@ -12,6 +13,23 @@ type UserRepo struct { db *bun.DB } +func (r *UserRepo) ExistsByUsername(ctx context.Context, username string, excludeUserID *int64) (bool, error) { + query := r.db.NewSelect(). + TableExpr("users AS u"). + Where("u.username = ?", strings.TrimSpace(username)) + + if excludeUserID != nil { + query = query.Where("u.id != ?", *excludeUserID) + } + + count, err := query.Count(ctx) + if err != nil { + return false, wrapError("userRepo.ExistsByUsername", err) + } + + return count > 0, nil +} + func NewUserRepo(db *bun.DB) *UserRepo { return &UserRepo{db: db} } diff --git a/internal/repo/user_repo_test.go b/internal/repo/user_repo_test.go index 04ad710..d4351bf 100644 --- a/internal/repo/user_repo_test.go +++ b/internal/repo/user_repo_test.go @@ -156,3 +156,35 @@ func TestUserRepoNotFoundCases(t *testing.T) { t.Fatalf("expected GetByEmailOrUsername not found, got %v", err) } } + +func TestUserRepoExistsByUsername(t *testing.T) { + env := setupRepoTest(t) + user := createUserWithNewTeam(t, env, "exists@example.com", "exists-user", "pass", models.UserRole) + + exists, err := env.userRepo.ExistsByUsername(context.Background(), "exists-user", nil) + if err != nil { + t.Fatalf("ExistsByUsername: %v", err) + } + + if !exists { + t.Fatalf("expected exists=true") + } + + exists, err = env.userRepo.ExistsByUsername(context.Background(), "exists-user", &user.ID) + if err != nil { + t.Fatalf("ExistsByUsername with exclude id: %v", err) + } + + if exists { + t.Fatalf("expected exists=false when excluding same user") + } + + exists, err = env.userRepo.ExistsByUsername(context.Background(), " missing-user ", nil) + if err != nil { + t.Fatalf("ExistsByUsername missing: %v", err) + } + + if exists { + t.Fatalf("expected exists=false for missing username") + } +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 88e5d3b..b6c2259 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -47,6 +47,7 @@ func (s *AuthService) Register(ctx context.Context, email, username, password, r validator.Required("password", password) validator.Required("registration_key", registrationKey) validator.Email("email", email) + validator.MaxBytes("password", password, bcryptInputMaxBytes) if registrationKey != "" && !isRegistrationCode(registrationKey) { validator.fields = append(validator.fields, FieldError{Field: "registration_key", Reason: "invalid"}) diff --git a/internal/service/auth_service_test.go b/internal/service/auth_service_test.go index 44465f8..d2ffe42 100644 --- a/internal/service/auth_service_test.go +++ b/internal/service/auth_service_test.go @@ -56,6 +56,22 @@ func TestAuthServiceRegisterValidation(t *testing.T) { } } +func TestAuthServiceRegisterPasswordTooLong(t *testing.T) { + env := setupServiceTest(t) + admin := createUserWithNewTeam(t, env, "admin@example.com", models.AdminRole, "pass", models.AdminRole) + key := createRegistrationKey(t, env, "ABCDEFGHJKLMNPQ7", admin.ID) + + _, err := env.authSvc.Register(context.Background(), "user@example.com", "user1", strings.Repeat("a", 73), key.Code, "") + var ve *ValidationError + if !errors.As(err, &ve) { + t.Fatalf("expected validation error, got %v", err) + } + + if len(ve.Fields) == 0 || ve.Fields[0].Field != "password" || ve.Fields[0].Reason != "max bytes is 72" { + t.Fatalf("unexpected validation details: %+v", ve.Fields) + } +} + func TestAuthServiceRegisterUserExists(t *testing.T) { env := setupServiceTest(t) admin := createUserWithNewTeam(t, env, "admin@example.com", models.AdminRole, "pass", models.AdminRole) diff --git a/internal/service/ctf_service.go b/internal/service/ctf_service.go index 1b09868..76ae09d 100644 --- a/internal/service/ctf_service.go +++ b/internal/service/ctf_service.go @@ -135,6 +135,7 @@ func (s *CTFService) CreateChallenge(ctx context.Context, title, description, ca validator.Required("flag", flag) validator.NonNegative("points", points) validator.NonNegative("minimum_points", minimumPoints) + validator.MaxBytes("flag", flag, bcryptInputMaxBytes) if previousChallengeID != nil { validator.PositiveID("previous_challenge_id", *previousChallengeID) } @@ -241,6 +242,7 @@ func (s *CTFService) UpdateChallenge(ctx context.Context, id int64, title, descr if value == "" { validator.fields = append(validator.fields, FieldError{Field: "flag", Reason: "required"}) } else { + validator.MaxBytes("flag", value, bcryptInputMaxBytes) normalizedFlag = &value } } diff --git a/internal/service/ctf_service_test.go b/internal/service/ctf_service_test.go index 11ce3ec..2517fdb 100644 --- a/internal/service/ctf_service_test.go +++ b/internal/service/ctf_service_test.go @@ -295,6 +295,20 @@ func TestCTFServiceUpdateChallenge(t *testing.T) { } } +func TestCTFServiceChallengeFlagTooLong(t *testing.T) { + env := setupServiceTest(t) + longFlag := strings.Repeat("a", 73) + + if _, err := env.ctfSvc.CreateChallenge(context.Background(), "Title", "Desc", "Misc", 100, 50, longFlag, true, false, nil, nil, nil); !errors.Is(err, ErrInvalidInput) { + t.Fatalf("expected invalid input for create long flag, got %v", err) + } + + challenge := createChallenge(t, env, "Old", 50, "FLAG{2}", true) + if _, err := env.ctfSvc.UpdateChallenge(context.Background(), challenge.ID, nil, nil, nil, nil, nil, &longFlag, nil, nil, nil, nil, nil, false); !errors.Is(err, ErrInvalidInput) { + t.Fatalf("expected invalid input for update long flag, got %v", err) + } +} + func TestCTFServiceDeleteChallenge(t *testing.T) { env := setupServiceTest(t) challenge := createChallenge(t, env, "Delete", 50, "FLAG{3}", true) diff --git a/internal/service/testenv_test.go b/internal/service/testenv_test.go index 57aa4ce..df53987 100644 --- a/internal/service/testenv_test.go +++ b/internal/service/testenv_test.go @@ -146,7 +146,10 @@ func startPostgres(ctx context.Context) (testcontainers.Container, config.DBConf "POSTGRES_PASSWORD": "smctf", "POSTGRES_DB": "smctf_test", }, - WaitingFor: wait.ForListeningPort("5432/tcp"), + WaitingFor: wait.ForAll( + wait.ForListeningPort("5432/tcp").SkipExternalCheck(), + wait.ForLog("database system is ready to accept connections"), + ), } container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 10ac657..82a9d67 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "smctf/internal/db" "smctf/internal/models" "smctf/internal/repo" ) @@ -82,10 +83,26 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, username } if username != nil { - user.Username = *username + normalizedUsername := normalizeTrim(*username) + if normalizedUsername == "" { + return nil, NewValidationError(FieldError{Field: "username", Reason: "required"}) + } + + exists, err := s.userRepo.ExistsByUsername(ctx, normalizedUsername, &userID) + if err != nil { + return nil, fmt.Errorf("user.UpdateProfile username exists: %w", err) + } + if exists { + return nil, ErrUserExists + } + + user.Username = normalizedUsername } if err := s.userRepo.Update(ctx, user); err != nil { + if db.IsUniqueViolation(err) { + return nil, ErrUserExists + } return nil, fmt.Errorf("user.UpdateProfile: %w", err) } diff --git a/internal/service/user_service_test.go b/internal/service/user_service_test.go index 3e3cd9a..8fac8ea 100644 --- a/internal/service/user_service_test.go +++ b/internal/service/user_service_test.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "strings" "testing" "smctf/internal/models" @@ -159,6 +160,37 @@ func TestUserServiceUpdateProfileWithoutUsernameChange(t *testing.T) { } } +func TestUserServiceUpdateProfileDuplicateUsername(t *testing.T) { + env := setupServiceTest(t) + user1 := createUserWithNewTeam(t, env, "dup1@example.com", "dup-user-1", "pass", models.UserRole) + _ = createUserWithNewTeam(t, env, "dup2@example.com", "dup-user-2", "pass", models.UserRole) + + dup := "dup-user-2" + if _, err := env.userSvc.UpdateProfile(context.Background(), user1.ID, &dup); !errors.Is(err, ErrUserExists) { + t.Fatalf("expected ErrUserExists, got %v", err) + } +} + +func TestUserServiceUpdateProfileTrimAndRequired(t *testing.T) { + env := setupServiceTest(t) + user := createUserWithNewTeam(t, env, "trim@example.com", "trim-user", "pass", models.UserRole) + + newName := " trimmed-name " + updated, err := env.userSvc.UpdateProfile(context.Background(), user.ID, &newName) + if err != nil { + t.Fatalf("expected trim update success, got %v", err) + } + + if updated.Username != "trimmed-name" { + t.Fatalf("expected trimmed username, got %q", updated.Username) + } + + blank := strings.Repeat(" ", 5) + if _, err := env.userSvc.UpdateProfile(context.Background(), user.ID, &blank); err == nil { + t.Fatalf("expected validation error for blank username") + } +} + func TestUserServiceMoveUserTeamValidationAndNotFound(t *testing.T) { env := setupServiceTest(t) user := createUserWithNewTeam(t, env, "move2@example.com", "move2", "pass", models.UserRole) diff --git a/internal/service/validation.go b/internal/service/validation.go index ab4a8ed..860f6f2 100644 --- a/internal/service/validation.go +++ b/internal/service/validation.go @@ -1,10 +1,13 @@ package service import ( + "fmt" "net/mail" "strings" ) +const bcryptInputMaxBytes = 72 + type fieldValidator struct { fields []FieldError } @@ -41,6 +44,12 @@ func (v *fieldValidator) Email(field, value string) { } } +func (v *fieldValidator) MaxBytes(field, value string, max int) { + if len(value) > max { + v.fields = append(v.fields, FieldError{Field: field, Reason: fmt.Sprintf("max bytes is %d", max)}) + } +} + func (v *fieldValidator) Error() error { if len(v.fields) == 0 { return nil diff --git a/internal/service/validation_test.go b/internal/service/validation_test.go index aa16aac..ab5c399 100644 --- a/internal/service/validation_test.go +++ b/internal/service/validation_test.go @@ -2,6 +2,7 @@ package service import ( "errors" + "strings" "testing" ) @@ -12,6 +13,7 @@ func TestFieldValidator(t *testing.T) { v.Required("username", " ") v.NonNegative("points", -1) v.PositiveID("challenge_id", 0) + v.MaxBytes("password", strings.Repeat("a", 73), bcryptInputMaxBytes) err := v.Error() @@ -20,8 +22,8 @@ func TestFieldValidator(t *testing.T) { t.Fatalf("expected validation error, got %v", err) } - if len(ve.Fields) != 5 { - t.Fatalf("expected 5 fields, got %d", len(ve.Fields)) + if len(ve.Fields) != 6 { + t.Fatalf("expected 6 fields, got %d", len(ve.Fields)) } }