diff --git a/ssh/server.go b/ssh/server.go index 3c0fcc953e..2bc8dccd1b 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -779,7 +779,11 @@ userAuthLoop: candidate.user = s.user candidate.pubKeyData = pubKeyData candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey) + var pse *PartialSuccessError _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + if !isPartialSuccessError { + isPartialSuccessError = errors.As(candidate.result, &pse) + } if isPartialSuccessError && config.VerifiedPublicKeyCallback != nil { return nil, errors.New("ssh: invalid library usage: PublicKeyCallback must not return partial success when VerifiedPublicKeyCallback is defined") } @@ -804,8 +808,12 @@ userAuthLoop: if len(payload) > 0 { return nil, parseError(msgUserAuthRequest) } - _, isPartialSuccessError := candidate.result.(*PartialSuccessError) - if candidate.result == nil || isPartialSuccessError { + var pse2 *PartialSuccessError + _, isPartialSuccessError2 := candidate.result.(*PartialSuccessError) + if !isPartialSuccessError2 { + isPartialSuccessError2 = errors.As(candidate.result, &pse2) + } + if candidate.result == nil || isPartialSuccessError2 { okMsg := userAuthPubKeyOkMsg{ Algo: algo, PubKey: pubKeyData, @@ -946,7 +954,8 @@ userAuthLoop: var failureMsg userAuthFailureMsg - if partialSuccess, ok := authErr.(*PartialSuccessError); ok { + var partialSuccess *PartialSuccessError + if ok := errors.As(authErr, &partialSuccess); ok { // Permissions are not preserved between authentication steps. To // avoid confusion about the final state of the connection, we // disallow returning non-nil Permissions combined with diff --git a/ssh/server_multi_auth_test.go b/ssh/server_multi_auth_test.go index 3b39802437..800763e7b4 100644 --- a/ssh/server_multi_auth_test.go +++ b/ssh/server_multi_auth_test.go @@ -410,3 +410,68 @@ func TestDynamicAuthCallbacks(t *testing.T) { t.Fatal("server not returned partial success") } } + +// TestPartialSuccessErrorWrappedInBannerError verifies that a PartialSuccessError +// wrapped inside a BannerError is correctly detected via errors.As, rather than +// silently treated as an authentication failure. Prior to the fix, the direct +// type assertion authErr.(*PartialSuccessError) would fail when authErr is a +// *BannerError, causing the partial-success state to be lost and authFailures +// to be incremented instead. +func TestPartialSuccessErrorWrappedInBannerError(t *testing.T) { + username := "testuser" + errPwdAuthFailed := errors.New("password auth failed") + + // The PasswordCallback returns a BannerError wrapping a PartialSuccessError. + // This is valid API usage: the callback wants to both send a banner message + // AND signal partial success requiring a second factor. + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if string(password) == clientPassword { + // First factor OK; wrap PartialSuccessError in a BannerError so + // a banner is also sent to the client. + return nil, &BannerError{ + Err: &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + }, + }, + Message: "First factor accepted; please provide second factor.", + } + } + return nil, errPwdAuthFailed + }, + } + + clientConfig := &ClientConfig{ + User: username, + Auth: []AuthMethod{ + // Two password attempts: first triggers partial success, second completes login. + Password(clientPassword), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: BannerDisplayStderr(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %v (PartialSuccessError wrapped in BannerError was not detected)", err) + } + + // Expected sequence: + // [0] ErrNoAuth (none method) + // [1] BannerError wrapping PartialSuccessError (first password) + // [2] nil (second password succeeds) + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %d, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + var pse *PartialSuccessError + if !errors.As(serverAuthErrors[1], &pse) { + t.Fatalf("expected a PartialSuccessError (possibly wrapped) at index 1, got: %v", serverAuthErrors[1]) + } +}