diff --git a/e2e_test.go b/e2e_test.go index 770974c..95a0c25 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -877,3 +877,50 @@ func TestE2EMaxConnectionAge_ConnectionAgeSet(t *testing.T) { require.NotEmpty(t, connectionStart) }) } + +type closeTrackingReadCloser struct { + inner io.ReadCloser + isClosed *bool +} + +func (c *closeTrackingReadCloser) Read(p []byte) (n int, err error) { + return c.inner.Read(p) +} + +func (c *closeTrackingReadCloser) Close() error { + *c.isClosed = true + return c.inner.Close() +} + +func TestE2EStatusCodesWithoutAResponseBody(t *testing.T) { + flavours(t, func(t *testing.T, flav e2eFlavour) { + ctx, cancel := flav.Context() + defer cancel() + + responseBodyWasClosed := false + + svc := Service(func(req Request) Response { + response := req.Response([]byte("hello")) + response.Header.Set("content-type", "text/plain") + response.StatusCode = http.StatusNotModified // Not Modified + oldBody := response.Body + response.Body = &closeTrackingReadCloser{ + inner: oldBody, + isClosed: &responseBodyWasClosed, + } + return response + }) + svc = svc.Filter(ErrorFilter) + s := flav.Serve(svc) + defer s.Stop(context.Background()) + + req := NewRequest(ctx, "GET", flav.URL(s), nil) + rsp := req.Send().Response() + require.NoError(t, rsp.Error) + assert.Equal(t, http.StatusNotModified, rsp.StatusCode) + rspBody, err := rsp.BodyBytes(true) + require.NoError(t, err) + assert.Empty(t, rspBody) + assert.True(t, responseBodyWasClosed) + }) +} diff --git a/http.go b/http.go index 7c5ee88..d217fd7 100644 --- a/http.go +++ b/http.go @@ -117,34 +117,40 @@ func HttpHandler(svc Service) http.Handler { } rsp := svc(req) - // If the connection was hijacked, we should not attempt to write anything out if rsp.hijacked { return } + // If the connection has been hijacked, the hijacker is responsible for any + // resource cleanup (as per http.Hijacker) + if rsp.Body != nil { + defer rsp.Body.Close() + } + rwHeader := rw.Header() for k, v := range rsp.Header { rwHeader[k] = v } rw.WriteHeader(rsp.StatusCode) - if rsp.Body != nil && bodyAllowedForStatus(rsp.StatusCode) { - defer rsp.Body.Close() - buf := *httpChunkBufPool.Get().(*[]byte) - defer httpChunkBufPool.Put(&buf) - if isStreamingRsp(rsp) { - // Streaming responses use copyChunked(), which takes care of flushing transparently - if _, err := copyChunked(rw, rsp.Body, buf); err != nil { - slog.Log(slog.Eventf(copyErrSeverity(err), req, "Couldn't send streaming response body", err)) - - // Prevent the client from accidentally consuming a truncated stream by aborting the response. - // The official way of interrupting an HTTP reply mid-stream is panic(http.ErrAbortHandler), which - // works for both HTTP/1.1 and HTTP.2. https://github.com/golang/go/issues/17790 - panic(http.ErrAbortHandler) - } - } else { - if _, err := io.CopyBuffer(rw, rsp.Body, buf); err != nil { - slog.Log(slog.Eventf(copyErrSeverity(err), req, "Couldn't send response body", err)) - } + if rsp.Body == nil || !bodyAllowedForStatus(rsp.StatusCode) { + return + } + + buf := *httpChunkBufPool.Get().(*[]byte) + defer httpChunkBufPool.Put(&buf) + if isStreamingRsp(rsp) { + // Streaming responses use copyChunked(), which takes care of flushing transparently + if _, err := copyChunked(rw, rsp.Body, buf); err != nil { + slog.Log(slog.Eventf(copyErrSeverity(err), req, "Couldn't send streaming response body", err)) + + // Prevent the client from accidentally consuming a truncated stream by aborting the response. + // The official way of interrupting an HTTP reply mid-stream is panic(http.ErrAbortHandler), which + // works for both HTTP/1.1 and HTTP.2. https://github.com/golang/go/issues/17790 + panic(http.ErrAbortHandler) + } + } else { + if _, err := io.CopyBuffer(rw, rsp.Body, buf); err != nil { + slog.Log(slog.Eventf(copyErrSeverity(err), req, "Couldn't send response body", err)) } } })