Refactoring MuxServer Close() method to always wait for conns to close before returning. Adding lock around ServerMux listener setting to protect against data race. Adding additional tests to server-mux_test.go to make sure open connections are closed and for ListenAndServe. (#2467)

This commit is contained in:
Jesse Lucas 2016-08-17 03:18:23 -04:00 committed by Harshavardhana
parent 674fdc4304
commit 0b7dfab17a
2 changed files with 102 additions and 29 deletions

View File

@ -181,7 +181,7 @@ type MuxServer struct {
listener *MuxListener listener *MuxListener
WaitGroup *sync.WaitGroup WaitGroup *sync.WaitGroup
GracefulTimeout time.Duration GracefulTimeout time.Duration
mu sync.Mutex // guards closed and conns mu sync.Mutex // guards closed, conns, and listener
closed bool closed bool
conns map[net.Conn]http.ConnState // except terminal states conns map[net.Conn]http.ConnState // except terminal states
} }
@ -221,7 +221,9 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error {
return err return err
} }
m.mu.Lock()
m.listener = mux m.listener = mux
m.mu.Unlock()
err = http.Serve(mux, err = http.Serve(mux,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -257,7 +259,9 @@ func (m *MuxServer) ListenAndServe() error {
return err return err
} }
m.mu.Lock()
m.listener = mux m.listener = mux
m.mu.Unlock()
return m.Server.Serve(mux) return m.Server.Serve(mux)
} }
@ -280,42 +284,39 @@ func longestWord(strings []string) int {
// Close initiates the graceful shutdown // Close initiates the graceful shutdown
func (m *MuxServer) Close() error { func (m *MuxServer) Close() error {
m.mu.Lock()
if m.closed { if m.closed {
return errors.New("Server has been closed") return errors.New("Server has been closed")
} }
m.mu.Lock()
m.Server.SetKeepAlivesEnabled(false)
m.closed = true m.closed = true
m.mu.Unlock()
// Make sure a listener was set
if err := m.listener.Close(); err != nil { if err := m.listener.Close(); err != nil {
return err return err
} }
// force connections to close after timeout m.SetKeepAlivesEnabled(false)
wait := make(chan struct{})
go func() {
defer close(wait)
m.mu.Lock()
for c, st := range m.conns { for c, st := range m.conns {
// Force close any idle and new connections. // Force close any idle and new connections. Waiting for other connections
// to close on their own (within the timeout period)
if st == http.StateIdle || st == http.StateNew { if st == http.StateIdle || st == http.StateNew {
c.Close() c.Close()
} }
} }
// If the GracefulTimeout happens then forcefully close all connections
t := time.AfterFunc(m.GracefulTimeout, func() {
for c := range m.conns {
c.Close()
}
})
defer t.Stop()
m.mu.Unlock() m.mu.Unlock()
// Wait for all connections to be gracefully closed // Block until all connections are closed
m.WaitGroup.Wait() m.WaitGroup.Wait()
}()
// We block until all active connections are closed or the GracefulTimeout happens
select {
case <-time.After(m.GracefulTimeout):
return nil return nil
case <-wait:
return nil
}
} }
// connState setups the ConnState tracking hook to know which connections are idle // connState setups the ConnState tracking hook to know which connections are idle

View File

@ -17,8 +17,10 @@
package main package main
import ( import (
"bufio"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -82,9 +84,79 @@ func TestMuxServer(t *testing.T) {
// Make sure there are zero connections // Make sure there are zero connections
m.mu.Lock() m.mu.Lock()
if len(m.conns) < 0 { if len(m.conns) > 0 {
t.Fatal("Should have 0 connections") t.Fatal("Should have 0 connections")
} }
m.mu.Unlock() m.mu.Unlock()
}
func TestServerCloseBlocking(t *testing.T) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()
// Create ServerMux
m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
}))
// Set the test server config to the mux
ts.Config = &m.Server
ts.Start()
// Create a MuxListener
// var err error
ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "")
if err != nil {
t.Fatal(err)
}
m.listener = ml
dial := func() net.Conn {
c, cerr := net.Dial("tcp", ts.Listener.Addr().String())
if cerr != nil {
t.Fatal(err)
}
return c
}
// Dial to open a StateNew but don't send anything
cnew := dial()
defer cnew.Close()
// Dial another connection but idle after a request to have StateIdle
cidle := dial()
defer cidle.Close()
cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
_, err = http.ReadResponse(bufio.NewReader(cidle), nil)
if err != nil {
t.Fatal(err)
}
// Make sure we don't block forever.
m.Close()
// Make sure there are zero connections
m.mu.Lock()
if len(m.conns) > 0 {
t.Fatal("Should have 0 connections")
}
m.mu.Unlock()
}
func TestListenAndServe(t *testing.T) {
m := NewMuxServer("", nil)
stopc := make(chan struct{})
errc := make(chan error)
go func() { errc <- m.ListenAndServe() }()
go func() { errc <- m.Close(); close(stopc) }()
select {
case err := <-errc:
if err != nil {
t.Fatal(err)
}
case <-stopc:
return
}
} }