server-mux: Rewrite graceful shutdown mechanism (#3771)

Old code uses waitgroup Add() and Wait() in different threads,
which eventually can lead to a race.
This commit is contained in:
Anis Elleuch 2017-02-18 22:28:54 +01:00 committed by Harshavardhana
parent d12f3e06b1
commit 7e84c7427d
2 changed files with 51 additions and 29 deletions

View File

@ -26,9 +26,14 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
const (
serverShutdownPoll = 500 * time.Millisecond
)
// The value chosen below is longest word chosen // The value chosen below is longest word chosen
// from all the http verbs comprising of // from all the http verbs comprising of
// "PRI", "OPTIONS", "GET", "HEAD", "POST", // "PRI", "OPTIONS", "GET", "HEAD", "POST",
@ -324,11 +329,13 @@ type ServerMux struct {
handler http.Handler handler http.Handler
listeners []*ListenerMux listeners []*ListenerMux
gracefulWait *sync.WaitGroup // Current number of concurrent http requests
currentReqs int32
// Time to wait before forcing server shutdown
gracefulTimeout time.Duration gracefulTimeout time.Duration
mu sync.Mutex // guards closed, and listener mu sync.Mutex // guards closing, and listeners
closed bool closing bool
} }
// NewServerMux constructor to create a ServerMux // NewServerMux constructor to create a ServerMux
@ -339,7 +346,6 @@ func NewServerMux(addr string, handler http.Handler) *ServerMux {
// Wait for 5 seconds for new incoming connnections, otherwise // Wait for 5 seconds for new incoming connnections, otherwise
// forcibly close them during graceful stop or restart. // forcibly close them during graceful stop or restart.
gracefulTimeout: 5 * time.Second, gracefulTimeout: 5 * time.Second,
gracefulWait: &sync.WaitGroup{},
} }
// Returns configured HTTP server. // Returns configured HTTP server.
@ -452,11 +458,22 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
} }
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
} else { } else {
// Execute registered handlers, protect with a waitgroup
// to accomplish a graceful shutdown when the user asks to quit // Return ServiceUnavailable for clients which are sending requests
m.gracefulWait.Add(1) // in shutdown phase
m.mu.Lock()
closing := m.closing
m.mu.Unlock()
if closing {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// Execute registered handlers, update currentReqs to keep
// tracks of current requests currently processed by the server
atomic.AddInt32(&m.currentReqs, 1)
m.handler.ServeHTTP(w, r) m.handler.ServeHTTP(w, r)
m.gracefulWait.Done() atomic.AddInt32(&m.currentReqs, -1)
} }
}) })
@ -481,12 +498,12 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
func (m *ServerMux) Close() error { func (m *ServerMux) Close() error {
m.mu.Lock() m.mu.Lock()
if m.closed { if m.closing {
m.mu.Unlock() m.mu.Unlock()
return errors.New("Server has been closed") return errors.New("Server has been closed")
} }
// Closed completely. // Closed completely.
m.closed = true m.closing = true
// Close the listeners. // Close the listeners.
for _, listener := range m.listeners { for _, listener := range m.listeners {
@ -497,19 +514,18 @@ func (m *ServerMux) Close() error {
} }
m.mu.Unlock() m.mu.Unlock()
// Prepare for a graceful shutdown // Starting graceful shutdown. Check if all requests are finished
waitSignal := make(chan struct{}) // in regular interval or force the shutdown
go func() { ticker := time.NewTicker(serverShutdownPoll)
defer close(waitSignal) defer ticker.Stop()
m.gracefulWait.Wait() for {
}()
select { select {
// Wait for everything to be properly closed
case <-waitSignal:
// Forced shutdown
case <-time.After(m.gracefulTimeout): case <-time.After(m.gracefulTimeout):
}
return nil return nil
case <-ticker.C:
if atomic.LoadInt32(&m.currentReqs) <= 0 {
return nil
}
}
}
} }

View File

@ -198,21 +198,27 @@ func TestServerMux(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Check if one listener is ready // Check if one listener is ready
m.mu.Lock() m.mu.Lock()
if len(m.listeners) == 0 { listenersCount := len(m.listeners)
m.mu.Unlock() m.mu.Unlock()
if listenersCount == 0 {
continue continue
} }
m.mu.Lock()
listenerAddr := m.listeners[0].Addr().String()
m.mu.Unlock() m.mu.Unlock()
// Issue the GET request // Issue the GET request
client := http.Client{} client := http.Client{}
m.mu.Lock() res, err = client.Get("http://" + listenerAddr)
res, err = client.Get("http://" + m.listeners[0].Addr().String())
m.mu.Unlock()
if err != nil { if err != nil {
continue continue
} }
// Read the request response // Read the request response
got, err = ioutil.ReadAll(res.Body) got, err = ioutil.ReadAll(res.Body)
if err != nil {
continue
}
// We've got a response, quit the loop
break
} }
// Check for error persisted after 5 times // Check for error persisted after 5 times