Allow idiomatic usage of middlewares in gorilla/mux (#9802)

Historically due to lack of support for middlewares
we ended up writing wrapped handlers for all
middlewares on top of the gorilla/mux, this causes
multiple issues when we want to let's say

- Overload r.Body with some custom implementation
  to track the incoming Reads()
- Add other sort of top level checks to avoid
  DDOSing the server with large incoming HTTP
  bodies.

Since 1.7.x release gorilla/mux provides proper
use of middlewares, which are honored by the muxer
directly. This makes sure that Go can honor its
own internal ServeHTTP(w, r) implementation where
Go net/http can wrap into its own customer readers.

This PR as a side-affect fixes rare issues of client
hangs which were reported in the wild but never really
understood or fixed in our codebase.

Fixes #9759
Fixes #7266
Fixes #6540
Fixes #5455
Fixes #5150

Refer https://github.com/boto/botocore/pull/1328 for
one variation of the same issue in #9759
This commit is contained in:
Harshavardhana 2020-06-11 08:19:55 -07:00 committed by GitHub
parent ff94b1b0a9
commit a42df3d364
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 10 deletions

View File

@ -203,13 +203,16 @@ func StartGateway(ctx *cli.Context, gw Gateway) {
// Add API router. // Add API router.
registerAPIRouter(router, encryptionEnabled, allowSSEKMS) registerAPIRouter(router, encryptionEnabled, allowSSEKMS)
// Use all the middlewares
router.Use(registerMiddlewares)
var getCert certs.GetCertificateFunc var getCert certs.GetCertificateFunc
if globalTLSCerts != nil { if globalTLSCerts != nil {
getCert = globalTLSCerts.GetCertificate getCert = globalTLSCerts.GetCertificate
} }
httpServer := xhttp.NewServer([]string{globalCLIContext.Addr}, httpServer := xhttp.NewServer([]string{globalCLIContext.Addr},
criticalErrorHandler{registerHandlers(router, globalHandlers...)}, getCert) criticalErrorHandler{router}, getCert)
httpServer.BaseContext = func(listener net.Listener) context.Context { httpServer.BaseContext = func(listener net.Listener) context.Context {
return GlobalContext return GlobalContext
} }

View File

@ -34,14 +34,14 @@ import (
"github.com/rs/cors" "github.com/rs/cors"
) )
// HandlerFunc - useful to chain different middleware http.Handler // MiddlewareFunc - useful to chain different http.Handler middlewares
type HandlerFunc func(http.Handler) http.Handler type MiddlewareFunc func(http.Handler) http.Handler
func registerHandlers(h http.Handler, handlerFns ...HandlerFunc) http.Handler { func registerMiddlewares(next http.Handler) http.Handler {
for _, hFn := range handlerFns { for _, handlerFn := range globalHandlers {
h = hFn(h) next = handlerFn(next)
} }
return h return next
} }
// Adds limiting body size middleware // Adds limiting body size middleware
@ -795,6 +795,7 @@ func (h criticalErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
defer func() { defer func() {
if err := recover(); err == logger.ErrCritical { // handle if err := recover(); err == logger.ErrCritical { // handle
writeErrorResponse(r.Context(), w, errorCodes.ToAPIErr(ErrInternalError), r.URL, guessIsBrowserReq(r)) writeErrorResponse(r.Context(), w, errorCodes.ToAPIErr(ErrInternalError), r.URL, guessIsBrowserReq(r))
return
} else if err != nil { } else if err != nil {
panic(err) // forward other panic calls panic(err) // forward other panic calls
} }

View File

@ -38,7 +38,7 @@ func registerDistXLRouters(router *mux.Router, endpointZones EndpointZones) {
} }
// List of some generic handlers which are applied for all incoming requests. // List of some generic handlers which are applied for all incoming requests.
var globalHandlers = []HandlerFunc{ var globalHandlers = []MiddlewareFunc{
// set x-amz-request-id header. // set x-amz-request-id header.
addCustomHeaders, addCustomHeaders,
// set HTTP security headers such as Content-Security-Policy. // set HTTP security headers such as Content-Security-Policy.
@ -118,6 +118,6 @@ func configureServerHandler(endpointZones EndpointZones) (http.Handler, error) {
router.NotFoundHandler = http.HandlerFunc(httpTraceAll(errorResponseHandler)) router.NotFoundHandler = http.HandlerFunc(httpTraceAll(errorResponseHandler))
router.MethodNotAllowedHandler = http.HandlerFunc(httpTraceAll(errorResponseHandler)) router.MethodNotAllowedHandler = http.HandlerFunc(httpTraceAll(errorResponseHandler))
// Register rest of the handlers. router.Use(registerMiddlewares)
return registerHandlers(router, globalHandlers...), nil return router, nil
} }