Graceful shutdown for ServerMux (#2341)

This commit is contained in:
Jesse Lucas 2016-08-12 00:33:55 -04:00 committed by Harshavardhana
parent 0b225269e1
commit ef0a108dde
3 changed files with 257 additions and 32 deletions

View File

@ -90,24 +90,6 @@ type serverCmdConfig struct {
ignoredDisks []string ignoredDisks []string
} }
// configureServer configure a new server instance
func configureServer(srvCmdConfig serverCmdConfig) *MuxServer {
// Minio server config
apiServer := &MuxServer{
Server: http.Server{
Addr: srvCmdConfig.serverAddr,
// Adding timeout of 10 minutes for unresponsive client connections.
ReadTimeout: 10 * time.Minute,
WriteTimeout: 10 * time.Minute,
Handler: configureServerHandler(srvCmdConfig),
MaxHeaderBytes: 1 << 20,
},
}
// Returns configured HTTP server.
return apiServer
}
// getListenIPs - gets all the ips to listen on. // getListenIPs - gets all the ips to listen on.
func getListenIPs(httpServerConf *http.Server) (hosts []string, port string) { func getListenIPs(httpServerConf *http.Server) (hosts []string, port string) {
host, port, err := net.SplitHostPort(httpServerConf.Addr) host, port, err := net.SplitHostPort(httpServerConf.Addr)
@ -263,12 +245,14 @@ func serverMain(c *cli.Context) {
disks := c.Args() disks := c.Args()
// Configure server. // Configure server.
apiServer := configureServer(serverCmdConfig{ handler := configureServerHandler(serverCmdConfig{
serverAddr: serverAddress, serverAddr: serverAddress,
disks: disks, disks: disks,
ignoredDisks: ignoredDisks, ignoredDisks: ignoredDisks,
}) })
apiServer := NewMuxServer(serverAddress, handler)
// Fetch endpoints which we are going to serve from. // Fetch endpoints which we are going to serve from.
endPoints := finalizeEndpoints(tls, &apiServer.Server) endPoints := finalizeEndpoints(tls, &apiServer.Server)

View File

@ -18,11 +18,14 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time"
) )
var defaultHTTP2Methods = []string{ var defaultHTTP2Methods = []string{
@ -120,31 +123,41 @@ func (c *MuxConn) Read(b []byte) (int, error) {
type MuxListener struct { type MuxListener struct {
net.Listener net.Listener
config *tls.Config config *tls.Config
wg *sync.WaitGroup
} }
// NewMuxListener - creates new MuxListener, returns error when cert/key files are not found // NewMuxListener - creates new MuxListener, returns error when cert/key files are not found
// or invalid // or invalid
func NewMuxListener(listener net.Listener, certPath, keyPath string) (MuxListener, error) { func NewMuxListener(listener net.Listener, wg *sync.WaitGroup, certPath, keyPath string) (*MuxListener, error) {
var err error var err error
config := &tls.Config{} var config *tls.Config
if config.NextProtos == nil { config = nil
config.NextProtos = []string{"http/1.1", "h2"}
if certPath != "" {
config = &tls.Config{}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1", "h2"}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return &MuxListener{}, err
}
} }
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath) l := &MuxListener{Listener: listener, config: config, wg: wg}
if err != nil {
return MuxListener{}, err return l, nil
}
return MuxListener{Listener: listener, config: config}, nil
} }
// Accept - peek the protocol to decide if we should wrap the // Accept - peek the protocol to decide if we should wrap the
// network stream with the TLS server // network stream with the TLS server
func (l MuxListener) Accept() (net.Conn, error) { func (l *MuxListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept() c, err := l.Listener.Accept()
if err != nil { if err != nil {
return c, err return c, err
} }
cmux := NewMuxConn(c) cmux := NewMuxConn(c)
protocol := cmux.PeekProtocol() protocol := cmux.PeekProtocol()
if protocol == "tls" { if protocol == "tls" {
@ -153,9 +166,46 @@ func (l MuxListener) Accept() (net.Conn, error) {
return cmux, nil return cmux, nil
} }
// Close Listener
func (l *MuxListener) Close() error {
if l == nil {
return nil
}
return l.Listener.Close()
}
// MuxServer - the main mux server // MuxServer - the main mux server
type MuxServer struct { type MuxServer struct {
http.Server http.Server
listener *MuxListener
WaitGroup *sync.WaitGroup
GracefulTimeout time.Duration
mu sync.Mutex // guards closed and conns
closed bool
conns map[net.Conn]http.ConnState // except terminal states
}
// NewMuxServer constructor to create a MuxServer
func NewMuxServer(addr string, handler http.Handler) *MuxServer {
m := &MuxServer{
Server: http.Server{
Addr: addr,
// Adding timeout of 10 minutes for unresponsive client connections.
ReadTimeout: 10 * time.Minute,
WriteTimeout: 10 * time.Minute,
Handler: handler,
MaxHeaderBytes: 1 << 20,
},
WaitGroup: &sync.WaitGroup{},
GracefulTimeout: 5 * time.Second,
}
// Track connection state
m.connState()
// Returns configured HTTP server.
return m
} }
// ListenAndServeTLS - similar to the http.Server version. However, it has the // ListenAndServeTLS - similar to the http.Server version. However, it has the
@ -166,10 +216,13 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error {
if err != nil { if err != nil {
return err return err
} }
mux, err := NewMuxListener(listener, mustGetCertFile(), mustGetKeyFile()) mux, err := NewMuxListener(listener, m.WaitGroup, mustGetCertFile(), mustGetKeyFile())
if err != nil { if err != nil {
return err return err
} }
m.listener = mux
err = http.Serve(mux, err = http.Serve(mux,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// We reach here when MuxListener.MuxConn is not wrapped with tls.Server // We reach here when MuxListener.MuxConn is not wrapped with tls.Server
@ -194,7 +247,19 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error {
// ListenAndServe - Same as the http.Server version // ListenAndServe - Same as the http.Server version
func (m *MuxServer) ListenAndServe() error { func (m *MuxServer) ListenAndServe() error {
return m.Server.ListenAndServe() listener, err := net.Listen("tcp", m.Server.Addr)
if err != nil {
return err
}
mux, err := NewMuxListener(listener, m.WaitGroup, "", "")
if err != nil {
return err
}
m.listener = mux
return m.Server.Serve(mux)
} }
func longestWord(strings []string) int { func longestWord(strings []string) int {
@ -212,3 +277,89 @@ func longestWord(strings []string) int {
return maxLen return maxLen
} }
// Close initiates the graceful shutdown
func (m *MuxServer) Close() error {
if m.closed {
return errors.New("Server has been closed")
}
m.mu.Lock()
m.Server.SetKeepAlivesEnabled(false)
m.closed = true
m.mu.Unlock()
if err := m.listener.Close(); err != nil {
return err
}
// force connections to close after timeout
wait := make(chan struct{})
go func() {
defer close(wait)
m.mu.Lock()
for c, st := range m.conns {
// Force close any idle and new connections.
if st == http.StateIdle || st == http.StateNew {
c.Close()
}
}
m.mu.Unlock()
// Wait for all connections to be gracefully closed
m.WaitGroup.Wait()
}()
// We block until all active connections are closed or the GracefulTimeout happens
select {
case <-time.After(m.GracefulTimeout):
return nil
case <-wait:
return nil
}
}
// connState setups the ConnState tracking hook to know which connections are idle
func (m *MuxServer) connState() {
// Set our ConnState to track idle connections
m.Server.ConnState = func(c net.Conn, cs http.ConnState) {
m.mu.Lock()
defer m.mu.Unlock()
switch cs {
case http.StateNew:
// New connections increment the WaitGroup and are added the the conns dictionary
m.WaitGroup.Add(1)
if m.conns == nil {
m.conns = make(map[net.Conn]http.ConnState)
}
m.conns[c] = cs
case http.StateActive:
// Only update status to StateActive if it's in the conns dictionary
if _, ok := m.conns[c]; ok {
m.conns[c] = cs
}
case http.StateIdle:
// Only update status to StateIdle if it's in the conns dictionary
if _, ok := m.conns[c]; ok {
m.conns[c] = cs
}
// If we've already closed then we need to close this connection.
// We don't allow connections to become idle after server is closed
if m.closed {
c.Close()
}
case http.StateHijacked, http.StateClosed:
// If the connection is hijacked or closed we forget it
m.forgetConn(c)
}
}
}
// forgetConn removes c from conns and decrements WaitGroup
func (m *MuxServer) forgetConn(c net.Conn) {
if _, ok := m.conns[c]; ok {
delete(m.conns, c)
m.WaitGroup.Done()
}
}

90
server-mux_test.go Normal file
View File

@ -0,0 +1,90 @@
/*
* Minio Cloud Storage, (C) 2015, 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
func TestClose(t *testing.T) {
// Create ServerMux
m := NewMuxServer("", nil)
if err := m.Close(); err != nil {
t.Error("Server errored while trying to Close", err)
}
}
func TestMuxServer(t *testing.T) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()
// Create ServerMux
m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
}))
// Set the test server config to the mux
ts.Config = &m.Server
ts.Start()
// Create a MuxListener
ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "")
if err != nil {
t.Fatal(err)
}
m.listener = ml
client := http.Client{}
res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got))
}
// Make sure there is only 1 connection
m.mu.Lock()
if len(m.conns) < 1 {
t.Fatal("Should have 1 connections")
}
m.mu.Unlock()
// Close the server
m.Close()
// Make sure there are zero connections
m.mu.Lock()
if len(m.conns) < 0 {
t.Fatal("Should have 0 connections")
}
m.mu.Unlock()
}