diff --git a/httperrors/status.go b/httperrors/status.go index 08731abc..0ff6735c 100644 --- a/httperrors/status.go +++ b/httperrors/status.go @@ -381,3 +381,53 @@ func WrapStatus(err error, httpStatus int, message string) *Status { func WrapStatusf(err error, httpStatus int, format string, a ...any) *Status { return Wrapf(err, httpStatus, ReasonFromStatus(httpStatus), format, a...) } + +// AlwaysWrap is like Wrap but always sets the underlying httpStatus, reason, and message +// fields, even if err is already a Status error. The original error is preserved as +// the cause. If err is nil, it returns an OK status (consistent with Wrap). +// +// Note: the computed StatusCode()/Reason() follow the Status invariant that a non-error +// cause cannot be OK. If err is non-nil and httpStatus is 2xx, StatusCode() returns 500 +// and Reason() returns "UNKNOWN", even though the stored fields are set as given. +// +// When wrapping an existing Status, structural details (field violations, metadata, etc.) +// are preserved via Clone, but the localized key is reset to match the new reason. +// Use WithLocalized/WithLocalizedArgs on the result if custom localization is needed. +func AlwaysWrap(err error, httpStatus int, reason, message string) *Status { + if err == nil { + return New(http.StatusOK, ReasonOK, "") + } + s, _ := FromError(err) + // Only clone when err is a *StatusError, since FromError returns a shared + // pointer in that case. Other branches already create a fresh *Status. + var se *StatusError + if errors.As(err, &se) { + s = Clone(s) + } + s.cause = errors.WithStack(err) + s.httpStatus = httpStatus + s.message = message + s.reason = reason + s.localized = &Localized{key: s.Reason()} + return s +} + +func AlwaysWrapf(err error, httpStatus int, reason, format string, a ...any) *Status { + return AlwaysWrap(err, httpStatus, reason, fmt.Sprintf(format, a...)) +} + +// AlwaysWrapStatus is like WrapStatus but always sets the given status code and message, +// even if err is already a Status error. The original error is preserved as the cause. +// If err is nil, it returns an OK status. See AlwaysWrap for details on the +// 2xx status + non-nil error edge case. +func AlwaysWrapStatus(err error, httpStatus int, message string) *Status { + return AlwaysWrap(err, httpStatus, ReasonFromStatus(httpStatus), message) +} + +// AlwaysWrapStatusf is like WrapStatusf but always sets the given status code and +// formatted message, even if err is already a Status error. The original error is +// preserved as the cause. If err is nil, it returns an OK status. See AlwaysWrap +// for details on the 2xx status + non-nil error edge case. +func AlwaysWrapStatusf(err error, httpStatus int, format string, a ...any) *Status { + return AlwaysWrapf(err, httpStatus, ReasonFromStatus(httpStatus), format, a...) +} diff --git a/httperrors/status_test.go b/httperrors/status_test.go index e8273155..296da18d 100644 --- a/httperrors/status_test.go +++ b/httperrors/status_test.go @@ -443,6 +443,156 @@ func TestWrapStatusf(t *testing.T) { assert.Equal(t, "failed to query users", s.Message()) } +func TestAlwaysWrap(t *testing.T) { + t.Run("plain error", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrap(originalErr, http.StatusInternalServerError, ReasonInternal, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + assert.True(t, errors.Is(wrapped.Err(), originalErr)) + }) + + t.Run("nil error", func(t *testing.T) { + wrapped := AlwaysWrap(nil, http.StatusInternalServerError, ReasonInternal, "internal server error") + require.NotNil(t, wrapped) + assert.Equal(t, http.StatusOK, wrapped.StatusCode()) + assert.Equal(t, ReasonOK, wrapped.Reason()) + assert.Equal(t, "", wrapped.Message()) + }) + + t.Run("overrides existing StatusError", func(t *testing.T) { + original := New(http.StatusNotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrap(original.Err(), http.StatusInternalServerError, ReasonInternal, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) + + t.Run("preserves details from existing StatusError", func(t *testing.T) { + original := New(http.StatusBadRequest, ReasonInvalidArgument, "validation failed"). + WithFieldViolations( + NewFieldViolation("email", "field.email.invalid", "Email is invalid"), + ). + WithMetadata(map[string]string{"key": "value"}) + + wrapped := AlwaysWrap(original.Err(), http.StatusInternalServerError, ReasonInternal, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + assert.Len(t, wrapped.FieldViolations(), 1) + assert.Equal(t, "email", wrapped.FieldViolations()[0].Field()) + }) + + t.Run("does not mutate original StatusError", func(t *testing.T) { + original := New(http.StatusNotFound, "NOT_FOUND", "resource not found"). + WithMetadata(map[string]string{"key": "value"}) + + _ = AlwaysWrap(original.Err(), http.StatusInternalServerError, ReasonInternal, "internal server error") + + assert.Equal(t, http.StatusNotFound, original.StatusCode()) + assert.Equal(t, "NOT_FOUND", original.Reason()) + assert.Equal(t, "resource not found", original.Message()) + }) + + t.Run("localized key matches new reason", func(t *testing.T) { + original := New(http.StatusNotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrap(original.Err(), http.StatusInternalServerError, "DATABASE_ERROR", "database failed") + require.NotNil(t, wrapped) + + localized := wrapped.Localized() + require.NotNil(t, localized) + assert.Equal(t, "DATABASE_ERROR", localized.Key()) + }) + + t.Run("non-nil error with 2xx status becomes 500", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrap(originalErr, http.StatusOK, ReasonOK, "should become unknown") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonUnknown, wrapped.Reason()) + assert.Equal(t, "should become unknown", wrapped.Message()) + assert.True(t, errors.Is(wrapped.Err(), originalErr)) + + localized := wrapped.Localized() + require.NotNil(t, localized) + assert.Equal(t, ReasonUnknown, localized.Key()) + }) +} + +func TestAlwaysWrapf(t *testing.T) { + t.Run("plain error with format", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrapf(originalErr, http.StatusInternalServerError, ReasonInternal, "error for %s: %d", "user", 42) + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "error for user: 42", wrapped.Message()) + }) + + t.Run("overrides existing StatusError with format", func(t *testing.T) { + original := New(http.StatusNotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrapf(original.Err(), http.StatusInternalServerError, ReasonInternal, "error for %s", "user") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "error for user", wrapped.Message()) + }) +} + +func TestAlwaysWrapStatus(t *testing.T) { + t.Run("plain error", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrapStatus(originalErr, http.StatusInternalServerError, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) + + t.Run("overrides existing StatusError", func(t *testing.T) { + original := New(http.StatusNotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrapStatus(original.Err(), http.StatusInternalServerError, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) +} + +func TestAlwaysWrapStatusf(t *testing.T) { + t.Run("plain error with format", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrapStatusf(originalErr, http.StatusInternalServerError, "error for %s", "user") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "error for user", wrapped.Message()) + }) + + t.Run("overrides existing StatusError", func(t *testing.T) { + original := New(http.StatusNotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrapStatusf(original.Err(), http.StatusInternalServerError, "error for %s", "user") + require.NotNil(t, wrapped) + + assert.Equal(t, http.StatusInternalServerError, wrapped.StatusCode()) + assert.Equal(t, ReasonInternal, wrapped.Reason()) + assert.Equal(t, "error for user", wrapped.Message()) + }) +} + func TestString(t *testing.T) { s := New(http.StatusNotFound, "NOT_FOUND", "user not found") str := s.String() diff --git a/statusx/code.go b/statusx/code.go index b9046878..0662a344 100644 --- a/statusx/code.go +++ b/statusx/code.go @@ -70,3 +70,19 @@ func WrapCode(err error, code codes.Code, message string) *Status { func WrapCodef(err error, code codes.Code, format string, a ...any) *Status { return Wrapf(err, code, ReasonFromCode(code).String(), format, a...) } + +// AlwaysWrapCode is like WrapCode but always sets the given code and message, +// even if err is already a Status error. The original error is preserved as the cause. +// If err is nil, it returns an OK status. See AlwaysWrap for details on the +// codes.OK + non-nil error edge case. +func AlwaysWrapCode(err error, code codes.Code, message string) *Status { + return AlwaysWrap(err, code, ReasonFromCode(code).String(), message) +} + +// AlwaysWrapCodef is like WrapCodef but always sets the given code and formatted message, +// even if err is already a Status error. The original error is preserved as the cause. +// If err is nil, it returns an OK status. See AlwaysWrap for details on the +// codes.OK + non-nil error edge case. +func AlwaysWrapCodef(err error, code codes.Code, format string, a ...any) *Status { + return AlwaysWrapf(err, code, ReasonFromCode(code).String(), format, a...) +} diff --git a/statusx/status.go b/statusx/status.go index 7ff52c1e..4b2e3d17 100644 --- a/statusx/status.go +++ b/statusx/status.go @@ -458,3 +458,38 @@ func Wrap(err error, c codes.Code, reason, message string) *Status { func Wrapf(err error, c codes.Code, reason, format string, a ...any) *Status { return Wrap(err, c, reason, fmt.Sprintf(format, a...)) } + +// AlwaysWrap is like Wrap but always sets the underlying code, reason, and message +// fields, even if err is already a Status error. The original error is preserved as +// the cause. If err is nil, it returns an OK status (consistent with Wrap). +// +// Note: the computed Code()/Reason() follow the Status invariant that a non-nil cause +// cannot be OK. If err is non-nil and c == codes.OK, Code() returns codes.Unknown and +// Reason() returns "UNKNOWN", even though the stored fields are set as given. +// +// When wrapping an existing Status, structural details (field violations, metadata, etc.) +// are preserved via Clone, but the localized key is reset to match the new reason. +// Use WithLocalized/WithLocalizedArgs on the result if custom localization is needed. +func AlwaysWrap(err error, c codes.Code, reason, message string) *Status { + if err == nil { + return New(codes.OK, statusv1.ErrorReason_OK.String(), "") + } + s, _ := FromError(err) + // Only clone when err is a *StatusError, since FromError returns a shared + // pointer in that case. Other branches already create a fresh *Status. + var se *StatusError + if errors.As(err, &se) { + s = Clone(s) + } + s.cause = errors.WithStack(err) + s.code = c + s.message = message + s.errorInfo.Reason = reason + // Immediately fix key to creation-time reason + s.localized = &statusv1.Localized{Key: s.Reason()} + return s +} + +func AlwaysWrapf(err error, c codes.Code, reason, format string, a ...any) *Status { + return AlwaysWrap(err, c, reason, fmt.Sprintf(format, a...)) +} diff --git a/statusx/status_test.go b/statusx/status_test.go index 49e7ed54..f2c50fd9 100644 --- a/statusx/status_test.go +++ b/statusx/status_test.go @@ -262,10 +262,11 @@ func TestWrap(t *testing.T) { } { - status, _ := status.New(codes.NotFound, "resource not found").WithDetails(&errdetails.ErrorInfo{ + st, err := status.New(codes.NotFound, "resource not found").WithDetails(&errdetails.ErrorInfo{ Reason: "NOT_FOUND", }) - wrapped := Wrap(status.Err(), codes.Internal, statusv1.ErrorReason_INTERNAL.String(), "internal server error") + require.NoError(t, err) + wrapped := Wrap(st.Err(), codes.Internal, statusv1.ErrorReason_INTERNAL.String(), "internal server error") assert.Equal(t, codes.NotFound, wrapped.Code()) assert.Equal(t, "NOT_FOUND", wrapped.Reason()) assert.Equal(t, "resource not found", wrapped.Message()) @@ -287,6 +288,170 @@ func TestWrap(t *testing.T) { }) } +func TestAlwaysWrap(t *testing.T) { + t.Run("plain error", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrap(originalErr, codes.Internal, "INTERNAL_ERROR", "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, "INTERNAL_ERROR", wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + assert.True(t, errors.Is(wrapped.Err(), originalErr)) + }) + + t.Run("nil error", func(t *testing.T) { + wrapped := AlwaysWrap(nil, codes.Internal, "INTERNAL_ERROR", "internal server error") + require.NotNil(t, wrapped) + assert.Equal(t, codes.OK, wrapped.Code()) + assert.Equal(t, statusv1.ErrorReason_OK.String(), wrapped.Reason()) + assert.Equal(t, "", wrapped.Message()) + }) + + t.Run("overrides existing StatusError", func(t *testing.T) { + original := New(codes.NotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrap(original.Err(), codes.Internal, "INTERNAL_ERROR", "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, "INTERNAL_ERROR", wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) + + t.Run("overrides existing gRPC status error", func(t *testing.T) { + st, err := status.New(codes.NotFound, "resource not found").WithDetails(&errdetails.ErrorInfo{ + Reason: "NOT_FOUND", + }) + require.NoError(t, err) + wrapped := AlwaysWrap(st.Err(), codes.Internal, "INTERNAL_ERROR", "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, "INTERNAL_ERROR", wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) + + t.Run("preserves details from existing StatusError", func(t *testing.T) { + original := New(codes.InvalidArgument, "VALIDATION_FAILED", "validation failed"). + WithFieldViolations( + NewFieldViolation("email", "field.email.invalid", "Email is invalid"), + ). + WithMetadata(map[string]string{"key": "value"}) + + wrapped := AlwaysWrap(original.Err(), codes.Internal, "INTERNAL_ERROR", "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, "INTERNAL_ERROR", wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + assert.NotNil(t, wrapped.badRequest) + assert.Len(t, wrapped.badRequest.FieldViolations, 1) + assert.Equal(t, "email", wrapped.badRequest.FieldViolations[0].Field) + }) + + t.Run("does not mutate original StatusError", func(t *testing.T) { + original := New(codes.NotFound, "NOT_FOUND", "resource not found"). + WithMetadata(map[string]string{"key": "value"}) + + _ = AlwaysWrap(original.Err(), codes.Internal, "INTERNAL_ERROR", "internal server error") + + assert.Equal(t, codes.NotFound, original.Code()) + assert.Equal(t, "NOT_FOUND", original.Reason()) + assert.Equal(t, "resource not found", original.Message()) + }) + + t.Run("localized key matches new reason", func(t *testing.T) { + original := New(codes.NotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrap(original.Err(), codes.Internal, "DATABASE_ERROR", "database failed") + require.NotNil(t, wrapped) + + localized := wrapped.Localized() + require.NotNil(t, localized) + assert.Equal(t, "DATABASE_ERROR", localized.Key) + }) + + t.Run("non-nil error with OK code becomes unknown", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrap(originalErr, codes.OK, statusv1.ErrorReason_OK.String(), "should become unknown") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Unknown, wrapped.Code()) + assert.Equal(t, statusv1.ErrorReason_UNKNOWN.String(), wrapped.Reason()) + assert.Equal(t, "should become unknown", wrapped.Message()) + assert.True(t, errors.Is(wrapped.Err(), originalErr)) + + localized := wrapped.Localized() + require.NotNil(t, localized) + assert.Equal(t, statusv1.ErrorReason_UNKNOWN.String(), localized.Key) + }) +} + +func TestAlwaysWrapf(t *testing.T) { + t.Run("plain error with format", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrapf(originalErr, codes.Internal, "INTERNAL_ERROR", "error for %s: %d", "user", 42) + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, "INTERNAL_ERROR", wrapped.Reason()) + assert.Equal(t, "error for user: 42", wrapped.Message()) + }) + + t.Run("overrides existing StatusError with format", func(t *testing.T) { + original := New(codes.NotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrapf(original.Err(), codes.Internal, "INTERNAL_ERROR", "error for %s", "user") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, "INTERNAL_ERROR", wrapped.Reason()) + assert.Equal(t, "error for user", wrapped.Message()) + }) +} + +func TestAlwaysWrapCode(t *testing.T) { + t.Run("plain error", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrapCode(originalErr, codes.Internal, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, statusv1.ErrorReason_INTERNAL.String(), wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) + + t.Run("overrides existing StatusError", func(t *testing.T) { + original := New(codes.NotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrapCode(original.Err(), codes.Internal, "internal server error") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, statusv1.ErrorReason_INTERNAL.String(), wrapped.Reason()) + assert.Equal(t, "internal server error", wrapped.Message()) + }) +} + +func TestAlwaysWrapCodef(t *testing.T) { + t.Run("plain error with format", func(t *testing.T) { + originalErr := errors.New("original error") + wrapped := AlwaysWrapCodef(originalErr, codes.Internal, "error for %s", "user") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, statusv1.ErrorReason_INTERNAL.String(), wrapped.Reason()) + assert.Equal(t, "error for user", wrapped.Message()) + }) + + t.Run("overrides existing StatusError", func(t *testing.T) { + original := New(codes.NotFound, "NOT_FOUND", "resource not found") + wrapped := AlwaysWrapCodef(original.Err(), codes.Internal, "error for %s", "user") + require.NotNil(t, wrapped) + + assert.Equal(t, codes.Internal, wrapped.Code()) + assert.Equal(t, statusv1.ErrorReason_INTERNAL.String(), wrapped.Reason()) + assert.Equal(t, "error for user", wrapped.Message()) + }) +} + func TestGRPCStatus(t *testing.T) { s := New(codes.PermissionDenied, statusv1.ErrorReason_PERMISSION_DENIED.String(), "permission denied"). WithLocalized("error.permission_denied", "access", "user").