Refactor HTTP server to address bugs (#4636)

* Refactor HTTP server to address bugs
* Remove unnecessary goroutine to start multiple TCP listeners.
* HTTP server waits for shutdown to maximum of Server.ShutdownTimeout
  than per serverShutdownPoll.
* Handles new connection errors properly.
* Handles read and write timeout properly.
* Handles error on start of HTTP server properly by exiting minio
  process.

Fixes #4494 #4476 & fixed review comments
This commit is contained in:
Bala FA 2017-07-13 05:03:21 +05:30 committed by Dee Koder
parent 2d23cd4f39
commit c3dd7c1f6c
14 changed files with 1737 additions and 1161 deletions

View File

@ -17,6 +17,7 @@
package cmd
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
@ -24,35 +25,34 @@ import (
"path/filepath"
)
func parsePublicCertFile(certFile string) (certs []*x509.Certificate, err error) {
var bytes []byte
if bytes, err = ioutil.ReadFile(certFile); err != nil {
return certs, err
func parsePublicCertFile(certFile string) (x509Certs []*x509.Certificate, err error) {
// Read certificate file.
var data []byte
if data, err = ioutil.ReadFile(certFile); err != nil {
return nil, err
}
// Parse all certs in the chain.
var block *pem.Block
var cert *x509.Certificate
current := bytes
current := data
for len(current) > 0 {
if block, current = pem.Decode(current); block == nil {
err = fmt.Errorf("Could not read PEM block from file %s", certFile)
return certs, err
var pemBlock *pem.Block
if pemBlock, current = pem.Decode(current); pemBlock == nil {
return nil, fmt.Errorf("Could not read PEM block from file %s", certFile)
}
if cert, err = x509.ParseCertificate(block.Bytes); err != nil {
return certs, err
var x509Cert *x509.Certificate
if x509Cert, err = x509.ParseCertificate(pemBlock.Bytes); err != nil {
return nil, err
}
certs = append(certs, cert)
x509Certs = append(x509Certs, x509Cert)
}
if len(certs) == 0 {
err = fmt.Errorf("Empty public certificate file %s", certFile)
if len(x509Certs) == 0 {
return nil, fmt.Errorf("Empty public certificate file %s", certFile)
}
return certs, err
return x509Certs, nil
}
func getRootCAs(certsCAsDir string) (*x509.CertPool, error) {
@ -81,7 +81,7 @@ func getRootCAs(certsCAsDir string) (*x509.CertPool, error) {
for _, caFile := range caFiles {
caCert, err := ioutil.ReadFile(caFile)
if err != nil {
return rootCAs, err
return nil, err
}
rootCAs.AppendCertsFromPEM(caCert)
@ -90,19 +90,26 @@ func getRootCAs(certsCAsDir string) (*x509.CertPool, error) {
return rootCAs, nil
}
func getSSLConfig() (publicCerts []*x509.Certificate, rootCAs *x509.CertPool, secureConn bool, err error) {
func getSSLConfig() (x509Certs []*x509.Certificate, rootCAs *x509.CertPool, tlsCert *tls.Certificate, secureConn bool, err error) {
if !(isFile(getPublicCertFile()) && isFile(getPrivateKeyFile())) {
return publicCerts, rootCAs, secureConn, err
return nil, nil, nil, false, nil
}
if publicCerts, err = parsePublicCertFile(getPublicCertFile()); err != nil {
return publicCerts, rootCAs, secureConn, err
if x509Certs, err = parsePublicCertFile(getPublicCertFile()); err != nil {
return nil, nil, nil, false, err
}
var cert tls.Certificate
if cert, err = tls.LoadX509KeyPair(getPublicCertFile(), getPrivateKeyFile()); err != nil {
return nil, nil, nil, false, err
}
tlsCert = &cert
if rootCAs, err = getRootCAs(getCADir()); err != nil {
return publicCerts, rootCAs, secureConn, err
return nil, nil, nil, false, err
}
secureConn = true
return publicCerts, rootCAs, secureConn, err
return x509Certs, rootCAs, tlsCert, secureConn, nil
}

View File

@ -20,11 +20,15 @@ import (
"errors"
"fmt"
"net/url"
"os"
"os/signal"
"runtime"
"strings"
"syscall"
"github.com/gorilla/mux"
"github.com/minio/cli"
miniohttp "github.com/minio/minio/pkg/http"
)
const azureGatewayTemplate = `NAME:
@ -314,8 +318,8 @@ func gatewayMain(ctx *cli.Context, backendType gatewayBackend) {
// Check and load SSL certificates.
var err error
globalPublicCerts, globalRootCAs, globalIsSSL, err = getSSLConfig()
fatalIf(err, "Invalid SSL key file")
globalPublicCerts, globalRootCAs, globalTLSCertificate, globalIsSSL, err = getSSLConfig()
fatalIf(err, "Invalid SSL certificate file")
initNSLock(false) // Enable local namespace lock.
@ -359,17 +363,15 @@ func gatewayMain(ctx *cli.Context, backendType gatewayBackend) {
}
apiServer := NewServerMux(ctx.GlobalString("address"), registerHandlers(router, handlerFns...))
globalHTTPServer = miniohttp.NewServer([]string{ctx.GlobalString("address")}, registerHandlers(router, handlerFns...), globalTLSCertificate)
// Start server, automatically configures TLS if certs are available.
go func() {
cert, key := "", ""
if globalIsSSL {
cert, key = getPublicCertFile(), getPrivateKeyFile()
}
fatalIf(apiServer.ListenAndServe(cert, key), "Failed to start minio server")
globalHTTPServerErrorCh <- globalHTTPServer.Start()
}()
signal.Notify(globalOSSignalCh, os.Interrupt, syscall.SIGTERM)
// Once endpoints are finalized, initialize the new object api.
globalObjLayerMutex.Lock()
globalObjectAPI = newObject
@ -391,8 +393,8 @@ func gatewayMain(ctx *cli.Context, backendType gatewayBackend) {
checkUpdate(mode)
// Print gateway startup message.
printGatewayStartupMessage(getAPIEndpoints(apiServer.Addr), backendType)
printGatewayStartupMessage(getAPIEndpoints(ctx.String("address")), backendType)
}
<-globalServiceDoneCh
handleSignals()
}

View File

@ -17,12 +17,15 @@
package cmd
import (
"crypto/tls"
"crypto/x509"
"os"
"runtime"
"time"
humanize "github.com/dustin/go-humanize"
"github.com/fatih/color"
miniohttp "github.com/minio/minio/pkg/http"
)
// minio configuration related constants.
@ -106,6 +109,12 @@ var (
// IsSSL indicates if the server is configured with SSL.
globalIsSSL bool
globalTLSCertificate *tls.Certificate
globalHTTPServer *miniohttp.Server
globalHTTPServerErrorCh = make(chan error)
globalOSSignalCh = make(chan os.Signal, 1)
// List of admin peers.
globalAdminPeers = adminPeers{}

View File

@ -17,11 +17,15 @@
package cmd
import (
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"github.com/minio/cli"
"github.com/minio/dsync"
miniohttp "github.com/minio/minio/pkg/http"
)
var serverFlags = []cli.Flag{
@ -149,8 +153,8 @@ func serverMain(ctx *cli.Context) {
// Check and load SSL certificates.
var err error
globalPublicCerts, globalRootCAs, globalIsSSL, err = getSSLConfig()
fatalIf(err, "Invalid SSL key file")
globalPublicCerts, globalRootCAs, globalTLSCertificate, globalIsSSL, err = getSSLConfig()
fatalIf(err, "Invalid SSL certificate file")
if !quietFlag {
// Check for new updates from dl.minio.io.
@ -176,43 +180,47 @@ func serverMain(ctx *cli.Context) {
initNSLock(globalIsDistXL)
// Configure server.
handler, err := configureServerHandler(globalEndpoints)
// Declare handler to avoid lint errors.
var handler http.Handler
handler, err = configureServerHandler(globalEndpoints)
fatalIf(err, "Unable to configure one of server's RPC services.")
// Initialize a new HTTP server.
apiServer := NewServerMux(globalMinioAddr, handler)
// Initialize S3 Peers inter-node communication only in distributed setup.
initGlobalS3Peers(globalEndpoints)
// Initialize Admin Peers inter-node communication only in distributed setup.
initGlobalAdminPeers(globalEndpoints)
// Start server, automatically configures TLS if certs are available.
globalHTTPServer = miniohttp.NewServer([]string{globalMinioAddr}, handler, globalTLSCertificate)
globalHTTPServer.UpdateBytesReadFunc = globalConnStats.incInputBytes
globalHTTPServer.UpdateBytesWrittenFunc = globalConnStats.incOutputBytes
globalHTTPServer.ErrorLogFunc = errorIf
go func() {
cert, key := "", ""
if globalIsSSL {
cert, key = getPublicCertFile(), getPrivateKeyFile()
}
fatalIf(apiServer.ListenAndServe(cert, key), "Failed to start minio server.")
globalHTTPServerErrorCh <- globalHTTPServer.Start()
}()
signal.Notify(globalOSSignalCh, os.Interrupt, syscall.SIGTERM)
newObject, err := newObjectLayer(globalEndpoints)
fatalIf(err, "Initializing object layer failed")
if err != nil {
errorIf(err, "Initializing object layer failed")
err = globalHTTPServer.Shutdown()
errorIf(err, "Unable to shutdown http server")
os.Exit(1)
}
globalObjLayerMutex.Lock()
globalObjectAPI = newObject
globalObjLayerMutex.Unlock()
// Prints the formatted startup message once object layer is initialized.
apiEndpoints := getAPIEndpoints(apiServer.Addr)
apiEndpoints := getAPIEndpoints(globalMinioAddr)
printStartupMessage(apiEndpoints)
// Set uptime time after object layer has initialized.
globalBootTime = UTCNow()
// Waits on the server.
<-globalServiceDoneCh
handleSignals()
}
// Initialize object layer with the supplied disks, objectLayer is nil upon any error.

View File

@ -1,526 +0,0 @@
/*
* Minio Cloud Storage, (C) 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cmd
import (
"bufio"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
serverShutdownPoll = 500 * time.Millisecond
)
// The value chosen below is longest word chosen
// from all the http verbs comprising of
// "PRI", "OPTIONS", "GET", "HEAD", "POST",
// "PUT", "DELETE", "TRACE", "CONNECT".
const (
maxHTTPVerbLen = 7
)
// HTTP2 PRI method.
var httpMethodPRI = "PRI"
var defaultHTTP2Methods = []string{
httpMethodPRI,
}
var defaultHTTP1Methods = []string{
http.MethodOptions,
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodDelete,
http.MethodTrace,
http.MethodConnect,
}
// ConnMux - Peeks into the incoming connection for relevant
// protocol without advancing the underlying net.Conn (io.Reader).
// ConnMux - allows us to multiplex between TLS and Regular HTTP
// connections on the same listeners.
type ConnMux struct {
net.Conn
// To peek net.Conn incoming data
peeker *bufio.Reader
}
// NewConnMux - creates a new ConnMux instance
func NewConnMux(c net.Conn) *ConnMux {
br := bufio.NewReader(c)
return &ConnMux{
Conn: c,
peeker: bufio.NewReader(br),
}
}
// List of protocols to be detected by PeekProtocol function.
const (
protocolTLS = "tls"
protocolHTTP1 = "http"
protocolHTTP2 = "http2"
)
// PeekProtocol - reads the first bytes, then checks if it is similar
// to one of the default http methods. Returns error if there are any
// errors in peeking over the connection.
func (c *ConnMux) PeekProtocol() (string, error) {
// Peek for HTTP verbs.
buf, err := c.peeker.Peek(maxHTTPVerbLen)
if err != nil {
return "", err
}
// Check for HTTP2 methods first.
for _, m := range defaultHTTP2Methods {
if strings.HasPrefix(string(buf), m) {
return protocolHTTP2, nil
}
}
// Check for HTTP1 methods.
for _, m := range defaultHTTP1Methods {
if strings.HasPrefix(string(buf), m) {
return protocolHTTP1, nil
}
}
// Default to TLS, this is not a real indication
// that the connection is TLS but that will be
// validated later by doing a handshake.
return protocolTLS, nil
}
// Read reads from the tcp session for data sent by
// the client, additionally sets deadline for 15 secs
// after each successful read. Deadline cancels and
// returns error if the client does not send any
// data in 15 secs. Also keeps track of the total
// bytes received from the client.
func (c *ConnMux) Read(b []byte) (n int, err error) {
// Update total incoming number of bytes.
defer func() {
globalConnStats.incInputBytes(n)
}()
n, err = c.peeker.Read(b)
if err != nil {
return n, err
}
// Read deadline was already set previously, set again
// after a successful read operation for future read
// operations.
c.Conn.SetReadDeadline(UTCNow().Add(defaultTCPReadTimeout))
// Success.
return n, nil
}
// Write to the client over a tcp session, additionally
// keeps track of the total bytes written by the server.
func (c *ConnMux) Write(b []byte) (n int, err error) {
// Update total outgoing number of bytes.
defer func() {
globalConnStats.incOutputBytes(n)
}()
// Call the conn write wrapper.
return c.Conn.Write(b)
}
// Close closes the underlying tcp connection.
func (c *ConnMux) Close() (err error) {
// Make sure that we always close a connection,
return c.Conn.Close()
}
// ListenerMux wraps the standard net.Listener to inspect
// the communication protocol upon network connection
// ListenerMux also wraps net.Listener to ensure that once
// Listener.Close returns, the underlying socket has been closed.
//
// - https://github.com/golang/go/issues/10527
//
// The default Listener returns from Close before the underlying
// socket has been closed if another goroutine has an active
// reference (e.g. is in Accept).
//
// The following sequence of events can happen:
//
// Goroutine 1 is running Accept, and is blocked, waiting for epoll
//
// Goroutine 2 calls Close. It sees an extra reference, and so cannot
// destroy the socket, but instead decrements a reference, marks the
// connection as closed and unblocks epoll.
//
// Goroutine 2 returns to the caller, makes a new connection.
// The new connection is sent to the socket (since it hasn't been destroyed)
//
// Goroutine 1 returns from epoll, and accepts the new connection.
//
// To avoid accepting connections after Close, we block Goroutine 2
// from returning from Close till Accept returns an error to the user.
type ListenerMux struct {
net.Listener
config *tls.Config
// acceptResCh is a channel for transporting wrapped net.Conn (regular or tls)
// after peeking the content of the latter
acceptResCh chan ListenerMuxAcceptRes
// Cond is used to signal Close when there are no references to the listener.
cond *sync.Cond
refs int
}
// ListenerMuxAcceptRes contains then final net.Conn data (wrapper by tls or not) to be sent to the http handler
type ListenerMuxAcceptRes struct {
conn net.Conn
err error
}
// Default keep alive interval timeout, on your Linux system to figure out
// maximum probes sent
//
// > cat /proc/sys/net/ipv4/tcp_keepalive_probes
// ! 9
//
// Final value of total keep-alive comes upto 9 x 10 * seconds = 1.5 minutes.
const defaultKeepAliveTimeout = 10 * time.Second // 10 seconds.
// Timeout to close and return error to the client when not sending any data.
const defaultTCPReadTimeout = 15 * time.Second // 15 seconds.
// newListenerMux listens and wraps accepted connections with tls after protocol peeking
func newListenerMux(listener net.Listener, config *tls.Config) *ListenerMux {
l := ListenerMux{
Listener: listener,
config: config,
cond: sync.NewCond(&sync.Mutex{}),
acceptResCh: make(chan ListenerMuxAcceptRes),
}
// Start listening, wrap connections with tls when needed
go func() {
// Extract tcp listener.
tcpListener, ok := l.Listener.(*net.TCPListener)
if !ok {
l.acceptResCh <- ListenerMuxAcceptRes{err: errInvalidArgument}
return
}
// Loop for accepting new connections
for {
// Use accept TCP method to receive the connection.
conn, err := tcpListener.AcceptTCP()
if err != nil {
l.acceptResCh <- ListenerMuxAcceptRes{err: err}
continue
}
// Enable Read timeout
conn.SetReadDeadline(UTCNow().Add(defaultTCPReadTimeout))
// Enable keep alive for each connection.
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(defaultKeepAliveTimeout)
// Allocate new conn muxer.
connMux := NewConnMux(conn)
// Wrap the connection with ConnMux to be able to peek the data in the incoming connection
// and decide if we need to wrap the connection itself with a TLS or not
go func(connMux *ConnMux) {
protocol, cerr := connMux.PeekProtocol()
if cerr != nil {
// io.EOF is usually returned by non-http clients,
// just close the connection to avoid any leak.
if cerr != io.EOF {
errorIf(cerr, "Unable to peek into incoming protocol")
}
connMux.Close()
return
}
switch protocol {
case protocolTLS:
tlsConn := tls.Server(connMux, l.config)
// Make sure to handshake so that we know that this
// is a TLS connection, if not we should close and reject
// such a connection.
if cerr = tlsConn.Handshake(); cerr != nil {
// Close for junk message.
tlsConn.Close()
return
}
l.acceptResCh <- ListenerMuxAcceptRes{
conn: tlsConn,
}
default:
l.acceptResCh <- ListenerMuxAcceptRes{
conn: connMux,
}
}
}(connMux)
}
}()
return &l
}
// IsClosed - Returns if the underlying listener is closed fully.
func (l *ListenerMux) IsClosed() bool {
l.cond.L.Lock()
defer l.cond.L.Unlock()
return l.refs == 0
}
func (l *ListenerMux) incRef() {
l.cond.L.Lock()
l.refs++
l.cond.L.Unlock()
}
func (l *ListenerMux) decRef() {
l.cond.L.Lock()
l.refs--
newRefs := l.refs
l.cond.L.Unlock()
if newRefs == 0 {
l.cond.Broadcast()
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *ListenerMux) Close() error {
if l == nil {
return nil
}
if err := l.Listener.Close(); err != nil {
return err
}
l.cond.L.Lock()
for l.refs > 0 {
l.cond.Wait()
}
l.cond.L.Unlock()
return nil
}
// Accept - peek the protocol to decide if we should wrap the
// network stream with the TLS server
func (l *ListenerMux) Accept() (net.Conn, error) {
l.incRef()
defer l.decRef()
res := <-l.acceptResCh
return res.conn, res.err
}
// ServerMux - the main mux server
type ServerMux struct {
Addr string
handler http.Handler
listeners []*ListenerMux
// Current number of concurrent http requests
currentReqs int32
// Time to wait before forcing server shutdown
gracefulTimeout time.Duration
mu sync.RWMutex // guards closing, and listeners
closing bool
}
// NewServerMux constructor to create a ServerMux
func NewServerMux(addr string, handler http.Handler) *ServerMux {
m := &ServerMux{
Addr: addr,
handler: handler,
// Wait for 5 seconds for new incoming connnections, otherwise
// forcibly close them during graceful stop or restart.
gracefulTimeout: 5 * time.Second,
}
// Returns configured HTTP server.
return m
}
// Initialize listeners on all ports.
func initListeners(serverAddr string, tls *tls.Config) ([]*ListenerMux, error) {
host, port, err := net.SplitHostPort(serverAddr)
if err != nil {
return nil, err
}
var listeners []*ListenerMux
if host == "" {
var listener net.Listener
listener, err = net.Listen("tcp", serverAddr)
if err != nil {
return nil, err
}
listeners = append(listeners, newListenerMux(listener, tls))
return listeners, nil
}
var addrs []string
if net.ParseIP(host) != nil {
addrs = append(addrs, host)
} else {
addrs, err = net.LookupHost(host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, errUnexpected
}
}
for _, addr := range addrs {
var listener net.Listener
listener, err = net.Listen("tcp", net.JoinHostPort(addr, port))
if err != nil {
return nil, err
}
listeners = append(listeners, newListenerMux(listener, tls))
}
return listeners, nil
}
// ListenAndServe - serve HTTP requests with protocol multiplexing support
// TLS is actived when certFile and keyFile parameters are not empty.
func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
tlsEnabled := certFile != "" && keyFile != ""
config := &tls.Config{
// Causes servers to use Go's default ciphersuite preferences,
// which are tuned to avoid attacks. Does nothing on clients.
PreferServerCipherSuites: true,
// Set minimum version to TLS 1.2
MinVersion: tls.VersionTLS12,
} // Always instantiate.
if tlsEnabled {
// Configure TLS in the server
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1", "h2"}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
}
go m.handleServiceSignals()
listeners, err := initListeners(m.Addr, config)
if err != nil {
return err
}
m.mu.Lock()
m.listeners = listeners
m.mu.Unlock()
// All http requests start to be processed by httpHandler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tlsEnabled && r.TLS == nil {
// TLS is enabled but request is not TLS
// configured - return error to client.
writeErrorResponse(w, ErrInsecureClientRequest, &url.URL{})
} else {
// Return ServiceUnavailable for clients which are sending requests
// in shutdown phase
m.mu.RLock()
closing := m.closing
m.mu.RUnlock()
if closing {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// Execute registered handlers, update currentReqs to keep
// track of concurrent requests processing on the server
atomic.AddInt32(&m.currentReqs, 1)
m.handler.ServeHTTP(w, r)
atomic.AddInt32(&m.currentReqs, -1)
}
})
var wg = &sync.WaitGroup{}
for _, listener := range listeners {
wg.Add(1)
go func(listener *ListenerMux) {
defer wg.Done()
serr := http.Serve(listener, httpHandler)
// Do not print the error if the listener is closed.
if !listener.IsClosed() {
errorIf(serr, "Unable to serve incoming requests.")
}
}(listener)
}
// Wait for all http.Serve's to return.
wg.Wait()
return nil
}
// Close initiates the graceful shutdown
func (m *ServerMux) Close() error {
m.mu.Lock()
if m.closing {
m.mu.Unlock()
return errors.New("Server has been closed")
}
// Closed completely.
m.closing = true
// Close the listeners.
for _, listener := range m.listeners {
if err := listener.Close(); err != nil {
m.mu.Unlock()
return err
}
}
m.mu.Unlock()
// Starting graceful shutdown. Check if all requests are finished
// in regular interval or force the shutdown
ticker := time.NewTicker(serverShutdownPoll)
defer ticker.Stop()
for {
select {
case <-time.After(m.gracefulTimeout):
return nil
case <-ticker.C:
if atomic.LoadInt32(&m.currentReqs) <= 0 {
return nil
}
}
}
}

View File

@ -1,506 +0,0 @@
/*
* Minio Cloud Storage, (C) 2015, 2016, 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cmd
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"net/http"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
)
func TestListenerAcceptAfterClose(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
runTest(t)
}
}()
}
wg.Wait()
}
func runTest(t *testing.T) {
const connectionsBeforeClose = 1
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
ln = newListenerMux(ln, &tls.Config{})
addr := ln.Addr().String()
waitForListener := make(chan error)
go func() {
defer close(waitForListener)
var connCount int
for {
conn, aerr := ln.Accept()
if aerr != nil {
return
}
connCount++
if connCount > connectionsBeforeClose {
waitForListener <- errUnexpected
return
}
conn.Close()
}
}()
for i := 0; i < connectionsBeforeClose; i++ {
err = dial(addr)
if err != nil {
t.Fatal(err)
}
}
ln.Close()
dial(addr)
err = <-waitForListener
if err != nil {
t.Fatal(err)
}
}
func dial(addr string) error {
conn, err := net.Dial("tcp", addr)
if err == nil {
conn.Close()
}
return err
}
// Tests initializing listeners.
func TestInitListeners(t *testing.T) {
testCases := []struct {
serverAddr string
shouldPass bool
}{
// Test 1 with ip and port.
{
serverAddr: net.JoinHostPort("127.0.0.1", "0"),
shouldPass: true,
},
// Test 2 only port.
{
serverAddr: net.JoinHostPort("", "0"),
shouldPass: true,
},
// Test 3 with no port error.
{
serverAddr: "127.0.0.1",
shouldPass: false,
},
// Test 4 with 'foobar' host not resolvable.
{
serverAddr: "foobar:9000",
shouldPass: false,
},
}
for i, testCase := range testCases {
listeners, err := initListeners(testCase.serverAddr, &tls.Config{})
if testCase.shouldPass {
if err != nil {
t.Fatalf("Test %d: Unable to initialize listeners %s", i+1, err)
}
for _, listener := range listeners {
if err = listener.Close(); err != nil {
t.Fatalf("Test %d: Unable to close listeners %s", i+1, err)
}
}
}
if err == nil && !testCase.shouldPass {
t.Fatalf("Test %d: Should fail but is successful", i+1)
}
}
// Windows doesn't have 'localhost' hostname.
if runtime.GOOS != globalWindowsOSName {
listeners, err := initListeners("localhost:"+getFreePort(), &tls.Config{})
if err != nil {
t.Fatalf("Test 3: Unable to initialize listeners %s", err)
}
for _, listener := range listeners {
if err = listener.Close(); err != nil {
t.Fatalf("Test 3: Unable to close listeners %s", err)
}
}
}
}
func TestClose(t *testing.T) {
// Create ServerMux
m := NewServerMux("", nil)
if err := m.Close(); err != nil {
t.Error("Server errored while trying to Close", err)
}
// Closing again should return an error.
if err := m.Close(); err.Error() != "Server has been closed" {
t.Error("Unexepcted error expected \"Server has been closed\", got", err)
}
}
func TestServerMux(t *testing.T) {
var err error
var got []byte
var res *http.Response
// Create ServerMux
m := NewServerMux("127.0.0.1:0", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
}))
// Start serving requests
go m.ListenAndServe("", "")
// Issue a GET request. Since we started server in a goroutine, it could be not ready
// at this point. So we allow until 5 failed retries before declare there is an error
for i := 0; i < 5; i++ {
// Sleep one second
time.Sleep(1 * time.Second)
// Check if one listener is ready
m.mu.Lock()
listenersCount := len(m.listeners)
m.mu.Unlock()
if listenersCount == 0 {
continue
}
m.mu.Lock()
listenerAddr := m.listeners[0].Addr().String()
m.mu.Unlock()
// Issue the GET request
client := http.Client{}
res, err = client.Get("http://" + listenerAddr)
if err != nil {
continue
}
// Read the request response
got, err = ioutil.ReadAll(res.Body)
if err != nil {
continue
}
// We've got a response, quit the loop
break
}
// Check for error persisted after 5 times
if err != nil {
t.Fatal(err)
}
// Check the web service response
if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got))
}
}
func TestServerCloseBlocking(t *testing.T) {
// Create ServerMux
m := NewServerMux("127.0.0.1:0", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
}))
// Start serving requests in a goroutine
go m.ListenAndServe("", "")
// Dial, try until 5 times before declaring a failure
dial := func() (net.Conn, error) {
var c net.Conn
var err error
for i := 0; i < 5; i++ {
// Sleep one second in case of the server is not ready yet
time.Sleep(1 * time.Second)
// Check if there is at least one listener configured
m.mu.Lock()
if len(m.listeners) == 0 {
m.mu.Unlock()
continue
}
m.mu.Unlock()
// Run the actual Dial
m.mu.Lock()
c, err = net.Dial("tcp", m.listeners[0].Addr().String())
m.mu.Unlock()
if err != nil {
continue
}
}
return c, err
}
// Dial to open a StateNew but don't send anything
cnew, err := dial()
if err != nil {
t.Fatal(err)
}
defer cnew.Close()
// Dial another connection but idle after a request to have StateIdle
cidle, err := dial()
if err != nil {
t.Fatal(err)
}
defer cidle.Close()
cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
_, err = http.ReadResponse(bufio.NewReader(cidle), nil)
if err != nil {
t.Fatal(err)
}
// Make sure we don't block forever.
m.Close()
}
func TestServerListenAndServePlain(t *testing.T) {
wait := make(chan struct{})
addr := net.JoinHostPort("127.0.0.1", getFreePort())
errc := make(chan error)
once := &sync.Once{}
// Initialize done channel specifically for each tests.
globalServiceDoneCh = make(chan struct{}, 1)
// Create ServerMux and when we receive a request we stop waiting
m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
once.Do(func() { close(wait) })
}))
// ListenAndServe in a goroutine, but we don't know when it's ready
go func() { errc <- m.ListenAndServe("", "") }()
// Keep trying the server until it's accepting connections
go func() {
client := http.Client{Timeout: time.Millisecond * 10}
for {
res, _ := client.Get("http://" + addr)
if res != nil && res.StatusCode == http.StatusOK {
break
}
}
}()
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
select {
case err := <-errc:
if err != nil {
t.Fatal(err)
}
case <-wait:
return
}
}()
// Wait until we get an error or wait closed
wg.Wait()
// Shutdown the ServerMux
m.Close()
}
func TestServerListenAndServeTLS(t *testing.T) {
rootPath, err := newTestConfig(globalMinioDefaultRegion)
if err != nil {
t.Fatalf("Init Test config failed")
}
defer removeAll(rootPath)
wait := make(chan struct{})
addr := net.JoinHostPort("127.0.0.1", getFreePort())
errc := make(chan error)
once := &sync.Once{}
// Initialize done channel specifically for each tests.
globalServiceDoneCh = make(chan struct{}, 1)
// Create ServerMux and when we receive a request we stop waiting
m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
once.Do(func() { close(wait) })
}))
// Create a cert
err = createConfigDir()
if err != nil {
t.Fatal(err)
}
certFile := getPublicCertFile()
keyFile := getPrivateKeyFile()
defer os.RemoveAll(certFile)
defer os.RemoveAll(keyFile)
err = generateTestCert(addr)
if err != nil {
t.Error(err)
return
}
// ListenAndServe in a goroutine, but we don't know when it's ready
go func() { errc <- m.ListenAndServe(certFile, keyFile) }()
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := http.Client{
Timeout: time.Millisecond * 10,
Transport: tr,
}
// Keep trying the server until it's accepting connections
start := UTCNow()
for {
res, _ := client.Get("https://" + addr)
if res != nil && res.StatusCode == http.StatusOK {
break
}
// Explicit check to terminate loop after 5 minutes
// (for investigational purpose of issue #4461)
if UTCNow().Sub(start) >= 5*time.Minute {
t.Fatalf("Failed to establish connection after 5 minutes")
}
}
// Once a request succeeds, subsequent requests should
// work fine.
res, err := client.Get("http://" + addr)
if err != nil {
t.Errorf("Got unexpected error: %v", err)
}
// Without TLS we expect a Bad-Request response from the server.
if !(res != nil && res.StatusCode == http.StatusBadRequest && res.Request.URL.Scheme == httpScheme) {
t.Fatalf("Plaintext request to TLS server did not have expected response!")
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Error reading body")
}
// Check that the expected error is received.
bodyStr := string(body)
apiErr := getAPIError(ErrInsecureClientRequest)
if !(strings.Contains(bodyStr, apiErr.Code) && strings.Contains(bodyStr, apiErr.Description)) {
t.Fatalf("Plaintext request to TLS server did not have expected response body!")
}
wg.Done()
}()
wg.Add(1)
go func() {
defer wg.Done()
select {
case err := <-errc:
if err != nil {
t.Error(err)
return
}
case <-wait:
return
}
}()
// Wait until we get an error or wait closed
wg.Wait()
// Shutdown the ServerMux
m.Close()
}
// generateTestCert creates a cert and a key used for testing only
func generateTestCert(host string) error {
certPath := getPublicCertFile()
keyPath := getPrivateKeyFile()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Minio Test Cert"},
},
NotBefore: UTCNow(),
NotAfter: UTCNow().Add(time.Minute * 1),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
if ip := net.ParseIP(host); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
}
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return err
}
certOut, err := os.Create(certPath)
if err != nil {
return err
}
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
keyOut.Close()
return nil
}

View File

@ -19,8 +19,6 @@ package cmd
import (
"os"
"os/exec"
"syscall"
"time"
)
// Type of service signals currently supported.
@ -64,59 +62,3 @@ func restartProcess() error {
cmd.Stderr = os.Stderr
return cmd.Start()
}
// Handles all serviceSignal and execute service functions.
func (m *ServerMux) handleServiceSignals() error {
// Custom exit function
runExitFn := func(err error) {
// If global profiler is set stop before we exit.
if globalProfiler != nil {
globalProfiler.Stop()
}
// Call user supplied user exit function
fatalIf(err, "Unable to gracefully complete service operation.")
// We are usually done here, close global service done channel.
globalServiceDoneCh <- struct{}{}
}
// Wait for SIGTERM in a go-routine.
trapCh := signalTrap(os.Interrupt, syscall.SIGTERM)
go func(trapCh <-chan bool) {
<-trapCh
globalServiceSignalCh <- serviceStop
}(trapCh)
// Start listening on service signal. Monitor signals.
for {
signal := <-globalServiceSignalCh
switch signal {
case serviceStatus:
/// We don't do anything for this.
case serviceRestart:
if err := m.Close(); err != nil {
errorIf(err, "Unable to close server gracefully")
}
if err := restartProcess(); err != nil {
errorIf(err, "Unable to restart the server.")
}
runExitFn(nil)
case serviceStop:
log.Println("Received signal to exit.")
go func() {
time.Sleep(serverShutdownPoll + time.Millisecond*100)
log.Println("Waiting for active connections to terminate - press Ctrl+C to quit immediately.")
}()
if err := m.Close(); err != nil {
errorIf(err, "Unable to close server gracefully")
}
objAPI := newObjectLayerFn()
if objAPI == nil {
// Server not initialized yet, exit happily.
runExitFn(nil)
} else {
runExitFn(objAPI.Shutdown())
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Minio Client, (C) 2015 Minio, Inc.
* Minio Cloud Storage, (C) 2015, 2016, 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,32 +18,67 @@ package cmd
import (
"os"
"os/signal"
)
// signalTrap traps the registered signals and notifies the caller.
func signalTrap(sig ...os.Signal) <-chan bool {
// channel to notify the caller.
trapCh := make(chan bool, 1)
func handleSignals() {
// Custom exit function
exit := func(state bool) {
// If global profiler is set stop before we exit.
if globalProfiler != nil {
globalProfiler.Stop()
}
go func(chan<- bool) {
// channel to receive signals.
sigCh := make(chan os.Signal, 1)
defer close(sigCh)
if state {
os.Exit(0)
}
// `signal.Notify` registers the given channel to
// receive notifications of the specified signals.
signal.Notify(sigCh, sig...)
os.Exit(1)
}
// Wait for the signal.
<-sigCh
stopProcess := func() bool {
var err, oerr error
// Once signal has been received stop signal Notify handler.
signal.Stop(sigCh)
err = globalHTTPServer.Shutdown()
errorIf(err, "Unable to shutdown http server")
// Notify the caller.
trapCh <- true
}(trapCh)
if objAPI := newObjectLayerFn(); objAPI != nil {
oerr = objAPI.Shutdown()
errorIf(oerr, "Unable to shutdown object layer")
}
return trapCh
return (err == nil && oerr == nil)
}
for {
select {
case err := <-globalHTTPServerErrorCh:
errorIf(err, "http server exited abnormally")
var oerr error
if objAPI := newObjectLayerFn(); objAPI != nil {
oerr = objAPI.Shutdown()
errorIf(oerr, "Unable to shutdown object layer")
}
exit(err == nil && oerr == nil)
case osSignal := <-globalOSSignalCh:
log.Printf("Exiting on signal %v\n", osSignal)
exit(stopProcess())
case signal := <-globalServiceSignalCh:
switch signal {
case serviceStatus:
// Ignore this at the moment.
case serviceRestart:
log.Println("Restarting on service signal")
err := globalHTTPServer.Shutdown()
errorIf(err, "Unable to shutdown http server")
rerr := restartProcess()
errorIf(rerr, "Unable to restart the server")
exit(err == nil && rerr == nil)
case serviceStop:
log.Println("Stopping on service signal")
exit(stopProcess())
}
}
}
}

88
pkg/http/bufconn.go Normal file
View File

@ -0,0 +1,88 @@
/*
* Minio Cloud Storage, (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package http
import (
"bufio"
"net"
"time"
)
// BufConn - is a generic stream-oriented network connection supporting buffered reader and read/write timeout.
type BufConn struct {
net.Conn
bufReader *bufio.Reader // buffered reader wraps reader in net.Conn.
readTimeout time.Duration // sets the read timeout in the connection.
writeTimeout time.Duration // sets the write timeout in the connection.
updateBytesReadFunc func(int) // function to be called to update bytes read.
updateBytesWrittenFunc func(int) // function to be called to update bytes written.
}
func (c *BufConn) setReadTimeout() {
if c.readTimeout != 0 {
c.SetReadDeadline(time.Now().UTC().Add(c.readTimeout))
}
}
func (c *BufConn) setWriteTimeout() {
if c.writeTimeout != 0 {
c.SetWriteDeadline(time.Now().UTC().Add(c.writeTimeout))
}
}
// Peek - returns the next n bytes without advancing the reader. It just wraps bufio.Reader.Peek().
func (c *BufConn) Peek(n int) ([]byte, error) {
c.setReadTimeout()
return c.bufReader.Peek(n)
}
// Read - reads data from the connection using wrapped buffered reader.
func (c *BufConn) Read(b []byte) (n int, err error) {
c.setReadTimeout()
n, err = c.bufReader.Read(b)
if err == nil && c.updateBytesReadFunc != nil {
c.updateBytesReadFunc(n)
}
return n, err
}
// Write - writes data to the connection.
func (c *BufConn) Write(b []byte) (n int, err error) {
c.setWriteTimeout()
n, err = c.Conn.Write(b)
if err == nil && c.updateBytesWrittenFunc != nil {
c.updateBytesWrittenFunc(n)
}
return n, err
}
// newBufConn - creates a new connection object wrapping net.Conn.
func newBufConn(c net.Conn, readTimeout, writeTimeout time.Duration,
updateBytesReadFunc, updateBytesWrittenFunc func(int)) *BufConn {
return &BufConn{
Conn: c,
bufReader: bufio.NewReader(c),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
updateBytesReadFunc: updateBytesReadFunc,
updateBytesWrittenFunc: updateBytesWrittenFunc,
}
}

108
pkg/http/bufconn_test.go Normal file
View File

@ -0,0 +1,108 @@
/*
* Minio Cloud Storage, (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package http
import (
"bufio"
"io"
"net"
"sync"
"testing"
"time"
)
// Test bufconn handles read timeout properly by reading two messages beyond deadline.
func TestBuffConnReadTimeout(t *testing.T) {
port := getNextPort()
l, err := net.Listen("tcp", "localhost:"+port)
if err != nil {
t.Fatalf("unable to create listener. %v", err)
}
defer l.Close()
tcpListener, ok := l.(*net.TCPListener)
if !ok {
t.Fatalf("failed to assert to net.TCPListener")
}
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
tcpConn, terr := tcpListener.AcceptTCP()
if terr != nil {
t.Fatalf("failed to accept new connection. %v", terr)
}
bufconn := newBufConn(tcpConn, 1*time.Second, 1*time.Second, nil, nil)
defer bufconn.Close()
// Read a line
var b = make([]byte, 12)
_, terr = bufconn.Read(b)
if terr != nil {
t.Fatalf("failed to read from client. %v", terr)
}
received := string(b)
if received != "message one\n" {
t.Fatalf(`server: expected: "message one\n", got: %v`, received)
}
// Wait for more than read timeout to simulate processing.
time.Sleep(3 * time.Second)
_, terr = bufconn.Read(b)
if terr != nil {
t.Fatalf("failed to read from client. %v", terr)
}
received = string(b)
if received != "message two\n" {
t.Fatalf(`server: expected: "message two\n", got: %v`, received)
}
// Send a response.
_, terr = io.WriteString(bufconn, "messages received\n")
if terr != nil {
t.Fatalf("failed to write to client. %v", terr)
}
}()
c, err := net.Dial("tcp", "localhost:"+port)
if err != nil {
t.Fatalf("unable to connect to server. %v", err)
}
defer c.Close()
_, err = io.WriteString(c, "message one\n")
if err != nil {
t.Fatalf("failed to write to server. %v", err)
}
_, err = io.WriteString(c, "message two\n")
if err != nil {
t.Fatalf("failed to write to server. %v", err)
}
received, err := bufio.NewReader(c).ReadString('\n')
if err != nil {
t.Fatalf("failed to read from server. %v", err)
}
if received != "messages received\n" {
t.Fatalf(`client: expected: "messages received\n", got: %v`, received)
}
wg.Wait()
}

320
pkg/http/listener.go Normal file
View File

@ -0,0 +1,320 @@
/*
* Minio Cloud Storage, (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package http
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"sync"
"syscall"
"time"
)
var sslRequiredErrMsg = []byte("HTTP/1.0 403 Forbidden\r\n\r\nSSL required")
// HTTP methods.
var methods = []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodConnect,
http.MethodOptions,
http.MethodTrace,
"PRI", // HTTP 2 method
}
// maximum length of above methods + one space.
var methodMaxLen = getMethodMaxLen() + 1
func getMethodMaxLen() int {
maxLen := 0
for _, method := range methods {
if len(method) > maxLen {
maxLen = len(method)
}
}
return maxLen
}
func isHTTPMethod(s string) bool {
for _, method := range methods {
if s == method {
return true
}
}
return false
}
type acceptResult struct {
conn net.Conn
err error
}
// httpListener - HTTP listener capable of handling multiple server addresses.
type httpListener struct {
mutex sync.Mutex // to guard Close() method.
tcpListeners []*net.TCPListener // underlaying TCP listeners.
acceptCh chan acceptResult // channel where all TCP listeners write accepted connection.
doneCh chan struct{} // done channel for TCP listener goroutines.
tlsConfig *tls.Config // TLS configuration
tcpKeepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
updateBytesReadFunc func(int) // function to be called to update bytes read in BufConn.
updateBytesWrittenFunc func(int) // function to be called to update bytes written in BufConn.
errorLogFunc func(error, string, ...interface{}) // function to be called on errors.
}
// start - starts separate goroutine for each TCP listener. A valid insecure/TLS HTTP new connection is passed to httpListener.acceptCh.
func (listener *httpListener) start() {
listener.acceptCh = make(chan acceptResult)
listener.doneCh = make(chan struct{})
// Closure to send acceptResult to acceptCh.
// It returns true if the result is sent else false if returns when doneCh is closed.
send := func(result acceptResult, doneCh <-chan struct{}) bool {
select {
case listener.acceptCh <- result:
// Successfully written to acceptCh
return true
case <-doneCh:
// As stop signal is received, close accepted connection.
if result.conn != nil {
result.conn.Close()
}
return false
}
}
// Closure to handle single connection.
handleConn := func(tcpConn *net.TCPConn, doneCh <-chan struct{}) {
// Tune accepted TCP connection.
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(listener.tcpKeepAliveTimeout)
bufconn := newBufConn(tcpConn, listener.readTimeout, listener.writeTimeout,
listener.updateBytesReadFunc, listener.updateBytesWrittenFunc)
// Peek bytes of maximum length of all HTTP methods.
data, err := bufconn.Peek(methodMaxLen)
if err != nil {
if listener.errorLogFunc != nil {
listener.errorLogFunc(err,
"Error in reading from new connection %s at server %s",
bufconn.RemoteAddr(), bufconn.LocalAddr())
}
bufconn.Close()
return
}
// Return bufconn if read data is a valid HTTP method.
tokens := strings.SplitN(string(data), " ", 2)
if isHTTPMethod(tokens[0]) {
if listener.tlsConfig == nil {
send(acceptResult{bufconn, nil}, doneCh)
} else {
// As TLS is configured and we got plain text HTTP request,
// return 403 (forbidden) error.
bufconn.Write(sslRequiredErrMsg)
bufconn.Close()
}
return
}
if listener.tlsConfig != nil {
// As the listener is configured with TLS, try to do TLS handshake, drop the connection if it fails.
tlsConn := tls.Server(bufconn, listener.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if listener.errorLogFunc != nil {
listener.errorLogFunc(err,
"TLS handshake failed with new connection %s at server %s",
bufconn.RemoteAddr(), bufconn.LocalAddr())
}
bufconn.Close()
return
}
// Check whether the connection contains HTTP request or not.
bufconn = newBufConn(tlsConn, listener.readTimeout, listener.writeTimeout,
listener.updateBytesReadFunc, listener.updateBytesWrittenFunc)
// Peek bytes of maximum length of all HTTP methods.
data, err := bufconn.Peek(methodMaxLen)
if err != nil {
if listener.errorLogFunc != nil {
listener.errorLogFunc(err,
"Error in reading from new TLS connection %s at server %s",
bufconn.RemoteAddr(), bufconn.LocalAddr())
}
bufconn.Close()
return
}
// Return bufconn if read data is a valid HTTP method.
tokens := strings.SplitN(string(data), " ", 2)
if isHTTPMethod(tokens[0]) {
send(acceptResult{bufconn, nil}, doneCh)
return
}
}
if listener.errorLogFunc != nil {
listener.errorLogFunc(errors.New("junk message"),
"Received non-HTTP message from new connection %s at server %s",
bufconn.RemoteAddr(), bufconn.LocalAddr())
}
bufconn.Close()
return
}
// Closure to handle TCPListener until done channel is closed.
handleListener := func(tcpListener *net.TCPListener, doneCh <-chan struct{}) {
for {
tcpConn, err := tcpListener.AcceptTCP()
if err != nil {
// Returns when send fails.
if !send(acceptResult{nil, err}, doneCh) {
return
}
} else {
go handleConn(tcpConn, doneCh)
}
}
}
// Start separate goroutine for each TCP listener to handle connection.
for _, tcpListener := range listener.tcpListeners {
go handleListener(tcpListener, listener.doneCh)
}
}
// Accept - reads from httpListener.acceptCh for one of previously accepted TCP connection and returns the same.
func (listener *httpListener) Accept() (conn net.Conn, err error) {
result, ok := <-listener.acceptCh
if ok {
return result.conn, result.err
}
return nil, syscall.EINVAL
}
// Close - closes underneath all TCP listeners.
func (listener *httpListener) Close() (err error) {
listener.mutex.Lock()
defer listener.mutex.Unlock()
if listener.doneCh == nil {
return syscall.EINVAL
}
for i := range listener.tcpListeners {
listener.tcpListeners[i].Close()
}
close(listener.doneCh)
listener.doneCh = nil
return nil
}
// Addr - net.Listener interface compatible method returns net.Addr. In case of multiple TCP listeners, it returns '0.0.0.0' as IP address.
func (listener *httpListener) Addr() (addr net.Addr) {
addr = listener.tcpListeners[0].Addr()
if len(listener.tcpListeners) == 1 {
return addr
}
tcpAddr := addr.(*net.TCPAddr)
if ip := net.ParseIP("0.0.0.0"); ip != nil {
tcpAddr.IP = ip
}
addr = tcpAddr
return addr
}
// Addrs - returns all address information of TCP listeners.
func (listener *httpListener) Addrs() (addrs []net.Addr) {
for i := range listener.tcpListeners {
addrs = append(addrs, listener.tcpListeners[i].Addr())
}
return addrs
}
// newHTTPListener - creates new httpListener object which is interface compatible to net.Listener.
// httpListener is capable to
// * listen to multiple addresses
// * controls incoming connections only doing HTTP protocol
func newHTTPListener(serverAddrs []string,
tlsConfig *tls.Config,
tcpKeepAliveTimeout time.Duration,
readTimeout time.Duration,
writeTimeout time.Duration,
updateBytesReadFunc func(int),
updateBytesWrittenFunc func(int),
errorLogFunc func(error, string, ...interface{})) (listener *httpListener, err error) {
var tcpListeners []*net.TCPListener
// Close all opened listeners on error
defer func() {
if err == nil {
return
}
for _, tcpListener := range tcpListeners {
// Ignore error on close.
tcpListener.Close()
}
}()
for _, serverAddr := range serverAddrs {
var l net.Listener
if l, err = net.Listen("tcp", serverAddr); err != nil {
return nil, err
}
tcpListener, ok := l.(*net.TCPListener)
if !ok {
return nil, fmt.Errorf("unexpected listener type found %v, expected net.TCPListener", l)
}
tcpListeners = append(tcpListeners, tcpListener)
}
listener = &httpListener{
tcpListeners: tcpListeners,
tlsConfig: tlsConfig,
tcpKeepAliveTimeout: tcpKeepAliveTimeout,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
updateBytesReadFunc: updateBytesReadFunc,
updateBytesWrittenFunc: updateBytesWrittenFunc,
errorLogFunc: errorLogFunc,
}
listener.start()
return listener, nil
}

817
pkg/http/listener_test.go Normal file
View File

@ -0,0 +1,817 @@
/*
* Minio Cloud Storage, (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package http
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/minio/minio-go/pkg/set"
)
var serverPort uint32 = 60000
func getNextPort() string {
return strconv.Itoa(int(atomic.AddUint32(&serverPort, 1)))
}
func getTLSCert() (tls.Certificate, error) {
keyPEMBlock := []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEApEkbPrT6wzcWK1W5atQiGptvuBsRdf8MCg4u6SN10QbslA5k
6BYRdZfFeRpwAwYyzkumug6+eBJatDZEd7+0FF86yxB7eMTSiHKRZ5Mi5ZyCFsez
dndknGBeK6I80s1jd5ZsLLuMKErvbNwSbfX+X6d2mBeYW8Scv9N+qYnNrHHHohvX
oxy1gZ18EhhogQhrD22zaqg/jtmOT8ImUiXzB1mKInt2LlSkoRYuBzepkDJrsE1L
/cyYZbtcO/ASDj+/qQAuQ66v9pNyJkIQ7bDOUyxaT5Hx9XvbqI1OqUVAdGLLi+eZ
IFguFyYd0lemwdN/IDvxftzegTO3cO0D28d1UQIDAQABAoIBAB42x8j3lerTNcOQ
h4JLM157WcedSs/NsVQkGaKM//0KbfYo04wPivR6jjngj9suh6eDKE2tqoAAuCfO
lzcCzca1YOW5yUuDv0iS8YT//XoHF7HC1pGiEaHk40zZEKCgX3u98XUkpPlAFtqJ
euY4SKkk7l24cS/ncACjj/b0PhxJoT/CncuaaJKqmCc+vdL4wj1UcrSNPZqRjDR/
sh5DO0LblB0XrqVjpNxqxM60/IkbftB8YTnyGgtO2tbTPr8KdQ8DhHQniOp+WEPV
u/iXt0LLM7u62LzadkGab2NDWS3agnmdvw2ADtv5Tt8fZ7WnPqiOpNyD5Bv1a3/h
YBw5HsUCgYEA0Sfv6BiSAFEby2KusRoq5UeUjp/SfL7vwpO1KvXeiYkPBh2XYVq2
azMnOw7Rz5ixFhtUtto2XhYdyvvr3dZu1fNHtxWo9ITBivqTGGRNwfiaQa58Bugo
gy7vCdIE/f6xE5LYIovBnES2vs/ZayMyhTX84SCWd0pTY0kdDA8ePGsCgYEAyRSA
OTzX43KUR1G/trpuM6VBc0W6YUNYzGRa1TcUxBP4K7DfKMpPGg6ulqypfoHmu8QD
L+z+iQmG9ySSuvScIW6u8LgkrTwZga8y2eb/A2FAVYY/bnelef1aMkis+bBX2OQ4
QAg2uq+pkhpW1k5NSS9lVCPkj4e5Ur9RCm9fRDMCgYAf3CSIR03eLHy+Y37WzXSh
TmELxL6sb+1Xx2Y+cAuBCda3CMTpeIb3F2ivb1d4dvrqsikaXW0Qse/B3tQUC7kA
cDmJYwxEiwBsajUD7yuFE5hzzt9nse+R5BFXfp1yD1zr7V9tC7rnUfRAZqrozgjB
D/NAW9VvwGupYRbCon7plwKBgQCRPfeoYGRoa9ji8w+Rg3QaZeGyy8jmfGjlqg9a
NyEOyIXXuThYFFmyrqw5NZpwQJBTTDApK/xnK7SLS6WY2Rr1oydFxRzo7KJX5B7M
+md1H4gCvqeOuWmThgbij1AyQsgRaDehOM2fZ0cKu2/B+Gkm1c9RSWPMsPKR7JMz
AGNFtQKBgQCRCFIdGJHnvz35vJfLoihifCejBWtZbAnZoBHpF3xMCtV755J96tUf
k1Tv9hz6WfSkOSlwLq6eGZY2dCENJRW1ft1UelpFvCjbfrfLvoFFLs3gu0lfqXHi
CS6fjhn9Ahvz10yD6fd4ixRUjoJvULzI0Sxc1O95SYVF1lIAuVr9Hw==
-----END RSA PRIVATE KEY-----`)
certPEMBlock := []byte(`-----BEGIN CERTIFICATE-----
MIIDXTCCAkWgAwIBAgIJAKlqK5HKlo9MMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTcwNjE5MTA0MzEyWhcNMjcwNjE3MTA0MzEyWjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEApEkbPrT6wzcWK1W5atQiGptvuBsRdf8MCg4u6SN10QbslA5k6BYRdZfF
eRpwAwYyzkumug6+eBJatDZEd7+0FF86yxB7eMTSiHKRZ5Mi5ZyCFsezdndknGBe
K6I80s1jd5ZsLLuMKErvbNwSbfX+X6d2mBeYW8Scv9N+qYnNrHHHohvXoxy1gZ18
EhhogQhrD22zaqg/jtmOT8ImUiXzB1mKInt2LlSkoRYuBzepkDJrsE1L/cyYZbtc
O/ASDj+/qQAuQ66v9pNyJkIQ7bDOUyxaT5Hx9XvbqI1OqUVAdGLLi+eZIFguFyYd
0lemwdN/IDvxftzegTO3cO0D28d1UQIDAQABo1AwTjAdBgNVHQ4EFgQUqMVdMIA1
68Dv+iwGugAaEGUSd0IwHwYDVR0jBBgwFoAUqMVdMIA168Dv+iwGugAaEGUSd0Iw
DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAjQVoqRv2HlE5PJIX/qk5
oMOKZlHTyJP+s2HzOOVt+eCE/jNdfC7+8R/HcPldQs7p9GqH2F6hQ9aOtDhJVEaU
pjxCi4qKeZ1kWwqv8UMBXW92eHGysBvE2Gmm/B1JFl8S2GR5fBmheZVnYW893MoI
gp+bOoCcIuMJRqCra4vJgrOsQjgRElQvd2OlP8qQzInf/fRqO/AnZPwMkGr3+KZ0
BKEOXtmSZaPs3xEsnvJd8wrTgA0NQK7v48E+gHSXzQtaHmOLqisRXlUOu2r1gNCJ
rr3DRiUP6V/10CZ/ImeSJ72k69VuTw9vq2HzB4x6pqxF2X7JQSLUCS2wfNN13N0d
9A==
-----END CERTIFICATE-----`)
return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
}
func getTLSConfig(t *testing.T) *tls.Config {
tlsCert, err := getTLSCert()
if err != nil {
t.Fatalf("Unable to parse private/certificate data. %v\n", err)
}
tlsConfig := &tls.Config{
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
NextProtos: []string{"http/1.1", "h2"},
}
tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert)
return tlsConfig
}
func getNonLoopBackIP(t *testing.T) string {
localIP4 := set.NewStringSet()
addrs, err := net.InterfaceAddrs()
if err != nil {
t.Fatalf("%s. Unable to get IP addresses of this host.", err)
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip.To4() != nil {
localIP4.Add(ip.String())
}
}
// Filter ipList by IPs those do not start with '127.'.
nonLoopBackIPs := localIP4.FuncMatch(func(ip string, matchString string) bool {
return !strings.HasPrefix(ip, "127.")
}, "")
if len(nonLoopBackIPs) == 0 {
t.Fatalf("No non-loop back IP address found for this host")
}
nonLoopBackIP := nonLoopBackIPs.ToSlice()[0]
return nonLoopBackIP
}
// Test getMethodMaxLen()
func TestGetMethodMaxLen(t *testing.T) {
l := getMethodMaxLen()
if l != (methodMaxLen - 1) {
t.Fatalf("expected: %v, got: %v", (methodMaxLen - 1), l)
}
}
// Test isHTTPMethod()
func TestIsHTTPMethod(t *testing.T) {
testCases := []struct {
method string
expectedResult bool
}{
{"", false},
{"get", false},
{"put", false},
{"UPLOAD", false},
{"OPTIONS", true},
{"GET", true},
{"HEAD", true},
{"POST", true},
{"PUT", true},
{"DELETE", true},
{"TRACE", true},
{"CONNECT", true},
{"PRI", true},
}
for _, testCase := range testCases {
result := isHTTPMethod(testCase.method)
if result != testCase.expectedResult {
t.Fatalf("expected: %v, got: %v", testCase.expectedResult, result)
}
}
}
func TestNewHTTPListener(t *testing.T) {
errMsg := ": no such host"
if runtime.GOOS == "windows" {
errMsg = ": No such host is known."
}
remoteAddrErrMsg := "listen tcp 93.184.216.34:9000: bind: cannot assign requested address"
if runtime.GOOS == "windows" {
remoteAddrErrMsg = "listen tcp 93.184.216.34:9000: bind: The requested address is not valid in its context."
}
tlsConfig := getTLSConfig(t)
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
tcpKeepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
updateBytesReadFunc func(int)
updateBytesWrittenFunc func(int)
errorLogFunc func(error, string, ...interface{})
expectedErr error
}{
{[]string{"93.184.216.34:9000"}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, errors.New(remoteAddrErrMsg)},
{[]string{"example.org:9000"}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, errors.New(remoteAddrErrMsg)},
{[]string{"unknown-host"}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, errors.New("listen tcp: missing port in address unknown-host")},
{[]string{"unknown-host:9000"}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, errors.New("listen tcp: lookup unknown-host" + errMsg)},
{[]string{"localhost:9000", "93.184.216.34:9000"}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, errors.New(remoteAddrErrMsg)},
{[]string{"localhost:9000", "unknown-host:9000"}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, errors.New("listen tcp: lookup unknown-host" + errMsg)},
{[]string{"localhost:" + getNextPort()}, nil, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, nil},
{[]string{"localhost:" + getNextPort()}, tlsConfig, time.Duration(0), time.Duration(0), time.Duration(0), nil, nil, nil, nil},
}
for _, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
testCase.tcpKeepAliveTimeout,
testCase.readTimeout,
testCase.writeTimeout,
testCase.updateBytesReadFunc,
testCase.updateBytesWrittenFunc,
testCase.errorLogFunc,
)
if testCase.expectedErr == nil {
if err != nil {
t.Fatalf("error: expected = <nil>, got = %v", err)
}
} else if err == nil {
t.Fatalf("error: expected = %v, got = <nil>", testCase.expectedErr)
} else {
var match bool
if strings.HasSuffix(testCase.expectedErr.Error(), errMsg) {
match = strings.HasSuffix(err.Error(), errMsg)
} else {
match = (testCase.expectedErr.Error() == err.Error())
}
if !match {
t.Fatalf("error: expected = %v, got = %v", testCase.expectedErr, err)
}
}
if err == nil {
listener.Close()
}
}
}
func TestHTTPListenerStartClose(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 6; i++ {
casePorts = append(casePorts, getNextPort())
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
}{
{[]string{"localhost:" + casePorts[0]}, nil},
{[]string{nonLoopBackIP + ":" + casePorts[1]}, nil},
{[]string{"127.0.0.1:" + casePorts[2], nonLoopBackIP + ":" + casePorts[2]}, nil},
{[]string{"localhost:" + casePorts[3]}, tlsConfig},
{[]string{nonLoopBackIP + ":" + casePorts[4]}, tlsConfig},
{[]string{"127.0.0.1:" + casePorts[5], nonLoopBackIP + ":" + casePorts[5]}, tlsConfig},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
nil,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
for _, serverAddr := range testCase.serverAddrs {
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
conn.Close()
}
listener.Close()
}
}
func TestHTTPListenerAddr(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 6; i++ {
casePorts = append(casePorts, getNextPort())
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
expectedAddr string
}{
{[]string{"localhost:" + casePorts[0]}, nil, "127.0.0.1:" + casePorts[0]},
{[]string{nonLoopBackIP + ":" + casePorts[1]}, nil, nonLoopBackIP + ":" + casePorts[1]},
{[]string{"127.0.0.1:" + casePorts[2], nonLoopBackIP + ":" + casePorts[2]}, nil, "0.0.0.0:" + casePorts[2]},
{[]string{"localhost:" + casePorts[3]}, tlsConfig, "127.0.0.1:" + casePorts[3]},
{[]string{nonLoopBackIP + ":" + casePorts[4]}, tlsConfig, nonLoopBackIP + ":" + casePorts[4]},
{[]string{"127.0.0.1:" + casePorts[5], nonLoopBackIP + ":" + casePorts[5]}, tlsConfig, "0.0.0.0:" + casePorts[5]},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
nil,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
addr := listener.Addr()
if addr.String() != testCase.expectedAddr {
t.Fatalf("Test %d: addr: expected = %v, got = %v", i+1, testCase.expectedAddr, addr)
}
listener.Close()
}
}
func TestHTTPListenerAddrs(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 6; i++ {
casePorts = append(casePorts, getNextPort())
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
expectedAddrs set.StringSet
}{
{[]string{"localhost:" + casePorts[0]}, nil, set.CreateStringSet("127.0.0.1:" + casePorts[0])},
{[]string{nonLoopBackIP + ":" + casePorts[1]}, nil, set.CreateStringSet(nonLoopBackIP + ":" + casePorts[1])},
{[]string{"127.0.0.1:" + casePorts[2], nonLoopBackIP + ":" + casePorts[2]}, nil, set.CreateStringSet("127.0.0.1:"+casePorts[2], nonLoopBackIP+":"+casePorts[2])},
{[]string{"localhost:" + casePorts[3]}, tlsConfig, set.CreateStringSet("127.0.0.1:" + casePorts[3])},
{[]string{nonLoopBackIP + ":" + casePorts[4]}, tlsConfig, set.CreateStringSet(nonLoopBackIP + ":" + casePorts[4])},
{[]string{"127.0.0.1:" + casePorts[5], nonLoopBackIP + ":" + casePorts[5]}, tlsConfig, set.CreateStringSet("127.0.0.1:"+casePorts[5], nonLoopBackIP+":"+casePorts[5])},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
nil,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
addrs := listener.Addrs()
addrSet := set.NewStringSet()
for _, addr := range addrs {
addrSet.Add(addr.String())
}
if !addrSet.Equals(testCase.expectedAddrs) {
t.Fatalf("Test %d: addr: expected = %v, got = %v", i+1, testCase.expectedAddrs, addrs)
}
listener.Close()
}
}
func TestHTTPListenerAccept(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 6; i++ {
casePorts = append(casePorts, getNextPort())
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
request string
reply string
}{
{[]string{"localhost:" + casePorts[0]}, nil, "GET / HTTP/1.0\n", "200 OK\n"},
{[]string{nonLoopBackIP + ":" + casePorts[1]}, nil, "POST / HTTP/1.0\n", "200 OK\n"},
{[]string{"127.0.0.1:" + casePorts[2], nonLoopBackIP + ":" + casePorts[2]}, nil, "CONNECT \n", "200 OK\n"},
{[]string{"localhost:" + casePorts[3]}, tlsConfig, "GET / HTTP/1.0\n", "200 OK\n"},
{[]string{nonLoopBackIP + ":" + casePorts[4]}, tlsConfig, "POST / HTTP/1.0\n", "200 OK\n"},
{[]string{"127.0.0.1:" + casePorts[5], nonLoopBackIP + ":" + casePorts[5]}, tlsConfig, "CONNECT \n", "200 OK\n"},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
nil,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
for _, serverAddr := range testCase.serverAddrs {
var conn net.Conn
var err error
if testCase.tlsConfig == nil {
conn, err = net.Dial("tcp", serverAddr)
} else {
conn, err = tls.Dial("tcp", serverAddr, &tls.Config{InsecureSkipVerify: true})
}
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
if _, err = io.WriteString(conn, testCase.request); err != nil {
t.Fatalf("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
serverConn, err := listener.Accept()
if err != nil {
t.Fatalf("Test %d: accept: expected = <nil>, got = %v", i+1, err)
}
request, err := bufio.NewReader(serverConn).ReadString('\n')
if err != nil {
t.Fatalf("Test %d: request read: expected = <nil>, got = %v", i+1, err)
}
if testCase.request != request {
t.Fatalf("Test %d: request: expected = %v, got = %v", i+1, testCase.request, request)
}
if _, err = io.WriteString(serverConn, testCase.reply); err != nil {
t.Fatalf("Test %d: reply send: expected = <nil>, got = %v", i+1, err)
}
reply, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
t.Fatalf("Test %d: reply read: expected = <nil>, got = %v", i+1, err)
}
if testCase.reply != reply {
t.Fatalf("Test %d: reply: expected = %v, got = %v", i+1, testCase.reply, reply)
}
serverConn.Close()
conn.Close()
}
listener.Close()
}
}
func TestHTTPListenerAcceptPeekError(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 2; i++ {
casePorts = append(casePorts, getNextPort())
}
errorFunc := func(err error, template string, args ...interface{}) {
msg := fmt.Sprintf("error: %v. ", err)
msg += fmt.Sprintf(template, args...)
fmt.Println(msg)
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
request string
}{
{[]string{"127.0.0.1:" + casePorts[0], nonLoopBackIP + ":" + casePorts[0]}, nil, "CONN"},
{[]string{"127.0.0.1:" + casePorts[1], nonLoopBackIP + ":" + casePorts[1]}, tlsConfig, "CONN"},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
errorFunc,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
go func() {
serverConn, aerr := listener.Accept()
if aerr == nil {
t.Fatalf("Test %d: accept: expected = <error>, got = <nil>", i+1)
}
if serverConn != nil {
t.Fatalf("Test %d: accept: server expected = <nil>, got = %v", i+1, serverConn)
}
}()
for _, serverAddr := range testCase.serverAddrs {
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
if _, err = io.WriteString(conn, testCase.request); err != nil {
t.Fatalf("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
conn.Close()
}
listener.Close()
}
}
func TestHTTPListenerAcceptTLSError(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 1; i++ {
casePorts = append(casePorts, getNextPort())
}
errorFunc := func(err error, template string, args ...interface{}) {
msg := fmt.Sprintf("error: %v. ", err)
msg += fmt.Sprintf(template, args...)
fmt.Println(msg)
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
request string
}{
{[]string{"127.0.0.1:" + casePorts[0], nonLoopBackIP + ":" + casePorts[0]}, tlsConfig, "GET / HTTP/1.0\n"},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
errorFunc,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
for _, serverAddr := range testCase.serverAddrs {
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
if _, err = io.WriteString(conn, testCase.request); err != nil {
t.Fatalf("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
go func() {
serverConn, aerr := listener.Accept()
if aerr == nil {
t.Fatalf("Test %d: accept: expected = <error>, got = <nil>", i+1)
}
if serverConn != nil {
t.Fatalf("Test %d: accept: server expected = <nil>, got = %v", i+1, serverConn)
}
}()
buf := make([]byte, len(sslRequiredErrMsg))
n, err := io.ReadFull(conn, buf)
if err != nil {
t.Fatalf("Test %d: reply read: expected = <nil> got = %v", i+1, err)
} else if n != len(buf) {
t.Fatalf("Test %d: reply length: expected = %v got = %v", i+1, len(buf), n)
} else if !bytes.Equal(buf, sslRequiredErrMsg) {
t.Fatalf("Test %d: reply: expected = %v got = %v", i+1, string(sslRequiredErrMsg), string(buf))
}
conn.Close()
}
listener.Close()
}
}
func TestHTTPListenerAcceptError(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
var casePorts []string
for i := 0; i < 3; i++ {
casePorts = append(casePorts, getNextPort())
}
errorFunc := func(err error, template string, args ...interface{}) {
msg := fmt.Sprintf("error: %v. ", err)
msg += fmt.Sprintf(template, args...)
fmt.Println(msg)
}
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
secureClient bool
request string
}{
{[]string{"127.0.0.1:" + casePorts[0], nonLoopBackIP + ":" + casePorts[0]}, nil, false, "CONNECTION"},
{[]string{"127.0.0.1:" + casePorts[1], nonLoopBackIP + ":" + casePorts[1]}, tlsConfig, false, "CONNECTION"},
{[]string{"127.0.0.1:" + casePorts[2], nonLoopBackIP + ":" + casePorts[2]}, tlsConfig, true, "CONNECTION"},
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
errorFunc,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
for _, serverAddr := range testCase.serverAddrs {
var conn net.Conn
var err error
if testCase.secureClient {
conn, err = tls.Dial("tcp", serverAddr, &tls.Config{InsecureSkipVerify: true})
} else {
conn, err = net.Dial("tcp", serverAddr)
}
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
if _, err = io.WriteString(conn, testCase.request); err != nil {
t.Fatalf("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
go func() {
serverConn, aerr := listener.Accept()
if aerr == nil {
t.Fatalf("Test %d: accept: expected = <error>, got = <nil>", i+1)
}
if serverConn != nil {
t.Fatalf("Test %d: accept: server expected = <nil>, got = %v", i+1, serverConn)
}
}()
_, err = bufio.NewReader(conn).ReadString('\n')
if err == nil {
t.Fatalf("Test %d: reply read: expected = EOF got = <nil>", i+1)
} else if err.Error() != "EOF" {
t.Fatalf("Test %d: reply read: expected = EOF got = %v", i+1, err)
}
conn.Close()
}
listener.Close()
}
}
func TestHTTPListenerAcceptParallel(t *testing.T) {
tlsConfig := getTLSConfig(t)
nonLoopBackIP := getNonLoopBackIP(t)
case1Port := getNextPort()
case2Port := getNextPort()
testCases := []struct {
serverAddrs []string
tlsConfig *tls.Config
reply string
}{
{[]string{"127.0.0.1:" + case1Port, nonLoopBackIP + ":" + case1Port}, nil, "200 OK\n"},
{[]string{"127.0.0.1:" + case2Port, nonLoopBackIP + ":" + case2Port}, tlsConfig, "200 OK\n"},
}
// As t.Fatalf() is not goroutine safe, use this closure.
fail := func(template string, args ...interface{}) {
fmt.Printf(template, args...)
fmt.Println()
t.Fail()
}
connect := func(i int, serverAddr string, secure bool, delay bool, request, reply string) {
var conn net.Conn
var err error
if secure {
conn, err = tls.Dial("tcp", serverAddr, &tls.Config{InsecureSkipVerify: true})
} else {
conn, err = net.Dial("tcp", serverAddr)
}
if err != nil {
fail("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
if delay {
if _, err = io.WriteString(conn, request[:3]); err != nil {
fail("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
time.Sleep(1 * time.Second)
if _, err = io.WriteString(conn, request[3:]); err != nil {
fail("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
} else {
if _, err = io.WriteString(conn, request); err != nil {
fail("Test %d: request send: expected = <nil>, got = %v", i+1, err)
}
}
received, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
fail("Test %d: reply read: expected = <nil>, got = %v", i+1, err)
}
if received != reply {
fail("Test %d: reply: expected = %v, got = %v", i+1, reply, received)
}
conn.Close()
}
handleConnection := func(i int, wg *sync.WaitGroup, serverConn net.Conn, request, reply string) {
wg.Add(1)
defer wg.Done()
received, err := bufio.NewReader(serverConn).ReadString('\n')
if err != nil {
fail("Test %d: request read: expected = <nil>, got = %v", i+1, err)
}
if received != request {
fail("Test %d: request: expected = %v, got = %v", i+1, request, received)
}
if _, err := io.WriteString(serverConn, reply); err != nil {
fail("Test %d: reply send: expected = <nil>, got = %v", i+1, err)
}
serverConn.Close()
}
for i, testCase := range testCases {
listener, err := newHTTPListener(
testCase.serverAddrs,
testCase.tlsConfig,
time.Duration(0),
time.Duration(0),
time.Duration(0),
nil,
nil,
nil,
)
if err != nil {
t.Fatalf("Test %d: error: expected = <nil>, got = %v", i+1, err)
}
for _, serverAddr := range testCase.serverAddrs {
go connect(i, serverAddr, testCase.tlsConfig != nil, true, "GET /1 HTTP/1.0\n", testCase.reply)
go connect(i, serverAddr, testCase.tlsConfig != nil, false, "GET /2 HTTP/1.0\n", testCase.reply)
var wg sync.WaitGroup
serverConn, err := listener.Accept()
if err != nil {
t.Fatalf("Test %d: accept: expected = <nil>, got = %v", i+1, err)
}
go handleConnection(i, &wg, serverConn, "GET /2 HTTP/1.0\n", testCase.reply)
serverConn, err = listener.Accept()
if err != nil {
t.Fatalf("Test %d: accept: expected = <nil>, got = %v", i+1, err)
}
go handleConnection(i, &wg, serverConn, "GET /1 HTTP/1.0\n", testCase.reply)
wg.Wait()
}
listener.Close()
}
}

173
pkg/http/server.go Normal file
View File

@ -0,0 +1,173 @@
/*
* Minio Cloud Storage, (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package http
import (
"crypto/tls"
"errors"
"net/http"
"sync"
"sync/atomic"
"time"
humanize "github.com/dustin/go-humanize"
)
const (
serverShutdownPoll = 500 * time.Millisecond
// DefaultShutdownTimeout - default shutdown timeout used for graceful http server shutdown.
DefaultShutdownTimeout = 5 * time.Second
// DefaultTCPKeepAliveTimeout - default TCP keep alive timeout for accepted connection.
DefaultTCPKeepAliveTimeout = 10 * time.Second
// DefaultReadTimeout - default timout to read data from accepted connection.
DefaultReadTimeout = 30 * time.Second
// DefaultWriteTimeout - default timout to write data to accepted connection.
DefaultWriteTimeout = 30 * time.Second
// DefaultMaxHeaderBytes - default maximum HTTP header size in bytes.
DefaultMaxHeaderBytes = 1 * humanize.MiByte
)
// Server - extended http.Server supports multiple addresses to serve and enhanced connection handling.
type Server struct {
http.Server
Addrs []string // addresses on which the server listens for new connection.
ShutdownTimeout time.Duration // timeout used for graceful server shutdown.
TCPKeepAliveTimeout time.Duration // timeout used for underneath TCP connection.
UpdateBytesReadFunc func(int) // function to be called to update bytes read in bufConn.
UpdateBytesWrittenFunc func(int) // function to be called to update bytes written in bufConn.
ErrorLogFunc func(error, string, ...interface{}) // function to be called on errors.
listenerMutex *sync.Mutex // to guard 'listener' field.
listener *httpListener // HTTP listener for all 'Addrs' field.
inShutdown uint32 // indicates whether the server is in shutdown or not
requestCount int32 // counter holds no. of request in process.
}
// Start - start HTTP server
func (srv *Server) Start() (err error) {
// Take a copy of server fields.
tlsConfig := srv.TLSConfig
readTimeout := srv.ReadTimeout
writeTimeout := srv.WriteTimeout
handler := srv.Handler
addrs := srv.Addrs
tcpKeepAliveTimeout := srv.TCPKeepAliveTimeout
updateBytesReadFunc := srv.UpdateBytesReadFunc
updateBytesWrittenFunc := srv.UpdateBytesWrittenFunc
errorLogFunc := srv.ErrorLogFunc
// Create new HTTP listener.
var listener *httpListener
listener, err = newHTTPListener(
addrs,
tlsConfig,
tcpKeepAliveTimeout,
readTimeout,
writeTimeout,
updateBytesReadFunc,
updateBytesWrittenFunc,
errorLogFunc,
)
if err != nil {
return err
}
// Wrap given handler to do additional
// * return 503 (service unavailable) if the server in shutdown.
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&srv.requestCount, 1)
defer atomic.AddInt32(&srv.requestCount, -1)
// If server is in shutdown, return 503 (service unavailable)
if atomic.LoadUint32(&srv.inShutdown) != 0 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// Handle request using passed handler.
handler.ServeHTTP(w, r)
})
srv.listenerMutex.Lock()
srv.Handler = wrappedHandler
srv.listener = listener
srv.listenerMutex.Unlock()
// Start servicing with listener.
return srv.Server.Serve(listener)
}
// Shutdown - shuts down HTTP server.
func (srv *Server) Shutdown() error {
if atomic.AddUint32(&srv.inShutdown, 1) > 1 {
// shutdown in progress
return errors.New("http server already in shutdown")
}
// Close underneath HTTP listener.
srv.listenerMutex.Lock()
err := srv.listener.Close()
srv.listenerMutex.Unlock()
// Wait for opened connection to be closed up to Shutdown timeout.
shutdownTimeout := srv.ShutdownTimeout
shutdownTimer := time.NewTimer(shutdownTimeout)
ticker := time.NewTicker(serverShutdownPoll)
defer ticker.Stop()
for {
select {
case <-shutdownTimer.C:
return errors.New("timed out. some connections are still active. doing abnormal shutdown")
case <-ticker.C:
if atomic.LoadInt32(&srv.requestCount) <= 0 {
return err
}
}
}
}
// NewServer - creates new HTTP server using given arguments.
func NewServer(addrs []string, handler http.Handler, certificate *tls.Certificate) *Server {
var tlsConfig *tls.Config
if certificate != nil {
tlsConfig = &tls.Config{
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
NextProtos: []string{"http/1.1", "h2"},
}
tlsConfig.Certificates = append(tlsConfig.Certificates, *certificate)
}
httpServer := &Server{
Addrs: addrs,
ShutdownTimeout: DefaultShutdownTimeout,
TCPKeepAliveTimeout: DefaultTCPKeepAliveTimeout,
listenerMutex: &sync.Mutex{},
}
httpServer.Handler = handler
httpServer.TLSConfig = tlsConfig
httpServer.ReadTimeout = DefaultReadTimeout
httpServer.WriteTimeout = DefaultWriteTimeout
httpServer.MaxHeaderBytes = DefaultMaxHeaderBytes
return httpServer
}

99
pkg/http/server_test.go Normal file
View File

@ -0,0 +1,99 @@
/*
* Minio Cloud Storage, (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package http
import (
"crypto/tls"
"fmt"
"net/http"
"reflect"
"testing"
)
func TestNewServer(t *testing.T) {
nonLoopBackIP := getNonLoopBackIP(t)
certificate, err := getTLSCert()
if err != nil {
t.Fatalf("Unable to parse private/certificate data. %v\n", err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, world")
})
testCases := []struct {
addrs []string
handler http.Handler
certificate *tls.Certificate
}{
{[]string{"127.0.0.1:9000"}, handler, nil},
{[]string{nonLoopBackIP + ":9000"}, handler, nil},
{[]string{"127.0.0.1:9000", nonLoopBackIP + ":9000"}, handler, nil},
{[]string{"127.0.0.1:9000"}, handler, &certificate},
{[]string{nonLoopBackIP + ":9000"}, handler, &certificate},
{[]string{"127.0.0.1:9000", nonLoopBackIP + ":9000"}, handler, &certificate},
}
for i, testCase := range testCases {
server := NewServer(testCase.addrs, testCase.handler, testCase.certificate)
if server == nil {
t.Fatalf("Case %v: server: expected: <non-nil>, got: <nil>", (i + 1))
}
if !reflect.DeepEqual(server.Addrs, testCase.addrs) {
t.Fatalf("Case %v: server.Addrs: expected: %v, got: %v", (i + 1), testCase.addrs, server.Addrs)
}
// Interfaces are not comparable even with reflection.
// if !reflect.DeepEqual(server.Handler, testCase.handler) {
// t.Fatalf("Case %v: server.Handler: expected: %v, got: %v", (i + 1), testCase.handler, server.Handler)
// }
if testCase.certificate == nil {
if server.TLSConfig != nil {
t.Fatalf("Case %v: server.TLSConfig: expected: <nil>, got: %v", (i + 1), server.TLSConfig)
}
} else {
if server.TLSConfig == nil {
t.Fatalf("Case %v: server.TLSConfig: expected: <non-nil>, got: <nil>", (i + 1))
}
}
if server.ShutdownTimeout != DefaultShutdownTimeout {
t.Fatalf("Case %v: server.ShutdownTimeout: expected: %v, got: %v", (i + 1), DefaultShutdownTimeout, server.ShutdownTimeout)
}
if server.TCPKeepAliveTimeout != DefaultTCPKeepAliveTimeout {
t.Fatalf("Case %v: server.TCPKeepAliveTimeout: expected: %v, got: %v", (i + 1), DefaultTCPKeepAliveTimeout, server.TCPKeepAliveTimeout)
}
if server.listenerMutex == nil {
t.Fatalf("Case %v: server.listenerMutex: expected: <non-nil>, got: <nil>", (i + 1))
}
if server.ReadTimeout != DefaultReadTimeout {
t.Fatalf("Case %v: server.ReadTimeout: expected: %v, got: %v", (i + 1), DefaultReadTimeout, server.ReadTimeout)
}
if server.WriteTimeout != DefaultWriteTimeout {
t.Fatalf("Case %v: server.WriteTimeout: expected: %v, got: %v", (i + 1), DefaultWriteTimeout, server.WriteTimeout)
}
if server.MaxHeaderBytes != DefaultMaxHeaderBytes {
t.Fatalf("Case %v: server.MaxHeaderBytes: expected: %v, got: %v", (i + 1), DefaultMaxHeaderBytes, server.MaxHeaderBytes)
}
}
}