diff --git a/cmd/server-mux.go b/cmd/server-mux.go index ed47e155e..44fe05d79 100644 --- a/cmd/server-mux.go +++ b/cmd/server-mux.go @@ -134,16 +134,88 @@ func (c *ConnMux) Read(b []byte) (int, error) { return c.Conn.Read(b) } -// ListenerMux - encapuslates the standard net.Listener to inspect +// ListenerMux wraps the standard net.Listener to inspect // the communication protocol upon network connection +// ListenerMux also wraps net.Listener to ensure that once +// Listener.Close returns, the underlying socket has been closed. +// +// - https://github.com/golang/go/issues/10527 +// +// The default Listener returns from Close before the underlying +// socket has been closed if another goroutine has an active +// reference (e.g. is in Accept). +// +// The following sequence of events can happen: +// +// Goroutine 1 is running Accept, and is blocked, waiting for epoll +// +// Goroutine 2 calls Close. It sees an extra reference, and so cannot +// destroy the socket, but instead decrements a reference, marks the +// connection as closed and unblocks epoll. +// +// Goroutine 2 returns to the caller, makes a new connection. +// The new connection is sent to the socket (since it hasn't been destroyed) +// +// Goroutine 1 returns from epoll, and accepts the new connection. +// +// To avoid accepting connections after Close, we block Goroutine 2 +// from returning from Close till Accept returns an error to the user. type ListenerMux struct { net.Listener config *tls.Config + // Cond is used to signal Close when there are no references to the listener. + cond *sync.Cond + refs int +} + +// IsClosed - Returns if the underlying listener is closed fully. +func (l *ListenerMux) IsClosed() bool { + l.cond.L.Lock() + defer l.cond.L.Unlock() + return l.refs == 0 +} + +func (l *ListenerMux) incRef() { + l.cond.L.Lock() + l.refs++ + l.cond.L.Unlock() +} + +func (l *ListenerMux) decRef() { + l.cond.L.Lock() + l.refs-- + newRefs := l.refs + l.cond.L.Unlock() + if newRefs == 0 { + l.cond.Broadcast() + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *ListenerMux) Close() error { + if l == nil { + return nil + } + + if err := l.Listener.Close(); err != nil { + return err + } + + l.cond.L.Lock() + for l.refs > 0 { + l.cond.Wait() + } + l.cond.L.Unlock() + return nil } // Accept - peek the protocol to decide if we should wrap the // network stream with the TLS server func (l *ListenerMux) Accept() (net.Conn, error) { + l.incRef() + defer l.decRef() + conn, err := l.Listener.Accept() if err != nil { return conn, err @@ -156,14 +228,6 @@ func (l *ListenerMux) Accept() (net.Conn, error) { return connMux, nil } -// Close Listener -func (l *ListenerMux) Close() error { - if l == nil { - return nil - } - return l.Listener.Close() -} - // ServerMux - the main mux server type ServerMux struct { *http.Server @@ -215,6 +279,7 @@ func initListeners(serverAddr string, tls *tls.Config) ([]*ListenerMux, error) { listeners = append(listeners, &ListenerMux{ Listener: listener, config: tls, + cond: sync.NewCond(&sync.Mutex{}), }) return listeners, nil } @@ -239,6 +304,7 @@ func initListeners(serverAddr string, tls *tls.Config) ([]*ListenerMux, error) { listeners = append(listeners, &ListenerMux{ Listener: listener, config: tls, + cond: sync.NewCond(&sync.Mutex{}), }) } return listeners, nil @@ -294,7 +360,10 @@ func (m *ServerMux) ListenAndServeTLS(certFile, keyFile string) (err error) { } }), ) - errorIf(serr, "Unable to serve incoming requests.") + // Do not print the error if the listener is closed. + if !listener.IsClosed() { + errorIf(serr, "Unable to serve incoming requests.") + } }(listener) } // Waits for all http.Serve's to return. @@ -321,7 +390,10 @@ func (m *ServerMux) ListenAndServe() error { go func(listener *ListenerMux) { defer wg.Done() serr := m.Server.Serve(listener) - errorIf(serr, "Unable to serve incoming requests.") + // Do not print the error if the listener is closed. + if !listener.IsClosed() { + errorIf(serr, "Unable to serve incoming requests.") + } }(listener) } // Wait for all the http.Serve to finish. diff --git a/cmd/server-mux_test.go b/cmd/server-mux_test.go index efcc392f9..e7d49476d 100644 --- a/cmd/server-mux_test.go +++ b/cmd/server-mux_test.go @@ -37,6 +37,79 @@ import ( "time" ) +func TestListenerAcceptAfterClose(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 16; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + for i := 0; i < 10; i++ { + runTest(t) + } + }() + } + wg.Wait() +} + +func runTest(t *testing.T) { + const connectionsBeforeClose = 1 + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + ln = &ListenerMux{ + Listener: ln, + config: &tls.Config{}, + cond: sync.NewCond(&sync.Mutex{}), + } + + addr := ln.Addr().String() + waitForListener := make(chan error) + go func() { + defer close(waitForListener) + + var connCount int + for { + conn, aerr := ln.Accept() + if aerr != nil { + return + } + + connCount++ + if connCount > connectionsBeforeClose { + waitForListener <- errUnexpected + return + } + conn.Close() + } + }() + + for i := 0; i < connectionsBeforeClose; i++ { + err = dial(addr) + if err != nil { + t.Fatal(err) + } + } + + ln.Close() + dial(addr) + + err = <-waitForListener + if err != nil { + t.Fatal(err) + } +} + +func dial(addr string) error { + conn, err := net.Dial("tcp", addr) + if err == nil { + conn.Close() + } + return err +} + // Tests initalizing listeners. func TestInitListeners(t *testing.T) { testCases := []struct { @@ -125,6 +198,7 @@ func TestServerMux(t *testing.T) { lm := &ListenerMux{ Listener: ts.Listener, config: &tls.Config{}, + cond: sync.NewCond(&sync.Mutex{}), } m.listeners = []*ListenerMux{lm} @@ -178,6 +252,7 @@ func TestServerCloseBlocking(t *testing.T) { lm := &ListenerMux{ Listener: ts.Listener, config: &tls.Config{}, + cond: sync.NewCond(&sync.Mutex{}), } m.listeners = []*ListenerMux{lm}