rpcclient: fix leaky connection (#3471)

Previously, more than one goroutine calls RPCClient.dial(), each
goroutine gets a new rpc.Client but only one such client is stored
into RPCClient object.  This leads to leaky connection at the server
side.  This is fixed by taking lock at top of dial() and release on
return.
This commit is contained in:
Bala FA 2016-12-18 07:47:40 +05:30 committed by Harshavardhana
parent 9c9f390350
commit 1875a47495
5 changed files with 51 additions and 69 deletions

View File

@ -111,7 +111,7 @@ func newAuthClient(cfg *authConfig) *AuthRPCClient {
// Save the config. // Save the config.
config: cfg, config: cfg,
// Initialize a new reconnectable rpc client. // Initialize a new reconnectable rpc client.
rpc: newClient(cfg.address, cfg.path, cfg.secureConn), rpc: newRPCClient(cfg.address, cfg.path, cfg.secureConn),
// Allocated auth client not logged in yet. // Allocated auth client not logged in yet.
isLoggedIn: false, isLoggedIn: false,
} }

View File

@ -70,7 +70,7 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) {
// Validate for invalid token. // Validate for invalid token.
args := SetAuthPeerArgs{Creds: creds} args := SetAuthPeerArgs{Creds: creds}
args.Token = "garbage" args.Token = "garbage"
rclient := newClient(s.testAuthConf.address, s.testAuthConf.path, false) rclient := newRPCClient(s.testAuthConf.address, s.testAuthConf.path, false)
defer rclient.Close() defer rclient.Close()
err := rclient.Call("BrowserPeer.SetAuthPeer", &args, &GenericReply{}) err := rclient.Call("BrowserPeer.SetAuthPeer", &args, &GenericReply{})
if err != nil { if err != nil {
@ -89,7 +89,7 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) {
} }
// Validate for failure in login handler with previous credentials. // Validate for failure in login handler with previous credentials.
rclient = newClient(s.testAuthConf.address, s.testAuthConf.path, false) rclient = newRPCClient(s.testAuthConf.address, s.testAuthConf.path, false)
defer rclient.Close() defer rclient.Close()
rargs := &RPCLoginArgs{ rargs := &RPCLoginArgs{
Username: s.testAuthConf.accessKey, Username: s.testAuthConf.accessKey,

View File

@ -279,7 +279,7 @@ func (l *lockServer) lockMaintenance(interval time.Duration) {
// Validate if long lived locks are indeed clean. // Validate if long lived locks are indeed clean.
for _, nlrip := range nlripLongLived { for _, nlrip := range nlripLongLived {
// Initialize client based on the long live locks. // Initialize client based on the long live locks.
c := newClient(nlrip.lri.node, nlrip.lri.rpcPath, isSSL()) c := newRPCClient(nlrip.lri.node, nlrip.lri.rpcPath, isSSL())
var expired bool var expired bool

View File

@ -30,19 +30,21 @@ import (
"time" "time"
) )
// defaultDialTimeout is used for non-secure connection.
const defaultDialTimeout = 3 * time.Second
// RPCClient is a wrapper type for rpc.Client which provides reconnect on first failure. // RPCClient is a wrapper type for rpc.Client which provides reconnect on first failure.
type RPCClient struct { type RPCClient struct {
mu sync.Mutex mu sync.Mutex
rpcPrivate *rpc.Client netRPCClient *rpc.Client
node string node string
rpcPath string rpcPath string
secureConn bool secureConn bool
} }
// newClient constructs a RPCClient object with node and rpcPath initialized. // newClient constructs a RPCClient object with node and rpcPath initialized.
// It _doesn't_ connect to the remote endpoint. See Call method to see when the // It does lazy connect to the remote endpoint on Call().
// connect happens. func newRPCClient(node, rpcPath string, secureConn bool) *RPCClient {
func newClient(node, rpcPath string, secureConn bool) *RPCClient {
return &RPCClient{ return &RPCClient{
node: node, node: node,
rpcPath: rpcPath, rpcPath: rpcPath,
@ -50,34 +52,19 @@ func newClient(node, rpcPath string, secureConn bool) *RPCClient {
} }
} }
// clearRPCClient clears the pointer to the rpc.Client object in a safe manner // dial tries to establish a connection to the server in a safe manner.
func (rpcClient *RPCClient) clearRPCClient() { // If there is a valid rpc.Cliemt, it returns that else creates a new one.
func (rpcClient *RPCClient) dial() (*rpc.Client, error) {
rpcClient.mu.Lock() rpcClient.mu.Lock()
rpcClient.rpcPrivate = nil defer rpcClient.mu.Unlock()
rpcClient.mu.Unlock()
}
// getRPCClient gets the pointer to the rpc.Client object in a safe manner // Nothing to do as we already have valid connection.
func (rpcClient *RPCClient) getRPCClient() *rpc.Client { if rpcClient.netRPCClient != nil {
rpcClient.mu.Lock() return rpcClient.netRPCClient, nil
rpcLocalStack := rpcClient.rpcPrivate
rpcClient.mu.Unlock()
return rpcLocalStack
}
// dialRPCClient tries to establish a connection to the server in a safe manner
func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) {
rpcClient.mu.Lock()
// After acquiring lock, check whether another thread may not have already dialed and established connection
if rpcClient.rpcPrivate != nil {
rpcClient.mu.Unlock()
return rpcClient.rpcPrivate, nil
} }
rpcClient.mu.Unlock()
var err error var err error
var conn net.Conn var conn net.Conn
if rpcClient.secureConn { if rpcClient.secureConn {
hostname, _, splitErr := net.SplitHostPort(rpcClient.node) hostname, _, splitErr := net.SplitHostPort(rpcClient.node)
if splitErr != nil { if splitErr != nil {
@ -92,14 +79,14 @@ func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) {
// ServerName in tls.Config needs to be specified to support SNI certificates // ServerName in tls.Config needs to be specified to support SNI certificates
conn, err = tls.Dial("tcp", rpcClient.node, &tls.Config{ServerName: hostname, RootCAs: globalRootCAs}) conn, err = tls.Dial("tcp", rpcClient.node, &tls.Config{ServerName: hostname, RootCAs: globalRootCAs})
} else { } else {
// Have a dial timeout with 3 secs. // Dial with 3 seconds timeout.
conn, err = net.DialTimeout("tcp", rpcClient.node, 3*time.Second) conn, err = net.DialTimeout("tcp", rpcClient.node, defaultDialTimeout)
} }
if err != nil { if err != nil {
// Print RPC connection errors that are worthy to display in log // Print RPC connection errors that are worthy to display in log
switch err.(type) { switch err.(type) {
case x509.HostnameError: case x509.HostnameError:
errorIf(err, "Unable to establish RPC to %s", rpcClient.node) errorIf(err, "Unable to establish secure connection to %s", rpcClient.node)
} }
return nil, &net.OpError{ return nil, &net.OpError{
Op: "dial-http", Op: "dial-http",
@ -108,25 +95,27 @@ func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) {
Err: err, Err: err,
} }
} }
io.WriteString(conn, "CONNECT "+rpcClient.rpcPath+" HTTP/1.0\n\n") io.WriteString(conn, "CONNECT "+rpcClient.rpcPath+" HTTP/1.0\n\n")
// Require successful HTTP response before switching to RPC protocol. // Require successful HTTP response before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == "200 Connected to Go RPC" { if err == nil && resp.Status == "200 Connected to Go RPC" {
rpc := rpc.NewClient(conn) netRPCClient := rpc.NewClient(conn)
if rpc == nil { if netRPCClient == nil {
return nil, &net.OpError{ return nil, &net.OpError{
Op: "dial-http", Op: "dial-http",
Net: rpcClient.node + " " + rpcClient.rpcPath, Net: rpcClient.node + " " + rpcClient.rpcPath,
Addr: nil, Addr: nil,
Err: fmt.Errorf("Unable to initialize new rpcClient, %s", errUnexpected), Err: fmt.Errorf("Unable to initialize new rpc.Client, %s", errUnexpected),
} }
} }
rpcClient.mu.Lock()
rpcClient.rpcPrivate = rpc rpcClient.netRPCClient = netRPCClient
rpcClient.mu.Unlock()
return rpc, nil return netRPCClient, nil
} }
if err == nil { if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status) err = errors.New("unexpected HTTP response: " + resp.Status)
} }
@ -141,38 +130,31 @@ func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) {
// Call makes a RPC call to the remote endpoint using the default codec, namely encoding/gob. // Call makes a RPC call to the remote endpoint using the default codec, namely encoding/gob.
func (rpcClient *RPCClient) Call(serviceMethod string, args interface{}, reply interface{}) error { func (rpcClient *RPCClient) Call(serviceMethod string, args interface{}, reply interface{}) error {
// Make a copy below so that we can safely (continue to) work with the rpc.Client. // Get a new or existing rpc.Client.
// Even in the case the two threads would simultaneously find that the connection is not initialised, netRPCClient, err := rpcClient.dial()
// they would both attempt to dial and only one of them would succeed in doing so. if err != nil {
rpcLocalStack := rpcClient.getRPCClient() return err
// If the rpc.Client is nil, we attempt to (re)connect with the remote endpoint.
if rpcLocalStack == nil {
var err error
rpcLocalStack, err = rpcClient.dialRPCClient()
if err != nil {
return err
}
} }
// If the RPC fails due to a network-related error return netRPCClient.Call(serviceMethod, args, reply)
return rpcLocalStack.Call(serviceMethod, args, reply)
} }
// Close closes the underlying socket file descriptor. // Close closes underlying rpc.Client.
func (rpcClient *RPCClient) Close() error { func (rpcClient *RPCClient) Close() error {
// See comment above for making a copy on local stack rpcClient.mu.Lock()
rpcLocalStack := rpcClient.getRPCClient()
// If rpc client has not connected yet there is nothing to close. if rpcClient.netRPCClient != nil {
if rpcLocalStack == nil { // We make a copy of rpc.Client and unlock it immediately so that another
return nil // goroutine could try to dial or close in parallel.
netRPCClient := rpcClient.netRPCClient
rpcClient.netRPCClient = nil
rpcClient.mu.Unlock()
return netRPCClient.Close()
} }
// Reset rpcClient.rpc to allow for subsequent calls to use a new rpcClient.mu.Unlock()
// (socket) connection. return nil
rpcClient.clearRPCClient()
return rpcLocalStack.Close()
} }
// Node returns the node (network address) of the connection // Node returns the node (network address) of the connection

View File

@ -63,7 +63,7 @@ func TestS3PeerRPC(t *testing.T) {
func (s *TestRPCS3PeerSuite) testS3PeerRPC(t *testing.T) { func (s *TestRPCS3PeerSuite) testS3PeerRPC(t *testing.T) {
// Validate for invalid token. // Validate for invalid token.
args := GenericArgs{Token: "garbage", Timestamp: time.Now().UTC()} args := GenericArgs{Token: "garbage", Timestamp: time.Now().UTC()}
rclient := newClient(s.testAuthConf.address, s.testAuthConf.path, false) rclient := newRPCClient(s.testAuthConf.address, s.testAuthConf.path, false)
defer rclient.Close() defer rclient.Close()
err := rclient.Call("S3.SetBucketNotificationPeer", &args, &GenericReply{}) err := rclient.Call("S3.SetBucketNotificationPeer", &args, &GenericReply{})
if err != nil { if err != nil {