mirror of
https://github.com/minio/minio.git
synced 2025-10-28 23:35:01 -04:00
fix(api): Don't send multiple responses for one request (#21651)
fix(api): Don't send responses twice. In some cases multiple responses are being sent for one request, causing the API server to incorrectly drop connections. This change introduces a ResponseWriter which tracks whether a response has already been sent. This is used to prevent a response being sent if something already has (e.g. by a preconditions check function). Fixes #21633. Co-authored-by: Menno Finlay-Smits <hello@menno.io>
This commit is contained in:
parent
c6d3aac5c4
commit
52eee5a2f1
@ -889,6 +889,12 @@ func generateMultiDeleteResponse(quiet bool, deletedObjects []DeletedObject, err
|
||||
}
|
||||
|
||||
func writeResponse(w http.ResponseWriter, statusCode int, response []byte, mType mimeType) {
|
||||
// Don't write a response if one has already been written.
|
||||
// Fixes https://github.com/minio/minio/issues/21633
|
||||
if headersAlreadyWritten(w) {
|
||||
return
|
||||
}
|
||||
|
||||
if statusCode == 0 {
|
||||
statusCode = 200
|
||||
}
|
||||
@ -1015,3 +1021,45 @@ func writeCustomErrorResponseJSON(ctx context.Context, w http.ResponseWriter, er
|
||||
encodedErrorResponse := encodeResponseJSON(errorResponse)
|
||||
writeResponse(w, err.HTTPStatusCode, encodedErrorResponse, mimeJSON)
|
||||
}
|
||||
|
||||
type unwrapper interface {
|
||||
Unwrap() http.ResponseWriter
|
||||
}
|
||||
|
||||
// headersAlreadyWritten returns true if the headers have already been written
|
||||
// to this response writer. It will unwrap the ResponseWriter if possible to try
|
||||
// and find a trackingResponseWriter.
|
||||
func headersAlreadyWritten(w http.ResponseWriter) bool {
|
||||
for {
|
||||
if trw, ok := w.(*trackingResponseWriter); ok {
|
||||
return trw.headerWritten
|
||||
} else if uw, ok := w.(unwrapper); ok {
|
||||
w = uw.Unwrap()
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trackingResponseWriter wraps a ResponseWriter and notes when WriterHeader has
|
||||
// been called. This allows high level request handlers to check if something
|
||||
// has already sent the header.
|
||||
type trackingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *trackingResponseWriter) WriteHeader(statusCode int) {
|
||||
if !w.headerWritten {
|
||||
w.headerWritten = true
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *trackingResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *trackingResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
@ -18,8 +18,12 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/gzhttp"
|
||||
)
|
||||
|
||||
// Tests object location.
|
||||
@ -122,3 +126,89 @@ func TestGetURLScheme(t *testing.T) {
|
||||
t.Errorf("Expected %s, got %s", httpsScheme, gotScheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackingResponseWriter(t *testing.T) {
|
||||
rw := httptest.NewRecorder()
|
||||
trw := &trackingResponseWriter{ResponseWriter: rw}
|
||||
trw.WriteHeader(123)
|
||||
if !trw.headerWritten {
|
||||
t.Fatal("headerWritten was not set by WriteHeader call")
|
||||
}
|
||||
|
||||
_, err := trw.Write([]byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("Write unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that WriteHeader and Write were called on the underlying response writer
|
||||
resp := rw.Result()
|
||||
if resp.StatusCode != 123 {
|
||||
t.Fatalf("unexpected status: %v", resp.StatusCode)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading response body failed: %v", err)
|
||||
}
|
||||
if string(body) != "hello" {
|
||||
t.Fatalf("response body incorrect: %v", string(body))
|
||||
}
|
||||
|
||||
// Check that Unwrap works
|
||||
if trw.Unwrap() != rw {
|
||||
t.Fatalf("Unwrap returned wrong result: %v", trw.Unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadersAlreadyWritten(t *testing.T) {
|
||||
rw := httptest.NewRecorder()
|
||||
trw := &trackingResponseWriter{ResponseWriter: rw}
|
||||
|
||||
if headersAlreadyWritten(trw) {
|
||||
t.Fatal("headers have not been written yet")
|
||||
}
|
||||
|
||||
trw.WriteHeader(123)
|
||||
if !headersAlreadyWritten(trw) {
|
||||
t.Fatal("headers were written")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadersAlreadyWrittenWrapped(t *testing.T) {
|
||||
rw := httptest.NewRecorder()
|
||||
trw := &trackingResponseWriter{ResponseWriter: rw}
|
||||
wrap1 := &gzhttp.NoGzipResponseWriter{ResponseWriter: trw}
|
||||
wrap2 := &gzhttp.NoGzipResponseWriter{ResponseWriter: wrap1}
|
||||
|
||||
if headersAlreadyWritten(wrap2) {
|
||||
t.Fatal("headers have not been written yet")
|
||||
}
|
||||
|
||||
wrap2.WriteHeader(123)
|
||||
if !headersAlreadyWritten(wrap2) {
|
||||
t.Fatal("headers were written")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteResponseHeadersNotWritten(t *testing.T) {
|
||||
rw := httptest.NewRecorder()
|
||||
trw := &trackingResponseWriter{ResponseWriter: rw}
|
||||
|
||||
writeResponse(trw, 299, []byte("hello"), "application/foo")
|
||||
|
||||
resp := rw.Result()
|
||||
if resp.StatusCode != 299 {
|
||||
t.Fatal("response wasn't written")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteResponseHeadersWritten(t *testing.T) {
|
||||
rw := httptest.NewRecorder()
|
||||
rw.Code = -1
|
||||
trw := &trackingResponseWriter{ResponseWriter: rw, headerWritten: true}
|
||||
|
||||
writeResponse(trw, 200, []byte("hello"), "application/foo")
|
||||
|
||||
if rw.Code != -1 {
|
||||
t.Fatalf("response was written when it shouldn't have been (Code=%v)", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@ -218,6 +218,8 @@ func s3APIMiddleware(f http.HandlerFunc, flags ...s3HFlag) http.HandlerFunc {
|
||||
handlerName := getHandlerName(f, "objectAPIHandlers")
|
||||
|
||||
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
w = &trackingResponseWriter{ResponseWriter: w}
|
||||
|
||||
// Wrap the actual handler with the appropriate tracing middleware.
|
||||
var tracedHandler http.HandlerFunc
|
||||
if handlerFlags.has(traceHdrsS3HFlag) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user