Remove requirement for custom RPCClient (#5405)

This change is a simplification over existing
code since it is not required to have a separate
RPCClient structure instead keep authRPCClient can
do the same job.

There is no code which directly uses netRPCClient(),
keeping authRPCClient is better and simpler. This
simplication also allows for removal of multiple
levels of locking code per object.

Observed in #5160
This commit is contained in:
Harshavardhana 2018-01-19 16:38:47 -08:00 committed by kannappanr
parent 7f99cc9768
commit e19eddd759
6 changed files with 251 additions and 205 deletions

View File

@ -1,5 +1,5 @@
/*
* Minio Cloud Storage, (C) 2016, 2017 Minio, Inc.
* Minio Cloud Storage, (C) 2016, 2017, 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,14 @@
package cmd
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/rpc"
"sync"
"time"
@ -52,10 +60,10 @@ type authConfig struct {
// AuthRPCClient is a authenticated RPC client which does authentication before doing Call().
type AuthRPCClient struct {
sync.RWMutex // Mutex to lock this object.
rpcClient *RPCClient // Reconnectable RPC client to make any RPC call.
config authConfig // Authentication configuration information.
authToken string // Authentication token.
sync.RWMutex // Mutex to lock this object.
rpcClient *rpc.Client // RPC Client to make any RPC call.
config authConfig // Authentication configuration information.
authToken string // Authentication token.
}
// newAuthRPCClient - returns a JWT based authenticated (go) rpc client, which does automatic reconnect.
@ -73,8 +81,7 @@ func newAuthRPCClient(config authConfig) *AuthRPCClient {
}
return &AuthRPCClient{
rpcClient: newRPCClient(config.serverAddr, config.serviceEndpoint, config.secureConn),
config: config,
config: config,
}
}
@ -99,23 +106,38 @@ func (authClient *AuthRPCClient) Login() (err error) {
// Attempt to login if not logged in already.
if authClient.authToken == "" {
authClient.authToken, err = authenticateNode(authClient.config.accessKey, authClient.config.secretKey)
var authToken string
authToken, err = authenticateNode(authClient.config.accessKey, authClient.config.secretKey)
if err != nil {
return err
}
// Login to authenticate your token.
var (
loginMethod = authClient.config.serviceName + loginMethodName
loginArgs = LoginRPCArgs{
AuthToken: authClient.authToken,
AuthToken: authToken,
Version: Version,
RequestTime: UTCNow(),
}
)
if err = authClient.rpcClient.Call(loginMethod, &loginArgs, &LoginRPCReply{}); err != nil {
// Re-dial after we have disconnected or if its a fresh run.
var rpcClient *rpc.Client
rpcClient, err = rpcDial(authClient.config.serverAddr, authClient.config.serviceEndpoint, authClient.config.secureConn)
if err != nil {
return err
}
if err = rpcClient.Call(loginMethod, &loginArgs, &LoginRPCReply{}); err != nil {
return err
}
// Initialize rpc client and auth token after a successful login.
authClient.authToken = authToken
authClient.rpcClient = rpcClient
}
return nil
}
@ -127,10 +149,10 @@ func (authClient *AuthRPCClient) call(serviceMethod string, args interface {
return err
} // On successful login, execute RPC call.
authClient.RLock()
// Set token before the rpc call.
authClient.RLock()
defer authClient.RUnlock()
args.SetAuthToken(authClient.authToken)
authClient.RUnlock()
// Do an RPC call.
return authClient.rpcClient.Call(serviceMethod, args, reply)
@ -169,6 +191,10 @@ func (authClient *AuthRPCClient) Close() error {
authClient.Lock()
defer authClient.Unlock()
if authClient.rpcClient == nil {
return nil
}
authClient.authToken = ""
return authClient.rpcClient.Close()
}
@ -182,3 +208,87 @@ func (authClient *AuthRPCClient) ServerAddr() string {
func (authClient *AuthRPCClient) ServiceEndpoint() string {
return authClient.config.serviceEndpoint
}
// default Dial timeout for RPC connections.
const defaultDialTimeout = 3 * time.Second
// Connect success message required from rpc server.
const connectSuccessMessage = "200 Connected to Go RPC"
// dial tries to establish a connection to serverAddr in a safe manner.
// If there is a valid rpc.Cliemt, it returns that else creates a new one.
func rpcDial(serverAddr, serviceEndpoint string, secureConn bool) (netRPCClient *rpc.Client, err error) {
if serverAddr == "" || serviceEndpoint == "" {
return nil, errInvalidArgument
}
d := &net.Dialer{
Timeout: defaultDialTimeout,
}
var conn net.Conn
if secureConn {
var hostname string
if hostname, _, err = net.SplitHostPort(serverAddr); err != nil {
return nil, &net.OpError{
Op: "dial-http",
Net: serverAddr + serviceEndpoint,
Addr: nil,
Err: fmt.Errorf("Unable to parse server address <%s>: %s", serverAddr, err),
}
}
// ServerName in tls.Config needs to be specified to support SNI certificates.
conn, err = tls.DialWithDialer(d, "tcp", serverAddr, &tls.Config{
ServerName: hostname,
RootCAs: globalRootCAs,
})
} else {
conn, err = d.Dial("tcp", serverAddr)
}
if err != nil {
// Print RPC connection errors that are worthy to display in log.
switch err.(type) {
case x509.HostnameError:
errorIf(err, "Unable to establish secure connection to %s", serverAddr)
}
return nil, &net.OpError{
Op: "dial-http",
Net: serverAddr + serviceEndpoint,
Addr: nil,
Err: err,
}
}
// Check for network errors writing over the dialed conn.
if _, err = io.WriteString(conn, "CONNECT "+serviceEndpoint+" HTTP/1.0\n\n"); err != nil {
conn.Close()
return nil, &net.OpError{
Op: "dial-http",
Net: serverAddr + serviceEndpoint,
Addr: nil,
Err: err,
}
}
// Attempt to read the HTTP response for the HTTP method CONNECT, upon
// success return the RPC connection instance.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{
Method: http.MethodConnect,
})
if err != nil {
conn.Close()
return nil, &net.OpError{
Op: "dial-http",
Net: serverAddr + serviceEndpoint,
Addr: nil,
Err: err,
}
}
if resp.Status != connectSuccessMessage {
conn.Close()
return nil, errors.New("unexpected HTTP response: " + resp.Status)
}
// Initialize rpc client.
return rpc.NewClient(conn), nil
}

View File

@ -16,7 +16,12 @@
package cmd
import "testing"
import (
"crypto/x509"
"os"
"path"
"testing"
)
// Tests authorized RPC client.
func TestAuthRPCClient(t *testing.T) {
@ -53,3 +58,81 @@ func TestAuthRPCClient(t *testing.T) {
t.Fatalf("Unexpected node value %s, but expected %s", authRPC.ServiceEndpoint(), authCfg.serviceEndpoint)
}
}
// Test rpc dial test.
func TestRPCDial(t *testing.T) {
prevRootCAs := globalRootCAs
defer func() {
globalRootCAs = prevRootCAs
}()
rootPath, err := newTestConfig(globalMinioDefaultRegion)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rootPath)
testServer := StartTestServer(t, "")
defer testServer.Stop()
cert, key, err := generateTLSCertKey("127.0.0.1")
if err != nil {
t.Fatal(err)
}
// Set global root CAs.
globalRootCAs = x509.NewCertPool()
globalRootCAs.AppendCertsFromPEM(cert)
testServerTLS := StartTestTLSServer(t, "", cert, key)
defer testServerTLS.Stop()
adminEndpoint := path.Join(minioReservedBucketPath, adminPath)
testCases := []struct {
serverAddr string
serverEndpoint string
success bool
secure bool
}{
// Empty server addr should fail.
{
serverAddr: "",
serverEndpoint: adminEndpoint,
success: false,
},
// Unexpected server addr should fail.
{
serverAddr: "example.com",
serverEndpoint: adminEndpoint,
success: false,
},
// Server addr connects but fails for CONNECT call.
{
serverAddr: "example.com:80",
serverEndpoint: "/",
success: false,
},
// Successful connecting to insecure RPC server.
{
serverAddr: testServer.Server.Listener.Addr().String(),
serverEndpoint: adminEndpoint,
success: true,
},
// Successful connecting to secure RPC server.
{
serverAddr: testServerTLS.Server.Listener.Addr().String(),
serverEndpoint: adminEndpoint,
success: true,
secure: true,
},
}
for i, testCase := range testCases {
_, err = rpcDial(testCase.serverAddr, testCase.serverEndpoint, testCase.secure)
if err != nil && testCase.success {
t.Errorf("Test %d: Expected success but found failure instead %s", i+1, err)
}
if err == nil && !testCase.success {
t.Errorf("Test %d: Expected failure but found success instead", i+1)
}
}
}

View File

@ -70,11 +70,13 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) {
// Validate for invalid token.
args := SetAuthPeerArgs{Creds: creds}
args.AuthToken = "garbage"
rclient := newRPCClient(s.testAuthConf.serverAddr, s.testAuthConf.serviceEndpoint, false)
rclient := newAuthRPCClient(s.testAuthConf)
defer rclient.Close()
err = rclient.Call("BrowserPeer.SetAuthPeer", &args, &AuthRPCReply{})
if err != nil {
if err = rclient.Login(); err != nil {
t.Fatal(err)
}
rclient.authToken = "garbage"
if err = rclient.Call("BrowserPeer.SetAuthPeer", &args, &AuthRPCReply{}); err != nil {
if err.Error() != errInvalidToken.Error() {
t.Fatal(err)
}
@ -90,20 +92,14 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) {
}
// Validate for failure in login handler with previous credentials.
rclient = newRPCClient(s.testAuthConf.serverAddr, s.testAuthConf.serviceEndpoint, false)
rclient = newAuthRPCClient(s.testAuthConf)
defer rclient.Close()
token, err := authenticateNode(creds.AccessKey, creds.SecretKey)
if err != nil {
t.Fatal(err)
}
rargs := &LoginRPCArgs{
AuthToken: token,
Version: Version,
RequestTime: UTCNow(),
}
rreply := &LoginRPCReply{}
err = rclient.Call("BrowserPeer"+loginMethodName, rargs, rreply)
if err != nil {
rclient.authToken = token
if err = rclient.Login(); err != nil {
if err.Error() != errInvalidAccessKeyID.Error() {
t.Fatal(err)
}
@ -113,14 +109,8 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// Validate for success in loing handled with valid credetnails.
rargs = &LoginRPCArgs{
AuthToken: token,
Version: Version,
RequestTime: UTCNow(),
}
rreply = &LoginRPCReply{}
if err = rclient.Call("BrowserPeer"+loginMethodName, rargs, rreply); err != nil {
rclient.authToken = token
if err = rclient.Login(); err != nil {
t.Fatal(err)
}
}

View File

@ -1,164 +0,0 @@
/*
* Minio Cloud Storage, (C) 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cmd
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/rpc"
"sync"
"time"
)
// defaultDialTimeout is used for non-secure connection.
const defaultDialTimeout = 3 * time.Second
// RPCClient is a reconnectable RPC client on Call().
type RPCClient struct {
sync.Mutex // Mutex to lock net rpc client.
netRPCClient *rpc.Client // Base RPC client to make any RPC call.
serverAddr string // RPC server address.
serviceEndpoint string // Endpoint on the server to make any RPC call.
secureConn bool // Make TLS connection to RPC server or not.
}
// newRPCClient returns new RPCClient object with given serverAddr and serviceEndpoint.
// It does lazy connect to the remote endpoint on Call().
func newRPCClient(serverAddr, serviceEndpoint string, secureConn bool) *RPCClient {
return &RPCClient{
serverAddr: serverAddr,
serviceEndpoint: serviceEndpoint,
secureConn: secureConn,
}
}
// dial tries to establish a connection to serverAddr in a safe manner.
// If there is a valid rpc.Cliemt, it returns that else creates a new one.
func (rpcClient *RPCClient) dial() (netRPCClient *rpc.Client, err error) {
rpcClient.Lock()
defer rpcClient.Unlock()
// Nothing to do as we already have valid connection.
if rpcClient.netRPCClient != nil {
return rpcClient.netRPCClient, nil
}
var conn net.Conn
if rpcClient.secureConn {
var hostname string
if hostname, _, err = net.SplitHostPort(rpcClient.serverAddr); err != nil {
err = &net.OpError{
Op: "dial-http",
Net: rpcClient.serverAddr + rpcClient.serviceEndpoint,
Addr: nil,
Err: fmt.Errorf("Unable to parse server address <%s>: %s", rpcClient.serverAddr, err.Error()),
}
return nil, err
}
// ServerName in tls.Config needs to be specified to support SNI certificates.
conn, err = tls.Dial("tcp", rpcClient.serverAddr, &tls.Config{ServerName: hostname, RootCAs: globalRootCAs})
} else {
// Dial with a timeout.
conn, err = net.DialTimeout("tcp", rpcClient.serverAddr, defaultDialTimeout)
}
if err != nil {
// Print RPC connection errors that are worthy to display in log.
switch err.(type) {
case x509.HostnameError:
errorIf(err, "Unable to establish secure connection to %s", rpcClient.serverAddr)
}
return nil, &net.OpError{
Op: "dial-http",
Net: rpcClient.serverAddr + rpcClient.serviceEndpoint,
Addr: nil,
Err: err,
}
}
io.WriteString(conn, "CONNECT "+rpcClient.serviceEndpoint+" HTTP/1.0\n\n")
// Require successful HTTP response before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == "200 Connected to Go RPC" {
netRPCClient := rpc.NewClient(conn)
if netRPCClient == nil {
return nil, &net.OpError{
Op: "dial-http",
Net: rpcClient.serverAddr + rpcClient.serviceEndpoint,
Addr: nil,
Err: fmt.Errorf("Unable to initialize new rpc.Client, %s", errUnexpected),
}
}
rpcClient.netRPCClient = netRPCClient
return netRPCClient, nil
}
conn.Close()
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
return nil, &net.OpError{
Op: "dial-http",
Net: rpcClient.serverAddr + rpcClient.serviceEndpoint,
Addr: nil,
Err: err,
}
}
// Call makes a RPC call to the remote endpoint using the default codec, namely encoding/gob.
func (rpcClient *RPCClient) Call(serviceMethod string, args interface{}, reply interface{}) error {
// Get a new or existing rpc.Client.
netRPCClient, err := rpcClient.dial()
if err != nil {
return err
}
return netRPCClient.Call(serviceMethod, args, reply)
}
// Close closes underlying rpc.Client.
func (rpcClient *RPCClient) Close() error {
rpcClient.Lock()
if rpcClient.netRPCClient != nil {
// We make a copy of rpc.Client and unlock it immediately so that another
// goroutine could try to dial or close in parallel.
netRPCClient := rpcClient.netRPCClient
rpcClient.netRPCClient = nil
rpcClient.Unlock()
return netRPCClient.Close()
}
rpcClient.Unlock()
return nil
}

View File

@ -19,6 +19,7 @@ package cmd
import (
"net/http"
"net/http/httptest"
"os"
"testing"
router "github.com/gorilla/mux"
@ -32,10 +33,11 @@ type ArithReply struct {
C int
}
type Arith int
type Arith struct {
AuthRPCServer
}
// Some of Arith's methods have value args, some have pointer args. That's deliberate.
func (t *Arith) Add(args ArithArgs, reply *ArithReply) error {
reply.C = args.A + args.B
return nil
@ -43,7 +45,9 @@ func (t *Arith) Add(args ArithArgs, reply *ArithReply) error {
func TestGoHTTPRPC(t *testing.T) {
newServer := newRPCServer()
newServer.Register(new(Arith))
newServer.Register(&Arith{
AuthRPCServer: AuthRPCServer{},
})
mux := router.NewRouter().SkipClean(true)
mux.Path("/foo").Handler(newServer)
@ -51,13 +55,30 @@ func TestGoHTTPRPC(t *testing.T) {
httpServer := httptest.NewServer(mux)
defer httpServer.Close()
client := newRPCClient(httpServer.Listener.Addr().String(), "/foo", false)
rootPath, err := newTestConfig("us-east-1")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rootPath)
creds := globalServerConfig.GetCredential()
client := newAuthRPCClient(authConfig{
serverAddr: httpServer.Listener.Addr().String(),
serviceName: "Arith",
serviceEndpoint: "/foo",
accessKey: creds.AccessKey,
secretKey: creds.SecretKey,
})
defer client.Close()
if err = client.Login(); err != nil {
t.Fatal(err)
}
// Synchronous calls
args := &ArithArgs{7, 8}
reply := new(ArithReply)
if err := client.Call("Arith.Add", args, reply); err != nil {
if err = client.rpcClient.Call("Arith.Add", args, reply); err != nil {
t.Errorf("Add: expected no error but got string %v", err)
}

View File

@ -62,9 +62,15 @@ func TestS3PeerRPC(t *testing.T) {
// Test S3 RPC handlers
func (s *TestRPCS3PeerSuite) testS3PeerRPC(t *testing.T) {
// Validate for invalid token.
args := AuthRPCArgs{AuthToken: "garbage"}
rclient := newRPCClient(s.testAuthConf.serverAddr, s.testAuthConf.serviceEndpoint, false)
args := AuthRPCArgs{}
rclient := newAuthRPCClient(s.testAuthConf)
defer rclient.Close()
if err := rclient.Login(); err != nil {
t.Fatal(err)
}
rclient.authToken = "garbage"
err := rclient.Call("S3.SetBucketNotificationPeer", &args, &AuthRPCReply{})
if err != nil {
if err.Error() != errInvalidToken.Error() {