diff --git a/cmd/rpc/client.go b/cmd/rpc/client.go index 385a866de..44025986d 100644 --- a/cmd/rpc/client.go +++ b/cmd/rpc/client.go @@ -17,7 +17,6 @@ package rpc import ( - "bytes" "context" "crypto/tls" "encoding/gob" @@ -48,22 +47,25 @@ func (client *Client) Call(serviceMethod string, args, reply interface{}) error return fmt.Errorf("rpc reply must be a pointer type, but found %v", replyKind) } - data, err := gobEncode(args) - if err != nil { + argBuf := bufPool.Get() + defer bufPool.Put(argBuf) + + if err := gobEncodeBuf(args, argBuf); err != nil { return err } callRequest := CallRequest{ Method: serviceMethod, - ArgBytes: data, + ArgBytes: argBuf.Bytes(), } - var buf bytes.Buffer - if err = gob.NewEncoder(&buf).Encode(callRequest); err != nil { + reqBuf := bufPool.Get() + defer bufPool.Put(reqBuf) + if err := gob.NewEncoder(reqBuf).Encode(callRequest); err != nil { return err } - response, err := client.httpClient.Post(client.serviceURL.String(), "", &buf) + response, err := client.httpClient.Post(client.serviceURL.String(), "", reqBuf) if err != nil { return err } diff --git a/cmd/rpc/pool.go b/cmd/rpc/pool.go new file mode 100644 index 000000000..84dcbf6e1 --- /dev/null +++ b/cmd/rpc/pool.go @@ -0,0 +1,48 @@ +/* + * Minio Cloud Storage, (C) 2018 Minio, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rpc + +import ( + "bytes" + "sync" +) + +// A Pool is a type-safe wrapper around a sync.Pool. +type Pool struct { + p *sync.Pool +} + +// NewPool constructs a new Pool. +func NewPool() Pool { + return Pool{p: &sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, + }} +} + +// Get retrieves a bytes.Buffer from the pool, creating one if necessary. +func (p Pool) Get() *bytes.Buffer { + buf := p.p.Get().(*bytes.Buffer) + return buf +} + +// Put - returns a bytes.Buffer to the pool. +func (p Pool) Put(buf *bytes.Buffer) { + buf.Reset() + p.p.Put(buf) +} diff --git a/cmd/rpc/server.go b/cmd/rpc/server.go index 675c26352..ba6c98107 100644 --- a/cmd/rpc/server.go +++ b/cmd/rpc/server.go @@ -40,14 +40,10 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem() // reflect.Type of Authenticator interface. var authenticatorType = reflect.TypeOf((*Authenticator)(nil)).Elem() -func gobEncode(e interface{}) ([]byte, error) { - var buf bytes.Buffer +var bufPool = NewPool() - if err := gob.NewEncoder(&buf).Encode(e); err != nil { - return nil, err - } - - return buf.Bytes(), nil +func gobEncodeBuf(e interface{}, buf *bytes.Buffer) error { + return gob.NewEncoder(buf).Encode(e) } func gobDecode(data []byte, e interface{}) error { @@ -146,21 +142,21 @@ func (server *Server) RegisterName(name string, receiver interface{}) error { } // call - call service method in receiver. -func (server *Server) call(serviceMethod string, argBytes []byte) (replyBytes []byte, err error) { +func (server *Server) call(serviceMethod string, argBytes []byte, replyBytes *bytes.Buffer) (err error) { tokens := strings.SplitN(serviceMethod, ".", 2) if len(tokens) != 2 { - return nil, fmt.Errorf("invalid service/method request ill-formed %v", serviceMethod) + return fmt.Errorf("invalid service/method request ill-formed %v", serviceMethod) } serviceName := tokens[0] if serviceName != server.serviceName { - return nil, fmt.Errorf("can't find service %v", serviceName) + return fmt.Errorf("can't find service %v", serviceName) } methodName := tokens[1] method, found := server.methodMap[methodName] if !found { - return nil, fmt.Errorf("can't find method %v", methodName) + return fmt.Errorf("can't find method %v", methodName) } var argv reflect.Value @@ -175,7 +171,7 @@ func (server *Server) call(serviceMethod string, argBytes []byte) (replyBytes [] } if err = gobDecode(argBytes, argv.Interface()); err != nil { - return nil, err + return err } if argIsValue { @@ -193,7 +189,7 @@ func (server *Server) call(serviceMethod string, argBytes []byte) (replyBytes [] err = errInter.(error) } if err != nil { - return nil, err + return err } replyv := reflect.New(method.Type.In(2).Elem()) @@ -211,10 +207,10 @@ func (server *Server) call(serviceMethod string, argBytes []byte) (replyBytes [] err = errInter.(error) } if err != nil { - return nil, err + return err } - return gobEncode(replyv.Interface()) + return gobEncodeBuf(replyv.Interface(), replyBytes) } // CallRequest - RPC call request parameters. @@ -242,20 +238,18 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - var callResponse CallResponse - var err error - callResponse.ReplyBytes, err = server.call(callRequest.Method, callRequest.ArgBytes) - if err != nil { + callResponse := CallResponse{} + buf := bufPool.Get() + defer bufPool.Put(buf) + + if err := server.call(callRequest.Method, callRequest.ArgBytes, buf); err != nil { callResponse.Error = err.Error() } + callResponse.ReplyBytes = buf.Bytes() - data, err := gobEncode(callResponse) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } + gob.NewEncoder(w).Encode(callResponse) - w.Write(data) + w.(http.Flusher).Flush() } // NewServer - returns new RPC server. diff --git a/cmd/rpc/server_test.go b/cmd/rpc/server_test.go index a1931024d..49b6c4396 100644 --- a/cmd/rpc/server_test.go +++ b/cmd/rpc/server_test.go @@ -18,6 +18,7 @@ package rpc import ( "bytes" + "encoding/gob" "errors" "net/http" "net/http/httptest" @@ -25,6 +26,12 @@ import ( "testing" ) +func gobEncode(e interface{}) ([]byte, error) { + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(e) + return buf.Bytes(), err +} + type Args struct { A, B int } @@ -251,7 +258,10 @@ func TestServerCall(t *testing.T) { } for i, testCase := range testCases { - result, err := testCase.server.call(testCase.serviceMethod, testCase.argBytes) + buf := bufPool.Get() + defer bufPool.Put(buf) + + err := testCase.server.call(testCase.serviceMethod, testCase.argBytes, buf) expectErr := (err != nil) if expectErr != testCase.expectErr { @@ -259,8 +269,8 @@ func TestServerCall(t *testing.T) { } if !testCase.expectErr { - if !reflect.DeepEqual(result, testCase.expectedResult) { - t.Fatalf("case %v: result: expected: %v, got: %v\n", i+1, testCase.expectedResult, result) + if !reflect.DeepEqual(buf.Bytes(), testCase.expectedResult) { + t.Fatalf("case %v: result: expected: %v, got: %v\n", i+1, testCase.expectedResult, buf.Bytes()) } } }