diff --git a/Makefile b/Makefile index bbe0cd0f7..29c4f043d 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ help: ## print this help getdeps: ## fetch necessary dependencies @mkdir -p ${GOPATH}/bin @echo "Installing golangci-lint" && curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOLANGCI_DIR) $(GOLANGCI_VERSION) - @echo "Installing msgp" && go install -v github.com/tinylib/msgp@v1.1.7 + @echo "Installing msgp" && go install -v github.com/tinylib/msgp@6ac204f0b4d48d17ab4fa442134c7fba13127a4e @echo "Installing stringer" && go install -v golang.org/x/tools/cmd/stringer@latest crosscompile: ## cross compile minio diff --git a/cmd/endpoint.go b/cmd/endpoint.go index dc57fa688..22d569a5a 100644 --- a/cmd/endpoint.go +++ b/cmd/endpoint.go @@ -96,6 +96,11 @@ func (endpoint Endpoint) HTTPS() bool { return endpoint.Scheme == "https" } +// GridHost returns the host to be used for grid connections. +func (endpoint Endpoint) GridHost() string { + return fmt.Sprintf("%s://%s", endpoint.Scheme, endpoint.Host) +} + // UpdateIsLocal - resolves the host and updates if it is local or not. func (endpoint *Endpoint) UpdateIsLocal() (err error) { if !endpoint.IsLocal { diff --git a/cmd/format-erasure.go b/cmd/format-erasure.go index 31384f99c..d96325f94 100644 --- a/cmd/format-erasure.go +++ b/cmd/format-erasure.go @@ -31,6 +31,7 @@ import ( "github.com/minio/minio/internal/color" "github.com/minio/minio/internal/config" "github.com/minio/minio/internal/config/storageclass" + "github.com/minio/minio/internal/grid" xioutil "github.com/minio/minio/internal/ioutil" "github.com/minio/minio/internal/logger" "github.com/minio/pkg/v2/sync/errgroup" @@ -388,6 +389,12 @@ func saveFormatErasure(disk StorageAPI, format *formatErasureV3, healID string) // loadFormatErasure - loads format.json from disk. func loadFormatErasure(disk StorageAPI) (format *formatErasureV3, err error) { + // Ensure that the grid is online. + if _, err := disk.DiskInfo(context.Background(), false); err != nil { + if errors.Is(err, grid.ErrDisconnected) { + return nil, err + } + } buf, err := disk.ReadAll(context.TODO(), minioMetaBucket, formatConfigFile) if err != nil { // 'file not found' and 'volume not found' as diff --git a/cmd/generic-handlers.go b/cmd/generic-handlers.go index 6226081a6..b55e99fa8 100644 --- a/cmd/generic-handlers.go +++ b/cmd/generic-handlers.go @@ -31,6 +31,7 @@ import ( "github.com/dustin/go-humanize" "github.com/minio/minio-go/v7/pkg/s3utils" "github.com/minio/minio-go/v7/pkg/set" + "github.com/minio/minio/internal/grid" xnet "github.com/minio/pkg/v2/net" "github.com/minio/minio/internal/amztime" @@ -240,6 +241,10 @@ func guessIsRPCReq(req *http.Request) bool { if req == nil { return false } + if req.Method == http.MethodGet && req.URL != nil && req.URL.Path == grid.RoutePath { + return true + } + return req.Method == http.MethodPost && strings.HasPrefix(req.URL.Path, minioReservedBucketPath+SlashSeparator) } diff --git a/cmd/generic-handlers_test.go b/cmd/generic-handlers_test.go index 70d7177b6..b76ec2910 100644 --- a/cmd/generic-handlers_test.go +++ b/cmd/generic-handlers_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/minio/minio/internal/crypto" + "github.com/minio/minio/internal/grid" xhttp "github.com/minio/minio/internal/http" ) @@ -54,6 +55,14 @@ func TestGuessIsRPC(t *testing.T) { if guessIsRPCReq(r) { t.Fatal("Test shouldn't report as net/rpc for a non net/rpc request.") } + r = &http.Request{ + Proto: "HTTP/1.1", + Method: http.MethodGet, + URL: &url.URL{Path: grid.RoutePath}, + } + if !guessIsRPCReq(r) { + t.Fatal("Grid RPC path not detected") + } } var isHTTPHeaderSizeTooLargeTests = []struct { diff --git a/cmd/grid.go b/cmd/grid.go new file mode 100644 index 000000000..bd7029c10 --- /dev/null +++ b/cmd/grid.go @@ -0,0 +1,83 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "context" + "crypto/tls" + "sync/atomic" + + "github.com/minio/minio-go/v7/pkg/set" + "github.com/minio/minio/internal/fips" + "github.com/minio/minio/internal/grid" + xhttp "github.com/minio/minio/internal/http" + "github.com/minio/minio/internal/rest" +) + +// globalGrid is the global grid manager. +var globalGrid atomic.Pointer[grid.Manager] + +// globalGridStart is a channel that will block startup of grid connections until closed. +var globalGridStart = make(chan struct{}) + +func initGlobalGrid(ctx context.Context, eps EndpointServerPools) error { + seenHosts := set.NewStringSet() + var hosts []string + var local string + for _, ep := range eps { + for _, endpoint := range ep.Endpoints { + u := endpoint.GridHost() + if seenHosts.Contains(u) { + continue + } + seenHosts.Add(u) + + // Set local endpoint + if endpoint.IsLocal { + local = u + } + hosts = append(hosts, u) + } + } + lookupHost := globalDNSCache.LookupHost + if IsKubernetes() || IsDocker() { + lookupHost = nil + } + g, err := grid.NewManager(ctx, grid.ManagerOptions{ + Dialer: grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions))), + Local: local, + Hosts: hosts, + AddAuth: newCachedAuthToken(), + AuthRequest: storageServerRequestValidate, + BlockConnect: globalGridStart, + TLSConfig: &tls.Config{ + RootCAs: globalRootCAs, + CipherSuites: fips.TLSCiphers(), + CurvePreferences: fips.TLSCurveIDs(), + }, + // Record incoming and outgoing bytes. + Incoming: globalConnStats.incInternodeInputBytes, + Outgoing: globalConnStats.incInternodeOutputBytes, + TraceTo: globalTrace, + }) + if err != nil { + return err + } + globalGrid.Store(g) + return nil +} diff --git a/cmd/handler-utils.go b/cmd/handler-utils.go index 830f5daac..d1c15f6e4 100644 --- a/cmd/handler-utils.go +++ b/cmd/handler-utils.go @@ -389,12 +389,6 @@ func errorResponseHandler(w http.ResponseWriter, r *http.Request) { Description: desc, HTTPStatusCode: http.StatusUpgradeRequired, }, r.URL) - case strings.HasPrefix(r.URL.Path, lockRESTPrefix): - writeErrorResponseString(r.Context(), w, APIError{ - Code: "XMinioLockVersionMismatch", - Description: desc, - HTTPStatusCode: http.StatusUpgradeRequired, - }, r.URL) case strings.HasPrefix(r.URL.Path, adminPathPrefix): var desc string version := extractAPIVersion(r) diff --git a/cmd/http-tracer.go b/cmd/http-tracer.go index 490a745a9..439549edd 100644 --- a/cmd/http-tracer.go +++ b/cmd/http-tracer.go @@ -55,6 +55,7 @@ func getOpName(name string) (op string) { op = strings.Replace(op, "(*peerRESTServer)", "peer", 1) op = strings.Replace(op, "(*lockRESTServer)", "lockR", 1) op = strings.Replace(op, "(*stsAPIHandlers)", "sts", 1) + op = strings.Replace(op, "(*peerS3Server)", "s3", 1) op = strings.Replace(op, "ClusterCheckHandler", "health.Cluster", 1) op = strings.Replace(op, "ClusterReadCheckHandler", "health.ClusterRead", 1) op = strings.Replace(op, "LivenessCheckHandler", "health.Liveness", 1) diff --git a/cmd/local-locker.go b/cmd/local-locker.go index dd609e759..37ea992ff 100644 --- a/cmd/local-locker.go +++ b/cmd/local-locker.go @@ -159,7 +159,7 @@ func (l *localLocker) removeEntry(name string, args dsync.LockArgs, lri *[]lockR } func (l *localLocker) RLock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - if len(args.Resources) > 1 { + if len(args.Resources) != 1 { return false, fmt.Errorf("internal error: localLocker.RLock called with more than one resource") } diff --git a/cmd/lock-rest-client.go b/cmd/lock-rest-client.go index 35b1623a5..22f4328e4 100644 --- a/cmd/lock-rest-client.go +++ b/cmd/lock-rest-client.go @@ -18,121 +18,86 @@ package cmd import ( - "bytes" "context" - "io" - "net/url" + "errors" "github.com/minio/minio/internal/dsync" - xhttp "github.com/minio/minio/internal/http" - "github.com/minio/minio/internal/rest" + "github.com/minio/minio/internal/grid" + "github.com/minio/minio/internal/logger" ) // lockRESTClient is authenticable lock REST client type lockRESTClient struct { - restClient *rest.Client - u *url.URL -} - -func toLockError(err error) error { - if err == nil { - return nil - } - - switch err.Error() { - case errLockConflict.Error(): - return errLockConflict - case errLockNotFound.Error(): - return errLockNotFound - } - return err -} - -// String stringer *dsync.NetLocker* interface compatible method. -func (client *lockRESTClient) String() string { - return client.u.String() -} - -// Wrapper to restClient.Call to handle network errors, in case of network error the connection is marked disconnected -// permanently. The only way to restore the connection is at the xl-sets layer by xlsets.monitorAndConnectEndpoints() -// after verifying format.json -func (client *lockRESTClient) callWithContext(ctx context.Context, method string, values url.Values, body io.Reader, length int64) (respBody io.ReadCloser, err error) { - if values == nil { - values = make(url.Values) - } - - respBody, err = client.restClient.Call(ctx, method, values, body, length) - if err == nil { - return respBody, nil - } - - return nil, toLockError(err) + connection *grid.Connection } // IsOnline - returns whether REST client failed to connect or not. -func (client *lockRESTClient) IsOnline() bool { - return client.restClient.IsOnline() +func (c *lockRESTClient) IsOnline() bool { + return c.connection.State() == grid.StateConnected } // Not a local locker -func (client *lockRESTClient) IsLocal() bool { +func (c *lockRESTClient) IsLocal() bool { return false } // Close - marks the client as closed. -func (client *lockRESTClient) Close() error { - client.restClient.Close() +func (c *lockRESTClient) Close() error { return nil } -// restCall makes a call to the lock REST server. -func (client *lockRESTClient) restCall(ctx context.Context, call string, args dsync.LockArgs) (reply bool, err error) { - argsBytes, err := args.MarshalMsg(metaDataPoolGet()[:0]) +// String - returns the remote host of the connection. +func (c *lockRESTClient) String() string { + return c.connection.Remote +} + +func (c *lockRESTClient) call(ctx context.Context, h *grid.SingleHandler[*dsync.LockArgs, *dsync.LockResp], args *dsync.LockArgs) (ok bool, err error) { + r, err := h.Call(ctx, c.connection, args) if err != nil { + logger.LogIfNot(ctx, err, grid.ErrDisconnected) return false, err } - defer metaDataPoolPut(argsBytes) - body := bytes.NewReader(argsBytes) - respBody, err := client.callWithContext(ctx, call, nil, body, body.Size()) - defer xhttp.DrainBody(respBody) - switch err { - case nil: - return true, nil - case errLockConflict, errLockNotFound: - return false, nil + defer h.PutResponse(r) + ok = r.Code == dsync.RespOK + switch r.Code { + case dsync.RespLockConflict, dsync.RespLockNotFound, dsync.RespOK: + // no error + case dsync.RespLockNotInitialized: + err = errLockNotInitialized default: - return false, err + err = errors.New(r.Err) } + return ok, err } // RLock calls read lock REST API. -func (client *lockRESTClient) RLock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - return client.restCall(ctx, lockRESTMethodRLock, args) +func (c *lockRESTClient) RLock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { + return c.call(ctx, lockRPCRLock, &args) } // Lock calls lock REST API. -func (client *lockRESTClient) Lock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - return client.restCall(ctx, lockRESTMethodLock, args) +func (c *lockRESTClient) Lock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { + return c.call(ctx, lockRPCLock, &args) } // RUnlock calls read unlock REST API. -func (client *lockRESTClient) RUnlock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - return client.restCall(ctx, lockRESTMethodRUnlock, args) +func (c *lockRESTClient) RUnlock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { + return c.call(ctx, lockRPCRUnlock, &args) } -// RUnlock calls read unlock REST API. -func (client *lockRESTClient) Refresh(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - return client.restCall(ctx, lockRESTMethodRefresh, args) +// Refresh calls Refresh REST API. +func (c *lockRESTClient) Refresh(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { + return c.call(ctx, lockRPCRefresh, &args) } // Unlock calls write unlock RPC. -func (client *lockRESTClient) Unlock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - return client.restCall(ctx, lockRESTMethodUnlock, args) +func (c *lockRESTClient) Unlock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { + return c.call(ctx, lockRPCUnlock, &args) } // ForceUnlock calls force unlock handler to forcibly unlock an active lock. -func (client *lockRESTClient) ForceUnlock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { - return client.restCall(ctx, lockRESTMethodForceUnlock, args) +func (c *lockRESTClient) ForceUnlock(ctx context.Context, args dsync.LockArgs) (reply bool, err error) { + return c.call(ctx, lockRPCForceUnlock, &args) } func newLockAPI(endpoint Endpoint) dsync.NetLocker { @@ -143,27 +108,6 @@ func newLockAPI(endpoint Endpoint) dsync.NetLocker { } // Returns a lock rest client. -func newlockRESTClient(endpoint Endpoint) *lockRESTClient { - serverURL := &url.URL{ - Scheme: endpoint.Scheme, - Host: endpoint.Host, - Path: pathJoin(lockRESTPrefix, lockRESTVersion), - } - - restClient := rest.NewClient(serverURL, globalInternodeTransport, newCachedAuthToken()) - // Use a separate client to avoid recursive calls. - healthClient := rest.NewClient(serverURL, globalInternodeTransport, newCachedAuthToken()) - healthClient.NoMetrics = true - restClient.HealthCheckFn = func() bool { - ctx, cancel := context.WithTimeout(context.Background(), restClient.HealthCheckTimeout) - defer cancel() - respBody, err := healthClient.Call(ctx, lockRESTMethodHealth, nil, nil, -1) - xhttp.DrainBody(respBody) - return !isNetworkError(err) - } - - return &lockRESTClient{u: &url.URL{ - Scheme: endpoint.Scheme, - Host: endpoint.Host, - }, restClient: restClient} +func newlockRESTClient(ep Endpoint) *lockRESTClient { + return &lockRESTClient{globalGrid.Load().Connection(ep.GridHost())} } diff --git a/cmd/lock-rest-client_test.go b/cmd/lock-rest-client_test.go index 8c441da82..10beb12a0 100644 --- a/cmd/lock-rest-client_test.go +++ b/cmd/lock-rest-client_test.go @@ -26,14 +26,27 @@ import ( // Tests lock rpc client. func TestLockRESTlient(t *testing.T) { - endpoint, err := NewEndpoint("http://localhost:9000") + // These should not be connectable. + endpoint, err := NewEndpoint("http://localhost:9876") if err != nil { t.Fatalf("unexpected error %v", err) } + endpointLocal, err := NewEndpoint("http://localhost:9012") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + endpointLocal.IsLocal = true + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = initGlobalGrid(ctx, []PoolEndpoints{{Endpoints: Endpoints{endpoint, endpointLocal}}}) + if err != nil { + t.Fatal(err) + } lkClient := newlockRESTClient(endpoint) - if !lkClient.IsOnline() { - t.Fatalf("unexpected error. connection failed") + if lkClient.IsOnline() { + t.Fatalf("unexpected result. connection was online") } // Attempt all calls. diff --git a/cmd/lock-rest-server-common.go b/cmd/lock-rest-server-common.go index dfc1938aa..ddaa6bb38 100644 --- a/cmd/lock-rest-server-common.go +++ b/cmd/lock-rest-server-common.go @@ -21,22 +21,6 @@ import ( "errors" ) -const ( - lockRESTVersion = "v7" // Add msgp for lockArgs - lockRESTVersionPrefix = SlashSeparator + lockRESTVersion - lockRESTPrefix = minioReservedBucketPath + "/lock" -) - -const ( - lockRESTMethodHealth = "/health" - lockRESTMethodRefresh = "/refresh" - lockRESTMethodLock = "/lock" - lockRESTMethodRLock = "/rlock" - lockRESTMethodUnlock = "/unlock" - lockRESTMethodRUnlock = "/runlock" - lockRESTMethodForceUnlock = "/force-unlock" -) - var ( errLockConflict = errors.New("lock conflict") errLockNotInitialized = errors.New("lock not initialized") diff --git a/cmd/lock-rest-server.go b/cmd/lock-rest-server.go index 5ab70fd7f..a0adea595 100644 --- a/cmd/lock-rest-server.go +++ b/cmd/lock-rest-server.go @@ -19,16 +19,132 @@ package cmd import ( "context" - "errors" - "io" - "net/http" "time" - "github.com/dustin/go-humanize" "github.com/minio/minio/internal/dsync" - "github.com/minio/mux" + "github.com/minio/minio/internal/grid" + "github.com/minio/minio/internal/logger" ) +// To abstract a node over network. +type lockRESTServer struct { + ll *localLocker +} + +// RefreshHandler - refresh the current lock +func (l *lockRESTServer) RefreshHandler(args *dsync.LockArgs) (*dsync.LockResp, *grid.RemoteErr) { + resp := lockRPCRefresh.NewResponse() + refreshed, err := l.ll.Refresh(context.Background(), *args) + if err != nil { + return l.makeResp(resp, err) + } + if !refreshed { + return l.makeResp(resp, errLockNotFound) + } + return l.makeResp(resp, err) +} + +// LockHandler - Acquires a lock. +func (l *lockRESTServer) LockHandler(args *dsync.LockArgs) (*dsync.LockResp, *grid.RemoteErr) { + resp := lockRPCLock.NewResponse() + success, err := l.ll.Lock(context.Background(), *args) + if err == nil && !success { + return l.makeResp(resp, errLockConflict) + } + return l.makeResp(resp, err) +} + +// UnlockHandler - releases the acquired lock. +func (l *lockRESTServer) UnlockHandler(args *dsync.LockArgs) (*dsync.LockResp, *grid.RemoteErr) { + resp := lockRPCUnlock.NewResponse() + _, err := l.ll.Unlock(context.Background(), *args) + // Ignore the Unlock() "reply" return value because if err == nil, "reply" is always true + // Consequently, if err != nil, reply is always false + return l.makeResp(resp, err) +} + +// RLockHandler - Acquires an RLock. +func (l *lockRESTServer) RLockHandler(args *dsync.LockArgs) (*dsync.LockResp, *grid.RemoteErr) { + resp := lockRPCRLock.NewResponse() + success, err := l.ll.RLock(context.Background(), *args) + if err == nil && !success { + err = errLockConflict + } + return l.makeResp(resp, err) +} + +// RUnlockHandler - releases the acquired read lock. +func (l *lockRESTServer) RUnlockHandler(args *dsync.LockArgs) (*dsync.LockResp, *grid.RemoteErr) { + resp := lockRPCRUnlock.NewResponse() + + // Ignore the RUnlock() "reply" return value because if err == nil, "reply" is always true. + // Consequently, if err != nil, reply is always false + _, err := l.ll.RUnlock(context.Background(), *args) + return l.makeResp(resp, err) +} + +// ForceUnlockHandler - query expired lock status. +func (l *lockRESTServer) ForceUnlockHandler(args *dsync.LockArgs) (*dsync.LockResp, *grid.RemoteErr) { + resp := lockRPCForceUnlock.NewResponse() + + _, err := l.ll.ForceUnlock(context.Background(), *args) + return l.makeResp(resp, err) +} + +var ( + // Static lock handlers. + // All have the same signature. + lockRPCForceUnlock = newLockHandler(grid.HandlerLockForceUnlock) + lockRPCRefresh = newLockHandler(grid.HandlerLockRefresh) + lockRPCLock = newLockHandler(grid.HandlerLockLock) + lockRPCUnlock = newLockHandler(grid.HandlerLockUnlock) + lockRPCRLock = newLockHandler(grid.HandlerLockRLock) + lockRPCRUnlock = newLockHandler(grid.HandlerLockRUnlock) +) + +func newLockHandler(h grid.HandlerID) *grid.SingleHandler[*dsync.LockArgs, *dsync.LockResp] { + return grid.NewSingleHandler[*dsync.LockArgs, *dsync.LockResp](h, func() *dsync.LockArgs { + return &dsync.LockArgs{} + }, func() *dsync.LockResp { + return &dsync.LockResp{} + }) +} + +// registerLockRESTHandlers - register lock rest router. +func registerLockRESTHandlers() { + lockServer := &lockRESTServer{ + ll: newLocker(), + } + + logger.FatalIf(lockRPCForceUnlock.Register(globalGrid.Load(), lockServer.ForceUnlockHandler), "unable to register handler") + logger.FatalIf(lockRPCRefresh.Register(globalGrid.Load(), lockServer.RefreshHandler), "unable to register handler") + logger.FatalIf(lockRPCLock.Register(globalGrid.Load(), lockServer.LockHandler), "unable to register handler") + logger.FatalIf(lockRPCUnlock.Register(globalGrid.Load(), lockServer.UnlockHandler), "unable to register handler") + logger.FatalIf(lockRPCRLock.Register(globalGrid.Load(), lockServer.RLockHandler), "unable to register handler") + logger.FatalIf(lockRPCRUnlock.Register(globalGrid.Load(), lockServer.RUnlockHandler), "unable to register handler") + + globalLockServer = lockServer.ll + + go lockMaintenance(GlobalContext) +} + +func (l *lockRESTServer) makeResp(dst *dsync.LockResp, err error) (*dsync.LockResp, *grid.RemoteErr) { + *dst = dsync.LockResp{Code: dsync.RespOK} + switch err { + case nil: + case errLockNotInitialized: + dst.Code = dsync.RespLockNotInitialized + case errLockConflict: + dst.Code = dsync.RespLockConflict + case errLockNotFound: + dst.Code = dsync.RespLockNotFound + default: + dst.Code = dsync.RespErr + dst.Err = err.Error() + } + return dst, nil +} + const ( // Lock maintenance interval. lockMaintenanceInterval = 1 * time.Minute @@ -37,185 +153,6 @@ const ( lockValidityDuration = 1 * time.Minute ) -// To abstract a node over network. -type lockRESTServer struct { - ll *localLocker -} - -func (l *lockRESTServer) writeErrorResponse(w http.ResponseWriter, err error) { - statusCode := http.StatusForbidden - switch err { - case errLockNotInitialized: - // Return 425 instead of 5xx, otherwise this node will be marked offline - statusCode = http.StatusTooEarly - case errLockConflict: - statusCode = http.StatusConflict - case errLockNotFound: - statusCode = http.StatusNotFound - } - w.WriteHeader(statusCode) - w.Write([]byte(err.Error())) -} - -// IsValid - To authenticate and verify the time difference. -func (l *lockRESTServer) IsValid(w http.ResponseWriter, r *http.Request) bool { - if l.ll == nil { - l.writeErrorResponse(w, errLockNotInitialized) - return false - } - - if err := storageServerRequestValidate(r); err != nil { - l.writeErrorResponse(w, err) - return false - } - return true -} - -func getLockArgs(r *http.Request) (args dsync.LockArgs, err error) { - dec := msgpNewReader(io.LimitReader(r.Body, 1000*humanize.KiByte)) - defer readMsgpReaderPoolPut(dec) - err = args.DecodeMsg(dec) - return args, err -} - -// HealthHandler returns success if request is authenticated. -func (l *lockRESTServer) HealthHandler(w http.ResponseWriter, r *http.Request) { - l.IsValid(w, r) -} - -// RefreshHandler - refresh the current lock -func (l *lockRESTServer) RefreshHandler(w http.ResponseWriter, r *http.Request) { - if !l.IsValid(w, r) { - l.writeErrorResponse(w, errors.New("invalid request")) - return - } - - args, err := getLockArgs(r) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - refreshed, err := l.ll.Refresh(r.Context(), args) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - if !refreshed { - l.writeErrorResponse(w, errLockNotFound) - return - } -} - -// LockHandler - Acquires a lock. -func (l *lockRESTServer) LockHandler(w http.ResponseWriter, r *http.Request) { - if !l.IsValid(w, r) { - l.writeErrorResponse(w, errors.New("invalid request")) - return - } - - args, err := getLockArgs(r) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - success, err := l.ll.Lock(r.Context(), args) - if err == nil && !success { - err = errLockConflict - } - if err != nil { - l.writeErrorResponse(w, err) - return - } -} - -// UnlockHandler - releases the acquired lock. -func (l *lockRESTServer) UnlockHandler(w http.ResponseWriter, r *http.Request) { - if !l.IsValid(w, r) { - l.writeErrorResponse(w, errors.New("invalid request")) - return - } - - args, err := getLockArgs(r) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - _, err = l.ll.Unlock(context.Background(), args) - // Ignore the Unlock() "reply" return value because if err == nil, "reply" is always true - // Consequently, if err != nil, reply is always false - if err != nil { - l.writeErrorResponse(w, err) - return - } -} - -// LockHandler - Acquires an RLock. -func (l *lockRESTServer) RLockHandler(w http.ResponseWriter, r *http.Request) { - if !l.IsValid(w, r) { - l.writeErrorResponse(w, errors.New("invalid request")) - return - } - - args, err := getLockArgs(r) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - success, err := l.ll.RLock(r.Context(), args) - if err == nil && !success { - err = errLockConflict - } - if err != nil { - l.writeErrorResponse(w, err) - return - } -} - -// RUnlockHandler - releases the acquired read lock. -func (l *lockRESTServer) RUnlockHandler(w http.ResponseWriter, r *http.Request) { - if !l.IsValid(w, r) { - l.writeErrorResponse(w, errors.New("invalid request")) - return - } - - args, err := getLockArgs(r) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - // Ignore the RUnlock() "reply" return value because if err == nil, "reply" is always true. - // Consequently, if err != nil, reply is always false - if _, err = l.ll.RUnlock(context.Background(), args); err != nil { - l.writeErrorResponse(w, err) - return - } -} - -// ForceUnlockHandler - query expired lock status. -func (l *lockRESTServer) ForceUnlockHandler(w http.ResponseWriter, r *http.Request) { - if !l.IsValid(w, r) { - l.writeErrorResponse(w, errors.New("invalid request")) - return - } - - args, err := getLockArgs(r) - if err != nil { - l.writeErrorResponse(w, err) - return - } - - if _, err = l.ll.ForceUnlock(r.Context(), args); err != nil { - l.writeErrorResponse(w, err) - return - } -} - // lockMaintenance loops over all locks and discards locks // that have not been refreshed for some time. func lockMaintenance(ctx context.Context) { @@ -241,27 +178,3 @@ func lockMaintenance(ctx context.Context) { } } } - -// registerLockRESTHandlers - register lock rest router. -func registerLockRESTHandlers(router *mux.Router) { - h := func(f http.HandlerFunc) http.HandlerFunc { - return collectInternodeStats(httpTraceHdrs(f)) - } - - lockServer := &lockRESTServer{ - ll: newLocker(), - } - - subrouter := router.PathPrefix(lockRESTPrefix).Subrouter() - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodHealth).HandlerFunc(h(lockServer.HealthHandler)) - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodRefresh).HandlerFunc(h(lockServer.RefreshHandler)) - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodLock).HandlerFunc(h(lockServer.LockHandler)) - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodRLock).HandlerFunc(h(lockServer.RLockHandler)) - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodUnlock).HandlerFunc(h(lockServer.UnlockHandler)) - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodRUnlock).HandlerFunc(h(lockServer.RUnlockHandler)) - subrouter.Methods(http.MethodPost).Path(lockRESTVersionPrefix + lockRESTMethodForceUnlock).HandlerFunc(h(lockServer.ForceUnlockHandler)) - - globalLockServer = lockServer.ll - - go lockMaintenance(GlobalContext) -} diff --git a/cmd/lock-rest-server_test.go b/cmd/lock-rest-server_test.go deleted file mode 100644 index ac3753483..000000000 --- a/cmd/lock-rest-server_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2015-2021 MinIO, Inc. -// -// This file is part of MinIO Object Storage stack -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package cmd - -import ( - "bufio" - "bytes" - "io" - "net/http" - "net/url" - "sort" - "strconv" - "testing" - - "github.com/minio/minio/internal/dsync" -) - -func BenchmarkLockArgs(b *testing.B) { - args := dsync.LockArgs{ - Owner: "minio", - UID: "uid", - Source: "lockArgs.go", - Quorum: 3, - Resources: []string{"obj.txt"}, - } - - argBytes, err := args.MarshalMsg(nil) - if err != nil { - b.Fatal(err) - } - - req := &http.Request{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - req.Body = io.NopCloser(bytes.NewReader(argBytes)) - getLockArgs(req) - } -} - -func BenchmarkLockArgsOld(b *testing.B) { - values := url.Values{} - values.Set("owner", "minio") - values.Set("uid", "uid") - values.Set("source", "lockArgs.go") - values.Set("quorum", "3") - - req := &http.Request{ - Form: values, - } - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - req.Body = io.NopCloser(bytes.NewReader([]byte(`obj.txt`))) - getLockArgsOld(req) - } -} - -func getLockArgsOld(r *http.Request) (args dsync.LockArgs, err error) { - values := r.Form - quorum, err := strconv.Atoi(values.Get("quorum")) - if err != nil { - return args, err - } - - args = dsync.LockArgs{ - Owner: values.Get("onwer"), - UID: values.Get("uid"), - Source: values.Get("source"), - Quorum: quorum, - } - - var resources []string - bio := bufio.NewScanner(r.Body) - for bio.Scan() { - resources = append(resources, bio.Text()) - } - - if err := bio.Err(); err != nil { - return args, err - } - - sort.Strings(resources) - args.Resources = resources - return args, nil -} diff --git a/cmd/metacache-walk.go b/cmd/metacache-walk.go index d8f35bf59..7ae3f7e12 100644 --- a/cmd/metacache-walk.go +++ b/cmd/metacache-walk.go @@ -19,21 +19,18 @@ package cmd import ( "context" - "fmt" "io" - "net/http" - "net/url" - "runtime/debug" "sort" - "strconv" "strings" - xhttp "github.com/minio/minio/internal/http" + "github.com/minio/minio/internal/grid" xioutil "github.com/minio/minio/internal/ioutil" "github.com/minio/minio/internal/logger" "github.com/valyala/bytebufferpool" ) +//go:generate msgp -file $GOFILE + // WalkDirOptions provides options for WalkDir operations. type WalkDirOptions struct { // Bucket to scanner @@ -57,6 +54,10 @@ type WalkDirOptions struct { // Limit the number of returned objects if > 0. Limit int + + // DiskID contains the disk ID of the disk. + // Leave empty to not check disk ID. + DiskID string } // WalkDir will traverse a directory and return all entries found. @@ -387,6 +388,9 @@ func (s *xlStorage) WalkDir(ctx context.Context, opts WalkDirOptions, wr io.Writ } func (p *xlStorageDiskIDCheck) WalkDir(ctx context.Context, opts WalkDirOptions, wr io.Writer) (err error) { + if err := p.checkID(opts.DiskID); err != nil { + return err + } ctx, done, err := p.TrackDiskHealth(ctx, storageMetricWalkDir, opts.Bucket, opts.BaseDir) if err != nil { return err @@ -399,59 +403,32 @@ func (p *xlStorageDiskIDCheck) WalkDir(ctx context.Context, opts WalkDirOptions, // WalkDir will traverse a directory and return all entries found. // On success a meta cache stream will be returned, that should be closed when done. func (client *storageRESTClient) WalkDir(ctx context.Context, opts WalkDirOptions, wr io.Writer) error { - values := make(url.Values) - values.Set(storageRESTVolume, opts.Bucket) - values.Set(storageRESTDirPath, opts.BaseDir) - values.Set(storageRESTRecursive, strconv.FormatBool(opts.Recursive)) - values.Set(storageRESTReportNotFound, strconv.FormatBool(opts.ReportNotFound)) - values.Set(storageRESTPrefixFilter, opts.FilterPrefix) - values.Set(storageRESTForwardFilter, opts.ForwardTo) - respBody, err := client.call(ctx, storageRESTMethodWalkDir, values, nil, -1) + // Ensure remote has the same disk ID. + opts.DiskID = client.diskID + b, err := opts.MarshalMsg(grid.GetByteBuffer()[:0]) if err != nil { - logger.LogIf(ctx, err) return err } - defer xhttp.DrainBody(respBody) - return waitForHTTPStream(respBody, wr) + + st, err := client.gridConn.NewStream(ctx, grid.HandlerWalkDir, b) + if err != nil { + return err + } + return toStorageErr(st.Results(func(in []byte) error { + _, err := wr.Write(in) + return err + })) } // WalkDirHandler - remote caller to list files and folders in a requested directory path. -func (s *storageRESTServer) WalkDirHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return - } - volume := r.Form.Get(storageRESTVolume) - dirPath := r.Form.Get(storageRESTDirPath) - recursive, err := strconv.ParseBool(r.Form.Get(storageRESTRecursive)) +func (s *storageRESTServer) WalkDirHandler(ctx context.Context, payload []byte, _ <-chan []byte, out chan<- []byte) (gerr *grid.RemoteErr) { + var opts WalkDirOptions + _, err := opts.UnmarshalMsg(payload) if err != nil { - s.writeErrorResponse(w, err) - return + return grid.NewRemoteErr(err) } - var reportNotFound bool - if v := r.Form.Get(storageRESTReportNotFound); v != "" { - reportNotFound, err = strconv.ParseBool(v) - if err != nil { - s.writeErrorResponse(w, err) - return - } - } - - prefix := r.Form.Get(storageRESTPrefixFilter) - forward := r.Form.Get(storageRESTForwardFilter) - writer := streamHTTPResponse(w) - defer func() { - if r := recover(); r != nil { - debug.PrintStack() - writer.CloseWithError(fmt.Errorf("panic: %v", r)) - } - }() - writer.CloseWithError(s.storage.WalkDir(r.Context(), WalkDirOptions{ - Bucket: volume, - BaseDir: dirPath, - Recursive: recursive, - ReportNotFound: reportNotFound, - FilterPrefix: prefix, - ForwardTo: forward, - }, writer)) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + return grid.NewRemoteErr(s.storage.WalkDir(ctx, opts, grid.WriterToChannel(ctx, out))) } diff --git a/cmd/metacache-walk_gen.go b/cmd/metacache-walk_gen.go new file mode 100644 index 000000000..e59cf64eb --- /dev/null +++ b/cmd/metacache-walk_gen.go @@ -0,0 +1,285 @@ +package cmd + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *WalkDirOptions) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Bucket": + z.Bucket, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Bucket") + return + } + case "BaseDir": + z.BaseDir, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "BaseDir") + return + } + case "Recursive": + z.Recursive, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "Recursive") + return + } + case "ReportNotFound": + z.ReportNotFound, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "ReportNotFound") + return + } + case "FilterPrefix": + z.FilterPrefix, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "FilterPrefix") + return + } + case "ForwardTo": + z.ForwardTo, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "ForwardTo") + return + } + case "Limit": + z.Limit, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "Limit") + return + } + case "DiskID": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *WalkDirOptions) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 8 + // write "Bucket" + err = en.Append(0x88, 0xa6, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74) + if err != nil { + return + } + err = en.WriteString(z.Bucket) + if err != nil { + err = msgp.WrapError(err, "Bucket") + return + } + // write "BaseDir" + err = en.Append(0xa7, 0x42, 0x61, 0x73, 0x65, 0x44, 0x69, 0x72) + if err != nil { + return + } + err = en.WriteString(z.BaseDir) + if err != nil { + err = msgp.WrapError(err, "BaseDir") + return + } + // write "Recursive" + err = en.Append(0xa9, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65) + if err != nil { + return + } + err = en.WriteBool(z.Recursive) + if err != nil { + err = msgp.WrapError(err, "Recursive") + return + } + // write "ReportNotFound" + err = en.Append(0xae, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, 0x64) + if err != nil { + return + } + err = en.WriteBool(z.ReportNotFound) + if err != nil { + err = msgp.WrapError(err, "ReportNotFound") + return + } + // write "FilterPrefix" + err = en.Append(0xac, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78) + if err != nil { + return + } + err = en.WriteString(z.FilterPrefix) + if err != nil { + err = msgp.WrapError(err, "FilterPrefix") + return + } + // write "ForwardTo" + err = en.Append(0xa9, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x54, 0x6f) + if err != nil { + return + } + err = en.WriteString(z.ForwardTo) + if err != nil { + err = msgp.WrapError(err, "ForwardTo") + return + } + // write "Limit" + err = en.Append(0xa5, 0x4c, 0x69, 0x6d, 0x69, 0x74) + if err != nil { + return + } + err = en.WriteInt(z.Limit) + if err != nil { + err = msgp.WrapError(err, "Limit") + return + } + // write "DiskID" + err = en.Append(0xa6, 0x44, 0x69, 0x73, 0x6b, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *WalkDirOptions) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 8 + // string "Bucket" + o = append(o, 0x88, 0xa6, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74) + o = msgp.AppendString(o, z.Bucket) + // string "BaseDir" + o = append(o, 0xa7, 0x42, 0x61, 0x73, 0x65, 0x44, 0x69, 0x72) + o = msgp.AppendString(o, z.BaseDir) + // string "Recursive" + o = append(o, 0xa9, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65) + o = msgp.AppendBool(o, z.Recursive) + // string "ReportNotFound" + o = append(o, 0xae, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, 0x64) + o = msgp.AppendBool(o, z.ReportNotFound) + // string "FilterPrefix" + o = append(o, 0xac, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78) + o = msgp.AppendString(o, z.FilterPrefix) + // string "ForwardTo" + o = append(o, 0xa9, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x54, 0x6f) + o = msgp.AppendString(o, z.ForwardTo) + // string "Limit" + o = append(o, 0xa5, 0x4c, 0x69, 0x6d, 0x69, 0x74) + o = msgp.AppendInt(o, z.Limit) + // string "DiskID" + o = append(o, 0xa6, 0x44, 0x69, 0x73, 0x6b, 0x49, 0x44) + o = msgp.AppendString(o, z.DiskID) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *WalkDirOptions) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Bucket": + z.Bucket, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Bucket") + return + } + case "BaseDir": + z.BaseDir, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "BaseDir") + return + } + case "Recursive": + z.Recursive, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Recursive") + return + } + case "ReportNotFound": + z.ReportNotFound, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ReportNotFound") + return + } + case "FilterPrefix": + z.FilterPrefix, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "FilterPrefix") + return + } + case "ForwardTo": + z.ForwardTo, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ForwardTo") + return + } + case "Limit": + z.Limit, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Limit") + return + } + case "DiskID": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *WalkDirOptions) Msgsize() (s int) { + s = 1 + 7 + msgp.StringPrefixSize + len(z.Bucket) + 8 + msgp.StringPrefixSize + len(z.BaseDir) + 10 + msgp.BoolSize + 15 + msgp.BoolSize + 13 + msgp.StringPrefixSize + len(z.FilterPrefix) + 10 + msgp.StringPrefixSize + len(z.ForwardTo) + 6 + msgp.IntSize + 7 + msgp.StringPrefixSize + len(z.DiskID) + return +} diff --git a/cmd/metacache-walk_gen_test.go b/cmd/metacache-walk_gen_test.go new file mode 100644 index 000000000..02c4a1ecc --- /dev/null +++ b/cmd/metacache-walk_gen_test.go @@ -0,0 +1,123 @@ +package cmd + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalWalkDirOptions(t *testing.T) { + v := WalkDirOptions{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgWalkDirOptions(b *testing.B) { + v := WalkDirOptions{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgWalkDirOptions(b *testing.B) { + v := WalkDirOptions{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalWalkDirOptions(b *testing.B) { + v := WalkDirOptions{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeWalkDirOptions(t *testing.T) { + v := WalkDirOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeWalkDirOptions Msgsize() is inaccurate") + } + + vn := WalkDirOptions{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeWalkDirOptions(b *testing.B) { + v := WalkDirOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeWalkDirOptions(b *testing.B) { + v := WalkDirOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/cmd/object-api-common.go b/cmd/object-api-common.go index 867f618bd..a4fe6a974 100644 --- a/cmd/object-api-common.go +++ b/cmd/object-api-common.go @@ -67,5 +67,5 @@ func newStorageAPI(endpoint Endpoint, opts storageOpts) (storage StorageAPI, err return newXLStorageDiskIDCheck(storage, opts.healthCheck), nil } - return newStorageRESTClient(endpoint, opts.healthCheck), nil + return newStorageRESTClient(endpoint, opts.healthCheck, globalGrid.Load()) } diff --git a/cmd/routers.go b/cmd/routers.go index 47d52df9c..91f47c739 100644 --- a/cmd/routers.go +++ b/cmd/routers.go @@ -20,13 +20,14 @@ package cmd import ( "net/http" + "github.com/minio/minio/internal/grid" "github.com/minio/mux" ) // Composed function registering routers for only distributed Erasure setup. func registerDistErasureRouters(router *mux.Router, endpointServerPools EndpointServerPools) { // Register storage REST router only if its a distributed setup. - registerStorageRESTHandlers(router, endpointServerPools) + registerStorageRESTHandlers(router, endpointServerPools, globalGrid.Load()) // Register peer REST router only if its a distributed setup. registerPeerRESTHandlers(router) @@ -38,7 +39,10 @@ func registerDistErasureRouters(router *mux.Router, endpointServerPools Endpoint registerBootstrapRESTHandlers(router) // Register distributed namespace lock routers. - registerLockRESTHandlers(router) + registerLockRESTHandlers() + + // Add grid to router + router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(), noGZFlag, noObjLayerFlag)) } // List of some generic middlewares which are applied for all incoming requests. diff --git a/cmd/server-main.go b/cmd/server-main.go index 99bf80ce7..f9735e61c 100644 --- a/cmd/server-main.go +++ b/cmd/server-main.go @@ -666,12 +666,19 @@ func serverMain(ctx *cli.Context) { getCert = globalTLSCerts.GetCertificate } + // Initialize grid + bootstrapTrace("initGrid", func() { + logger.FatalIf(initGlobalGrid(GlobalContext, globalEndpoints), "Unable to configure server grid RPC services") + }) + // Configure server. bootstrapTrace("configureServer", func() { handler, err := configureServerHandler(globalEndpoints) if err != nil { logger.Fatal(config.ErrUnexpectedError(err), "Unable to configure one of server's RPC services") } + // Allow grid to start after registering all services. + close(globalGridStart) httpServer := xhttp.NewServer(getServerListenAddrs()). UseHandler(setCriticalErrorHandler(corsHandler(handler))). diff --git a/cmd/storage-datatypes.go b/cmd/storage-datatypes.go index beab78045..d9cf9ff6e 100644 --- a/cmd/storage-datatypes.go +++ b/cmd/storage-datatypes.go @@ -22,11 +22,9 @@ import ( ) // DeleteOptions represents the disk level delete options available for the APIs -// -//msgp:ignore DeleteOptions type DeleteOptions struct { - Recursive bool - Force bool + Recursive bool `msg:"r"` + Force bool `msg:"f"` } //go:generate msgp -file=$GOFILE @@ -143,7 +141,7 @@ func (f *FileInfoVersions) findVersionIndex(v string) int { // Make sure to bump the internode version at storage-rest-common.go type RawFileInfo struct { // Content of entire xl.meta (may contain data depending on what was requested by the caller. - Buf []byte `msg:"b"` + Buf []byte `msg:"b,allownil"` // DiskMTime indicates the mtime of the xl.meta on disk // This is mainly used for detecting a particular issue @@ -349,3 +347,57 @@ type ReadMultipleResp struct { Data []byte // Contains all data of file. Modtime time.Time // Modtime of file on disk. } + +// DeleteVersionHandlerParams are parameters for DeleteVersionHandler +type DeleteVersionHandlerParams struct { + DiskID string `msg:"id"` + Volume string `msg:"v"` + FilePath string `msg:"fp"` + ForceDelMarker bool `msg:"fdm"` + FI FileInfo `msg:"fi"` +} + +// MetadataHandlerParams is request info for UpdateMetadataHandle and WriteMetadataHandler. +type MetadataHandlerParams struct { + DiskID string `msg:"id"` + Volume string `msg:"v"` + FilePath string `msg:"fp"` + UpdateOpts UpdateMetadataOpts `msg:"uo"` + FI FileInfo `msg:"fi"` +} + +// UpdateMetadataOpts provides an optional input to indicate if xl.meta updates need to be fully synced to disk. +type UpdateMetadataOpts struct { + NoPersistence bool `msg:"np"` +} + +// CheckPartsHandlerParams are parameters for CheckPartsHandler +type CheckPartsHandlerParams struct { + DiskID string `msg:"id"` + Volume string `msg:"v"` + FilePath string `msg:"fp"` + FI FileInfo `msg:"fi"` +} + +// DeleteFileHandlerParams are parameters for DeleteFileHandler +type DeleteFileHandlerParams struct { + DiskID string `msg:"id"` + Volume string `msg:"v"` + FilePath string `msg:"fp"` + Opts DeleteOptions `msg:"do"` +} + +// RenameDataHandlerParams are parameters for RenameDataHandler. +type RenameDataHandlerParams struct { + DiskID string `msg:"id"` + SrcVolume string `msg:"sv"` + SrcPath string `msg:"sp"` + DstVolume string `msg:"dv"` + DstPath string `msg:"dp"` + FI FileInfo `msg:"fi"` +} + +// RenameDataResp - RenameData()'s response. +type RenameDataResp struct { + Signature uint64 `msg:"sig"` +} diff --git a/cmd/storage-datatypes_gen.go b/cmd/storage-datatypes_gen.go index 508fbd3b1..720497d6e 100644 --- a/cmd/storage-datatypes_gen.go +++ b/cmd/storage-datatypes_gen.go @@ -6,6 +6,781 @@ import ( "github.com/tinylib/msgp/msgp" ) +// DecodeMsg implements msgp.Decodable +func (z *CheckPartsHandlerParams) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "fi": + err = z.FI.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *CheckPartsHandlerParams) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 4 + // write "id" + err = en.Append(0x84, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + // write "v" + err = en.Append(0xa1, 0x76) + if err != nil { + return + } + err = en.WriteString(z.Volume) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + // write "fp" + err = en.Append(0xa2, 0x66, 0x70) + if err != nil { + return + } + err = en.WriteString(z.FilePath) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + // write "fi" + err = en.Append(0xa2, 0x66, 0x69) + if err != nil { + return + } + err = z.FI.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *CheckPartsHandlerParams) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 4 + // string "id" + o = append(o, 0x84, 0xa2, 0x69, 0x64) + o = msgp.AppendString(o, z.DiskID) + // string "v" + o = append(o, 0xa1, 0x76) + o = msgp.AppendString(o, z.Volume) + // string "fp" + o = append(o, 0xa2, 0x66, 0x70) + o = msgp.AppendString(o, z.FilePath) + // string "fi" + o = append(o, 0xa2, 0x66, 0x69) + o, err = z.FI.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *CheckPartsHandlerParams) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "fi": + bts, err = z.FI.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *CheckPartsHandlerParams) Msgsize() (s int) { + s = 1 + 3 + msgp.StringPrefixSize + len(z.DiskID) + 2 + msgp.StringPrefixSize + len(z.Volume) + 3 + msgp.StringPrefixSize + len(z.FilePath) + 3 + z.FI.Msgsize() + return +} + +// DecodeMsg implements msgp.Decodable +func (z *DeleteFileHandlerParams) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "do": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "Opts") + return + } + for zb0002 > 0 { + zb0002-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "Opts") + return + } + switch msgp.UnsafeString(field) { + case "r": + z.Opts.Recursive, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "Opts", "Recursive") + return + } + case "f": + z.Opts.Force, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "Opts", "Force") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "Opts") + return + } + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *DeleteFileHandlerParams) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 4 + // write "id" + err = en.Append(0x84, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + // write "v" + err = en.Append(0xa1, 0x76) + if err != nil { + return + } + err = en.WriteString(z.Volume) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + // write "fp" + err = en.Append(0xa2, 0x66, 0x70) + if err != nil { + return + } + err = en.WriteString(z.FilePath) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + // write "do" + err = en.Append(0xa2, 0x64, 0x6f) + if err != nil { + return + } + // map header, size 2 + // write "r" + err = en.Append(0x82, 0xa1, 0x72) + if err != nil { + return + } + err = en.WriteBool(z.Opts.Recursive) + if err != nil { + err = msgp.WrapError(err, "Opts", "Recursive") + return + } + // write "f" + err = en.Append(0xa1, 0x66) + if err != nil { + return + } + err = en.WriteBool(z.Opts.Force) + if err != nil { + err = msgp.WrapError(err, "Opts", "Force") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *DeleteFileHandlerParams) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 4 + // string "id" + o = append(o, 0x84, 0xa2, 0x69, 0x64) + o = msgp.AppendString(o, z.DiskID) + // string "v" + o = append(o, 0xa1, 0x76) + o = msgp.AppendString(o, z.Volume) + // string "fp" + o = append(o, 0xa2, 0x66, 0x70) + o = msgp.AppendString(o, z.FilePath) + // string "do" + o = append(o, 0xa2, 0x64, 0x6f) + // map header, size 2 + // string "r" + o = append(o, 0x82, 0xa1, 0x72) + o = msgp.AppendBool(o, z.Opts.Recursive) + // string "f" + o = append(o, 0xa1, 0x66) + o = msgp.AppendBool(o, z.Opts.Force) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *DeleteFileHandlerParams) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "do": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Opts") + return + } + for zb0002 > 0 { + zb0002-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "Opts") + return + } + switch msgp.UnsafeString(field) { + case "r": + z.Opts.Recursive, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Opts", "Recursive") + return + } + case "f": + z.Opts.Force, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Opts", "Force") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "Opts") + return + } + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *DeleteFileHandlerParams) Msgsize() (s int) { + s = 1 + 3 + msgp.StringPrefixSize + len(z.DiskID) + 2 + msgp.StringPrefixSize + len(z.Volume) + 3 + msgp.StringPrefixSize + len(z.FilePath) + 3 + 1 + 2 + msgp.BoolSize + 2 + msgp.BoolSize + return +} + +// DecodeMsg implements msgp.Decodable +func (z *DeleteOptions) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "r": + z.Recursive, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "Recursive") + return + } + case "f": + z.Force, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "Force") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z DeleteOptions) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "r" + err = en.Append(0x82, 0xa1, 0x72) + if err != nil { + return + } + err = en.WriteBool(z.Recursive) + if err != nil { + err = msgp.WrapError(err, "Recursive") + return + } + // write "f" + err = en.Append(0xa1, 0x66) + if err != nil { + return + } + err = en.WriteBool(z.Force) + if err != nil { + err = msgp.WrapError(err, "Force") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z DeleteOptions) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "r" + o = append(o, 0x82, 0xa1, 0x72) + o = msgp.AppendBool(o, z.Recursive) + // string "f" + o = append(o, 0xa1, 0x66) + o = msgp.AppendBool(o, z.Force) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *DeleteOptions) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "r": + z.Recursive, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Recursive") + return + } + case "f": + z.Force, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Force") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z DeleteOptions) Msgsize() (s int) { + s = 1 + 2 + msgp.BoolSize + 2 + msgp.BoolSize + return +} + +// DecodeMsg implements msgp.Decodable +func (z *DeleteVersionHandlerParams) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "fdm": + z.ForceDelMarker, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "ForceDelMarker") + return + } + case "fi": + err = z.FI.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *DeleteVersionHandlerParams) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 5 + // write "id" + err = en.Append(0x85, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + // write "v" + err = en.Append(0xa1, 0x76) + if err != nil { + return + } + err = en.WriteString(z.Volume) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + // write "fp" + err = en.Append(0xa2, 0x66, 0x70) + if err != nil { + return + } + err = en.WriteString(z.FilePath) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + // write "fdm" + err = en.Append(0xa3, 0x66, 0x64, 0x6d) + if err != nil { + return + } + err = en.WriteBool(z.ForceDelMarker) + if err != nil { + err = msgp.WrapError(err, "ForceDelMarker") + return + } + // write "fi" + err = en.Append(0xa2, 0x66, 0x69) + if err != nil { + return + } + err = z.FI.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *DeleteVersionHandlerParams) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 5 + // string "id" + o = append(o, 0x85, 0xa2, 0x69, 0x64) + o = msgp.AppendString(o, z.DiskID) + // string "v" + o = append(o, 0xa1, 0x76) + o = msgp.AppendString(o, z.Volume) + // string "fp" + o = append(o, 0xa2, 0x66, 0x70) + o = msgp.AppendString(o, z.FilePath) + // string "fdm" + o = append(o, 0xa3, 0x66, 0x64, 0x6d) + o = msgp.AppendBool(o, z.ForceDelMarker) + // string "fi" + o = append(o, 0xa2, 0x66, 0x69) + o, err = z.FI.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *DeleteVersionHandlerParams) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "fdm": + z.ForceDelMarker, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ForceDelMarker") + return + } + case "fi": + bts, err = z.FI.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *DeleteVersionHandlerParams) Msgsize() (s int) { + s = 1 + 3 + msgp.StringPrefixSize + len(z.DiskID) + 2 + msgp.StringPrefixSize + len(z.Volume) + 3 + msgp.StringPrefixSize + len(z.FilePath) + 4 + msgp.BoolSize + 3 + z.FI.Msgsize() + return +} + // DecodeMsg implements msgp.Decodable func (z *DiskInfo) DecodeMsg(dc *msgp.Reader) (err error) { var zb0001 uint32 @@ -831,10 +1606,19 @@ func (z *FileInfo) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "ReplicationState") return } - z.Data, err = dc.ReadBytes(z.Data) - if err != nil { - err = msgp.WrapError(err, "Data") - return + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err) + return + } + z.Data = nil + } else { + z.Data, err = dc.ReadBytes(z.Data) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } } z.NumVersions, err = dc.ReadInt() if err != nil { @@ -861,10 +1645,19 @@ func (z *FileInfo) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "DiskMTime") return } - z.Checksum, err = dc.ReadBytes(z.Checksum) - if err != nil { - err = msgp.WrapError(err, "Checksum") - return + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err) + return + } + z.Checksum = nil + } else { + z.Checksum, err = dc.ReadBytes(z.Checksum) + if err != nil { + err = msgp.WrapError(err, "Checksum") + return + } } z.Versioned, err = dc.ReadBool() if err != nil { @@ -1005,10 +1798,17 @@ func (z *FileInfo) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "ReplicationState") return } - err = en.WriteBytes(z.Data) - if err != nil { - err = msgp.WrapError(err, "Data") - return + if z.Data == nil { // allownil: if nil + err = en.WriteNil() + if err != nil { + return + } + } else { + err = en.WriteBytes(z.Data) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } } err = en.WriteInt(z.NumVersions) if err != nil { @@ -1035,10 +1835,17 @@ func (z *FileInfo) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "DiskMTime") return } - err = en.WriteBytes(z.Checksum) - if err != nil { - err = msgp.WrapError(err, "Checksum") - return + if z.Checksum == nil { // allownil: if nil + err = en.WriteNil() + if err != nil { + return + } + } else { + err = en.WriteBytes(z.Checksum) + if err != nil { + err = msgp.WrapError(err, "Checksum") + return + } } err = en.WriteBool(z.Versioned) if err != nil { @@ -1093,13 +1900,21 @@ func (z *FileInfo) MarshalMsg(b []byte) (o []byte, err error) { err = msgp.WrapError(err, "ReplicationState") return } - o = msgp.AppendBytes(o, z.Data) + if z.Data == nil { // allownil: if nil + o = msgp.AppendNil(o) + } else { + o = msgp.AppendBytes(o, z.Data) + } o = msgp.AppendInt(o, z.NumVersions) o = msgp.AppendTime(o, z.SuccessorModTime) o = msgp.AppendBool(o, z.Fresh) o = msgp.AppendInt(o, z.Idx) o = msgp.AppendTime(o, z.DiskMTime) - o = msgp.AppendBytes(o, z.Checksum) + if z.Checksum == nil { // allownil: if nil + o = msgp.AppendNil(o) + } else { + o = msgp.AppendBytes(o, z.Checksum) + } o = msgp.AppendBool(o, z.Versioned) return } @@ -1258,10 +2073,15 @@ func (z *FileInfo) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "ReplicationState") return } - z.Data, bts, err = msgp.ReadBytesBytes(bts, z.Data) - if err != nil { - err = msgp.WrapError(err, "Data") - return + if msgp.IsNil(bts) { + bts = bts[1:] + z.Data = nil + } else { + z.Data, bts, err = msgp.ReadBytesBytes(bts, z.Data) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } } z.NumVersions, bts, err = msgp.ReadIntBytes(bts) if err != nil { @@ -1288,10 +2108,15 @@ func (z *FileInfo) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "DiskMTime") return } - z.Checksum, bts, err = msgp.ReadBytesBytes(bts, z.Checksum) - if err != nil { - err = msgp.WrapError(err, "Checksum") - return + if msgp.IsNil(bts) { + bts = bts[1:] + z.Checksum = nil + } else { + z.Checksum, bts, err = msgp.ReadBytesBytes(bts, z.Checksum) + if err != nil { + err = msgp.WrapError(err, "Checksum") + return + } } z.Versioned, bts, err = msgp.ReadBoolBytes(bts) if err != nil { @@ -1713,6 +2538,268 @@ func (z *FilesInfo) Msgsize() (s int) { return } +// DecodeMsg implements msgp.Decodable +func (z *MetadataHandlerParams) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "uo": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "UpdateOpts") + return + } + for zb0002 > 0 { + zb0002-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "UpdateOpts") + return + } + switch msgp.UnsafeString(field) { + case "np": + z.UpdateOpts.NoPersistence, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "UpdateOpts", "NoPersistence") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "UpdateOpts") + return + } + } + } + case "fi": + err = z.FI.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *MetadataHandlerParams) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 5 + // write "id" + err = en.Append(0x85, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + // write "v" + err = en.Append(0xa1, 0x76) + if err != nil { + return + } + err = en.WriteString(z.Volume) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + // write "fp" + err = en.Append(0xa2, 0x66, 0x70) + if err != nil { + return + } + err = en.WriteString(z.FilePath) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + // write "uo" + err = en.Append(0xa2, 0x75, 0x6f) + if err != nil { + return + } + // map header, size 1 + // write "np" + err = en.Append(0x81, 0xa2, 0x6e, 0x70) + if err != nil { + return + } + err = en.WriteBool(z.UpdateOpts.NoPersistence) + if err != nil { + err = msgp.WrapError(err, "UpdateOpts", "NoPersistence") + return + } + // write "fi" + err = en.Append(0xa2, 0x66, 0x69) + if err != nil { + return + } + err = z.FI.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *MetadataHandlerParams) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 5 + // string "id" + o = append(o, 0x85, 0xa2, 0x69, 0x64) + o = msgp.AppendString(o, z.DiskID) + // string "v" + o = append(o, 0xa1, 0x76) + o = msgp.AppendString(o, z.Volume) + // string "fp" + o = append(o, 0xa2, 0x66, 0x70) + o = msgp.AppendString(o, z.FilePath) + // string "uo" + o = append(o, 0xa2, 0x75, 0x6f) + // map header, size 1 + // string "np" + o = append(o, 0x81, 0xa2, 0x6e, 0x70) + o = msgp.AppendBool(o, z.UpdateOpts.NoPersistence) + // string "fi" + o = append(o, 0xa2, 0x66, 0x69) + o, err = z.FI.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *MetadataHandlerParams) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "v": + z.Volume, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Volume") + return + } + case "fp": + z.FilePath, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "FilePath") + return + } + case "uo": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "UpdateOpts") + return + } + for zb0002 > 0 { + zb0002-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "UpdateOpts") + return + } + switch msgp.UnsafeString(field) { + case "np": + z.UpdateOpts.NoPersistence, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "UpdateOpts", "NoPersistence") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "UpdateOpts") + return + } + } + } + case "fi": + bts, err = z.FI.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *MetadataHandlerParams) Msgsize() (s int) { + s = 1 + 3 + msgp.StringPrefixSize + len(z.DiskID) + 2 + msgp.StringPrefixSize + len(z.Volume) + 3 + msgp.StringPrefixSize + len(z.FilePath) + 3 + 1 + 3 + msgp.BoolSize + 3 + z.FI.Msgsize() + return +} + // DecodeMsg implements msgp.Decodable func (z *RawFileInfo) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte @@ -1732,10 +2819,19 @@ func (z *RawFileInfo) DecodeMsg(dc *msgp.Reader) (err error) { } switch msgp.UnsafeString(field) { case "b": - z.Buf, err = dc.ReadBytes(z.Buf) - if err != nil { - err = msgp.WrapError(err, "Buf") - return + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err, "Buf") + return + } + z.Buf = nil + } else { + z.Buf, err = dc.ReadBytes(z.Buf) + if err != nil { + err = msgp.WrapError(err, "Buf") + return + } } case "dmt": z.DiskMTime, err = dc.ReadTime() @@ -1762,10 +2858,17 @@ func (z *RawFileInfo) EncodeMsg(en *msgp.Writer) (err error) { if err != nil { return } - err = en.WriteBytes(z.Buf) - if err != nil { - err = msgp.WrapError(err, "Buf") - return + if z.Buf == nil { // allownil: if nil + err = en.WriteNil() + if err != nil { + return + } + } else { + err = en.WriteBytes(z.Buf) + if err != nil { + err = msgp.WrapError(err, "Buf") + return + } } // write "dmt" err = en.Append(0xa3, 0x64, 0x6d, 0x74) @@ -1786,7 +2889,11 @@ func (z *RawFileInfo) MarshalMsg(b []byte) (o []byte, err error) { // map header, size 2 // string "b" o = append(o, 0x82, 0xa1, 0x62) - o = msgp.AppendBytes(o, z.Buf) + if z.Buf == nil { // allownil: if nil + o = msgp.AppendNil(o) + } else { + o = msgp.AppendBytes(o, z.Buf) + } // string "dmt" o = append(o, 0xa3, 0x64, 0x6d, 0x74) o = msgp.AppendTime(o, z.DiskMTime) @@ -1812,10 +2919,15 @@ func (z *RawFileInfo) UnmarshalMsg(bts []byte) (o []byte, err error) { } switch msgp.UnsafeString(field) { case "b": - z.Buf, bts, err = msgp.ReadBytesBytes(bts, z.Buf) - if err != nil { - err = msgp.WrapError(err, "Buf") - return + if msgp.IsNil(bts) { + bts = bts[1:] + z.Buf = nil + } else { + z.Buf, bts, err = msgp.ReadBytesBytes(bts, z.Buf) + if err != nil { + err = msgp.WrapError(err, "Buf") + return + } } case "dmt": z.DiskMTime, bts, err = msgp.ReadTimeBytes(bts) @@ -2387,6 +3499,444 @@ func (z *ReadMultipleResp) Msgsize() (s int) { return } +// DecodeMsg implements msgp.Decodable +func (z *RenameDataHandlerParams) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "sv": + z.SrcVolume, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "SrcVolume") + return + } + case "sp": + z.SrcPath, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "SrcPath") + return + } + case "dv": + z.DstVolume, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DstVolume") + return + } + case "dp": + z.DstPath, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DstPath") + return + } + case "fi": + err = z.FI.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *RenameDataHandlerParams) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 6 + // write "id" + err = en.Append(0x86, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + // write "sv" + err = en.Append(0xa2, 0x73, 0x76) + if err != nil { + return + } + err = en.WriteString(z.SrcVolume) + if err != nil { + err = msgp.WrapError(err, "SrcVolume") + return + } + // write "sp" + err = en.Append(0xa2, 0x73, 0x70) + if err != nil { + return + } + err = en.WriteString(z.SrcPath) + if err != nil { + err = msgp.WrapError(err, "SrcPath") + return + } + // write "dv" + err = en.Append(0xa2, 0x64, 0x76) + if err != nil { + return + } + err = en.WriteString(z.DstVolume) + if err != nil { + err = msgp.WrapError(err, "DstVolume") + return + } + // write "dp" + err = en.Append(0xa2, 0x64, 0x70) + if err != nil { + return + } + err = en.WriteString(z.DstPath) + if err != nil { + err = msgp.WrapError(err, "DstPath") + return + } + // write "fi" + err = en.Append(0xa2, 0x66, 0x69) + if err != nil { + return + } + err = z.FI.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *RenameDataHandlerParams) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 6 + // string "id" + o = append(o, 0x86, 0xa2, 0x69, 0x64) + o = msgp.AppendString(o, z.DiskID) + // string "sv" + o = append(o, 0xa2, 0x73, 0x76) + o = msgp.AppendString(o, z.SrcVolume) + // string "sp" + o = append(o, 0xa2, 0x73, 0x70) + o = msgp.AppendString(o, z.SrcPath) + // string "dv" + o = append(o, 0xa2, 0x64, 0x76) + o = msgp.AppendString(o, z.DstVolume) + // string "dp" + o = append(o, 0xa2, 0x64, 0x70) + o = msgp.AppendString(o, z.DstPath) + // string "fi" + o = append(o, 0xa2, 0x66, 0x69) + o, err = z.FI.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *RenameDataHandlerParams) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "sv": + z.SrcVolume, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "SrcVolume") + return + } + case "sp": + z.SrcPath, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "SrcPath") + return + } + case "dv": + z.DstVolume, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DstVolume") + return + } + case "dp": + z.DstPath, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DstPath") + return + } + case "fi": + bts, err = z.FI.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "FI") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *RenameDataHandlerParams) Msgsize() (s int) { + s = 1 + 3 + msgp.StringPrefixSize + len(z.DiskID) + 3 + msgp.StringPrefixSize + len(z.SrcVolume) + 3 + msgp.StringPrefixSize + len(z.SrcPath) + 3 + msgp.StringPrefixSize + len(z.DstVolume) + 3 + msgp.StringPrefixSize + len(z.DstPath) + 3 + z.FI.Msgsize() + return +} + +// DecodeMsg implements msgp.Decodable +func (z *RenameDataResp) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "sig": + z.Signature, err = dc.ReadUint64() + if err != nil { + err = msgp.WrapError(err, "Signature") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z RenameDataResp) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 1 + // write "sig" + err = en.Append(0x81, 0xa3, 0x73, 0x69, 0x67) + if err != nil { + return + } + err = en.WriteUint64(z.Signature) + if err != nil { + err = msgp.WrapError(err, "Signature") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z RenameDataResp) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 1 + // string "sig" + o = append(o, 0x81, 0xa3, 0x73, 0x69, 0x67) + o = msgp.AppendUint64(o, z.Signature) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *RenameDataResp) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "sig": + z.Signature, bts, err = msgp.ReadUint64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Signature") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z RenameDataResp) Msgsize() (s int) { + s = 1 + 4 + msgp.Uint64Size + return +} + +// DecodeMsg implements msgp.Decodable +func (z *UpdateMetadataOpts) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "np": + z.NoPersistence, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "NoPersistence") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z UpdateMetadataOpts) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 1 + // write "np" + err = en.Append(0x81, 0xa2, 0x6e, 0x70) + if err != nil { + return + } + err = en.WriteBool(z.NoPersistence) + if err != nil { + err = msgp.WrapError(err, "NoPersistence") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z UpdateMetadataOpts) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 1 + // string "np" + o = append(o, 0x81, 0xa2, 0x6e, 0x70) + o = msgp.AppendBool(o, z.NoPersistence) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *UpdateMetadataOpts) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "np": + z.NoPersistence, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "NoPersistence") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z UpdateMetadataOpts) Msgsize() (s int) { + s = 1 + 3 + msgp.BoolSize + return +} + // DecodeMsg implements msgp.Decodable func (z *VolInfo) DecodeMsg(dc *msgp.Reader) (err error) { var zb0001 uint32 diff --git a/cmd/storage-datatypes_gen_test.go b/cmd/storage-datatypes_gen_test.go index ffc08f690..6a4a5c554 100644 --- a/cmd/storage-datatypes_gen_test.go +++ b/cmd/storage-datatypes_gen_test.go @@ -9,6 +9,458 @@ import ( "github.com/tinylib/msgp/msgp" ) +func TestMarshalUnmarshalCheckPartsHandlerParams(t *testing.T) { + v := CheckPartsHandlerParams{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgCheckPartsHandlerParams(b *testing.B) { + v := CheckPartsHandlerParams{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgCheckPartsHandlerParams(b *testing.B) { + v := CheckPartsHandlerParams{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalCheckPartsHandlerParams(b *testing.B) { + v := CheckPartsHandlerParams{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeCheckPartsHandlerParams(t *testing.T) { + v := CheckPartsHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeCheckPartsHandlerParams Msgsize() is inaccurate") + } + + vn := CheckPartsHandlerParams{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeCheckPartsHandlerParams(b *testing.B) { + v := CheckPartsHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeCheckPartsHandlerParams(b *testing.B) { + v := CheckPartsHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalDeleteFileHandlerParams(t *testing.T) { + v := DeleteFileHandlerParams{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgDeleteFileHandlerParams(b *testing.B) { + v := DeleteFileHandlerParams{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgDeleteFileHandlerParams(b *testing.B) { + v := DeleteFileHandlerParams{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalDeleteFileHandlerParams(b *testing.B) { + v := DeleteFileHandlerParams{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeDeleteFileHandlerParams(t *testing.T) { + v := DeleteFileHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeDeleteFileHandlerParams Msgsize() is inaccurate") + } + + vn := DeleteFileHandlerParams{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeDeleteFileHandlerParams(b *testing.B) { + v := DeleteFileHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeDeleteFileHandlerParams(b *testing.B) { + v := DeleteFileHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalDeleteOptions(t *testing.T) { + v := DeleteOptions{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgDeleteOptions(b *testing.B) { + v := DeleteOptions{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgDeleteOptions(b *testing.B) { + v := DeleteOptions{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalDeleteOptions(b *testing.B) { + v := DeleteOptions{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeDeleteOptions(t *testing.T) { + v := DeleteOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeDeleteOptions Msgsize() is inaccurate") + } + + vn := DeleteOptions{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeDeleteOptions(b *testing.B) { + v := DeleteOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeDeleteOptions(b *testing.B) { + v := DeleteOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalDeleteVersionHandlerParams(t *testing.T) { + v := DeleteVersionHandlerParams{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgDeleteVersionHandlerParams(b *testing.B) { + v := DeleteVersionHandlerParams{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgDeleteVersionHandlerParams(b *testing.B) { + v := DeleteVersionHandlerParams{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalDeleteVersionHandlerParams(b *testing.B) { + v := DeleteVersionHandlerParams{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeDeleteVersionHandlerParams(t *testing.T) { + v := DeleteVersionHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeDeleteVersionHandlerParams Msgsize() is inaccurate") + } + + vn := DeleteVersionHandlerParams{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeDeleteVersionHandlerParams(b *testing.B) { + v := DeleteVersionHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeDeleteVersionHandlerParams(b *testing.B) { + v := DeleteVersionHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + func TestMarshalUnmarshalDiskInfo(t *testing.T) { v := DiskInfo{} bts, err := v.MarshalMsg(nil) @@ -574,6 +1026,119 @@ func BenchmarkDecodeFilesInfo(b *testing.B) { } } +func TestMarshalUnmarshalMetadataHandlerParams(t *testing.T) { + v := MetadataHandlerParams{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgMetadataHandlerParams(b *testing.B) { + v := MetadataHandlerParams{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgMetadataHandlerParams(b *testing.B) { + v := MetadataHandlerParams{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalMetadataHandlerParams(b *testing.B) { + v := MetadataHandlerParams{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeMetadataHandlerParams(t *testing.T) { + v := MetadataHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeMetadataHandlerParams Msgsize() is inaccurate") + } + + vn := MetadataHandlerParams{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeMetadataHandlerParams(b *testing.B) { + v := MetadataHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeMetadataHandlerParams(b *testing.B) { + v := MetadataHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + func TestMarshalUnmarshalRawFileInfo(t *testing.T) { v := RawFileInfo{} bts, err := v.MarshalMsg(nil) @@ -913,6 +1478,345 @@ func BenchmarkDecodeReadMultipleResp(b *testing.B) { } } +func TestMarshalUnmarshalRenameDataHandlerParams(t *testing.T) { + v := RenameDataHandlerParams{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgRenameDataHandlerParams(b *testing.B) { + v := RenameDataHandlerParams{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgRenameDataHandlerParams(b *testing.B) { + v := RenameDataHandlerParams{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalRenameDataHandlerParams(b *testing.B) { + v := RenameDataHandlerParams{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeRenameDataHandlerParams(t *testing.T) { + v := RenameDataHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeRenameDataHandlerParams Msgsize() is inaccurate") + } + + vn := RenameDataHandlerParams{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeRenameDataHandlerParams(b *testing.B) { + v := RenameDataHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeRenameDataHandlerParams(b *testing.B) { + v := RenameDataHandlerParams{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalRenameDataResp(t *testing.T) { + v := RenameDataResp{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgRenameDataResp(b *testing.B) { + v := RenameDataResp{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgRenameDataResp(b *testing.B) { + v := RenameDataResp{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalRenameDataResp(b *testing.B) { + v := RenameDataResp{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeRenameDataResp(t *testing.T) { + v := RenameDataResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeRenameDataResp Msgsize() is inaccurate") + } + + vn := RenameDataResp{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeRenameDataResp(b *testing.B) { + v := RenameDataResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeRenameDataResp(b *testing.B) { + v := RenameDataResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalUpdateMetadataOpts(t *testing.T) { + v := UpdateMetadataOpts{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgUpdateMetadataOpts(b *testing.B) { + v := UpdateMetadataOpts{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgUpdateMetadataOpts(b *testing.B) { + v := UpdateMetadataOpts{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalUpdateMetadataOpts(b *testing.B) { + v := UpdateMetadataOpts{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeUpdateMetadataOpts(t *testing.T) { + v := UpdateMetadataOpts{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeUpdateMetadataOpts Msgsize() is inaccurate") + } + + vn := UpdateMetadataOpts{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeUpdateMetadataOpts(b *testing.B) { + v := UpdateMetadataOpts{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeUpdateMetadataOpts(b *testing.B) { + v := UpdateMetadataOpts{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + func TestMarshalUnmarshalVolInfo(t *testing.T) { v := VolInfo{} bts, err := v.MarshalMsg(nil) diff --git a/cmd/storage-rest-client.go b/cmd/storage-rest-client.go index a59e936fe..9c4586d05 100644 --- a/cmd/storage-rest-client.go +++ b/cmd/storage-rest-client.go @@ -23,6 +23,7 @@ import ( "encoding/gob" "encoding/hex" "errors" + "fmt" "io" "net/http" "net/url" @@ -34,6 +35,7 @@ import ( "time" "github.com/minio/madmin-go/v3" + "github.com/minio/minio/internal/grid" xhttp "github.com/minio/minio/internal/http" "github.com/minio/minio/internal/logger" "github.com/minio/minio/internal/rest" @@ -52,7 +54,9 @@ func isNetworkError(err error) bool { return true } } - + if errors.Is(err, grid.ErrDisconnected) { + return true + } // More corner cases suitable for storage REST API switch { // A peer node can be in shut down phase and proactively @@ -139,6 +143,7 @@ type storageRESTClient struct { endpoint Endpoint restClient *rest.Client + gridConn *grid.Subroute diskID string // Indexes, will be -1 until assigned a set. @@ -184,7 +189,7 @@ func (client *storageRESTClient) String() string { // IsOnline - returns whether RPC client failed to connect or not. func (client *storageRESTClient) IsOnline() bool { - return client.restClient.IsOnline() + return client.restClient.IsOnline() && client.gridConn.State() == grid.StateConnected } // LastConn - returns when the disk is seen to be connected the last time @@ -213,57 +218,37 @@ func (client *storageRESTClient) Healing() *healingTracker { func (client *storageRESTClient) NSScanner(ctx context.Context, cache dataUsageCache, updates chan<- dataUsageEntry, scanMode madmin.HealScanMode) (dataUsageCache, error) { atomic.AddInt32(&client.scanning, 1) defer atomic.AddInt32(&client.scanning, -1) - defer close(updates) - pr, pw := io.Pipe() - go func() { - pw.CloseWithError(cache.serializeTo(pw)) - }() - vals := make(url.Values) - vals.Set(storageRESTScanMode, strconv.Itoa(int(scanMode))) - respBody, err := client.call(ctx, storageRESTMethodNSScanner, vals, pr, -1) - defer xhttp.DrainBody(respBody) - pr.CloseWithError(err) + + st, err := storageNSScannerHandler.Call(ctx, client.gridConn, &nsScannerOptions{ + DiskID: client.diskID, + ScanMode: int(scanMode), + Cache: &cache, + }) if err != nil { - return cache, err + return cache, toStorageErr(err) } - - rr, rw := io.Pipe() - go func() { - rw.CloseWithError(waitForHTTPStream(respBody, rw)) - }() - - ms := msgpNewReader(rr) - defer readMsgpReaderPoolPut(ms) - for { - // Read whether it is an update. - upd, err := ms.ReadBool() - if err != nil { - rr.CloseWithError(err) - return cache, err + var final *dataUsageCache + err = st.Results(func(resp *nsScannerResp) error { + if resp.Update != nil { + select { + case <-ctx.Done(): + case updates <- *resp.Update: + } } - if !upd { - // No more updates... New cache follows. - break - } - var update dataUsageEntry - err = update.DecodeMsg(ms) - if err != nil || err == io.EOF { - rr.CloseWithError(err) - return cache, err - } - select { - case <-ctx.Done(): - case updates <- update: + if resp.Final != nil { + final = resp.Final } + // We can't reuse the response since it is sent upstream. + return nil + }) + if err != nil { + return cache, toStorageErr(err) } - var newCache dataUsageCache - err = newCache.DecodeMsg(ms) - rr.CloseWithError(err) - if err == io.EOF { - err = nil + if final == nil { + return cache, errors.New("no final cache") } - return newCache, err + return *final, nil } func (client *storageRESTClient) GetDiskID() (string, error) { @@ -278,77 +263,44 @@ func (client *storageRESTClient) SetDiskID(id string) { client.diskID = id } -// DiskInfo - fetch disk information for a remote disk. -func (client *storageRESTClient) DiskInfo(_ context.Context, metrics bool) (info DiskInfo, err error) { - if !client.IsOnline() { +func (client *storageRESTClient) DiskInfo(ctx context.Context, metrics bool) (info DiskInfo, err error) { + if client.gridConn.State() != grid.StateConnected { // make sure to check if the disk is offline, since the underlying // value is cached we should attempt to invalidate it if such calls // were attempted. This can lead to false success under certain conditions // - this change attempts to avoid stale information if the underlying // transport is already down. - return info, errDiskNotFound + return info, grid.ErrDisconnected } - // Do not cache results from atomic variables - scanning := atomic.LoadInt32(&client.scanning) == 1 - if metrics { - client.diskInfoCacheMetrics.Once.Do(func() { - client.diskInfoCacheMetrics.TTL = time.Second - client.diskInfoCacheMetrics.Update = func() (interface{}, error) { - var info DiskInfo - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - vals := make(url.Values) - vals.Set(storageRESTMetrics, "true") - respBody, err := client.call(ctx, storageRESTMethodDiskInfo, vals, nil, -1) - if err != nil { - return info, err - } - defer xhttp.DrainBody(respBody) - if err = msgp.Decode(respBody, &info); err != nil { - return info, err - } - if info.Error != "" { - return info, toStorageErr(errors.New(info.Error)) - } - return info, nil + fetchDI := func(di *timedValue, metrics bool) { + di.TTL = time.Second + di.Update = func() (interface{}, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + info, err := storageDiskInfoHandler.Call(ctx, client.gridConn, grid.NewMSSWith(map[string]string{ + storageRESTDiskID: client.diskID, + // Always request metrics, since we are caching the result. + storageRESTMetrics: strconv.FormatBool(metrics), + })) + if err != nil { + return info, err } - }) - } else { - client.diskInfoCache.Once.Do(func() { - client.diskInfoCache.TTL = time.Second - client.diskInfoCache.Update = func() (interface{}, error) { - var info DiskInfo - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - vals := make(url.Values) - respBody, err := client.call(ctx, storageRESTMethodDiskInfo, vals, nil, -1) - if err != nil { - return info, err - } - defer xhttp.DrainBody(respBody) - if err = msgp.Decode(respBody, &info); err != nil { - return info, err - } - if info.Error != "" { - return info, toStorageErr(errors.New(info.Error)) - } - return info, nil + if info.Error != "" { + return info, toStorageErr(errors.New(info.Error)) } - }) + return info, nil + } } - - var val interface{} + // Fetch disk info from appropriate cache. + dic := &client.diskInfoCache if metrics { - val, err = client.diskInfoCacheMetrics.Get() - } else { - val, err = client.diskInfoCache.Get() + dic = &client.diskInfoCacheMetrics } - if val != nil { - info = val.(DiskInfo) + dic.Once.Do(func() { fetchDI(dic, metrics) }) + val, err := dic.Get() + if di, ok := val.(*DiskInfo); di != nil && ok { + info = *di } - info.Scanning = scanning return info, err } @@ -384,15 +336,16 @@ func (client *storageRESTClient) ListVols(ctx context.Context) (vols []VolInfo, // StatVol - get volume info over the network. func (client *storageRESTClient) StatVol(ctx context.Context, volume string) (vol VolInfo, err error) { - values := make(url.Values) - values.Set(storageRESTVolume, volume) - respBody, err := client.call(ctx, storageRESTMethodStatVol, values, nil, -1) + v, err := storageStatVolHandler.Call(ctx, client.gridConn, grid.NewMSSWith(map[string]string{ + storageRESTDiskID: client.diskID, + storageRESTVolume: volume, + })) if err != nil { - return + return vol, toStorageErr(err) } - defer xhttp.DrainBody(respBody) - err = msgp.Decode(respBody, &vol) - return vol, err + vol = *v + storageStatVolHandler.PutResponse(v) + return vol, nil } // DeleteVol - Deletes a volume over the network. @@ -433,50 +386,35 @@ func (client *storageRESTClient) CreateFile(ctx context.Context, volume, path st } func (client *storageRESTClient) WriteMetadata(ctx context.Context, volume, path string, fi FileInfo) error { - values := make(url.Values) - values.Set(storageRESTVolume, volume) - values.Set(storageRESTFilePath, path) - - var reader bytes.Buffer - if err := msgp.Encode(&reader, &fi); err != nil { - return err - } - - respBody, err := client.call(ctx, storageRESTMethodWriteMetadata, values, &reader, -1) - defer xhttp.DrainBody(respBody) - return err + _, err := storageWriteMetadataHandler.Call(ctx, client.gridConn, &MetadataHandlerParams{ + DiskID: client.diskID, + Volume: volume, + FilePath: path, + FI: fi, + }) + return toStorageErr(err) } func (client *storageRESTClient) UpdateMetadata(ctx context.Context, volume, path string, fi FileInfo, opts UpdateMetadataOpts) error { - values := make(url.Values) - values.Set(storageRESTVolume, volume) - values.Set(storageRESTFilePath, path) - values.Set(storageRESTNoPersistence, strconv.FormatBool(opts.NoPersistence)) - - var reader bytes.Buffer - if err := msgp.Encode(&reader, &fi); err != nil { - return err - } - - respBody, err := client.call(ctx, storageRESTMethodUpdateMetadata, values, &reader, -1) - defer xhttp.DrainBody(respBody) - return err + _, err := storageUpdateMetadataHandler.Call(ctx, client.gridConn, &MetadataHandlerParams{ + DiskID: client.diskID, + Volume: volume, + FilePath: path, + UpdateOpts: opts, + FI: fi, + }) + return toStorageErr(err) } func (client *storageRESTClient) DeleteVersion(ctx context.Context, volume, path string, fi FileInfo, forceDelMarker bool) error { - values := make(url.Values) - values.Set(storageRESTVolume, volume) - values.Set(storageRESTFilePath, path) - values.Set(storageRESTForceDelMarker, strconv.FormatBool(forceDelMarker)) - - var buffer bytes.Buffer - if err := msgp.Encode(&buffer, &fi); err != nil { - return err - } - - respBody, err := client.call(ctx, storageRESTMethodDeleteVersion, values, &buffer, -1) - defer xhttp.DrainBody(respBody) - return err + _, err := storageDeleteVersionHandler.Call(ctx, client.gridConn, &DeleteVersionHandlerParams{ + DiskID: client.diskID, + Volume: volume, + FilePath: path, + ForceDelMarker: forceDelMarker, + FI: fi, + }) + return toStorageErr(err) } // WriteAll - write all data to a file. @@ -491,51 +429,32 @@ func (client *storageRESTClient) WriteAll(ctx context.Context, volume string, pa // CheckParts - stat all file parts. func (client *storageRESTClient) CheckParts(ctx context.Context, volume string, path string, fi FileInfo) error { - values := make(url.Values) - values.Set(storageRESTVolume, volume) - values.Set(storageRESTFilePath, path) - - var reader bytes.Buffer - if err := msgp.Encode(&reader, &fi); err != nil { - logger.LogIf(context.Background(), err) - return err - } - - respBody, err := client.call(ctx, storageRESTMethodCheckParts, values, &reader, -1) - defer xhttp.DrainBody(respBody) - return err + _, err := storageCheckPartsHandler.Call(ctx, client.gridConn, &CheckPartsHandlerParams{ + DiskID: client.diskID, + Volume: volume, + FilePath: path, + FI: fi, + }) + return toStorageErr(err) } // RenameData - rename source path to destination path atomically, metadata and data file. func (client *storageRESTClient) RenameData(ctx context.Context, srcVolume, srcPath string, fi FileInfo, dstVolume, dstPath string) (sign uint64, err error) { - values := make(url.Values) - values.Set(storageRESTSrcVolume, srcVolume) - values.Set(storageRESTSrcPath, srcPath) - values.Set(storageRESTDstVolume, dstVolume) - values.Set(storageRESTDstPath, dstPath) - - var reader bytes.Buffer - if err = msgp.Encode(&reader, &fi); err != nil { - return 0, err - } - - respBody, err := client.call(ctx, storageRESTMethodRenameData, values, &reader, -1) - defer xhttp.DrainBody(respBody) + // Set a very long timeout for rename data. + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + resp, err := storageRenameDataHandler.Call(ctx, client.gridConn, &RenameDataHandlerParams{ + DiskID: client.diskID, + SrcVolume: srcVolume, + SrcPath: srcPath, + DstPath: dstPath, + DstVolume: dstVolume, + FI: fi, + }) if err != nil { - return 0, err + return 0, toStorageErr(err) } - - respReader, err := waitForHTTPResponse(respBody) - if err != nil { - return 0, err - } - - resp := &RenameDataResp{} - if err = gob.NewDecoder(respReader).Decode(resp); err != nil { - return 0, err - } - - return resp.Signature, toStorageErr(resp.Err) + return resp.Signature, nil } // where we keep old *Readers @@ -562,6 +481,21 @@ func readMsgpReaderPoolPut(r *msgp.Reader) { } func (client *storageRESTClient) ReadVersion(ctx context.Context, volume, path, versionID string, readData bool) (fi FileInfo, err error) { + // Use websocket when not reading data. + if !readData { + resp, err := storageReadVersionHandler.Call(ctx, client.gridConn, grid.NewMSSWith(map[string]string{ + storageRESTDiskID: client.diskID, + storageRESTVolume: volume, + storageRESTFilePath: path, + storageRESTVersionID: versionID, + storageRESTReadData: "false", + })) + if err != nil { + return fi, toStorageErr(err) + } + return *resp, nil + } + values := make(url.Values) values.Set(storageRESTVolume, volume) values.Set(storageRESTFilePath, path) @@ -583,13 +517,27 @@ func (client *storageRESTClient) ReadVersion(ctx context.Context, volume, path, // ReadXL - reads all contents of xl.meta of a file. func (client *storageRESTClient) ReadXL(ctx context.Context, volume string, path string, readData bool) (rf RawFileInfo, err error) { + // Use websocket when not reading data. + if !readData { + resp, err := storageReadXLHandler.Call(ctx, client.gridConn, grid.NewMSSWith(map[string]string{ + storageRESTDiskID: client.diskID, + storageRESTVolume: volume, + storageRESTFilePath: path, + storageRESTReadData: "false", + })) + if err != nil { + return rf, toStorageErr(err) + } + return *resp, nil + } + values := make(url.Values) values.Set(storageRESTVolume, volume) values.Set(storageRESTFilePath, path) values.Set(storageRESTReadData, strconv.FormatBool(readData)) respBody, err := client.call(ctx, storageRESTMethodReadXL, values, nil, -1) if err != nil { - return rf, err + return rf, toStorageErr(err) } defer xhttp.DrainBody(respBody) @@ -667,15 +615,13 @@ func (client *storageRESTClient) ListDir(ctx context.Context, volume, dirPath st // DeleteFile - deletes a file. func (client *storageRESTClient) Delete(ctx context.Context, volume string, path string, deleteOpts DeleteOptions) error { - values := make(url.Values) - values.Set(storageRESTVolume, volume) - values.Set(storageRESTFilePath, path) - values.Set(storageRESTRecursive, strconv.FormatBool(deleteOpts.Recursive)) - values.Set(storageRESTForceDelete, strconv.FormatBool(deleteOpts.Force)) - - respBody, err := client.call(ctx, storageRESTMethodDeleteFile, values, nil, -1) - defer xhttp.DrainBody(respBody) - return err + _, err := storageDeleteFileHandler.Call(ctx, client.gridConn, &DeleteFileHandlerParams{ + DiskID: client.diskID, + Volume: volume, + FilePath: path, + Opts: deleteOpts, + }) + return toStorageErr(err) } // DeleteVersions - deletes list of specified versions if present @@ -867,7 +813,7 @@ func (client *storageRESTClient) Close() error { } // Returns a storage rest client. -func newStorageRESTClient(endpoint Endpoint, healthCheck bool) *storageRESTClient { +func newStorageRESTClient(endpoint Endpoint, healthCheck bool, gm *grid.Manager) (*storageRESTClient, error) { serverURL := &url.URL{ Scheme: endpoint.Scheme, Host: endpoint.Host, @@ -888,6 +834,12 @@ func newStorageRESTClient(endpoint Endpoint, healthCheck bool) *storageRESTClien return toStorageErr(err) != errDiskNotFound } } - - return &storageRESTClient{endpoint: endpoint, restClient: restClient, poolIndex: -1, setIndex: -1, diskIndex: -1} + conn := gm.Connection(endpoint.GridHost()).Subroute(endpoint.Path) + if conn == nil { + return nil, fmt.Errorf("unable to find connection for %s in targets: %v", endpoint.GridHost(), gm.Targets()) + } + return &storageRESTClient{ + endpoint: endpoint, restClient: restClient, poolIndex: -1, setIndex: -1, diskIndex: -1, + gridConn: conn, + }, nil } diff --git a/cmd/storage-rest-common.go b/cmd/storage-rest-common.go index 326b80fff..8e9beb1b1 100644 --- a/cmd/storage-rest-common.go +++ b/cmd/storage-rest-common.go @@ -17,6 +17,8 @@ package cmd +//go:generate msgp -file $GOFILE -unexported + const ( storageRESTVersion = "v50" // Added DiskInfo metrics query storageRESTVersionPrefix = SlashSeparator + storageRESTVersion @@ -25,64 +27,58 @@ const ( const ( storageRESTMethodHealth = "/health" - storageRESTMethodDiskInfo = "/diskinfo" - storageRESTMethodNSScanner = "/nsscanner" storageRESTMethodMakeVol = "/makevol" storageRESTMethodMakeVolBulk = "/makevolbulk" - storageRESTMethodStatVol = "/statvol" storageRESTMethodDeleteVol = "/deletevol" storageRESTMethodListVols = "/listvols" storageRESTMethodAppendFile = "/appendfile" storageRESTMethodCreateFile = "/createfile" storageRESTMethodWriteAll = "/writeall" - storageRESTMethodWriteMetadata = "/writemetadata" - storageRESTMethodUpdateMetadata = "/updatemetadata" - storageRESTMethodDeleteVersion = "/deleteversion" storageRESTMethodReadVersion = "/readversion" storageRESTMethodReadXL = "/readxl" - storageRESTMethodRenameData = "/renamedata" - storageRESTMethodCheckParts = "/checkparts" storageRESTMethodReadAll = "/readall" storageRESTMethodReadFile = "/readfile" storageRESTMethodReadFileStream = "/readfilestream" storageRESTMethodListDir = "/listdir" - storageRESTMethodDeleteFile = "/deletefile" storageRESTMethodDeleteVersions = "/deleteverions" storageRESTMethodRenameFile = "/renamefile" storageRESTMethodVerifyFile = "/verifyfile" - storageRESTMethodWalkDir = "/walkdir" storageRESTMethodStatInfoFile = "/statfile" storageRESTMethodReadMultiple = "/readmultiple" storageRESTMethodCleanAbandoned = "/cleanabandoned" ) const ( - storageRESTVolume = "volume" - storageRESTVolumes = "volumes" - storageRESTDirPath = "dir-path" - storageRESTFilePath = "file-path" - storageRESTForceDelMarker = "force-delete-marker" - storageRESTVersionID = "version-id" - storageRESTReadData = "read-data" - storageRESTTotalVersions = "total-versions" - storageRESTSrcVolume = "source-volume" - storageRESTSrcPath = "source-path" - storageRESTDstVolume = "destination-volume" - storageRESTDstPath = "destination-path" - storageRESTOffset = "offset" - storageRESTLength = "length" - storageRESTCount = "count" - storageRESTPrefixFilter = "prefix" - storageRESTForwardFilter = "forward" - storageRESTRecursive = "recursive" - storageRESTReportNotFound = "report-notfound" - storageRESTBitrotAlgo = "bitrot-algo" - storageRESTBitrotHash = "bitrot-hash" - storageRESTDiskID = "disk-id" - storageRESTForceDelete = "force-delete" - storageRESTGlob = "glob" - storageRESTScanMode = "scan-mode" - storageRESTMetrics = "metrics" - storageRESTNoPersistence = "no-persistence" + storageRESTVolume = "volume" + storageRESTVolumes = "volumes" + storageRESTDirPath = "dir-path" + storageRESTFilePath = "file-path" + storageRESTVersionID = "version-id" + storageRESTReadData = "read-data" + storageRESTTotalVersions = "total-versions" + storageRESTSrcVolume = "source-volume" + storageRESTSrcPath = "source-path" + storageRESTDstVolume = "destination-volume" + storageRESTDstPath = "destination-path" + storageRESTOffset = "offset" + storageRESTLength = "length" + storageRESTCount = "count" + storageRESTBitrotAlgo = "bitrot-algo" + storageRESTBitrotHash = "bitrot-hash" + storageRESTDiskID = "disk-id" + storageRESTForceDelete = "force-delete" + storageRESTGlob = "glob" + storageRESTMetrics = "metrics" ) + +type nsScannerOptions struct { + DiskID string `msg:"id"` + ScanMode int `msg:"m"` + Cache *dataUsageCache `msg:"c"` +} + +type nsScannerResp struct { + Update *dataUsageEntry `msg:"u"` + Final *dataUsageCache `msg:"f"` +} diff --git a/cmd/storage-rest-common_gen.go b/cmd/storage-rest-common_gen.go new file mode 100644 index 000000000..f81f8c97e --- /dev/null +++ b/cmd/storage-rest-common_gen.go @@ -0,0 +1,418 @@ +package cmd + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *nsScannerOptions) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "m": + z.ScanMode, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "ScanMode") + return + } + case "c": + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err, "Cache") + return + } + z.Cache = nil + } else { + if z.Cache == nil { + z.Cache = new(dataUsageCache) + } + err = z.Cache.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "Cache") + return + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *nsScannerOptions) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 3 + // write "id" + err = en.Append(0x83, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteString(z.DiskID) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + // write "m" + err = en.Append(0xa1, 0x6d) + if err != nil { + return + } + err = en.WriteInt(z.ScanMode) + if err != nil { + err = msgp.WrapError(err, "ScanMode") + return + } + // write "c" + err = en.Append(0xa1, 0x63) + if err != nil { + return + } + if z.Cache == nil { + err = en.WriteNil() + if err != nil { + return + } + } else { + err = z.Cache.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "Cache") + return + } + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *nsScannerOptions) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 3 + // string "id" + o = append(o, 0x83, 0xa2, 0x69, 0x64) + o = msgp.AppendString(o, z.DiskID) + // string "m" + o = append(o, 0xa1, 0x6d) + o = msgp.AppendInt(o, z.ScanMode) + // string "c" + o = append(o, 0xa1, 0x63) + if z.Cache == nil { + o = msgp.AppendNil(o) + } else { + o, err = z.Cache.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "Cache") + return + } + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *nsScannerOptions) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.DiskID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "DiskID") + return + } + case "m": + z.ScanMode, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ScanMode") + return + } + case "c": + if msgp.IsNil(bts) { + bts, err = msgp.ReadNilBytes(bts) + if err != nil { + return + } + z.Cache = nil + } else { + if z.Cache == nil { + z.Cache = new(dataUsageCache) + } + bts, err = z.Cache.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "Cache") + return + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *nsScannerOptions) Msgsize() (s int) { + s = 1 + 3 + msgp.StringPrefixSize + len(z.DiskID) + 2 + msgp.IntSize + 2 + if z.Cache == nil { + s += msgp.NilSize + } else { + s += z.Cache.Msgsize() + } + return +} + +// DecodeMsg implements msgp.Decodable +func (z *nsScannerResp) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "u": + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err, "Update") + return + } + z.Update = nil + } else { + if z.Update == nil { + z.Update = new(dataUsageEntry) + } + err = z.Update.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "Update") + return + } + } + case "f": + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err, "Final") + return + } + z.Final = nil + } else { + if z.Final == nil { + z.Final = new(dataUsageCache) + } + err = z.Final.DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "Final") + return + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *nsScannerResp) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "u" + err = en.Append(0x82, 0xa1, 0x75) + if err != nil { + return + } + if z.Update == nil { + err = en.WriteNil() + if err != nil { + return + } + } else { + err = z.Update.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "Update") + return + } + } + // write "f" + err = en.Append(0xa1, 0x66) + if err != nil { + return + } + if z.Final == nil { + err = en.WriteNil() + if err != nil { + return + } + } else { + err = z.Final.EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "Final") + return + } + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *nsScannerResp) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "u" + o = append(o, 0x82, 0xa1, 0x75) + if z.Update == nil { + o = msgp.AppendNil(o) + } else { + o, err = z.Update.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "Update") + return + } + } + // string "f" + o = append(o, 0xa1, 0x66) + if z.Final == nil { + o = msgp.AppendNil(o) + } else { + o, err = z.Final.MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "Final") + return + } + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *nsScannerResp) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "u": + if msgp.IsNil(bts) { + bts, err = msgp.ReadNilBytes(bts) + if err != nil { + return + } + z.Update = nil + } else { + if z.Update == nil { + z.Update = new(dataUsageEntry) + } + bts, err = z.Update.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "Update") + return + } + } + case "f": + if msgp.IsNil(bts) { + bts, err = msgp.ReadNilBytes(bts) + if err != nil { + return + } + z.Final = nil + } else { + if z.Final == nil { + z.Final = new(dataUsageCache) + } + bts, err = z.Final.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "Final") + return + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *nsScannerResp) Msgsize() (s int) { + s = 1 + 2 + if z.Update == nil { + s += msgp.NilSize + } else { + s += z.Update.Msgsize() + } + s += 2 + if z.Final == nil { + s += msgp.NilSize + } else { + s += z.Final.Msgsize() + } + return +} diff --git a/cmd/storage-rest-common_gen_test.go b/cmd/storage-rest-common_gen_test.go new file mode 100644 index 000000000..8085a115c --- /dev/null +++ b/cmd/storage-rest-common_gen_test.go @@ -0,0 +1,236 @@ +package cmd + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalnsScannerOptions(t *testing.T) { + v := nsScannerOptions{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgnsScannerOptions(b *testing.B) { + v := nsScannerOptions{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgnsScannerOptions(b *testing.B) { + v := nsScannerOptions{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalnsScannerOptions(b *testing.B) { + v := nsScannerOptions{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodensScannerOptions(t *testing.T) { + v := nsScannerOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodensScannerOptions Msgsize() is inaccurate") + } + + vn := nsScannerOptions{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodensScannerOptions(b *testing.B) { + v := nsScannerOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodensScannerOptions(b *testing.B) { + v := nsScannerOptions{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalnsScannerResp(t *testing.T) { + v := nsScannerResp{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgnsScannerResp(b *testing.B) { + v := nsScannerResp{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgnsScannerResp(b *testing.B) { + v := nsScannerResp{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalnsScannerResp(b *testing.B) { + v := nsScannerResp{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodensScannerResp(t *testing.T) { + v := nsScannerResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodensScannerResp Msgsize() is inaccurate") + } + + vn := nsScannerResp{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodensScannerResp(b *testing.B) { + v := nsScannerResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodensScannerResp(b *testing.B) { + v := nsScannerResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/cmd/storage-rest-server.go b/cmd/storage-rest-server.go index 8210ce097..59adf69f0 100644 --- a/cmd/storage-rest-server.go +++ b/cmd/storage-rest-server.go @@ -33,8 +33,10 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" + "github.com/minio/minio/internal/grid" "github.com/tinylib/msgp/msgp" jwtreq "github.com/golang-jwt/jwt/v4/request" @@ -116,7 +118,7 @@ func storageServerRequestValidate(r *http.Request) error { return nil } -// IsValid - To authenticate and verify the time difference. +// IsAuthValid - To authenticate and verify the time difference. func (s *storageRESTServer) IsAuthValid(w http.ResponseWriter, r *http.Request) bool { if s.storage == nil { s.writeErrorResponse(w, errDiskNotFound) @@ -165,56 +167,63 @@ func (s *storageRESTServer) IsValid(w http.ResponseWriter, r *http.Request) bool return true } +// checkID - check if the disk-id in the request corresponds to the underlying disk. +func (s *storageRESTServer) checkID(wantID string) bool { + if s.storage == nil { + return false + } + if wantID == "" { + // Request sent empty disk-id, we allow the request + // as the peer might be coming up and trying to read format.json + // or create format.json + return true + } + + storedDiskID, err := s.storage.GetDiskID() + if err != nil { + return false + } + + return wantID == storedDiskID +} + // HealthHandler handler checks if disk is stale func (s *storageRESTServer) HealthHandler(w http.ResponseWriter, r *http.Request) { s.IsValid(w, r) } +// DiskInfo types. +// DiskInfo.Metrics elements are shared, so we cannot reuse. +var storageDiskInfoHandler = grid.NewSingleHandler[*grid.MSS, *DiskInfo](grid.HandlerDiskInfo, grid.NewMSS, func() *DiskInfo { return &DiskInfo{} }).WithSharedResponse() + // DiskInfoHandler - returns disk info. -func (s *storageRESTServer) DiskInfoHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsAuthValid(w, r) { - return +func (s *storageRESTServer) DiskInfoHandler(params *grid.MSS) (*DiskInfo, *grid.RemoteErr) { + if !s.checkID(params.Get(storageRESTDiskID)) { + return nil, grid.NewRemoteErr(errDiskNotFound) } - info, err := s.storage.DiskInfo(r.Context(), r.Form.Get(storageRESTMetrics) == "true") + withMetrics := params.Get(storageRESTMetrics) == "true" + info, err := s.storage.DiskInfo(context.Background(), withMetrics) if err != nil { info.Error = err.Error() } - logger.LogIf(r.Context(), msgp.Encode(w, &info)) + info.Scanning = s.storage != nil && s.storage.storage != nil && atomic.LoadInt32(&s.storage.storage.scanning) > 0 + return &info, nil } -func (s *storageRESTServer) NSScannerHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return +// scanner rpc handler. +var storageNSScannerHandler = grid.NewStream[*nsScannerOptions, grid.NoPayload, *nsScannerResp](grid.HandlerNSScanner, + func() *nsScannerOptions { return &nsScannerOptions{} }, + nil, + func() *nsScannerResp { return &nsScannerResp{} }) + +func (s *storageRESTServer) NSScannerHandler(ctx context.Context, params *nsScannerOptions, out chan<- *nsScannerResp) *grid.RemoteErr { + if !s.checkID(params.DiskID) { + return grid.NewRemoteErr(errDiskNotFound) } - - scanMode, err := strconv.Atoi(r.Form.Get(storageRESTScanMode)) - if err != nil { - logger.LogIf(r.Context(), err) - s.writeErrorResponse(w, err) - return + if params.Cache == nil { + return grid.NewRemoteErrString("NSScannerHandler: provided cache is nil") } - setEventStreamHeaders(w) - - var cache dataUsageCache - err = cache.deserialize(r.Body) - if err != nil { - logger.LogIf(r.Context(), err) - s.writeErrorResponse(w, err) - return - } - - ctx, cancel := context.WithCancel(r.Context()) - defer cancel() - resp := streamHTTPResponse(w) - defer func() { - if r := recover(); r != nil { - debug.PrintStack() - resp.CloseWithError(fmt.Errorf("panic: %v", r)) - } - }() - respW := msgp.NewWriter(resp) - // Collect updates, stream them before the full cache is sent. updates := make(chan dataUsageEntry, 1) var wg sync.WaitGroup @@ -222,36 +231,21 @@ func (s *storageRESTServer) NSScannerHandler(w http.ResponseWriter, r *http.Requ go func() { defer wg.Done() for update := range updates { - // Write true bool to indicate update. - var err error - if err = respW.WriteBool(true); err == nil { - err = update.EncodeMsg(respW) - } - respW.Flush() - if err != nil { - cancel() - resp.CloseWithError(err) - return - } + resp := storageNSScannerHandler.NewResponse() + resp.Update = &update + out <- resp } }() - usageInfo, err := s.storage.NSScanner(ctx, cache, updates, madmin.HealScanMode(scanMode)) - if err != nil { - respW.Flush() - resp.CloseWithError(err) - return - } - - // Write false bool to indicate we finished. + ui, err := s.storage.NSScanner(ctx, *params.Cache, updates, madmin.HealScanMode(params.ScanMode)) wg.Wait() - if err = respW.WriteBool(false); err == nil { - err = usageInfo.EncodeMsg(respW) - } if err != nil { - resp.CloseWithError(err) - return + return grid.NewRemoteErr(err) } - resp.CloseWithError(respW.Flush()) + // Send final response. + resp := storageNSScannerHandler.NewResponse() + resp.Final = &ui + out <- resp + return nil } // MakeVolHandler - make a volume. @@ -291,21 +285,22 @@ func (s *storageRESTServer) ListVolsHandler(w http.ResponseWriter, r *http.Reque logger.LogIf(r.Context(), msgp.Encode(w, VolsInfo(infos))) } +// statvol types. +var storageStatVolHandler = grid.NewSingleHandler[*grid.MSS, *VolInfo](grid.HandlerStatVol, grid.NewMSS, func() *VolInfo { return &VolInfo{} }) + // StatVolHandler - stat a volume. -func (s *storageRESTServer) StatVolHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return +func (s *storageRESTServer) StatVolHandler(params *grid.MSS) (*VolInfo, *grid.RemoteErr) { + if !s.checkID(params.Get(storageRESTDiskID)) { + return nil, grid.NewRemoteErr(errDiskNotFound) } - volume := r.Form.Get(storageRESTVolume) - info, err := s.storage.StatVol(r.Context(), volume) + info, err := s.storage.StatVol(context.Background(), params.Get(storageRESTVolume)) if err != nil { - s.writeErrorResponse(w, err) - return + return nil, grid.NewRemoteErr(err) } - logger.LogIf(r.Context(), msgp.Encode(w, &info)) + return &info, nil } -// DeleteVolumeHandler - delete a volume. +// DeleteVolHandler - delete a volume. func (s *storageRESTServer) DeleteVolHandler(w http.ResponseWriter, r *http.Request) { if !s.IsValid(w, r) { return @@ -357,37 +352,48 @@ func (s *storageRESTServer) CreateFileHandler(w http.ResponseWriter, r *http.Req done(s.storage.CreateFile(r.Context(), volume, filePath, int64(fileSize), body)) } -// DeleteVersion delete updated metadata. -func (s *storageRESTServer) DeleteVersionHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return - } - volume := r.Form.Get(storageRESTVolume) - filePath := r.Form.Get(storageRESTFilePath) - forceDelMarker, err := strconv.ParseBool(r.Form.Get(storageRESTForceDelMarker)) - if err != nil { - s.writeErrorResponse(w, errInvalidArgument) - return - } +var storageDeleteVersionHandler = grid.NewSingleHandler[*DeleteVersionHandlerParams, grid.NoPayload](grid.HandlerDeleteVersion, func() *DeleteVersionHandlerParams { + return &DeleteVersionHandlerParams{} +}, grid.NewNoPayload) - if r.ContentLength < 0 { - s.writeErrorResponse(w, errInvalidArgument) - return +// DeleteVersionHandler delete updated metadata. +func (s *storageRESTServer) DeleteVersionHandler(p *DeleteVersionHandlerParams) (np grid.NoPayload, gerr *grid.RemoteErr) { + if !s.checkID(p.DiskID) { + return np, grid.NewRemoteErr(errDiskNotFound) } + volume := p.Volume + filePath := p.FilePath + forceDelMarker := p.ForceDelMarker - var fi FileInfo - if err := msgp.Decode(r.Body, &fi); err != nil { - s.writeErrorResponse(w, err) - return - } - - err = s.storage.DeleteVersion(r.Context(), volume, filePath, fi, forceDelMarker) - if err != nil { - s.writeErrorResponse(w, err) - } + err := s.storage.DeleteVersion(context.Background(), volume, filePath, p.FI, forceDelMarker) + return np, grid.NewRemoteErr(err) } -// ReadVersion read metadata of versionID +var storageReadVersionHandler = grid.NewSingleHandler[*grid.MSS, *FileInfo](grid.HandlerReadVersion, grid.NewMSS, func() *FileInfo { + return &FileInfo{} +}) + +// ReadVersionHandlerWS read metadata of versionID +func (s *storageRESTServer) ReadVersionHandlerWS(params *grid.MSS) (*FileInfo, *grid.RemoteErr) { + if !s.checkID(params.Get(storageRESTDiskID)) { + return nil, grid.NewRemoteErr(errDiskNotFound) + } + volume := params.Get(storageRESTVolume) + filePath := params.Get(storageRESTFilePath) + versionID := params.Get(storageRESTVersionID) + readData, err := strconv.ParseBool(params.Get(storageRESTReadData)) + if err != nil { + return nil, grid.NewRemoteErr(err) + } + + fi, err := s.storage.ReadVersion(context.Background(), volume, filePath, versionID, readData) + if err != nil { + return nil, grid.NewRemoteErr(err) + } + return &fi, nil +} + +// ReadVersionHandler read metadata of versionID func (s *storageRESTServer) ReadVersionHandler(w http.ResponseWriter, r *http.Request) { if !s.IsValid(w, r) { return @@ -410,55 +416,35 @@ func (s *storageRESTServer) ReadVersionHandler(w http.ResponseWriter, r *http.Re logger.LogIf(r.Context(), msgp.Encode(w, &fi)) } -// WriteMetadata write new updated metadata. -func (s *storageRESTServer) WriteMetadataHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return - } - volume := r.Form.Get(storageRESTVolume) - filePath := r.Form.Get(storageRESTFilePath) +var storageWriteMetadataHandler = grid.NewSingleHandler[*MetadataHandlerParams, grid.NoPayload](grid.HandlerWriteMetadata, func() *MetadataHandlerParams { + return &MetadataHandlerParams{} +}, grid.NewNoPayload) - if r.ContentLength < 0 { - s.writeErrorResponse(w, errInvalidArgument) - return +// WriteMetadataHandler rpc handler to write new updated metadata. +func (s *storageRESTServer) WriteMetadataHandler(p *MetadataHandlerParams) (np grid.NoPayload, gerr *grid.RemoteErr) { + if !s.checkID(p.DiskID) { + return grid.NewNPErr(errDiskNotFound) } + volume := p.Volume + filePath := p.FilePath - var fi FileInfo - if err := msgp.Decode(r.Body, &fi); err != nil { - s.writeErrorResponse(w, err) - return - } - - err := s.storage.WriteMetadata(r.Context(), volume, filePath, fi) - if err != nil { - s.writeErrorResponse(w, err) - } + err := s.storage.WriteMetadata(context.Background(), volume, filePath, p.FI) + return np, grid.NewRemoteErr(err) } -// UpdateMetadata update new updated metadata. -func (s *storageRESTServer) UpdateMetadataHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return - } - volume := r.Form.Get(storageRESTVolume) - filePath := r.Form.Get(storageRESTFilePath) - noPersistence := r.Form.Get(storageRESTNoPersistence) == "true" +var storageUpdateMetadataHandler = grid.NewSingleHandler[*MetadataHandlerParams, grid.NoPayload](grid.HandlerUpdateMetadata, func() *MetadataHandlerParams { + return &MetadataHandlerParams{} +}, grid.NewNoPayload) - if r.ContentLength < 0 { - s.writeErrorResponse(w, errInvalidArgument) - return +// UpdateMetadataHandler update new updated metadata. +func (s *storageRESTServer) UpdateMetadataHandler(p *MetadataHandlerParams) (grid.NoPayload, *grid.RemoteErr) { + if !s.checkID(p.DiskID) { + return grid.NewNPErr(errDiskNotFound) } + volume := p.Volume + filePath := p.FilePath - var fi FileInfo - if err := msgp.Decode(r.Body, &fi); err != nil { - s.writeErrorResponse(w, err) - return - } - - err := s.storage.UpdateMetadata(r.Context(), volume, filePath, fi, UpdateMetadataOpts{NoPersistence: noPersistence}) - if err != nil { - s.writeErrorResponse(w, err) - } + return grid.NewNPErr(s.storage.UpdateMetadata(context.Background(), volume, filePath, p.FI, p.UpdateOpts)) } // WriteAllHandler - write to file all content. @@ -485,28 +471,18 @@ func (s *storageRESTServer) WriteAllHandler(w http.ResponseWriter, r *http.Reque } } +var storageCheckPartsHandler = grid.NewSingleHandler[*CheckPartsHandlerParams, grid.NoPayload](grid.HandlerCheckParts, func() *CheckPartsHandlerParams { + return &CheckPartsHandlerParams{} +}, grid.NewNoPayload) + // CheckPartsHandler - check if a file metadata exists. -func (s *storageRESTServer) CheckPartsHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return - } - volume := r.Form.Get(storageRESTVolume) - filePath := r.Form.Get(storageRESTFilePath) - - if r.ContentLength < 0 { - s.writeErrorResponse(w, errInvalidArgument) - return - } - - var fi FileInfo - if err := msgp.Decode(r.Body, &fi); err != nil { - s.writeErrorResponse(w, err) - return - } - - if err := s.storage.CheckParts(r.Context(), volume, filePath, fi); err != nil { - s.writeErrorResponse(w, err) +func (s *storageRESTServer) CheckPartsHandler(p *CheckPartsHandlerParams) (grid.NoPayload, *grid.RemoteErr) { + if !s.checkID(p.DiskID) { + return grid.NewNPErr(errDiskNotFound) } + volume := p.Volume + filePath := p.FilePath + return grid.NewNPErr(s.storage.CheckParts(context.Background(), volume, filePath, p.FI)) } // ReadAllHandler - read all the contents of a file. @@ -550,6 +526,30 @@ func (s *storageRESTServer) ReadXLHandler(w http.ResponseWriter, r *http.Request logger.LogIf(r.Context(), msgp.Encode(w, &rf)) } +var storageReadXLHandler = grid.NewSingleHandler[*grid.MSS, *RawFileInfo](grid.HandlerReadXL, grid.NewMSS, func() *RawFileInfo { + return &RawFileInfo{} +}) + +// ReadXLHandlerWS - read xl.meta for an object at path. +func (s *storageRESTServer) ReadXLHandlerWS(params *grid.MSS) (*RawFileInfo, *grid.RemoteErr) { + if !s.checkID(params.Get(storageRESTDiskID)) { + return nil, grid.NewRemoteErr(errDiskNotFound) + } + volume := params.Get(storageRESTVolume) + filePath := params.Get(storageRESTFilePath) + readData, err := strconv.ParseBool(params.Get(storageRESTReadData)) + if err != nil { + return nil, grid.NewRemoteErr(err) + } + + rf, err := s.storage.ReadXL(context.Background(), volume, filePath, readData) + if err != nil { + return nil, grid.NewRemoteErr(err) + } + + return &rf, nil +} + // ReadFileHandler - read section of a file. func (s *storageRESTServer) ReadFileHandler(w http.ResponseWriter, r *http.Request) { if !s.IsValid(w, r) { @@ -593,7 +593,7 @@ func (s *storageRESTServer) ReadFileHandler(w http.ResponseWriter, r *http.Reque w.Write(buf) } -// ReadFileHandler - read section of a file. +// ReadFileStreamHandler - read section of a file. func (s *storageRESTServer) ReadFileStreamHandler(w http.ResponseWriter, r *http.Request) { if !s.IsValid(w, r) { return @@ -666,30 +666,16 @@ func (s *storageRESTServer) ListDirHandler(w http.ResponseWriter, r *http.Reques gob.NewEncoder(w).Encode(&entries) } +var storageDeleteFileHandler = grid.NewSingleHandler[*DeleteFileHandlerParams, grid.NoPayload](grid.HandlerDeleteFile, func() *DeleteFileHandlerParams { + return &DeleteFileHandlerParams{} +}, grid.NewNoPayload) + // DeleteFileHandler - delete a file. -func (s *storageRESTServer) DeleteFileHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return - } - volume := r.Form.Get(storageRESTVolume) - filePath := r.Form.Get(storageRESTFilePath) - recursive, err := strconv.ParseBool(r.Form.Get(storageRESTRecursive)) - if err != nil { - s.writeErrorResponse(w, err) - return - } - force, err := strconv.ParseBool(r.Form.Get(storageRESTForceDelete)) - if err != nil { - s.writeErrorResponse(w, err) - return - } - err = s.storage.Delete(r.Context(), volume, filePath, DeleteOptions{ - Recursive: recursive, - Force: force, - }) - if err != nil { - s.writeErrorResponse(w, err) +func (s *storageRESTServer) DeleteFileHandler(p *DeleteFileHandlerParams) (grid.NoPayload, *grid.RemoteErr) { + if !s.checkID(p.DiskID) { + return grid.NewNPErr(errDiskNotFound) } + return grid.NewNPErr(s.storage.Delete(context.Background(), p.Volume, p.FilePath, p.Opts)) } // DeleteVersionsErrsResp - collection of delete errors @@ -737,48 +723,23 @@ func (s *storageRESTServer) DeleteVersionsHandler(w http.ResponseWriter, r *http encoder.Encode(dErrsResp) } -// RenameDataResp - RenameData()'s response. -type RenameDataResp struct { - Signature uint64 - Err error -} +var storageRenameDataHandler = grid.NewSingleHandler[*RenameDataHandlerParams, *RenameDataResp](grid.HandlerRenamedata, func() *RenameDataHandlerParams { + return &RenameDataHandlerParams{} +}, func() *RenameDataResp { + return &RenameDataResp{} +}) // RenameDataHandler - renames a meta object and data dir to destination. -func (s *storageRESTServer) RenameDataHandler(w http.ResponseWriter, r *http.Request) { - if !s.IsValid(w, r) { - return +func (s *storageRESTServer) RenameDataHandler(p *RenameDataHandlerParams) (*RenameDataResp, *grid.RemoteErr) { + if !s.checkID(p.DiskID) { + return nil, grid.NewRemoteErr(errDiskNotFound) } - srcVolume := r.Form.Get(storageRESTSrcVolume) - srcFilePath := r.Form.Get(storageRESTSrcPath) - dstVolume := r.Form.Get(storageRESTDstVolume) - dstFilePath := r.Form.Get(storageRESTDstPath) - - if r.ContentLength < 0 { - s.writeErrorResponse(w, errInvalidArgument) - return - } - - var fi FileInfo - if err := msgp.Decode(r.Body, &fi); err != nil { - s.writeErrorResponse(w, err) - return - } - - setEventStreamHeaders(w) - encoder := gob.NewEncoder(w) - done := keepHTTPResponseAlive(w) - - sign, err := s.storage.RenameData(r.Context(), srcVolume, srcFilePath, fi, dstVolume, dstFilePath) - done(nil) - + sign, err := s.storage.RenameData(context.Background(), p.SrcVolume, p.SrcPath, p.FI, p.DstVolume, p.DstPath) resp := &RenameDataResp{ Signature: sign, } - if err != nil { - resp.Err = StorageErr(err.Error()) - } - encoder.Encode(resp) + return resp, grid.NewRemoteErr(err) } // RenameFileHandler - rename a file. @@ -1349,8 +1310,8 @@ func (s *storageRESTServer) ReadMultiple(w http.ResponseWriter, r *http.Request) rw.CloseWithError(err) } -// registerStorageRPCRouter - register storage rpc router. -func registerStorageRESTHandlers(router *mux.Router, endpointServerPools EndpointServerPools) { +// registerStorageRESTHandlers - register storage rpc router. +func registerStorageRESTHandlers(router *mux.Router, endpointServerPools EndpointServerPools, gm *grid.Manager) { storageDisks := make([][]*xlStorage, len(endpointServerPools)) for poolIdx, ep := range endpointServerPools { storageDisks[poolIdx] = make([]*xlStorage, len(ep.Endpoints)) @@ -1394,38 +1355,44 @@ func registerStorageRESTHandlers(router *mux.Router, endpointServerPools Endpoin subrouter := router.PathPrefix(path.Join(storageRESTPrefix, endpoint.Path)).Subrouter() subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodHealth).HandlerFunc(h(server.HealthHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodDiskInfo).HandlerFunc(h(server.DiskInfoHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodNSScanner).HandlerFunc(h(server.NSScannerHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodMakeVol).HandlerFunc(h(server.MakeVolHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodMakeVolBulk).HandlerFunc(h(server.MakeVolBulkHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodStatVol).HandlerFunc(h(server.StatVolHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodDeleteVol).HandlerFunc(h(server.DeleteVolHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodListVols).HandlerFunc(h(server.ListVolsHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodAppendFile).HandlerFunc(h(server.AppendFileHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodWriteAll).HandlerFunc(h(server.WriteAllHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodWriteMetadata).HandlerFunc(h(server.WriteMetadataHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodUpdateMetadata).HandlerFunc(h(server.UpdateMetadataHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodDeleteVersion).HandlerFunc(h(server.DeleteVersionHandler)) + subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodReadVersion).HandlerFunc(h(server.ReadVersionHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodReadXL).HandlerFunc(h(server.ReadXLHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodRenameData).HandlerFunc(h(server.RenameDataHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodCreateFile).HandlerFunc(h(server.CreateFileHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodCheckParts).HandlerFunc(h(server.CheckPartsHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodReadAll).HandlerFunc(h(server.ReadAllHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodReadFile).HandlerFunc(h(server.ReadFileHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodReadFileStream).HandlerFunc(h(server.ReadFileStreamHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodListDir).HandlerFunc(h(server.ListDirHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodDeleteVersions).HandlerFunc(h(server.DeleteVersionsHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodDeleteFile).HandlerFunc(h(server.DeleteFileHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodRenameFile).HandlerFunc(h(server.RenameFileHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodVerifyFile).HandlerFunc(h(server.VerifyFileHandler)) - subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodWalkDir).HandlerFunc(h(server.WalkDirHandler)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodStatInfoFile).HandlerFunc(h(server.StatInfoFile)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodReadMultiple).HandlerFunc(h(server.ReadMultiple)) subrouter.Methods(http.MethodPost).Path(storageRESTVersionPrefix + storageRESTMethodCleanAbandoned).HandlerFunc(h(server.CleanAbandonedDataHandler)) + logger.FatalIf(storageRenameDataHandler.Register(gm, server.RenameDataHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageDeleteFileHandler.Register(gm, server.DeleteFileHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageCheckPartsHandler.Register(gm, server.CheckPartsHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageReadVersionHandler.Register(gm, server.ReadVersionHandlerWS, endpoint.Path), "unable to register handler") + logger.FatalIf(storageWriteMetadataHandler.Register(gm, server.WriteMetadataHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageUpdateMetadataHandler.Register(gm, server.UpdateMetadataHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageDeleteVersionHandler.Register(gm, server.DeleteVersionHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageReadXLHandler.Register(gm, server.ReadXLHandlerWS, endpoint.Path), "unable to register handler") + logger.FatalIf(storageNSScannerHandler.RegisterNoInput(gm, server.NSScannerHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageDiskInfoHandler.Register(gm, server.DiskInfoHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(storageStatVolHandler.Register(gm, server.StatVolHandler, endpoint.Path), "unable to register handler") + logger.FatalIf(gm.RegisterStreamingHandler(grid.HandlerWalkDir, grid.StreamHandler{ + Subroute: endpoint.Path, + Handle: server.WalkDirHandler, + OutCapacity: 1, + }), "unable to register handler") } } } diff --git a/cmd/storage-rest_test.go b/cmd/storage-rest_test.go index 1b56e8719..8a2c5719d 100644 --- a/cmd/storage-rest_test.go +++ b/cmd/storage-rest_test.go @@ -20,12 +20,11 @@ package cmd import ( "bytes" "context" - "net/http/httptest" "reflect" "runtime" "testing" - "github.com/minio/mux" + "github.com/minio/minio/internal/grid" xnet "github.com/minio/pkg/v2/net" ) @@ -437,17 +436,21 @@ func testStorageAPIRenameFile(t *testing.T, storage StorageAPI) { } } -func newStorageRESTHTTPServerClient(t *testing.T) *storageRESTClient { +func newStorageRESTHTTPServerClient(t testing.TB) *storageRESTClient { + // Grid with 2 hosts + tg, err := grid.SetupTestGrid(2) + if err != nil { + t.Fatalf("SetupTestGrid: %v", err) + } + t.Cleanup(tg.Cleanup) prevHost, prevPort := globalMinioHost, globalMinioPort defer func() { globalMinioHost, globalMinioPort = prevHost, prevPort }() + // tg[0] = local, tg[1] = remote - router := mux.NewRouter() - httpServer := httptest.NewServer(router) - t.Cleanup(httpServer.Close) - - url, err := xnet.ParseHTTPURL(httpServer.URL) + // Remote URL + url, err := xnet.ParseHTTPURL(tg.Servers[1].URL) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -464,11 +467,18 @@ func newStorageRESTHTTPServerClient(t *testing.T) *storageRESTClient { t.Fatalf("UpdateIsLocal failed %v", err) } - registerStorageRESTHandlers(router, []PoolEndpoints{{ + // Register handlers on newly created servers + registerStorageRESTHandlers(tg.Mux[0], []PoolEndpoints{{ Endpoints: Endpoints{endpoint}, - }}) + }}, tg.Managers[0]) + registerStorageRESTHandlers(tg.Mux[1], []PoolEndpoints{{ + Endpoints: Endpoints{endpoint}, + }}, tg.Managers[1]) - restClient := newStorageRESTClient(endpoint, false) + restClient, err := newStorageRESTClient(endpoint, false, tg.Managers[0]) + if err != nil { + t.Fatal(err) + } return restClient } diff --git a/cmd/xl-storage-disk-id-check.go b/cmd/xl-storage-disk-id-check.go index 86611aee8..bd66378f4 100644 --- a/cmd/xl-storage-disk-id-check.go +++ b/cmd/xl-storage-disk-id-check.go @@ -1121,6 +1121,21 @@ func (p *xlStorageDiskIDCheck) monitorDiskWritable(ctx context.Context) { } } +// checkID will check if the disk ID matches the provided ID. +func (p *xlStorageDiskIDCheck) checkID(wantID string) (err error) { + if wantID == "" { + return nil + } + id, err := p.storage.GetDiskID() + if err != nil { + return err + } + if id != wantID { + return fmt.Errorf("disk ID %s does not match. disk reports %s", wantID, id) + } + return nil +} + // diskHealthCheckOK will check if the provided error is nil // and update disk status if good. // For convenience a bool is returned to indicate any error state diff --git a/cmd/xl-storage.go b/cmd/xl-storage.go index 8edd7d2cf..82b79334f 100644 --- a/cmd/xl-storage.go +++ b/cmd/xl-storage.go @@ -1226,11 +1226,6 @@ func (s *xlStorage) DeleteVersion(ctx context.Context, volume, path string, fi F return s.deleteFile(volumeDir, filePath, true, false) } -// UpdateMetadataOpts provides an optional input to indicate if xl.meta updates need to be fully synced to disk. -type UpdateMetadataOpts struct { - NoPersistence bool -} - // Updates only metadata for a given version. func (s *xlStorage) UpdateMetadata(ctx context.Context, volume, path string, fi FileInfo, opts UpdateMetadataOpts) error { if len(fi.Metadata) == 0 { diff --git a/go.mod b/go.mod index 29d3704fc..38e541ecb 100644 --- a/go.mod +++ b/go.mod @@ -26,13 +26,14 @@ require ( github.com/go-ldap/ldap/v3 v3.4.6 github.com/go-openapi/loads v0.21.2 github.com/go-sql-driver/mysql v1.7.1 + github.com/gobwas/ws v1.3.1-0.20231030152437-516805a9f3b3 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/gomodule/redigo v1.8.9 github.com/google/uuid v1.3.1 github.com/hashicorp/golang-lru v1.0.2 github.com/inconshreveable/mousetrap v1.1.0 github.com/json-iterator/go v1.1.12 - github.com/klauspost/compress v1.17.1 + github.com/klauspost/compress v1.17.3 github.com/klauspost/cpuid/v2 v2.2.5 github.com/klauspost/filepathx v1.1.1 github.com/klauspost/pgzip v1.2.6 @@ -78,7 +79,7 @@ require ( github.com/secure-io/sio-go v0.3.1 github.com/shirou/gopsutil/v3 v3.23.9 github.com/tidwall/gjson v1.17.0 - github.com/tinylib/msgp v1.1.8 + github.com/tinylib/msgp v1.1.9-0.20230705140925-6ac204f0b4d4 github.com/valyala/bytebufferpool v1.0.0 github.com/xdg/scram v1.0.5 github.com/zeebo/xxh3 v1.0.2 @@ -138,6 +139,8 @@ require ( github.com/go-openapi/strfmt v0.21.7 // indirect github.com/go-openapi/swag v0.22.4 // indirect github.com/go-openapi/validate v0.22.1 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect diff --git a/go.sum b/go.sum index d00590787..2ceeaf620 100644 --- a/go.sum +++ b/go.sum @@ -235,6 +235,12 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.1-0.20231030152437-516805a9f3b3 h1:u5on5kZjHKikhx6d2IAGOxFf4BAcJhUb2v8VJFHBgFA= +github.com/gobwas/ws v1.3.1-0.20231030152437-516805a9f3b3/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -373,8 +379,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.17.1 h1:NE3C767s2ak2bweCZo3+rdP4U/HoyVXLv/X9f2gPS5g= -github.com/klauspost/compress v1.17.1/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= +github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -695,8 +701,8 @@ github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhso github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= -github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= -github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= +github.com/tinylib/msgp v1.1.9-0.20230705140925-6ac204f0b4d4 h1:IEP0iEIadHj1iwMq0eYNIY5RpthYfckHNx5zOc8oL/g= +github.com/tinylib/msgp v1.1.9-0.20230705140925-6ac204f0b4d4/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/internal/dsync/lock-args.go b/internal/dsync/lock-args.go index 66adead7f..8525618e9 100644 --- a/internal/dsync/lock-args.go +++ b/internal/dsync/lock-args.go @@ -38,3 +38,21 @@ type LockArgs struct { // Quorum represents the expected quorum for this lock type. Quorum int } + +// ResponseCode is the response code for a locking request. +type ResponseCode uint8 + +// Response codes for a locking request. +const ( + RespOK ResponseCode = iota + RespLockConflict + RespLockNotInitialized + RespLockNotFound + RespErr +) + +// LockResp is a locking request response. +type LockResp struct { + Code ResponseCode + Err string +} diff --git a/internal/dsync/lock-args_gen.go b/internal/dsync/lock-args_gen.go index 46bc4b475..1ac930ab6 100644 --- a/internal/dsync/lock-args_gen.go +++ b/internal/dsync/lock-args_gen.go @@ -248,3 +248,191 @@ func (z *LockArgs) Msgsize() (s int) { s += 7 + msgp.StringPrefixSize + len(z.Source) + 6 + msgp.StringPrefixSize + len(z.Owner) + 7 + msgp.IntSize return } + +// DecodeMsg implements msgp.Decodable +func (z *LockResp) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Code": + { + var zb0002 uint8 + zb0002, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err, "Code") + return + } + z.Code = ResponseCode(zb0002) + } + case "Err": + z.Err, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z LockResp) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "Code" + err = en.Append(0x82, 0xa4, 0x43, 0x6f, 0x64, 0x65) + if err != nil { + return + } + err = en.WriteUint8(uint8(z.Code)) + if err != nil { + err = msgp.WrapError(err, "Code") + return + } + // write "Err" + err = en.Append(0xa3, 0x45, 0x72, 0x72) + if err != nil { + return + } + err = en.WriteString(z.Err) + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z LockResp) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "Code" + o = append(o, 0x82, 0xa4, 0x43, 0x6f, 0x64, 0x65) + o = msgp.AppendUint8(o, uint8(z.Code)) + // string "Err" + o = append(o, 0xa3, 0x45, 0x72, 0x72) + o = msgp.AppendString(o, z.Err) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *LockResp) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Code": + { + var zb0002 uint8 + zb0002, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Code") + return + } + z.Code = ResponseCode(zb0002) + } + case "Err": + z.Err, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z LockResp) Msgsize() (s int) { + s = 1 + 5 + msgp.Uint8Size + 4 + msgp.StringPrefixSize + len(z.Err) + return +} + +// DecodeMsg implements msgp.Decodable +func (z *ResponseCode) DecodeMsg(dc *msgp.Reader) (err error) { + { + var zb0001 uint8 + zb0001, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = ResponseCode(zb0001) + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z ResponseCode) EncodeMsg(en *msgp.Writer) (err error) { + err = en.WriteUint8(uint8(z)) + if err != nil { + err = msgp.WrapError(err) + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z ResponseCode) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + o = msgp.AppendUint8(o, uint8(z)) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *ResponseCode) UnmarshalMsg(bts []byte) (o []byte, err error) { + { + var zb0001 uint8 + zb0001, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = ResponseCode(zb0001) + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z ResponseCode) Msgsize() (s int) { + s = msgp.Uint8Size + return +} diff --git a/internal/dsync/lock-args_gen_test.go b/internal/dsync/lock-args_gen_test.go index 7f9c93a21..d94a51525 100644 --- a/internal/dsync/lock-args_gen_test.go +++ b/internal/dsync/lock-args_gen_test.go @@ -121,3 +121,116 @@ func BenchmarkDecodeLockArgs(b *testing.B) { } } } + +func TestMarshalUnmarshalLockResp(t *testing.T) { + v := LockResp{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgLockResp(b *testing.B) { + v := LockResp{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgLockResp(b *testing.B) { + v := LockResp{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalLockResp(b *testing.B) { + v := LockResp{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeLockResp(t *testing.T) { + v := LockResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeLockResp Msgsize() is inaccurate") + } + + vn := LockResp{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeLockResp(b *testing.B) { + v := LockResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeLockResp(b *testing.B) { + v := LockResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/internal/grid/README.md b/internal/grid/README.md new file mode 100644 index 000000000..8356dd535 --- /dev/null +++ b/internal/grid/README.md @@ -0,0 +1,251 @@ +# MinIO Grid + +The MinIO Grid is a package that provides two-way communication between servers. +It uses a single two-way connection to send and receive messages between servers. + +It includes built in muxing of concurrent requests as well as congestion handling for streams. + +Requests can be "Single Payload" or "Streamed". + +Use the MinIO Grid for: + +* Small, frequent requests with low latency requirements. +* Long-running requests with small/medium payloads. + +Do *not* use the MinIO Grid for: + +* Large payloads. + +Only a single connection is ever made between two servers. +Likely this means that this connection will not be able to saturate network bandwidth. +Therefore, using this for large payloads will likely be slower than using a separate connection, +and other connections will be blocked while the large payload is being sent. + +## Handlers & Routes + +Handlers have a predefined Handler ID. +In addition, there can be several *static* subroutes used to differentiate between different handlers of the same ID. +A subroute on a client must match a subroute on the server. So routes cannot be used for dynamic routing, unlike HTTP. + +Handlers should remain backwards compatible. If a breaking API change is required, a new handler ID should be created. + +## Setup & Configuration + +A **Manager** is used to manage all incoming and outgoing connections to a server. + +On startup all remote servers must be specified. +From that individual connections will be spawned to each remote server, +or incoming requests will be hooked up to the appropriate connection. + +To get a connection to a specific server, use `Manager.Connection(host)` to get a connection to the specified host. +From this connection individual requests can be made. + +Each handler, with optional subroutes can be registered with the manager using +`Manager.RegisterXHandler(handlerID, handler, subroutes...)`. + +A `Handler()` function provides an HTTP handler, which should be hooked up to the appropriate route on the server. + +On startup, the manager will start connecting to remotes and also starts listening for incoming connections. +Until a connection is established, all outgoing requests will return `ErrDisconnected`. + +# Usage + +## Single Payload Requests + +Single payload requests are requests and responses that are sent in a single message. +In essence, they are `[]byte` -> `[]byte, error` functions. + +It is not possible to return *both* an error and a response. + +Handlers are registered on the manager using `(*Manager).RegisterSingleHandler(id HandlerID, h SingleHandlerFn, subroute ...string)`. + +The server handler function has this signature: `type SingleHandlerFn func(payload []byte) ([]byte, *RemoteErr)`. + +Sample handler: +```go + handler := func(payload []byte) ([]byte, *grid.RemoteErr) { + // Do something with payload + return []byte("response"), nil + } + + err := manager.RegisterSingleHandler(grid.HandlerDiskInfo, handler) +``` + +Sample call: +```go + // Get a connection to the remote host + conn := manager.Connection(host) + + payload := []byte("request") + response, err := conn.SingleRequest(ctx, grid.HandlerDiskInfo, payload) +``` + +If the error type is `*RemoteErr`, then the error was returned by the remote server. Otherwise it is a local error. + +Context timeouts are propagated, and a default timeout of 1 minute is added if none is specified. + +There is no cancellation propagation for single payload requests. +When the context is canceled, the request will return at once with an appropriate error. +However, the remote call will not see the cancellation - as can be seen from the 'missing' context on the handler. +The result will be discarded. + +### Typed handlers + +Typed handlers are handlers that have a specific type for the request and response payloads. +These must provide `msgp` serialization and deserialization. + +In the examples we use a `MSS` type, which is a `map[string]string` that is `msgp` serializable. + +```go + handler := func(request *grid.MSS) (*grid.MSS, *grid.RemoteErr) { + fmt.Println("Got request with field", request["myfield"]) + // Do something with payload + return NewMSSWith(map[string]string{"result": "ok"}), nil + } + + // Create a typed handler. + // Due to current generics limitations, a constructor of the empty type must be provided. + instance := grid.NewSingleHandler[*grid.MSS, *grid.MSS](h, grid.NewMSS, grid.NewMSS) + + // Register the handler on the manager + instance.Register(manager, handler) + + // The typed instance is also used for calls + conn := manager.Connection("host") + resp, err := instance.Call(ctx, conn, grid.NewMSSWith(map[string]string{"myfield": "myvalue"})) + if err == nil { + fmt.Println("Got response with field", resp["result"]) + } +``` + +The wrapper will handle all serialization and de-seralization of the request and response, +and furthermore provides re-use of the structs used for the request and response. + +Note that Responses sent for serialization are automatically reused for similar requests. +If the response contains shared data it will cause issues, since each unique response is reused. +To disable this behavior, use `(SingleHandler).WithSharedResponse()` to disable it. + +## Streaming Requests + +Streams consists of an initial request with payload and allows for full two-way communication between the client and server. + +The handler function has this signature. + +Sample handler: +```go + handler := func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr { + fmt.Println("Got request with initial payload", p, "from", GetCaller(ctx context.Context)) + fmt.Println("Subroute:", GetSubroute(ctx)) + for { + select { + case <-ctx.Done(): + return nil + case req, ok := <-in: + if !ok { + break + } + // Do something with payload + out <- []byte("response") + + // Return the request for reuse + grid.PutByteBuffer(req) + } + } + // out is closed by the caller and should never be closed by the handler. + return nil + } + + err := manager.RegisterStreamingHandler(grid.HandlerDiskInfo, StreamHandler{ + Handle: handler, + Subroute: "asubroute", + OutCapacity: 1, + InCapacity: 1, + }) +``` + +Sample call: +```go + // Get a connection to the remote host + conn := manager.Connection(host).Subroute("asubroute") + + payload := []byte("request") + stream, err := conn.NewStream(ctx, grid.HandlerDiskInfo, payload) + if err != nil { + return err + } + // Read results from the stream + err = stream.Results(func(result []byte) error { + fmt.Println("Got result", string(result)) + + // Return the response for reuse + grid.PutByteBuffer(result) + return nil + }) +``` + +Context cancellation and timeouts are propagated to the handler. +The client does not wait for the remote handler to finish before returning. +Returning any error will also cancel the stream remotely. + +CAREFUL: When utilizing two-way communication, it is important to ensure that the remote handler is not blocked on a send. +If the remote handler is blocked on a send, and the client is trying to send without the remote receiving, +the operation would become deadlocked if the channels are full. + +### Typed handlers + +Typed handlers are handlers that have a specific type for the request and response payloads. + +```go + // Create a typed handler. + handler := func(ctx context.Context, p *Payload, in <-chan *Req, out chan<- *Resp) *RemoteErr { + fmt.Println("Got request with initial payload", p, "from", GetCaller(ctx context.Context)) + fmt.Println("Subroute:", GetSubroute(ctx)) + for { + select { + case <-ctx.Done(): + return nil + case req, ok := <-in: + if !ok { + break + } + fmt.Println("Got request", in) + // Do something with payload + out <- Resp{"response"} + } + // out is closed by the caller and should never be closed by the handler. + return nil + } + + // Create a typed handler. + // Due to current generics limitations, a constructor of the empty type must be provided. + instance := grid.NewStream[*Payload, *Req, *Resp](h, newPayload, newReq, newResp) + + // Tweakable options + instance.WithPayload = true // default true when newPayload != nil + instance.OutCapacity = 1 // default + instance.InCapacity = 1 // default true when newReq != nil + + // Register the handler on the manager + instance.Register(manager, handler, "asubroute") + + // The typed instance is also used for calls + conn := manager.Connection("host").Subroute("asubroute") + stream, err := instance.Call(ctx, conn, &Payload{"request payload"}) + if err != nil { ... } + + // Read results from the stream + err = stream.Results(func(resp *Resp) error { + fmt.Println("Got result", resp) + // Return the response for reuse + instance.PutResponse(resp) + return nil + }) +``` + +There are handlers for requests with: + * No input stream: `RegisterNoInput`. + * No initial payload: `RegisterNoPayload`. + +Note that Responses sent for serialization are automatically reused for similar requests. +If the response contains shared data it will cause issues, since each unique response is reused. +To disable this behavior, use `(StreamTypeHandler).WithSharedResponse()` to disable it. diff --git a/internal/grid/benchmark_test.go b/internal/grid/benchmark_test.go new file mode 100644 index 000000000..54feb9aa2 --- /dev/null +++ b/internal/grid/benchmark_test.go @@ -0,0 +1,440 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "fmt" + "math/rand" + "runtime" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/minio/minio/internal/logger/target/testlogger" +) + +func BenchmarkRequests(b *testing.B) { + for n := 2; n <= 32; n *= 2 { + b.Run("servers="+strconv.Itoa(n), func(b *testing.B) { + benchmarkGridRequests(b, n) + }) + } +} + +func benchmarkGridRequests(b *testing.B, n int) { + defer testlogger.T.SetErrorTB(b)() + errFatal := func(err error) { + b.Helper() + if err != nil { + b.Fatal(err) + } + } + rpc := NewSingleHandler[*testRequest, *testResponse](handlerTest2, newTestRequest, newTestResponse) + grid, err := SetupTestGrid(n) + errFatal(err) + b.Cleanup(grid.Cleanup) + // Create n managers. + for _, remote := range grid.Managers { + // Register a single handler which echos the payload. + errFatal(remote.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) { + defer PutByteBuffer(payload) + return append(GetByteBuffer()[:0], payload...), nil + })) + errFatal(rpc.Register(remote, func(req *testRequest) (resp *testResponse, err *RemoteErr) { + return &testResponse{ + OrgNum: req.Num, + OrgString: req.String, + Embedded: *req, + }, nil + })) + errFatal(err) + } + const payloadSize = 512 + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + payload := make([]byte, payloadSize) + _, err = rng.Read(payload) + errFatal(err) + + // Wait for all to connect + // Parallel writes per server. + b.Run("bytes", func(b *testing.B) { + for par := 1; par <= 32; par *= 2 { + b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) { + defer timeout(60 * time.Second)() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + b.ReportAllocs() + b.SetBytes(int64(len(payload) * 2)) + b.ResetTimer() + t := time.Now() + var ops int64 + var lat int64 + b.SetParallelism(par) + b.RunParallel(func(pb *testing.PB) { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + n := 0 + var latency int64 + managers := grid.Managers + hosts := grid.Hosts + for pb.Next() { + // Pick a random manager. + src, dst := rng.Intn(len(managers)), rng.Intn(len(managers)) + if src == dst { + dst = (dst + 1) % len(managers) + } + local := managers[src] + conn := local.Connection(hosts[dst]) + if conn == nil { + b.Fatal("No connection") + } + // Send the payload. + t := time.Now() + resp, err := conn.Request(ctx, handlerTest, payload) + latency += time.Since(t).Nanoseconds() + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + PutByteBuffer(resp) + n++ + } + atomic.AddInt64(&ops, int64(n)) + atomic.AddInt64(&lat, latency) + }) + spent := time.Since(t) + if spent > 0 && n > 0 { + // Since we are benchmarking n parallel servers we need to multiply by n. + // This will give an estimate of the total ops/s. + latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond) + b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s") + b.ReportMetric(latency/float64(ops), "ms/op") + } + }) + } + }) + b.Run("rpc", func(b *testing.B) { + for par := 1; par <= 32; par *= 2 { + b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) { + defer timeout(60 * time.Second)() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + b.ReportAllocs() + b.ResetTimer() + t := time.Now() + var ops int64 + var lat int64 + b.SetParallelism(par) + b.RunParallel(func(pb *testing.PB) { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + n := 0 + var latency int64 + managers := grid.Managers + hosts := grid.Hosts + req := testRequest{ + Num: rng.Int(), + String: "hello", + } + for pb.Next() { + // Pick a random manager. + src, dst := rng.Intn(len(managers)), rng.Intn(len(managers)) + if src == dst { + dst = (dst + 1) % len(managers) + } + local := managers[src] + conn := local.Connection(hosts[dst]) + if conn == nil { + b.Fatal("No connection") + } + // Send the payload. + t := time.Now() + resp, err := rpc.Call(ctx, conn, &req) + latency += time.Since(t).Nanoseconds() + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + rpc.PutResponse(resp) + n++ + } + atomic.AddInt64(&ops, int64(n)) + atomic.AddInt64(&lat, latency) + }) + spent := time.Since(t) + if spent > 0 && n > 0 { + // Since we are benchmarking n parallel servers we need to multiply by n. + // This will give an estimate of the total ops/s. + latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond) + b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s") + b.ReportMetric(latency/float64(ops), "ms/op") + } + }) + } + }) +} + +func BenchmarkStream(b *testing.B) { + tests := []struct { + name string + fn func(b *testing.B, n int) + }{ + {name: "request", fn: benchmarkGridStreamReqOnly}, + {name: "responses", fn: benchmarkGridStreamRespOnly}, + } + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + for n := 2; n <= 32; n *= 2 { + b.Run("servers="+strconv.Itoa(n), func(b *testing.B) { + test.fn(b, n) + }) + } + }) + } +} + +func benchmarkGridStreamRespOnly(b *testing.B, n int) { + defer testlogger.T.SetErrorTB(b)() + errFatal := func(err error) { + b.Helper() + if err != nil { + b.Fatal(err) + } + } + grid, err := SetupTestGrid(n) + errFatal(err) + b.Cleanup(grid.Cleanup) + const responses = 10 + // Create n managers. + for _, remote := range grid.Managers { + // Register a single handler which echos the payload. + errFatal(remote.RegisterStreamingHandler(handlerTest, StreamHandler{ + // Send 10x response. + Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, out chan<- []byte) *RemoteErr { + for i := 0; i < responses; i++ { + toSend := GetByteBuffer()[:0] + toSend = append(toSend, byte(i)) + toSend = append(toSend, payload...) + select { + case <-ctx.Done(): + return nil + case out <- toSend: + } + } + return nil + }, + + Subroute: "some-subroute", + OutCapacity: 1, // Only one message buffered. + InCapacity: 0, + })) + errFatal(err) + } + const payloadSize = 512 + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + payload := make([]byte, payloadSize) + _, err = rng.Read(payload) + errFatal(err) + + // Wait for all to connect + // Parallel writes per server. + for par := 1; par <= 32; par *= 2 { + b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) { + defer timeout(30 * time.Second)() + b.ReportAllocs() + b.SetBytes(int64(len(payload) * (responses + 1))) + b.ResetTimer() + t := time.Now() + var ops int64 + var lat int64 + b.SetParallelism(par) + b.RunParallel(func(pb *testing.PB) { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + n := 0 + var latency int64 + managers := grid.Managers + hosts := grid.Hosts + for pb.Next() { + // Pick a random manager. + src, dst := rng.Intn(len(managers)), rng.Intn(len(managers)) + if src == dst { + dst = (dst + 1) % len(managers) + } + local := managers[src] + conn := local.Connection(hosts[dst]).Subroute("some-subroute") + if conn == nil { + b.Fatal("No connection") + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // Send the payload. + t := time.Now() + st, err := conn.NewStream(ctx, handlerTest, payload) + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + got := 0 + err = st.Results(func(b []byte) error { + got++ + PutByteBuffer(b) + return nil + }) + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + latency += time.Since(t).Nanoseconds() + cancel() + n += got + } + atomic.AddInt64(&ops, int64(n)) + atomic.AddInt64(&lat, latency) + }) + spent := time.Since(t) + if spent > 0 && n > 0 { + // Since we are benchmarking n parallel servers we need to multiply by n. + // This will give an estimate of the total ops/s. + latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond) + b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s") + b.ReportMetric(latency/float64(ops), "ms/op") + } + }) + } +} + +func benchmarkGridStreamReqOnly(b *testing.B, n int) { + defer testlogger.T.SetErrorTB(b)() + errFatal := func(err error) { + b.Helper() + if err != nil { + b.Fatal(err) + } + } + grid, err := SetupTestGrid(n) + errFatal(err) + b.Cleanup(grid.Cleanup) + const requests = 10 + // Create n managers. + for _, remote := range grid.Managers { + // Register a single handler which echos the payload. + errFatal(remote.RegisterStreamingHandler(handlerTest, StreamHandler{ + // Send 10x requests. + Handle: func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr { + got := 0 + for b := range in { + PutByteBuffer(b) + got++ + } + if got != requests { + return NewRemoteErrf("wrong number of requests. want %d, got %d", requests, got) + } + return nil + }, + + Subroute: "some-subroute", + OutCapacity: 1, + InCapacity: 1, // Only one message buffered. + })) + errFatal(err) + } + const payloadSize = 512 + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + payload := make([]byte, payloadSize) + _, err = rng.Read(payload) + errFatal(err) + + // Wait for all to connect + // Parallel writes per server. + for par := 1; par <= 32; par *= 2 { + b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) { + defer timeout(30 * time.Second)() + b.ReportAllocs() + b.SetBytes(int64(len(payload) * (requests + 1))) + b.ResetTimer() + t := time.Now() + var ops int64 + var lat int64 + b.SetParallelism(par) + b.RunParallel(func(pb *testing.PB) { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + n := 0 + var latency int64 + managers := grid.Managers + hosts := grid.Hosts + for pb.Next() { + // Pick a random manager. + src, dst := rng.Intn(len(managers)), rng.Intn(len(managers)) + if src == dst { + dst = (dst + 1) % len(managers) + } + local := managers[src] + conn := local.Connection(hosts[dst]).Subroute("some-subroute") + if conn == nil { + b.Fatal("No connection") + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // Send the payload. + t := time.Now() + st, err := conn.NewStream(ctx, handlerTest, payload) + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + got := 0 + for i := 0; i < requests; i++ { + got++ + st.Requests <- append(GetByteBuffer()[:0], payload...) + } + close(st.Requests) + err = st.Results(func(b []byte) error { + return nil + }) + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + latency += time.Since(t).Nanoseconds() + cancel() + n += got + } + atomic.AddInt64(&ops, int64(n)) + atomic.AddInt64(&lat, latency) + }) + spent := time.Since(t) + if spent > 0 && n > 0 { + // Since we are benchmarking n parallel servers we need to multiply by n. + // This will give an estimate of the total ops/s. + latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond) + b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s") + b.ReportMetric(latency/float64(ops), "ms/op") + } + }) + } +} diff --git a/internal/grid/connection.go b/internal/grid/connection.go new file mode 100644 index 000000000..395a13fc9 --- /dev/null +++ b/internal/grid/connection.go @@ -0,0 +1,1604 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "math/rand" + "net" + "net/http" + "runtime/debug" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/google/uuid" + "github.com/minio/madmin-go/v3" + "github.com/minio/minio/internal/logger" + "github.com/minio/minio/internal/pubsub" + "github.com/tinylib/msgp/msgp" + "github.com/zeebo/xxh3" +) + +// A Connection is a remote connection. +// There is no distinction externally whether the connection was initiated from +// this server or from the remote. +type Connection struct { + // NextID is the next ID that can be used (atomic). + NextID uint64 + + // LastPong is last pong time (atomic) + // Only valid when StateConnected. + LastPong int64 + + // State of the connection (atomic) + state State + + // Non-atomic + Remote string + Local string + + // ID of this connection instance. + id uuid.UUID + + // Remote uuid, if we have been connected. + remoteID *uuid.UUID + + // Context for the server. + ctx context.Context + + // Active mux connections. + outgoing *lockedClientMap + + // Incoming streams + inStream *lockedServerMap + + // outQueue is the output queue + outQueue chan []byte + + // Client or serverside. + side ws.State + + // Transport for outgoing connections. + dialer ContextDialer + header http.Header + + handleMsgWg sync.WaitGroup + + // connChange will be signaled whenever State has been updated, or at regular intervals. + // Holding the lock allows safe reads of State, and guarantees that changes will be detected. + connChange *sync.Cond + handlers *handlers + + remote *RemoteClient + auth AuthFn + clientPingInterval time.Duration + connPingInterval time.Duration + tlsConfig *tls.Config + blockConnect chan struct{} + + incomingBytes func(n int64) // Record incoming bytes. + outgoingBytes func(n int64) // Record outgoing bytes. + trace *tracer // tracer for this connection. + baseFlags Flags + + // For testing only + debugInConn net.Conn + debugOutConn net.Conn + addDeadline time.Duration + connMu sync.Mutex +} + +// Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID. +type Subroute struct { + *Connection + trace *tracer + route string + subID subHandlerID +} + +// String returns a string representation of the connection. +func (c *Connection) String() string { + return fmt.Sprintf("%s->%s", c.Local, c.Remote) +} + +// StringReverse returns a string representation of the reverse connection. +func (c *Connection) StringReverse() string { + return fmt.Sprintf("%s->%s", c.Remote, c.Local) +} + +// State is a connection state. +type State uint32 + +// MANUAL go:generate stringer -type=State -output=state_string.go -trimprefix=State $GOFILE + +const ( + // StateUnconnected is the initial state of a connection. + // When the first message is sent it will attempt to connect. + StateUnconnected = iota + + // StateConnecting is the state from StateUnconnected while the connection is attempted to be established. + // After this connection will be StateConnected or StateConnectionError. + StateConnecting + + // StateConnected is the state when the connection has been established and is considered stable. + // If the connection is lost, state will switch to StateConnecting. + StateConnected + + // StateConnectionError is the state once a connection attempt has been made, and it failed. + // The connection will remain in this stat until the connection has been successfully re-established. + StateConnectionError + + // StateShutdown is the state when the server has been shut down. + // This will not be used under normal operation. + StateShutdown + + // MaxDeadline is the maximum deadline allowed, + // Approx 49 days. + MaxDeadline = time.Duration(math.MaxUint32) * time.Millisecond +) + +// ContextDialer is a dialer that can be used to dial a remote. +type ContextDialer func(ctx context.Context, network, address string) (net.Conn, error) + +// DialContext implements the Dialer interface. +func (c ContextDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return c(ctx, network, address) +} + +const ( + defaultOutQueue = 10000 + readBufferSize = 16 << 10 + writeBufferSize = 16 << 10 + defaultDialTimeout = 2 * time.Second + connPingInterval = 10 * time.Second +) + +type connectionParams struct { + ctx context.Context + id uuid.UUID + local, remote string + dial ContextDialer + handlers *handlers + auth AuthFn + tlsConfig *tls.Config + incomingBytes func(n int64) // Record incoming bytes. + outgoingBytes func(n int64) // Record outgoing bytes. + publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] + + blockConnect chan struct{} +} + +// newConnection will create an unconnected connection to a remote. +func newConnection(o connectionParams) *Connection { + c := &Connection{ + state: StateUnconnected, + Remote: o.remote, + Local: o.local, + id: o.id, + ctx: o.ctx, + outgoing: &lockedClientMap{m: make(map[uint64]*muxClient, 1000)}, + inStream: &lockedServerMap{m: make(map[uint64]*muxServer, 1000)}, + outQueue: make(chan []byte, defaultOutQueue), + dialer: o.dial, + side: ws.StateServerSide, + connChange: &sync.Cond{L: &sync.Mutex{}}, + handlers: o.handlers, + auth: o.auth, + header: make(http.Header, 1), + remote: &RemoteClient{Name: o.remote}, + clientPingInterval: clientPingInterval, + connPingInterval: connPingInterval, + tlsConfig: o.tlsConfig, + incomingBytes: o.incomingBytes, + outgoingBytes: o.outgoingBytes, + } + if debugPrint { + // Random Mux ID + c.NextID = rand.Uint64() + } + if !strings.HasPrefix(o.remote, "https://") && !strings.HasPrefix(o.remote, "wss://") { + c.baseFlags |= FlagCRCxxh3 + } + if !strings.HasPrefix(o.local, "https://") && !strings.HasPrefix(o.local, "wss://") { + c.baseFlags |= FlagCRCxxh3 + } + if o.publisher != nil { + c.traceRequests(o.publisher) + } + if o.local == o.remote { + panic("equal hosts") + } + if c.shouldConnect() { + c.side = ws.StateClientSide + + go func() { + if o.blockConnect != nil { + <-o.blockConnect + } + c.connect() + }() + } + if debugPrint { + fmt.Println(c.Local, "->", c.Remote, "Should local connect:", c.shouldConnect(), "side:", c.side) + } + if debugReqs { + fmt.Println("Created connection", c.String()) + } + return c +} + +// Subroute returns a static subroute for the connection. +func (c *Connection) Subroute(s string) *Subroute { + if c == nil { + return nil + } + return &Subroute{ + Connection: c, + route: s, + subID: makeSubHandlerID(0, s), + trace: c.trace.subroute(s), + } +} + +// Subroute adds a subroute to the subroute. +// The subroutes are combined with '/'. +func (c *Subroute) Subroute(s string) *Subroute { + route := strings.Join([]string{c.route, s}, "/") + return &Subroute{ + Connection: c.Connection, + route: route, + subID: makeSubHandlerID(0, route), + trace: c.trace.subroute(route), + } +} + +// newMuxClient returns a mux client for manual use. +func (c *Connection) newMuxClient(ctx context.Context) (*muxClient, error) { + client := newMuxClient(ctx, atomic.AddUint64(&c.NextID, 1), c) + if dl, ok := ctx.Deadline(); ok { + client.deadline = getDeadline(time.Until(dl)) + if client.deadline == 0 { + return nil, context.DeadlineExceeded + } + } + for { + // Handle the extremely unlikely scenario that we wrapped. + if _, loaded := c.outgoing.LoadOrStore(client.MuxID, client); client.MuxID != 0 && !loaded { + if debugReqs { + _, found := c.outgoing.Load(client.MuxID) + fmt.Println(client.MuxID, c.String(), "Connection.newMuxClient: RELOADED MUX. loaded:", loaded, "found:", found) + } + return client, nil + } + client.MuxID = atomic.AddUint64(&c.NextID, 1) + } +} + +// newMuxClient returns a mux client for manual use. +func (c *Subroute) newMuxClient(ctx context.Context) (*muxClient, error) { + cl, err := c.Connection.newMuxClient(ctx) + if err != nil { + return nil, err + } + cl.subroute = &c.subID + return cl, nil +} + +// Request allows to do a single remote request. +// 'req' will not be used after the call and caller can reuse. +// If no deadline is set on ctx, a 1-minute deadline will be added. +func (c *Connection) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) { + if !h.valid() { + return nil, ErrUnknownHandler + } + if c.State() != StateConnected { + return nil, ErrDisconnected + } + handler := c.handlers.single[h] + if handler == nil { + return nil, ErrUnknownHandler + } + client, err := c.newMuxClient(ctx) + if err != nil { + return nil, err + } + defer func() { + if debugReqs { + _, ok := c.outgoing.Load(client.MuxID) + fmt.Println(client.MuxID, c.String(), "Connection.Request: DELETING MUX. Exists:", ok) + } + c.outgoing.Delete(client.MuxID) + }() + return client.traceRoundtrip(c.trace, h, req) +} + +// Request allows to do a single remote request. +// 'req' will not be used after the call and caller can reuse. +// If no deadline is set on ctx, a 1-minute deadline will be added. +func (c *Subroute) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) { + if !h.valid() { + return nil, ErrUnknownHandler + } + if c.State() != StateConnected { + return nil, ErrDisconnected + } + handler := c.handlers.subSingle[makeZeroSubHandlerID(h)] + if handler == nil { + return nil, ErrUnknownHandler + } + client, err := c.newMuxClient(ctx) + if err != nil { + return nil, err + } + client.subroute = &c.subID + defer func() { + if debugReqs { + fmt.Println(client.MuxID, c.String(), "Subroute.Request: DELETING MUX") + } + c.outgoing.Delete(client.MuxID) + }() + return client.traceRoundtrip(c.trace, h, req) +} + +// NewStream creates a new stream. +// Initial payload can be reused by the caller. +func (c *Connection) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) { + if !h.valid() { + return nil, ErrUnknownHandler + } + if c.State() != StateConnected { + return nil, ErrDisconnected + } + handler := c.handlers.streams[h] + if handler == nil { + return nil, ErrUnknownHandler + } + + var requests chan []byte + var responses chan Response + if handler.InCapacity > 0 { + requests = make(chan []byte, handler.InCapacity) + } + if handler.OutCapacity > 0 { + responses = make(chan Response, handler.OutCapacity) + } else { + responses = make(chan Response, 1) + } + + cl, err := c.newMuxClient(ctx) + if err != nil { + return nil, err + } + + return cl.RequestStream(h, payload, requests, responses) +} + +// NewStream creates a new stream. +// Initial payload can be reused by the caller. +func (c *Subroute) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) { + if !h.valid() { + return nil, ErrUnknownHandler + } + if c.State() != StateConnected { + return nil, ErrDisconnected + } + handler := c.handlers.subStreams[makeZeroSubHandlerID(h)] + if handler == nil { + if debugPrint { + fmt.Println("want", makeZeroSubHandlerID(h), c.route, "got", c.handlers.subStreams) + } + return nil, ErrUnknownHandler + } + + var requests chan []byte + var responses chan Response + if handler.InCapacity > 0 { + requests = make(chan []byte, handler.InCapacity) + } + if handler.OutCapacity > 0 { + responses = make(chan Response, handler.OutCapacity) + } else { + responses = make(chan Response, 1) + } + + cl, err := c.newMuxClient(ctx) + if err != nil { + return nil, err + } + cl.subroute = &c.subID + + return cl.RequestStream(h, payload, requests, responses) +} + +// WaitForConnect will block until a connection has been established or +// the context is canceled, in which case the context error is returned. +func (c *Connection) WaitForConnect(ctx context.Context) error { + if debugPrint { + fmt.Println(c.Local, "->", c.Remote, "WaitForConnect") + defer fmt.Println(c.Local, "->", c.Remote, "WaitForConnect done") + } + c.connChange.L.Lock() + if atomic.LoadUint32((*uint32)(&c.state)) == StateConnected { + c.connChange.L.Unlock() + // Happy path. + return nil + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + changed := make(chan State, 1) + go func() { + defer close(changed) + for { + c.connChange.Wait() + newState := c.State() + select { + case changed <- newState: + if newState == StateConnected || newState == StateShutdown { + c.connChange.L.Unlock() + return + } + case <-ctx.Done(): + c.connChange.L.Unlock() + return + } + } + }() + + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case newState := <-changed: + if newState == StateConnected { + return nil + } + } + } +} + +/* +var ErrDone = errors.New("done for now") + +var ErrRemoteRestart = errors.New("remote restarted") + + +// Stateless connects to the remote handler and return all packets sent back. +// If the remote is restarted will return ErrRemoteRestart. +// If nil will be returned remote call sent EOF or ErrDone is returned by the callback. +// If ErrDone is returned on cb nil will be returned. +func (c *Connection) Stateless(ctx context.Context, h HandlerID, req []byte, cb func([]byte) error) error { + client, err := c.newMuxClient(ctx) + if err != nil { + return err + } + defer c.outgoing.Delete(client.MuxID) + resp := make(chan Response, 10) + client.RequestStateless(h, req, resp) + + for r := range resp { + if r.Err != nil { + return r.Err + } + if len(r.Msg) > 0 { + err := cb(r.Msg) + if err != nil { + if errors.Is(err, ErrDone) { + break + } + return err + } + } + } + return nil +} +*/ + +// shouldConnect returns a deterministic bool whether the local should initiate the connection. +// It should be 50% chance of any host initiating the connection. +func (c *Connection) shouldConnect() bool { + // The remote should have the opposite result. + h0 := xxh3.HashString(c.Local + c.Remote) + h1 := xxh3.HashString(c.Remote + c.Local) + if h0 == h1 { + return c.Local < c.Remote + } + return h0 < h1 +} + +func (c *Connection) send(msg []byte) error { + select { + case <-c.ctx.Done(): + return context.Cause(c.ctx) + case c.outQueue <- msg: + return nil + } +} + +// queueMsg queues a message, with an optional payload. +// sender should not reference msg.Payload +func (c *Connection) queueMsg(msg message, payload sender) error { + msg.Flags |= c.baseFlags + if payload != nil { + if cap(msg.Payload) < payload.Msgsize() { + old := msg.Payload + msg.Payload = GetByteBuffer()[:0] + PutByteBuffer(old) + } + var err error + msg.Payload, err = payload.MarshalMsg(msg.Payload[:0]) + msg.Op = payload.Op() + if err != nil { + return err + } + } + defer PutByteBuffer(msg.Payload) + dst := GetByteBuffer()[:0] + dst, err := msg.MarshalMsg(dst) + if err != nil { + return err + } + if msg.Flags&FlagCRCxxh3 != 0 { + h := xxh3.Hash(dst) + dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) + } + return c.send(dst) +} + +// sendMsg will send +func (c *Connection) sendMsg(conn net.Conn, msg message, payload msgp.MarshalSizer) error { + if payload != nil { + if cap(msg.Payload) < payload.Msgsize() { + PutByteBuffer(msg.Payload) + msg.Payload = GetByteBuffer()[:0] + } + var err error + msg.Payload, err = payload.MarshalMsg(msg.Payload) + if err != nil { + return err + } + defer PutByteBuffer(msg.Payload) + } + dst := GetByteBuffer()[:0] + dst, err := msg.MarshalMsg(dst) + if err != nil { + return err + } + if msg.Flags&FlagCRCxxh3 != 0 { + h := xxh3.Hash(dst) + dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) + } + if debugPrint { + fmt.Println(c.Local, "sendMsg: Sending", msg.Op, "as", len(dst), "bytes") + } + if c.outgoingBytes != nil { + c.outgoingBytes(int64(len(dst))) + } + return wsutil.WriteMessage(conn, c.side, ws.OpBinary, dst) +} + +func (c *Connection) connect() { + c.updateState(StateConnecting) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + // Runs until the server is shut down. + for { + if c.State() == StateShutdown { + return + } + toDial := strings.Replace(c.Remote, "http://", "ws://", 1) + toDial = strings.Replace(toDial, "https://", "wss://", 1) + toDial += RoutePath + + dialer := ws.DefaultDialer + dialer.ReadBufferSize = readBufferSize + dialer.WriteBufferSize = writeBufferSize + dialer.Timeout = defaultDialTimeout + if c.dialer != nil { + dialer.NetDial = c.dialer.DialContext + } + if c.header == nil { + c.header = make(http.Header, 2) + } + c.header.Set("Authorization", "Bearer "+c.auth("")) + c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339)) + + if len(c.header) > 0 { + dialer.Header = ws.HandshakeHeaderHTTP(c.header) + } + dialer.TLSConfig = c.tlsConfig + dialStarted := time.Now() + if debugPrint { + fmt.Println(c.Local, "Connecting to ", toDial) + } + conn, br, _, err := dialer.Dial(c.ctx, toDial) + if br != nil { + ws.PutReader(br) + } + c.connMu.Lock() + c.debugOutConn = conn + c.connMu.Unlock() + retry := func(err error) { + if debugPrint { + fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err) + } + sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout))) + next := dialStarted.Add(sleep) + sleep = time.Until(next).Round(time.Millisecond) + if sleep < 0 { + sleep = 0 + } + gotState := c.State() + if gotState == StateShutdown { + return + } + if gotState != StateConnecting { + // Don't print error on first attempt, + // and after that only once per hour. + cHour := strconv.FormatInt(time.Now().Unix()/60/60, 10) + logger.LogOnceIf(c.ctx, fmt.Errorf("grid: %s connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), c.Local+toDial+cHour+err.Error()) + } + c.updateState(StateConnectionError) + time.Sleep(sleep) + } + if err != nil { + retry(err) + continue + } + // Send connect message. + m := message{ + Op: OpConnect, + } + req := connectReq{ + Host: c.Local, + ID: c.id, + } + err = c.sendMsg(conn, m, &req) + if err != nil { + retry(err) + continue + } + // Wait for response + var r connectResp + err = c.receive(conn, &r) + if err != nil { + if debugPrint { + fmt.Println(c.Local, "receive err:", err, "side:", c.side) + } + retry(err) + continue + } + if debugPrint { + fmt.Println(c.Local, "Got connectResp:", r) + } + if !r.Accepted { + retry(fmt.Errorf("connection rejected: %s", r.RejectedReason)) + continue + } + remoteUUID := uuid.UUID(r.ID) + if c.remoteID != nil { + c.reconnected() + } + c.remoteID = &remoteUUID + if debugPrint { + fmt.Println(c.Local, "Connected Waiting for Messages") + } + c.updateState(StateConnected) + go c.handleMessages(c.ctx, conn) + // Monitor state changes and reconnect if needed. + c.connChange.L.Lock() + for { + newState := c.State() + if newState != StateConnected { + c.connChange.L.Unlock() + if newState == StateShutdown { + conn.Close() + return + } + if debugPrint { + fmt.Println(c.Local, "Disconnected") + } + // Reconnect + break + } + // Unlock and wait for state change. + c.connChange.Wait() + } + } +} + +func (c *Connection) disconnected() { + c.outgoing.Range(func(key uint64, client *muxClient) bool { + if !client.stateless { + client.cancelFn(ErrDisconnected) + } + return true + }) + if debugReqs { + fmt.Println(c.String(), "Disconnected. Clearing outgoing.") + } + c.outgoing.Clear() + c.inStream.Range(func(key uint64, client *muxServer) bool { + client.cancel() + return true + }) + c.inStream.Clear() +} + +func (c *Connection) receive(conn net.Conn, r receiver) error { + b, op, err := wsutil.ReadData(conn, c.side) + if err != nil { + return err + } + if op != ws.OpBinary { + return fmt.Errorf("unexpected connect response type %v", op) + } + var m message + _, _, err = m.parse(b) + if err != nil { + return err + } + if m.Op != r.Op() { + return fmt.Errorf("unexpected response OP, want %v, got %v", r.Op(), m.Op) + } + _, err = r.UnmarshalMsg(m.Payload) + return err +} + +func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req connectReq) error { + c.connMu.Lock() + c.debugInConn = conn + c.connMu.Unlock() + if c.blockConnect != nil { + // Block until we are allowed to connect. + <-c.blockConnect + } + if req.Host != c.Remote { + err := fmt.Errorf("expected remote '%s', got '%s'", c.Remote, req.Host) + if debugPrint { + fmt.Println(err) + } + return err + } + if c.shouldConnect() { + if debugPrint { + fmt.Println("expected to be client side, not server side") + } + return errors.New("expected to be client side, not server side") + } + msg := message{ + Op: OpConnectResponse, + } + + if c.remoteID != nil { + c.reconnected() + } + rid := uuid.UUID(req.ID) + c.remoteID = &rid + resp := connectResp{ + ID: c.id, + Accepted: true, + } + err := c.sendMsg(conn, msg, &resp) + if debugPrint { + fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side) + } + if err == nil { + c.updateState(StateConnected) + c.handleMessages(ctx, conn) + } + return err +} + +func (c *Connection) reconnected() { + c.updateState(StateConnectionError) + // Close all active requests. + if debugReqs { + fmt.Println(c.String(), "Reconnected. Clearing outgoing.") + } + c.outgoing.Range(func(key uint64, client *muxClient) bool { + client.close() + return true + }) + c.inStream.Range(func(key uint64, value *muxServer) bool { + value.close() + return true + }) + + c.inStream.Clear() + c.outgoing.Clear() + + // Wait for existing to exit + c.handleMsgWg.Wait() +} + +func (c *Connection) updateState(s State) { + c.connChange.L.Lock() + defer c.connChange.L.Unlock() + + // We may have reads that aren't locked, so update atomically. + gotState := atomic.LoadUint32((*uint32)(&c.state)) + if gotState == StateShutdown || State(gotState) == s { + return + } + if s == StateConnected { + atomic.StoreInt64(&c.LastPong, time.Now().UnixNano()) + } + atomic.StoreUint32((*uint32)(&c.state), uint32(s)) + if debugPrint { + fmt.Println(c.Local, "updateState:", gotState, "->", s) + } + c.connChange.Broadcast() +} + +func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { + // Read goroutine + c.handleMsgWg.Add(2) + ctx, cancel := context.WithCancelCause(ctx) + go func() { + defer func() { + if rec := recover(); rec != nil { + logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) + debug.PrintStack() + } + c.connChange.L.Lock() + if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { + c.connChange.Broadcast() + } + c.connChange.L.Unlock() + conn.Close() + c.handleMsgWg.Done() + }() + + controlHandler := wsutil.ControlFrameHandler(conn, c.side) + wsReader := wsutil.Reader{ + Source: conn, + State: c.side, + CheckUTF8: true, + SkipHeaderCheck: false, + OnIntermediate: controlHandler, + } + readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) { + dst = dst[:0] + for { + hdr, err := wsReader.NextFrame() + if err != nil { + return nil, err + } + if hdr.OpCode.IsControl() { + if err := controlHandler(hdr, &wsReader); err != nil { + return nil, err + } + continue + } + if hdr.OpCode&want == 0 { + if err := wsReader.Discard(); err != nil { + return nil, err + } + continue + } + + if int64(cap(dst)) < hdr.Length+1 { + dst = make([]byte, 0, hdr.Length+hdr.Length>>3) + } + return readAllInto(dst[:0], &wsReader) + } + } + + // Keep reusing the same buffer. + var msg []byte + for { + if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { + cancel(ErrDisconnected) + return + } + + var err error + msg, err = readDataInto(msg, conn, c.side, ws.OpBinary) + if err != nil { + cancel(ErrDisconnected) + logger.LogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) + return + } + if c.incomingBytes != nil { + c.incomingBytes(int64(len(msg))) + } + // Parse the received message + var m message + subID, remain, err := m.parse(msg) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws parse package: %w", err)) + cancel(ErrDisconnected) + return + } + if debugPrint { + fmt.Printf("%s Got msg: %v\n", c.Local, m) + } + if m.Op != OpMerged { + c.handleMsg(ctx, m, subID) + continue + } + // Handle merged messages. + messages := int(m.Seq) + for i := 0; i < messages; i++ { + if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { + cancel(ErrDisconnected) + return + } + var next []byte + next, remain, err = msgp.ReadBytesZC(remain) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws read merged: %w", err)) + cancel(ErrDisconnected) + return + } + + m.Payload = nil + subID, _, err = m.parse(next) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws parse merged: %w", err)) + cancel(ErrDisconnected) + return + } + c.handleMsg(ctx, m, subID) + } + } + }() + + // Write function. + defer func() { + if rec := recover(); rec != nil { + logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) + debug.PrintStack() + } + if debugPrint { + fmt.Println("handleMessages: write goroutine exited") + } + cancel(ErrDisconnected) + c.connChange.L.Lock() + if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { + c.connChange.Broadcast() + } + c.disconnected() + c.connChange.L.Unlock() + + conn.Close() + c.handleMsgWg.Done() + }() + + c.connMu.Lock() + connPingInterval := c.connPingInterval + c.connMu.Unlock() + ping := time.NewTicker(connPingInterval) + pingFrame := message{ + Op: OpPing, + DeadlineMS: 5000, + } + + defer ping.Stop() + queue := make([][]byte, 0, maxMergeMessages) + merged := make([]byte, 0, writeBufferSize) + var queueSize int + var buf bytes.Buffer + var wsw wsWriter + for { + var toSend []byte + select { + case <-ctx.Done(): + return + case <-ping.C: + if c.State() != StateConnected { + continue + } + lastPong := atomic.LoadInt64(&c.LastPong) + if lastPong > 0 { + lastPongTime := time.Unix(lastPong, 0) + if d := time.Since(lastPongTime); d > connPingInterval*2 { + logger.LogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond))) + cancel(ErrDisconnected) + return + } + } + var err error + toSend, err = pingFrame.MarshalMsg(GetByteBuffer()[:0]) + if err != nil { + logger.LogIf(ctx, err) + // Fake it... + atomic.StoreInt64(&c.LastPong, time.Now().Unix()) + continue + } + case toSend = <-c.outQueue: + if len(toSend) == 0 { + continue + } + } + if len(queue) < maxMergeMessages && queueSize+len(toSend) < writeBufferSize-1024 && len(c.outQueue) > 0 { + queue = append(queue, toSend) + queueSize += len(toSend) + continue + } + c.connChange.L.Lock() + for { + state := c.State() + if state == StateConnected { + break + } + if debugPrint { + fmt.Println(c.Local, "Waiting for connection ->", c.Remote, "state: ", state) + } + if state == StateShutdown || state == StateConnectionError { + c.connChange.L.Unlock() + return + } + c.connChange.Wait() + select { + case <-ctx.Done(): + c.connChange.L.Unlock() + return + default: + } + } + c.connChange.L.Unlock() + if len(queue) == 0 { + // Combine writes. + buf.Reset() + err := wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) + cancel(ErrDisconnected) + return + } + PutByteBuffer(toSend) + _, err = buf.WriteTo(conn) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) + cancel(ErrDisconnected) + return + } + continue + } + + // Merge entries and send + queue = append(queue, toSend) + if debugPrint { + fmt.Println("Merging", len(queue), "messages") + } + + toSend = merged[:0] + m := message{Op: OpMerged, Seq: uint32(len(queue))} + var err error + toSend, err = m.MarshalMsg(toSend) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("msg.MarshalMsg: %w", err)) + cancel(ErrDisconnected) + return + } + // Append as byte slices. + for _, q := range queue { + toSend = msgp.AppendBytes(toSend, q) + PutByteBuffer(q) + } + queue = queue[:0] + queueSize = 0 + + // Combine writes. + // Consider avoiding buffer copy. + buf.Reset() + err = wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) + cancel(ErrDisconnected) + return + } + // Tosend is our local buffer, so we can reuse it. + _, err = buf.WriteTo(conn) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) + cancel(ErrDisconnected) + return + } + + if buf.Cap() > writeBufferSize*4 { + // Reset buffer if it gets too big, so we don't keep it around. + buf = bytes.Buffer{} + } + } +} + +func (c *Connection) handleMsg(ctx context.Context, m message, subID *subHandlerID) { + switch m.Op { + case OpMuxServerMsg: + c.handleMuxServerMsg(ctx, m) + case OpResponse: + c.handleResponse(m) + case OpMuxClientMsg: + c.handleMuxClientMsg(ctx, m) + case OpUnblockSrvMux: + c.handleUnblockSrvMux(m) + case OpUnblockClMux: + c.handleUnblockClMux(m) + case OpDisconnectServerMux: + c.handleDisconnectServerMux(m) + case OpDisconnectClientMux: + c.handleDisconnectClientMux(m) + case OpPing: + c.handlePing(ctx, m) + case OpPong: + c.handlePong(ctx, m) + case OpRequest: + c.handleRequest(ctx, m, subID) + case OpAckMux: + c.handleAckMux(ctx, m) + case OpConnectMux: + c.handleConnectMux(ctx, m, subID) + default: + logger.LogIf(ctx, fmt.Errorf("unknown message type: %v", m.Op)) + } +} + +func (c *Connection) handleConnectMux(ctx context.Context, m message, subID *subHandlerID) { + // Stateless stream: + if m.Flags&FlagStateless != 0 { + // Reject for now, so we can safely add it later. + if true { + logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Stateless streams not supported"})) + return + } + + var handler *StatelessHandler + if subID == nil { + handler = c.handlers.stateless[m.Handler] + } else { + handler = c.handlers.subStateless[*subID] + } + if handler == nil { + logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"})) + return + } + _, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer { + return newMuxStateless(ctx, m, c, *handler) + }) + } else { + // Stream: + var handler *StreamHandler + if subID == nil { + if !m.Handler.valid() { + logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"})) + return + } + handler = c.handlers.streams[m.Handler] + } else { + handler = c.handlers.subStreams[*subID] + } + if handler == nil { + logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"})) + return + } + + // Start a new server handler if none exists. + _, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer { + return newMuxStream(ctx, m, c, *handler) + }) + } +} + +func (c *Connection) handleAckMux(ctx context.Context, m message) { + PutByteBuffer(m.Payload) + v, ok := c.outgoing.Load(m.MuxID) + if !ok { + if m.Flags&FlagEOF == 0 { + logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) + } + return + } + if debugPrint { + fmt.Println(c.Local, "Mux", m.MuxID, "Acknowledged") + } + v.ack(m.Seq) +} + +func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHandlerID) { + if !m.Handler.valid() { + logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"})) + return + } + if debugReqs { + fmt.Println(m.MuxID, c.StringReverse(), "INCOMING") + } + // Singleshot message + var handler SingleHandlerFn + if subID == nil { + handler = c.handlers.single[m.Handler] + } else { + handler = c.handlers.subSingle[*subID] + } + if handler == nil { + logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"})) + return + } + + // TODO: This causes allocations, but escape analysis doesn't really show the cause. + // If another faithful engineer wants to take a stab, feel free. + go func(m message) { + var start time.Time + if m.DeadlineMS > 0 { + start = time.Now() + } + var b []byte + var err *RemoteErr + func() { + defer func() { + if rec := recover(); rec != nil { + err = NewRemoteErrString(fmt.Sprintf("handleMessages: panic recovered: %v", rec)) + debug.PrintStack() + logger.LogIf(ctx, err) + } + }() + b, err = handler(m.Payload) + if debugPrint { + fmt.Println(c.Local, "Handler returned payload:", bytesOrLength(b), "err:", err) + } + }() + + if m.DeadlineMS > 0 && time.Since(start).Milliseconds()+c.addDeadline.Milliseconds() > int64(m.DeadlineMS) { + if debugReqs { + fmt.Println(m.MuxID, c.StringReverse(), "DEADLINE EXCEEDED") + } + // No need to return result + PutByteBuffer(b) + return + } + if debugReqs { + fmt.Println(m.MuxID, c.StringReverse(), "RESPONDING") + } + m = message{ + MuxID: m.MuxID, + Seq: m.Seq, + Op: OpResponse, + Flags: FlagEOF, + } + if err != nil { + m.Flags |= FlagPayloadIsErr + m.Payload = []byte(*err) + } else { + m.Payload = b + m.setZeroPayloadFlag() + } + logger.LogIf(ctx, c.queueMsg(m, nil)) + }(m) +} + +func (c *Connection) handlePong(ctx context.Context, m message) { + var pong pongMsg + _, err := pong.UnmarshalMsg(m.Payload) + PutByteBuffer(m.Payload) + logger.LogIf(ctx, err) + if m.MuxID == 0 { + atomic.StoreInt64(&c.LastPong, time.Now().Unix()) + return + } + if v, ok := c.outgoing.Load(m.MuxID); ok { + v.pong(pong) + } else { + // We don't care if the client was removed in the meantime, + // but we send a disconnect message to the server just in case. + logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) + } +} + +func (c *Connection) handlePing(ctx context.Context, m message) { + if m.MuxID == 0 { + logger.LogIf(ctx, c.queueMsg(m, &pongMsg{})) + return + } + // Single calls do not support pinging. + if v, ok := c.inStream.Load(m.MuxID); ok { + pong := v.ping(m.Seq) + logger.LogIf(ctx, c.queueMsg(m, &pong)) + } else { + pong := pongMsg{NotFound: true} + logger.LogIf(ctx, c.queueMsg(m, &pong)) + } + return +} + +func (c *Connection) handleDisconnectClientMux(m message) { + if v, ok := c.outgoing.Load(m.MuxID); ok { + if m.Flags&FlagPayloadIsErr != 0 { + v.error(RemoteErr(m.Payload)) + } else { + v.error("remote disconnected") + } + return + } + PutByteBuffer(m.Payload) +} + +func (c *Connection) handleDisconnectServerMux(m message) { + if debugPrint { + fmt.Println(c.Local, "Disconnect server mux:", m.MuxID) + } + PutByteBuffer(m.Payload) + m.Payload = nil + if v, ok := c.inStream.Load(m.MuxID); ok { + v.close() + } +} + +func (c *Connection) handleUnblockClMux(m message) { + PutByteBuffer(m.Payload) + m.Payload = nil + v, ok := c.outgoing.Load(m.MuxID) + if !ok { + if debugPrint { + fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID) + } + // We can expect to receive unblocks for closed muxes + return + } + v.unblockSend(m.Seq) +} + +func (c *Connection) handleUnblockSrvMux(m message) { + if m.Payload != nil { + PutByteBuffer(m.Payload) + } + m.Payload = nil + if v, ok := c.inStream.Load(m.MuxID); ok { + v.unblockSend(m.Seq) + return + } + // We can expect to receive unblocks for closed muxes + if debugPrint { + fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID) + } +} + +func (c *Connection) handleMuxClientMsg(ctx context.Context, m message) { + v, ok := c.inStream.Load(m.MuxID) + if !ok { + if debugPrint { + fmt.Println(c.Local, "OpMuxClientMsg: Unknown Mux:", m.MuxID) + } + logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) + PutByteBuffer(m.Payload) + return + } + v.message(m) +} + +func (c *Connection) handleResponse(m message) { + if debugPrint { + fmt.Printf("%s Got mux response: %v\n", c.Local, m) + } + v, ok := c.outgoing.Load(m.MuxID) + if !ok { + if debugReqs { + fmt.Println(m.MuxID, c.String(), "Got response for unknown mux") + } + PutByteBuffer(m.Payload) + return + } + if m.Flags&FlagPayloadIsErr != 0 { + v.response(m.Seq, Response{ + Msg: nil, + Err: RemoteErr(m.Payload), + }) + PutByteBuffer(m.Payload) + } else { + v.response(m.Seq, Response{ + Msg: m.Payload, + Err: nil, + }) + } + v.close() + if debugReqs { + fmt.Println(m.MuxID, c.String(), "handleResponse: closing mux") + } +} + +func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) { + if debugPrint { + fmt.Printf("%s Got mux msg: %v\n", c.Local, m) + } + v, ok := c.outgoing.Load(m.MuxID) + if !ok { + if m.Flags&FlagEOF == 0 { + logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) + } + PutByteBuffer(m.Payload) + return + } + if m.Flags&FlagPayloadIsErr != 0 { + v.response(m.Seq, Response{ + Msg: nil, + Err: RemoteErr(m.Payload), + }) + PutByteBuffer(m.Payload) + } else if m.Payload != nil { + v.response(m.Seq, Response{ + Msg: m.Payload, + Err: nil, + }) + } + if m.Flags&FlagEOF != 0 { + v.close() + if debugReqs { + fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX") + } + c.outgoing.Delete(m.MuxID) + } +} + +func (c *Connection) deleteMux(incoming bool, muxID uint64) { + if incoming { + if debugPrint { + fmt.Println("deleteMux: disconnect incoming mux", muxID) + } + v, loaded := c.inStream.LoadAndDelete(muxID) + if loaded && v != nil { + logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: muxID}, nil)) + v.close() + } + } else { + if debugPrint { + fmt.Println("deleteMux: disconnect outgoing mux", muxID) + } + v, loaded := c.outgoing.LoadAndDelete(muxID) + if loaded && v != nil { + if debugReqs { + fmt.Println(muxID, c.String(), "deleteMux: DELETING MUX") + } + v.close() + logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectServerMux, MuxID: muxID}, nil)) + } + } +} + +// State returns the current connection status. +func (c *Connection) State() State { + return State(atomic.LoadUint32((*uint32)(&c.state))) +} + +// Stats returns the current connection stats. +func (c *Connection) Stats() ConnectionStats { + return ConnectionStats{ + IncomingStreams: c.inStream.Size(), + OutgoingStreams: c.outgoing.Size(), + } +} + +func (c *Connection) debugMsg(d debugMsg, args ...any) { + if debugPrint { + fmt.Println("debug: sending message", d, args) + } + + switch d { + case debugShutdown: + c.updateState(StateShutdown) + case debugKillInbound: + c.connMu.Lock() + defer c.connMu.Unlock() + if c.debugInConn != nil { + if debugPrint { + fmt.Println("debug: closing inbound connection") + } + c.debugInConn.Close() + } + case debugKillOutbound: + c.connMu.Lock() + defer c.connMu.Unlock() + if c.debugInConn != nil { + if debugPrint { + fmt.Println("debug: closing outgoing connection") + } + c.debugInConn.Close() + } + case debugWaitForExit: + c.handleMsgWg.Wait() + case debugSetConnPingDuration: + c.connMu.Lock() + defer c.connMu.Unlock() + c.connPingInterval = args[0].(time.Duration) + case debugSetClientPingDuration: + c.clientPingInterval = args[0].(time.Duration) + case debugAddToDeadline: + c.addDeadline = args[0].(time.Duration) + } +} + +// wsWriter writes websocket messages. +type wsWriter struct { + tmp [ws.MaxHeaderSize]byte +} + +// writeMessage writes a message to w without allocations. +func (ww *wsWriter) writeMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error { + const fin = true + var frame ws.Frame + if s.ClientSide() { + // We do not need to copy the payload, since we own it. + payload := p + + frame = ws.NewFrame(op, fin, payload) + frame = ws.MaskFrameInPlace(frame) + } else { + frame = ws.NewFrame(op, fin, p) + } + + return ww.writeFrame(w, frame) +} + +// writeFrame writes frame binary representation into w. +func (ww *wsWriter) writeFrame(w io.Writer, f ws.Frame) error { + const ( + bit0 = 0x80 + len7 = int64(125) + len16 = int64(^(uint16(0))) + len64 = int64(^(uint64(0)) >> 1) + ) + + bts := ww.tmp[:] + if f.Header.Fin { + bts[0] |= bit0 + } + bts[0] |= f.Header.Rsv << 4 + bts[0] |= byte(f.Header.OpCode) + + var n int + switch { + case f.Header.Length <= len7: + bts[1] = byte(f.Header.Length) + n = 2 + + case f.Header.Length <= len16: + bts[1] = 126 + binary.BigEndian.PutUint16(bts[2:4], uint16(f.Header.Length)) + n = 4 + + case f.Header.Length <= len64: + bts[1] = 127 + binary.BigEndian.PutUint64(bts[2:10], uint64(f.Header.Length)) + n = 10 + + default: + return ws.ErrHeaderLengthUnexpected + } + + if f.Header.Masked { + bts[1] |= bit0 + n += copy(bts[n:], f.Header.Mask[:]) + } + + if _, err := w.Write(bts[:n]); err != nil { + return err + } + + _, err := w.Write(f.Payload) + return err +} diff --git a/internal/grid/connection_test.go b/internal/grid/connection_test.go new file mode 100644 index 000000000..f95b122e1 --- /dev/null +++ b/internal/grid/connection_test.go @@ -0,0 +1,209 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/minio/minio/internal/logger/target/testlogger" +) + +func TestDisconnect(t *testing.T) { + defer testlogger.T.SetLogTB(t)() + defer timeout(10 * time.Second)() + hosts, listeners, _ := getHosts(2) + dialer := &net.Dialer{ + Timeout: 1 * time.Second, + } + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + wrapServer := func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Got a %s request for: %v", r.Method, r.URL) + handler.ServeHTTP(w, r) + }) + } + connReady := make(chan struct{}) + // We fake a local and remote server. + localHost := hosts[0] + remoteHost := hosts[1] + local, err := NewManager(context.Background(), ManagerOptions{ + Dialer: dialer.DialContext, + Local: localHost, + Hosts: hosts, + AddAuth: func(aud string) string { return aud }, + AuthRequest: dummyRequestValidate, + BlockConnect: connReady, + }) + errFatal(err) + + // 1: Echo + errFatal(local.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) { + t.Log("1: server payload: ", len(payload), "bytes.") + return append([]byte{}, payload...), nil + })) + // 2: Return as error + errFatal(local.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) { + t.Log("2: server payload: ", len(payload), "bytes.") + err := RemoteErr(payload) + return nil, &err + })) + + remote, err := NewManager(context.Background(), ManagerOptions{ + Dialer: dialer.DialContext, + Local: remoteHost, + Hosts: hosts, + AddAuth: func(aud string) string { return aud }, + AuthRequest: dummyRequestValidate, + BlockConnect: connReady, + }) + errFatal(err) + + localServer := startServer(t, listeners[0], wrapServer(local.Handler())) + remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler())) + close(connReady) + + defer func() { + local.debugMsg(debugShutdown) + remote.debugMsg(debugShutdown) + remoteServer.Close() + localServer.Close() + remote.debugMsg(debugWaitForExit) + local.debugMsg(debugWaitForExit) + }() + + cleanReqs := make(chan struct{}) + gotCall := make(chan struct{}) + defer close(cleanReqs) + // 1: Block forever + h1 := func(payload []byte) ([]byte, *RemoteErr) { + gotCall <- struct{}{} + <-cleanReqs + return nil, nil + } + // 2: Also block, but with streaming. + h2 := StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + gotCall <- struct{}{} + select { + case <-ctx.Done(): + gotCall <- struct{}{} + case <-cleanReqs: + panic("should not be called") + } + return nil + }, + OutCapacity: 1, + InCapacity: 1, + } + errFatal(remote.RegisterSingleHandler(handlerTest, h1)) + errFatal(remote.RegisterStreamingHandler(handlerTest2, h2)) + errFatal(local.RegisterSingleHandler(handlerTest, h1)) + errFatal(local.RegisterStreamingHandler(handlerTest2, h2)) + + // local to remote + remoteConn := local.Connection(remoteHost) + errFatal(remoteConn.WaitForConnect(context.Background())) + const testPayload = "Hello Grid World!" + + gotResp := make(chan struct{}) + go func() { + start := time.Now() + t.Log("Roundtrip: sending request") + resp, err := remoteConn.Request(context.Background(), handlerTest, []byte(testPayload)) + t.Log("Roundtrip:", time.Since(start), resp, err) + gotResp <- struct{}{} + }() + <-gotCall + remote.debugMsg(debugKillInbound) + local.debugMsg(debugKillInbound) + <-gotResp + + // Must reconnect + errFatal(remoteConn.WaitForConnect(context.Background())) + + stream, err := remoteConn.NewStream(context.Background(), handlerTest2, []byte(testPayload)) + errFatal(err) + go func() { + for resp := range stream.responses { + t.Log("Resp:", resp, err) + } + gotResp <- struct{}{} + }() + + <-gotCall + remote.debugMsg(debugKillOutbound) + local.debugMsg(debugKillOutbound) + errFatal(remoteConn.WaitForConnect(context.Background())) + + <-gotResp + // Killing should cancel the context on the request. + <-gotCall +} + +func dummyRequestValidate(r *http.Request) error { + return nil +} + +func TestShouldConnect(t *testing.T) { + var c Connection + var cReverse Connection + hosts := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "x", "y", "z", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} + for x := range hosts { + should := 0 + for y := range hosts { + if x == y { + continue + } + c.Local = hosts[x] + c.Remote = hosts[y] + cReverse.Local = hosts[y] + cReverse.Remote = hosts[x] + if c.shouldConnect() == cReverse.shouldConnect() { + t.Errorf("shouldConnect(%q, %q) != shouldConnect(%q, %q)", hosts[x], hosts[y], hosts[y], hosts[x]) + } + if c.shouldConnect() { + should++ + } + } + if should < 10 { + t.Errorf("host %q only connects to %d hosts", hosts[x], should) + } + t.Logf("host %q should connect to %d hosts", hosts[x], should) + } +} + +func startServer(t testing.TB, listener net.Listener, handler http.Handler) (server *httptest.Server) { + t.Helper() + server = httptest.NewUnstartedServer(handler) + server.Config.Addr = listener.Addr().String() + server.Listener = listener + server.Start() + // t.Cleanup(server.Close) + t.Log("Started server on", server.Config.Addr, "URL:", server.URL) + return server +} diff --git a/internal/grid/debug.go b/internal/grid/debug.go new file mode 100644 index 000000000..2f7fee487 --- /dev/null +++ b/internal/grid/debug.go @@ -0,0 +1,163 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "time" + + "github.com/minio/mux" +) + +//go:generate stringer -type=debugMsg $GOFILE + +// debugMsg is a debug message for testing purposes. +// may only be used for tests. +type debugMsg int + +const ( + debugPrint = false + debugReqs = false +) + +const ( + debugShutdown debugMsg = iota + debugKillInbound + debugKillOutbound + debugWaitForExit + debugSetConnPingDuration + debugSetClientPingDuration + debugAddToDeadline +) + +// TestGrid contains a grid of servers for testing purposes. +type TestGrid struct { + Servers []*httptest.Server + Listeners []net.Listener + Managers []*Manager + Mux []*mux.Router + Hosts []string + cleanupOnce sync.Once + cancel context.CancelFunc +} + +// SetupTestGrid creates a new grid for testing purposes. +// Select the number of hosts to create. +// Call (TestGrid).Cleanup() when done. +func SetupTestGrid(n int) (*TestGrid, error) { + hosts, listeners, err := getHosts(n) + if err != nil { + return nil, err + } + dialer := &net.Dialer{ + Timeout: 5 * time.Second, + } + var res TestGrid + res.Hosts = hosts + ready := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + res.cancel = cancel + for i, host := range hosts { + manager, err := NewManager(ctx, ManagerOptions{ + Dialer: dialer.DialContext, + Local: host, + Hosts: hosts, + AuthRequest: func(r *http.Request) error { + return nil + }, + AddAuth: func(aud string) string { return aud }, + BlockConnect: ready, + }) + if err != nil { + return nil, err + } + m := mux.NewRouter() + m.Handle(RoutePath, manager.Handler()) + res.Managers = append(res.Managers, manager) + res.Servers = append(res.Servers, startHTTPServer(listeners[i], m)) + res.Listeners = append(res.Listeners, listeners[i]) + res.Mux = append(res.Mux, m) + } + close(ready) + for _, m := range res.Managers { + for _, remote := range m.Targets() { + if err := m.Connection(remote).WaitForConnect(ctx); err != nil { + return nil, err + } + } + } + return &res, nil +} + +// Cleanup will clean up the test grid. +func (t *TestGrid) Cleanup() { + t.cancel() + t.cleanupOnce.Do(func() { + for _, manager := range t.Managers { + manager.debugMsg(debugShutdown) + } + for _, server := range t.Servers { + server.Close() + } + for _, listener := range t.Listeners { + listener.Close() + } + }) +} + +// WaitAllConnect will wait for all connections to be established. +func (t *TestGrid) WaitAllConnect(ctx context.Context) { + for _, manager := range t.Managers { + for _, remote := range manager.Targets() { + if manager.HostName() == remote { + continue + } + if err := manager.Connection(remote).WaitForConnect(ctx); err != nil { + panic(err) + } + } + } +} + +func getHosts(n int) (hosts []string, listeners []net.Listener, err error) { + for i := 0; i < n; i++ { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + return nil, nil, fmt.Errorf("httptest: failed to listen on a port: %v", err) + } + } + addr := l.Addr() + hosts = append(hosts, "http://"+addr.String()) + listeners = append(listeners, l) + } + return +} + +func startHTTPServer(listener net.Listener, handler http.Handler) (server *httptest.Server) { + server = httptest.NewUnstartedServer(handler) + server.Config.Addr = listener.Addr().String() + server.Listener = listener + server.Start() + return server +} diff --git a/internal/grid/debugmsg_string.go b/internal/grid/debugmsg_string.go new file mode 100644 index 000000000..4c8676e39 --- /dev/null +++ b/internal/grid/debugmsg_string.go @@ -0,0 +1,29 @@ +// Code generated by "stringer -type=debugMsg debug.go"; DO NOT EDIT. + +package grid + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[debugShutdown-0] + _ = x[debugKillInbound-1] + _ = x[debugKillOutbound-2] + _ = x[debugWaitForExit-3] + _ = x[debugSetConnPingDuration-4] + _ = x[debugSetClientPingDuration-5] + _ = x[debugAddToDeadline-6] +} + +const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadline" + +var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130} + +func (i debugMsg) String() string { + if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) { + return "debugMsg(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _debugMsg_name[_debugMsg_index[i]:_debugMsg_index[i+1]] +} diff --git a/internal/grid/errors.go b/internal/grid/errors.go new file mode 100644 index 000000000..dbf62f7e2 --- /dev/null +++ b/internal/grid/errors.go @@ -0,0 +1,43 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "errors" + "fmt" +) + +var ( + // ErrUnknownHandler is returned when an unknown handler is requested. + ErrUnknownHandler = errors.New("unknown mux handler") + + // ErrHandlerAlreadyExists is returned when a handler is already registered. + ErrHandlerAlreadyExists = errors.New("mux handler already exists") + + // ErrIncorrectSequence is returned when an out-of-sequence item is received. + ErrIncorrectSequence = errors.New("out-of-sequence item received") +) + +// ErrResponse is a remote error response. +type ErrResponse struct { + msg string +} + +func (e ErrResponse) Error() string { + return fmt.Sprintf("remote: %s", e.msg) +} diff --git a/internal/grid/grid.go b/internal/grid/grid.go new file mode 100644 index 000000000..1cf9129d1 --- /dev/null +++ b/internal/grid/grid.go @@ -0,0 +1,287 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package grid provides single-connection two-way grid communication. +package grid + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/gobwas/ws/wsutil" +) + +// ErrDisconnected is returned when the connection to the remote has been lost during the call. +var ErrDisconnected = errors.New("remote disconnected") + +const ( + // minBufferSize is the minimum buffer size. + // Buffers below this is not reused. + minBufferSize = 1 << 10 + + // defaultBufferSize is the default buffer allocation size. + defaultBufferSize = 4 << 10 + + // maxBufferSize is the maximum buffer size. + // Buffers larger than this is not reused. + maxBufferSize = 64 << 10 + + // If there is a queue, merge up to this many messages. + maxMergeMessages = 20 + + // clientPingInterval will ping the remote handler every 15 seconds. + // Clients disconnect when we exceed 2 intervals. + clientPingInterval = 15 * time.Second + + // Deadline for single (non-streaming) requests to complete. + // Used if no deadline is provided on context. + defaultSingleRequestTimeout = time.Minute +) + +var internalByteBuffer = sync.Pool{ + New: func() any { + m := make([]byte, 0, defaultBufferSize) + return &m + }, +} + +// GetByteBuffer can be replaced with a function that returns a small +// byte buffer. +// When replacing PutByteBuffer should also be replaced +// There is no minimum size. +var GetByteBuffer = func() []byte { + b := *internalByteBuffer.Get().(*[]byte) + return b[:0] +} + +// PutByteBuffer is for returning byte buffers. +var PutByteBuffer = func(b []byte) { + if cap(b) >= minBufferSize && cap(b) < maxBufferSize { + internalByteBuffer.Put(&b) + } +} + +// readAllInto reads from r and appends to b until an error or EOF and returns the data it read. +// A successful call returns err == nil, not err == EOF. Because readAllInto is +// defined to read from src until EOF, it does not treat an EOF from Read +// as an error to be reported. +func readAllInto(b []byte, r *wsutil.Reader) ([]byte, error) { + for { + if len(b) == cap(b) { + // Add more capacity (let append pick how much). + b = append(b, 0)[:len(b)] + } + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + } + return b, err + } + } +} + +// getDeadline will truncate the deadline so it is at least 1ms and at most MaxDeadline. +func getDeadline(d time.Duration) time.Duration { + if d < time.Millisecond { + return 0 + } + if d > MaxDeadline { + return MaxDeadline + } + return d +} + +type writerWrapper struct { + ch chan<- []byte + ctx context.Context +} + +func (w *writerWrapper) Write(p []byte) (n int, err error) { + buf := GetByteBuffer() + if cap(buf) < len(p) { + PutByteBuffer(buf) + buf = make([]byte, len(p)) + } + buf = buf[:len(p)] + copy(buf, p) + select { + case w.ch <- buf: + return len(p), nil + case <-w.ctx.Done(): + return 0, context.Cause(w.ctx) + } +} + +// WriterToChannel will return an io.Writer that writes to the given channel. +// The context both allows returning errors on writes and to ensure that +// this isn't abandoned if the channel is no longer being read from. +func WriterToChannel(ctx context.Context, ch chan<- []byte) io.Writer { + return &writerWrapper{ch: ch, ctx: ctx} +} + +// bytesOrLength returns small (<=100b) byte slices as string, otherwise length. +func bytesOrLength(b []byte) string { + if len(b) > 100 { + return fmt.Sprintf("%d bytes", len(b)) + } + return fmt.Sprint(b) +} + +type lockedClientMap struct { + m map[uint64]*muxClient + mu sync.Mutex +} + +func (m *lockedClientMap) Load(id uint64) (*muxClient, bool) { + m.mu.Lock() + v, ok := m.m[id] + m.mu.Unlock() + return v, ok +} + +func (m *lockedClientMap) LoadAndDelete(id uint64) (*muxClient, bool) { + m.mu.Lock() + v, ok := m.m[id] + if ok { + delete(m.m, id) + } + m.mu.Unlock() + return v, ok +} + +func (m *lockedClientMap) Size() int { + m.mu.Lock() + v := len(m.m) + m.mu.Unlock() + return v +} + +func (m *lockedClientMap) Delete(id uint64) { + m.mu.Lock() + delete(m.m, id) + m.mu.Unlock() +} + +func (m *lockedClientMap) Range(fn func(key uint64, value *muxClient) bool) { + m.mu.Lock() + for k, v := range m.m { + if !fn(k, v) { + break + } + } + m.mu.Unlock() +} + +func (m *lockedClientMap) Clear() { + m.mu.Lock() + m.m = map[uint64]*muxClient{} + m.mu.Unlock() +} + +func (m *lockedClientMap) LoadOrStore(id uint64, v *muxClient) (*muxClient, bool) { + m.mu.Lock() + v2, ok := m.m[id] + if ok { + m.mu.Unlock() + return v2, true + } + m.m[id] = v + m.mu.Unlock() + return v, false +} + +type lockedServerMap struct { + m map[uint64]*muxServer + mu sync.Mutex +} + +func (m *lockedServerMap) Load(id uint64) (*muxServer, bool) { + m.mu.Lock() + v, ok := m.m[id] + m.mu.Unlock() + return v, ok +} + +func (m *lockedServerMap) LoadAndDelete(id uint64) (*muxServer, bool) { + m.mu.Lock() + v, ok := m.m[id] + if ok { + delete(m.m, id) + } + m.mu.Unlock() + return v, ok +} + +func (m *lockedServerMap) Size() int { + m.mu.Lock() + v := len(m.m) + m.mu.Unlock() + return v +} + +func (m *lockedServerMap) Delete(id uint64) { + m.mu.Lock() + delete(m.m, id) + m.mu.Unlock() +} + +func (m *lockedServerMap) Range(fn func(key uint64, value *muxServer) bool) { + m.mu.Lock() + for k, v := range m.m { + if !fn(k, v) { + break + } + } + m.mu.Unlock() +} + +func (m *lockedServerMap) Clear() { + m.mu.Lock() + m.m = map[uint64]*muxServer{} + m.mu.Unlock() +} + +func (m *lockedServerMap) LoadOrStore(id uint64, v *muxServer) (*muxServer, bool) { + m.mu.Lock() + v2, ok := m.m[id] + if ok { + m.mu.Unlock() + return v2, true + } + m.m[id] = v + m.mu.Unlock() + return v, false +} + +func (m *lockedServerMap) LoadOrCompute(id uint64, fn func() *muxServer) (*muxServer, bool) { + m.mu.Lock() + v2, ok := m.m[id] + if ok { + m.mu.Unlock() + return v2, true + } + v := fn() + m.m[id] = v + m.mu.Unlock() + return v, false +} diff --git a/internal/grid/grid_test.go b/internal/grid/grid_test.go new file mode 100644 index 000000000..b8262fecb --- /dev/null +++ b/internal/grid/grid_test.go @@ -0,0 +1,893 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "runtime" + "strconv" + "strings" + "testing" + "time" + + "github.com/minio/minio/internal/logger/target/testlogger" +) + +func TestSingleRoundtrip(t *testing.T) { + defer testlogger.T.SetLogTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + grid, err := SetupTestGrid(2) + errFatal(err) + remoteHost := grid.Hosts[1] + local := grid.Managers[0] + + // 1: Echo + errFatal(local.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) { + t.Log("1: server payload: ", len(payload), "bytes.") + return append([]byte{}, payload...), nil + })) + // 2: Return as error + errFatal(local.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) { + t.Log("2: server payload: ", len(payload), "bytes.") + err := RemoteErr(payload) + return nil, &err + })) + + remote := grid.Managers[1] + + // 1: Echo + errFatal(remote.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) { + t.Log("1: server payload: ", len(payload), "bytes.") + return append([]byte{}, payload...), nil + })) + // 2: Return as error + errFatal(remote.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) { + t.Log("2: server payload: ", len(payload), "bytes.") + err := RemoteErr(payload) + return nil, &err + })) + + // local to remote + remoteConn := local.Connection(remoteHost) + remoteConn.WaitForConnect(context.Background()) + defer testlogger.T.SetErrorTB(t)() + + t.Run("localToRemote", func(t *testing.T) { + const testPayload = "Hello Grid World!" + + start := time.Now() + resp, err := remoteConn.Request(context.Background(), handlerTest, []byte(testPayload)) + errFatal(err) + if string(resp) != testPayload { + t.Errorf("want %q, got %q", testPayload, string(resp)) + } + t.Log("Roundtrip:", time.Since(start)) + }) + + t.Run("localToRemoteErr", func(t *testing.T) { + const testPayload = "Hello Grid World!" + start := time.Now() + resp, err := remoteConn.Request(context.Background(), handlerTest2, []byte(testPayload)) + t.Log("Roundtrip:", time.Since(start)) + if len(resp) != 0 { + t.Errorf("want nil, got %q", string(resp)) + } + if err != RemoteErr(testPayload) { + t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err) + } + t.Log("Roundtrip:", time.Since(start)) + }) + + t.Run("localToRemoteHuge", func(t *testing.T) { + testPayload := bytes.Repeat([]byte("?"), 1<<20) + + start := time.Now() + resp, err := remoteConn.Request(context.Background(), handlerTest, testPayload) + errFatal(err) + if string(resp) != string(testPayload) { + t.Errorf("want %q, got %q", testPayload, string(resp)) + } + t.Log("Roundtrip:", time.Since(start)) + }) + + t.Run("localToRemoteErrHuge", func(t *testing.T) { + testPayload := bytes.Repeat([]byte("!"), 1<<10) + + start := time.Now() + resp, err := remoteConn.Request(context.Background(), handlerTest2, testPayload) + if len(resp) != 0 { + t.Errorf("want nil, got %q", string(resp)) + } + if err != RemoteErr(testPayload) { + t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err) + } + t.Log("Roundtrip:", time.Since(start)) + }) +} + +func TestSingleRoundtripGenerics(t *testing.T) { + defer testlogger.T.SetLogTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + grid, err := SetupTestGrid(2) + errFatal(err) + remoteHost := grid.Hosts[1] + local := grid.Managers[0] + remote := grid.Managers[1] + + // 1: Echo + h1 := NewSingleHandler[*testRequest, *testResponse](handlerTest, func() *testRequest { + return &testRequest{} + }, func() *testResponse { + return &testResponse{} + }) + // Handles incoming requests, returns a response + handler1 := func(req *testRequest) (resp *testResponse, err *RemoteErr) { + resp = h1.NewResponse() + *resp = testResponse{ + OrgNum: req.Num, + OrgString: req.String, + Embedded: *req, + } + return resp, nil + } + // Return error + h2 := NewSingleHandler[*testRequest, *testResponse](handlerTest2, newTestRequest, newTestResponse) + handler2 := func(req *testRequest) (resp *testResponse, err *RemoteErr) { + r := RemoteErr(req.String) + return nil, &r + } + errFatal(h1.Register(local, handler1)) + errFatal(h2.Register(local, handler2)) + + errFatal(h1.Register(remote, handler1)) + errFatal(h2.Register(remote, handler2)) + + // local to remote connection + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + start := time.Now() + req := testRequest{Num: 1, String: testPayload} + resp, err := h1.Call(context.Background(), remoteConn, &req) + errFatal(err) + if resp.OrgString != testPayload { + t.Errorf("want %q, got %q", testPayload, resp.OrgString) + } + t.Log("Roundtrip:", time.Since(start)) + + start = time.Now() + resp, err = h2.Call(context.Background(), remoteConn, &testRequest{Num: 1, String: testPayload}) + t.Log("Roundtrip:", time.Since(start)) + if err != RemoteErr(testPayload) { + t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err) + } + if resp != nil { + t.Errorf("want nil, got %q", resp) + } + t.Log("Roundtrip:", time.Since(start)) +} + +func TestStreamSuite(t *testing.T) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + grid, err := SetupTestGrid(2) + errFatal(err) + t.Cleanup(grid.Cleanup) + + local := grid.Managers[0] + localHost := grid.Hosts[0] + remote := grid.Managers[1] + remoteHost := grid.Hosts[1] + + connLocalToRemote := local.Connection(remoteHost) + connRemoteLocal := remote.Connection(localHost) + + t.Run("testStreamRoundtrip", func(t *testing.T) { + defer timeout(5 * time.Second)() + testStreamRoundtrip(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testStreamCancel", func(t *testing.T) { + defer timeout(5 * time.Second)() + testStreamCancel(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testStreamDeadline", func(t *testing.T) { + defer timeout(5 * time.Second)() + testStreamDeadline(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerOutCongestion", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerOutCongestion(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerInCongestion", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerInCongestion(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testGenericsStreamRoundtrip", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testGenericsStreamRoundtrip(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testGenericsStreamRoundtripSubroute", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testGenericsStreamRoundtripSubroute(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) +} + +func testStreamRoundtrip(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + defer timeout(5 * time.Second)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + for in := range request { + b := append([]byte{}, payload...) + b = append(b, in...) + resp <- b + } + t.Log(GetCaller(ctx).Name, "Handler done") + return nil + }, + OutCapacity: 1, + InCapacity: 1, + })) + // 2: Return as error + errFatal(manager.RegisterStreamingHandler(handlerTest2, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + for in := range request { + t.Log("2: Got err request", string(in)) + err := RemoteErr(append(payload, in...)) + return &err + } + return nil + }, + OutCapacity: 1, + InCapacity: 1, + })) + } + register(local) + register(remote) + + // local to remote + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + start := time.Now() + stream, err := remoteConn.NewStream(context.Background(), handlerTest, []byte(testPayload)) + errFatal(err) + var n int + stream.Requests <- []byte(strconv.Itoa(n)) + for resp := range stream.responses { + errFatal(resp.Err) + t.Logf("got resp: %+v", string(resp.Msg)) + if string(resp.Msg) != testPayload+strconv.Itoa(n) { + t.Errorf("want %q, got %q", testPayload+strconv.Itoa(n), string(resp.Msg)) + } + if n == 10 { + close(stream.Requests) + continue + } + n++ + t.Log("sending new client request") + stream.Requests <- []byte(strconv.Itoa(n)) + } + t.Log("EOF. 10 Roundtrips:", time.Since(start)) +} + +func testStreamCancel(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + <-ctx.Done() + serverCanceled <- struct{}{} + t.Log(GetCaller(ctx).Name, "Server Context canceled") + return nil + }, + OutCapacity: 1, + InCapacity: 0, + })) + errFatal(manager.RegisterStreamingHandler(handlerTest2, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + <-ctx.Done() + serverCanceled <- struct{}{} + t.Log(GetCaller(ctx).Name, "Server Context canceled") + return nil + }, + OutCapacity: 1, + InCapacity: 1, + })) + } + register(local) + register(remote) + + // local to remote + testHandler := func(t *testing.T, handler HandlerID) { + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + ctx, cancel := context.WithCancel(context.Background()) + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + clientCanceled := make(chan time.Time, 1) + err = nil + go func(t *testing.T) { + for resp := range st.responses { + t.Log("got resp:", string(resp.Msg), "err:", resp.Err) + if err != nil { + t.Log("ERROR: got second error:", resp.Err, "first:", err) + continue + } + err = resp.Err + } + t.Log("Client Context canceled. err state:", err) + clientCanceled <- time.Now() + }(t) + start := time.Now() + cancel() + <-serverCanceled + t.Log("server cancel time:", time.Since(start)) + clientEnd := <-clientCanceled + if !errors.Is(err, context.Canceled) { + t.Error("expected context.Canceled, got", err) + } + t.Log("client after", time.Since(clientEnd)) + } + // local to remote, unbuffered + t.Run("unbuffered", func(t *testing.T) { + testHandler(t, handlerTest) + }) + + t.Run("buffered", func(t *testing.T) { + testHandler(t, handlerTest2) + }) +} + +// testStreamDeadline will test if server +func testStreamDeadline(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + const wantDL = 50 * time.Millisecond + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + serverCanceled := make(chan time.Duration, 1) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + started := time.Now() + dl, _ := ctx.Deadline() + if testing.Verbose() { + fmt.Println(GetCaller(ctx).Name, "Server deadline:", time.Until(dl)) + } + <-ctx.Done() + serverCanceled <- time.Since(started) + if testing.Verbose() { + fmt.Println(GetCaller(ctx).Name, "Server Context canceled with", ctx.Err(), "after", time.Since(started)) + } + return nil + }, + OutCapacity: 1, + InCapacity: 0, + })) + errFatal(manager.RegisterStreamingHandler(handlerTest2, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + started := time.Now() + dl, _ := ctx.Deadline() + if testing.Verbose() { + fmt.Println(GetCaller(ctx).Name, "Server deadline:", time.Until(dl)) + } + <-ctx.Done() + serverCanceled <- time.Since(started) + if testing.Verbose() { + fmt.Println(GetCaller(ctx).Name, "Server Context canceled with", ctx.Err(), "after", time.Since(started)) + } + return nil + }, + OutCapacity: 1, + InCapacity: 1, + })) + } + register(local) + register(remote) + // Double remote DL + local.debugMsg(debugAddToDeadline, wantDL) + defer local.debugMsg(debugAddToDeadline, time.Duration(0)) + remote.debugMsg(debugAddToDeadline, wantDL) + defer remote.debugMsg(debugAddToDeadline, time.Duration(0)) + + testHandler := func(t *testing.T, handler HandlerID) { + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + ctx, cancel := context.WithTimeout(context.Background(), wantDL) + defer cancel() + st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload)) + errFatal(err) + clientCanceled := make(chan time.Duration, 1) + go func() { + started := time.Now() + for resp := range st.responses { + err = resp.Err + } + clientCanceled <- time.Since(started) + t.Log("Client Context canceled") + }() + serverEnd := <-serverCanceled + clientEnd := <-clientCanceled + t.Log("server cancel time:", serverEnd) + t.Log("client cancel time:", clientEnd) + if !errors.Is(err, context.DeadlineExceeded) { + t.Error("expected context.DeadlineExceeded, got", err) + } + } + // local to remote, unbuffered + t.Run("unbuffered", func(t *testing.T) { + testHandler(t, handlerTest) + }) + + t.Run("buffered", func(t *testing.T) { + testHandler(t, handlerTest2) + }) +} + +func testServerOutCongestion(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + serverSent := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + // Send many responses. + // Test that this doesn't block. + for i := byte(0); i < 100; i++ { + select { + case resp <- []byte{i}: + // ok + case <-ctx.Done(): + return NewRemoteErr(ctx.Err()) + } + if i == 0 { + close(serverSent) + } + } + return nil + }, + OutCapacity: 1, + InCapacity: 0, + })) + errFatal(manager.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) { + // Simple roundtrip + return append([]byte{}, payload...), nil + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server to send the first response. + <-serverSent + + // Now do 100 other requests to ensure that the server doesn't block. + for i := 0; i < 100; i++ { + _, err := remoteConn.Request(ctx, handlerTest2, []byte(testPayload)) + errFatal(err) + } + // Drain responses + got := 0 + for resp := range st.responses { + // t.Log("got response", resp) + errFatal(resp.Err) + if resp.Msg[0] != byte(got) { + t.Error("expected response", got, "got", resp.Msg[0]) + } + got++ + } + if got != 100 { + t.Error("expected 100 responses, got", got) + } +} + +func testServerInCongestion(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + processHandler := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr { + // Block incoming requests. + var n byte + <-processHandler + for { + select { + case in, ok := <-request: + if !ok { + return nil + } + if in[0] != n { + return NewRemoteErrString(fmt.Sprintf("expected incoming %d, got %d", n, in[0])) + } + n++ + resp <- append([]byte{}, in...) + case <-ctx.Done(): + return NewRemoteErr(ctx.Err()) + } + } + }, + OutCapacity: 5, + InCapacity: 5, + })) + errFatal(manager.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) { + // Simple roundtrip + return append([]byte{}, payload...), nil + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Start sending requests. + go func() { + for i := byte(0); i < 100; i++ { + st.Requests <- []byte{i} + } + close(st.Requests) + }() + // Now do 100 other requests to ensure that the server doesn't block. + for i := 0; i < 100; i++ { + _, err := remoteConn.Request(ctx, handlerTest2, []byte(testPayload)) + errFatal(err) + } + // Start processing requests. + close(processHandler) + + // Drain responses + got := 0 + for resp := range st.responses { + // t.Log("got response", resp) + errFatal(resp.Err) + if resp.Msg[0] != byte(got) { + t.Error("expected response", got, "got", resp.Msg[0]) + } + got++ + } + if got != 100 { + t.Error("expected 100 responses, got", got) + } +} + +func testGenericsStreamRoundtrip(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + defer timeout(5 * time.Second)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + handler := NewStream[*testRequest, *testRequest, *testResponse](handlerTest, newTestRequest, newTestRequest, newTestResponse) + handler.InCapacity = 1 + handler.OutCapacity = 1 + const payloads = 10 + + // 1: Echo + register := func(manager *Manager) { + errFatal(handler.Register(manager, func(ctx context.Context, pp *testRequest, in <-chan *testRequest, out chan<- *testResponse) *RemoteErr { + n := 0 + for i := range in { + if n > payloads { + panic("too many requests") + } + + // t.Log("Got request:", *i) + out <- &testResponse{ + OrgNum: i.Num + pp.Num, + OrgString: pp.String + i.String, + Embedded: *i, + } + n++ + } + return nil + })) + } + register(local) + register(remote) + + // local to remote + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + start := time.Now() + stream, err := handler.Call(context.Background(), remoteConn, &testRequest{Num: 1, String: testPayload}) + errFatal(err) + go func() { + defer close(stream.Requests) + for i := 0; i < payloads; i++ { + // t.Log("sending new client request") + stream.Requests <- &testRequest{Num: i, String: testPayload} + } + }() + var n int + err = stream.Results(func(resp *testResponse) error { + const wantString = testPayload + testPayload + if resp.OrgString != testPayload+testPayload { + t.Errorf("want %q, got %q", wantString, resp.OrgString) + } + if resp.OrgNum != n+1 { + t.Errorf("want %d, got %d", n+1, resp.OrgNum) + } + handler.PutResponse(resp) + n++ + return nil + }) + errFatal(err) + t.Log("EOF.", payloads, " Roundtrips:", time.Since(start)) +} + +func testGenericsStreamRoundtripSubroute(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + defer timeout(5 * time.Second)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + handler := NewStream[*testRequest, *testRequest, *testResponse](handlerTest, newTestRequest, newTestRequest, newTestResponse) + handler.InCapacity = 1 + handler.OutCapacity = 1 + const payloads = 10 + + // 1: Echo + register := func(manager *Manager) { + errFatal(handler.Register(manager, func(ctx context.Context, pp *testRequest, in <-chan *testRequest, out chan<- *testResponse) *RemoteErr { + sub := GetSubroute(ctx) + if sub != "subroute/1" { + t.Fatal("expected subroute/1, got", sub) + } + n := 0 + for i := range in { + if n > payloads { + panic("too many requests") + } + + // t.Log("Got request:", *i) + out <- &testResponse{ + OrgNum: i.Num + pp.Num, + OrgString: pp.String + i.String, + Embedded: *i, + } + n++ + } + return nil + }, "subroute", "1")) + } + register(local) + register(remote) + + // local to remote + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + // Add subroute + remoteSub := remoteConn.Subroute(strings.Join([]string{"subroute", "1"}, "/")) + + start := time.Now() + stream, err := handler.Call(context.Background(), remoteSub, &testRequest{Num: 1, String: testPayload}) + errFatal(err) + go func() { + defer close(stream.Requests) + for i := 0; i < payloads; i++ { + // t.Log("sending new client request") + stream.Requests <- &testRequest{Num: i, String: testPayload} + } + }() + var n int + err = stream.Results(func(resp *testResponse) error { + // t.Logf("got resp: %+v", *resp.Msg) + const wantString = testPayload + testPayload + if resp.OrgString != testPayload+testPayload { + t.Errorf("want %q, got %q", wantString, resp.OrgString) + } + if resp.OrgNum != n+1 { + t.Errorf("want %d, got %d", n+1, resp.OrgNum) + } + handler.PutResponse(resp) + n++ + return nil + }) + + errFatal(err) + t.Log("EOF.", payloads, " Roundtrips:", time.Since(start)) +} + +func timeout(after time.Duration) (cancel func()) { + c := time.After(after) + cc := make(chan struct{}) + go func() { + select { + case <-cc: + return + case <-c: + buf := make([]byte, 1<<20) + stacklen := runtime.Stack(buf, true) + fmt.Printf("=== Timeout, assuming deadlock ===\n*** goroutine dump...\n%s\n*** end\n", string(buf[:stacklen])) + os.Exit(2) + } + }() + return func() { + close(cc) + } +} + +func assertNoActive(t *testing.T, c *Connection) { + t.Helper() + // Tiny bit racy for tests, but we try to play nice. + for i := 10; i >= 0; i-- { + runtime.Gosched() + stats := c.Stats() + if stats.IncomingStreams != 0 { + if i > 0 { + time.Sleep(100 * time.Millisecond) + continue + } + var found []uint64 + c.inStream.Range(func(key uint64, value *muxServer) bool { + found = append(found, key) + return true + }) + t.Errorf("expected no active streams, got %d incoming: %v", stats.IncomingStreams, found) + } + if stats.OutgoingStreams != 0 { + if i > 0 { + time.Sleep(100 * time.Millisecond) + continue + } + var found []uint64 + c.outgoing.Range(func(key uint64, value *muxClient) bool { + found = append(found, key) + return true + }) + t.Errorf("expected no active streams, got %d outgoing: %v", stats.OutgoingStreams, found) + } + return + } +} + +// Inserted manually. +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[StateUnconnected-0] + _ = x[StateConnecting-1] + _ = x[StateConnected-2] + _ = x[StateConnectionError-3] + _ = x[StateShutdown-4] +} + +const stateName = "UnconnectedConnectingConnectedConnectionErrorShutdown" + +var stateIndex = [...]uint8{0, 11, 21, 30, 45, 53} + +func (i State) String() string { + if i >= State(len(stateIndex)-1) { + return "State(" + strconv.FormatInt(int64(i), 10) + ")" + } + return stateName[stateIndex[i]:stateIndex[i+1]] +} diff --git a/internal/grid/grid_types_msgp_test.go b/internal/grid/grid_types_msgp_test.go new file mode 100644 index 000000000..7252ae360 --- /dev/null +++ b/internal/grid/grid_types_msgp_test.go @@ -0,0 +1,368 @@ +package grid + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *testRequest) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Num": + z.Num, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "Num") + return + } + case "String": + z.String, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "String") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z testRequest) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "Num" + err = en.Append(0x82, 0xa3, 0x4e, 0x75, 0x6d) + if err != nil { + return + } + err = en.WriteInt(z.Num) + if err != nil { + err = msgp.WrapError(err, "Num") + return + } + // write "String" + err = en.Append(0xa6, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67) + if err != nil { + return + } + err = en.WriteString(z.String) + if err != nil { + err = msgp.WrapError(err, "String") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z testRequest) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "Num" + o = append(o, 0x82, 0xa3, 0x4e, 0x75, 0x6d) + o = msgp.AppendInt(o, z.Num) + // string "String" + o = append(o, 0xa6, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67) + o = msgp.AppendString(o, z.String) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *testRequest) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Num": + z.Num, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Num") + return + } + case "String": + z.String, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "String") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z testRequest) Msgsize() (s int) { + s = 1 + 4 + msgp.IntSize + 7 + msgp.StringPrefixSize + len(z.String) + return +} + +// DecodeMsg implements msgp.Decodable +func (z *testResponse) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "OrgNum": + z.OrgNum, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "OrgNum") + return + } + case "OrgString": + z.OrgString, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "OrgString") + return + } + case "Embedded": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "Embedded") + return + } + for zb0002 > 0 { + zb0002-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "Embedded") + return + } + switch msgp.UnsafeString(field) { + case "Num": + z.Embedded.Num, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "Embedded", "Num") + return + } + case "String": + z.Embedded.String, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Embedded", "String") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "Embedded") + return + } + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *testResponse) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 3 + // write "OrgNum" + err = en.Append(0x83, 0xa6, 0x4f, 0x72, 0x67, 0x4e, 0x75, 0x6d) + if err != nil { + return + } + err = en.WriteInt(z.OrgNum) + if err != nil { + err = msgp.WrapError(err, "OrgNum") + return + } + // write "OrgString" + err = en.Append(0xa9, 0x4f, 0x72, 0x67, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67) + if err != nil { + return + } + err = en.WriteString(z.OrgString) + if err != nil { + err = msgp.WrapError(err, "OrgString") + return + } + // write "Embedded" + err = en.Append(0xa8, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64) + if err != nil { + return + } + // map header, size 2 + // write "Num" + err = en.Append(0x82, 0xa3, 0x4e, 0x75, 0x6d) + if err != nil { + return + } + err = en.WriteInt(z.Embedded.Num) + if err != nil { + err = msgp.WrapError(err, "Embedded", "Num") + return + } + // write "String" + err = en.Append(0xa6, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67) + if err != nil { + return + } + err = en.WriteString(z.Embedded.String) + if err != nil { + err = msgp.WrapError(err, "Embedded", "String") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *testResponse) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 3 + // string "OrgNum" + o = append(o, 0x83, 0xa6, 0x4f, 0x72, 0x67, 0x4e, 0x75, 0x6d) + o = msgp.AppendInt(o, z.OrgNum) + // string "OrgString" + o = append(o, 0xa9, 0x4f, 0x72, 0x67, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67) + o = msgp.AppendString(o, z.OrgString) + // string "Embedded" + o = append(o, 0xa8, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64) + // map header, size 2 + // string "Num" + o = append(o, 0x82, 0xa3, 0x4e, 0x75, 0x6d) + o = msgp.AppendInt(o, z.Embedded.Num) + // string "String" + o = append(o, 0xa6, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67) + o = msgp.AppendString(o, z.Embedded.String) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *testResponse) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "OrgNum": + z.OrgNum, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "OrgNum") + return + } + case "OrgString": + z.OrgString, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "OrgString") + return + } + case "Embedded": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Embedded") + return + } + for zb0002 > 0 { + zb0002-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "Embedded") + return + } + switch msgp.UnsafeString(field) { + case "Num": + z.Embedded.Num, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Embedded", "Num") + return + } + case "String": + z.Embedded.String, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Embedded", "String") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "Embedded") + return + } + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *testResponse) Msgsize() (s int) { + s = 1 + 7 + msgp.IntSize + 10 + msgp.StringPrefixSize + len(z.OrgString) + 9 + 1 + 4 + msgp.IntSize + 7 + msgp.StringPrefixSize + len(z.Embedded.String) + return +} diff --git a/internal/grid/grid_types_test.go b/internal/grid/grid_types_test.go new file mode 100644 index 000000000..1c3b3bf74 --- /dev/null +++ b/internal/grid/grid_types_test.go @@ -0,0 +1,39 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +//go:generate msgp -unexported -file=$GOFILE -tests=false -o=grid_types_msgp_test.go + +type testRequest struct { + Num int + String string +} + +type testResponse struct { + OrgNum int + OrgString string + Embedded testRequest +} + +func newTestRequest() *testRequest { + return &testRequest{} +} + +func newTestResponse() *testResponse { + return &testResponse{} +} diff --git a/internal/grid/handlers.go b/internal/grid/handlers.go new file mode 100644 index 000000000..b705ba340 --- /dev/null +++ b/internal/grid/handlers.go @@ -0,0 +1,697 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "strings" + "sync" + + "github.com/minio/minio/internal/hash/sha256" + "github.com/minio/minio/internal/logger" + "github.com/tinylib/msgp/msgp" +) + +//go:generate stringer -type=HandlerID -output=handlers_string.go -trimprefix=Handler msg.go $GOFILE + +// HandlerID is a handler identifier. +// It is used to determine request routing on the server. +// Handlers can be registered with a static subroute. +const ( + // handlerInvalid is reserved to check for uninitialized values. + handlerInvalid HandlerID = iota + HandlerLockLock + HandlerLockRLock + HandlerLockUnlock + HandlerLockRUnlock + HandlerLockRefresh + HandlerLockForceUnlock + HandlerWalkDir + HandlerStatVol + HandlerDiskInfo + HandlerNSScanner + HandlerReadXL + HandlerReadVersion + HandlerDeleteFile + HandlerDeleteVersion + HandlerUpdateMetadata + HandlerWriteMetadata + HandlerCheckParts + HandlerRenamedata + + // Add more above here ^^^ + // If all handlers are used, the type of Handler can be changed. + // Handlers have no versioning, so non-compatible handler changes must result in new IDs. + handlerTest + handlerTest2 + handlerLast +) + +func init() { + // Static check if we exceed 255 handler ids. + // Extend the type to uint16 when hit. + if handlerLast > 255 { + panic(fmt.Sprintf("out of handler IDs. %d > %d", handlerLast, 255)) + } +} + +func (h HandlerID) valid() bool { + return h != handlerInvalid && h < handlerLast +} + +func (h HandlerID) isTestHandler() bool { + return h >= handlerTest && h <= handlerTest2 +} + +// RemoteErr is a remote error type. +// Any error seen on a remote will be returned like this. +type RemoteErr string + +// NewRemoteErr creates a new remote error. +// The error type is not preserved. +func NewRemoteErr(err error) *RemoteErr { + if err == nil { + return nil + } + r := RemoteErr(err.Error()) + return &r +} + +// NewRemoteErrf creates a new remote error from a format string. +func NewRemoteErrf(format string, a ...any) *RemoteErr { + r := RemoteErr(fmt.Sprintf(format, a...)) + return &r +} + +// NewNPErr is a helper to no payload and optional remote error. +// The error type is not preserved. +func NewNPErr(err error) (NoPayload, *RemoteErr) { + if err == nil { + return NoPayload{}, nil + } + r := RemoteErr(err.Error()) + return NoPayload{}, &r +} + +// NewRemoteErrString creates a new remote error from a string. +func NewRemoteErrString(msg string) *RemoteErr { + r := RemoteErr(msg) + return &r +} + +func (r RemoteErr) Error() string { + return string(r) +} + +// Is returns if the string representation matches. +func (r *RemoteErr) Is(other error) bool { + if r == nil || other == nil { + return r == other + } + var o RemoteErr + if errors.As(other, &o) { + return r == &o + } + return false +} + +// IsRemoteErr returns the value if the error is a RemoteErr. +func IsRemoteErr(err error) *RemoteErr { + var r RemoteErr + if errors.As(err, &r) { + return &r + } + return nil +} + +type ( + // SingleHandlerFn is handlers for one to one requests. + // A non-nil error value will be returned as RemoteErr(msg) to client. + // No client information or cancellation (deadline) is available. + // Include this in payload if needed. + // Payload should be recycled with PutByteBuffer if not needed after the call. + SingleHandlerFn func(payload []byte) ([]byte, *RemoteErr) + + // StatelessHandlerFn must handle incoming stateless request. + // A non-nil error value will be returned as RemoteErr(msg) to client. + StatelessHandlerFn func(ctx context.Context, payload []byte, resp chan<- []byte) *RemoteErr + + // StatelessHandler is handlers for one to many requests, + // where responses may be dropped. + // Stateless requests provide no incoming stream and there is no flow control + // on outgoing messages. + StatelessHandler struct { + Handle StatelessHandlerFn + // OutCapacity is the output capacity on the caller. + // If <= 0 capacity will be 1. + OutCapacity int + } + + // StreamHandlerFn must process a request with an optional initial payload. + // It must keep consuming from 'in' until it returns. + // 'in' and 'out' are independent. + // The handler should never close out. + // Buffers received from 'in' can be recycled with PutByteBuffer. + // Buffers sent on out can not be referenced once sent. + StreamHandlerFn func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr + + // StreamHandler handles fully bidirectional streams, + // There is flow control in both directions. + StreamHandler struct { + // Handle an incoming request. Initial payload is sent. + // Additional input packets (if any) are streamed to request. + // Upstream will block when request channel is full. + // Response packets can be sent at any time. + // Any non-nil error sent as response means no more responses are sent. + Handle StreamHandlerFn + + // Subroute for handler. + // Subroute must be static and clients should specify a matching subroute. + // Should not be set unless there are different handlers for the same HandlerID. + Subroute string + + // OutCapacity is the output capacity. If <= 0 capacity will be 1. + OutCapacity int + + // InCapacity is the output capacity. + // If == 0 no input is expected + InCapacity int + } +) + +type subHandlerID [32]byte + +func makeSubHandlerID(id HandlerID, subRoute string) subHandlerID { + b := subHandlerID(sha256.Sum256([]byte(subRoute))) + b[0] = byte(id) + b[1] = 0 // Reserved + return b +} + +func (s subHandlerID) withHandler(id HandlerID) subHandlerID { + s[0] = byte(id) + s[1] = 0 // Reserved + return s +} + +func (s *subHandlerID) String() string { + if s == nil { + return "" + } + return hex.EncodeToString(s[:]) +} + +func makeZeroSubHandlerID(id HandlerID) subHandlerID { + return subHandlerID{byte(id)} +} + +type handlers struct { + single [handlerLast]SingleHandlerFn + stateless [handlerLast]*StatelessHandler + streams [handlerLast]*StreamHandler + + subSingle map[subHandlerID]SingleHandlerFn + subStateless map[subHandlerID]*StatelessHandler + subStreams map[subHandlerID]*StreamHandler +} + +func (h *handlers) init() { + h.subSingle = make(map[subHandlerID]SingleHandlerFn) + h.subStateless = make(map[subHandlerID]*StatelessHandler) + h.subStreams = make(map[subHandlerID]*StreamHandler) +} + +func (h *handlers) hasAny(id HandlerID) bool { + if !id.valid() { + return false + } + return h.single[id] != nil || h.stateless[id] != nil || h.streams[id] != nil +} + +func (h *handlers) hasSubhandler(id subHandlerID) bool { + return h.subSingle[id] != nil || h.subStateless[id] != nil || h.subStreams[id] != nil +} + +// RoundTripper provides an interface for type roundtrip serialization. +type RoundTripper interface { + msgp.Unmarshaler + msgp.Marshaler + msgp.Sizer + + comparable +} + +// SingleHandler is a type safe handler for single roundtrip requests. +type SingleHandler[Req, Resp RoundTripper] struct { + id HandlerID + sharedResponse bool + + reqPool sync.Pool + respPool sync.Pool + + nilReq Req + nilResp Resp +} + +// NewSingleHandler creates a typed handler that can provide Marshal/Unmarshal. +// Use Register to register a server handler. +// Use Call to initiate a clientside call. +func NewSingleHandler[Req, Resp RoundTripper](h HandlerID, newReq func() Req, newResp func() Resp) *SingleHandler[Req, Resp] { + s := SingleHandler[Req, Resp]{id: h} + s.reqPool.New = func() interface{} { + return newReq() + } + s.respPool.New = func() interface{} { + return newResp() + } + return &s +} + +// PutResponse will accept a response for reuse. +// These should be returned by the caller. +func (h *SingleHandler[Req, Resp]) PutResponse(r Resp) { + if r != h.nilResp { + h.respPool.Put(r) + } +} + +// WithSharedResponse indicates it is unsafe to reuse the response. +// Typically this is used when the response sharing part of its data structure. +func (h *SingleHandler[Req, Resp]) WithSharedResponse() *SingleHandler[Req, Resp] { + h.sharedResponse = true + return h +} + +// NewResponse creates a new response. +// Handlers can use this to create a reusable response. +// The response may be reused, so caller should clear any fields. +func (h *SingleHandler[Req, Resp]) NewResponse() Resp { + return h.respPool.Get().(Resp) +} + +// putRequest will accept a request for reuse. +// This is not exported, since it shouldn't be needed. +func (h *SingleHandler[Req, Resp]) putRequest(r Req) { + if r != h.nilReq { + h.reqPool.Put(r) + } +} + +// NewRequest creates a new request. +// Handlers can use this to create a reusable request. +// The request may be reused, so caller should clear any fields. +func (h *SingleHandler[Req, Resp]) NewRequest() Req { + return h.reqPool.Get().(Req) +} + +// Register a handler for a Req -> Resp roundtrip. +func (h *SingleHandler[Req, Resp]) Register(m *Manager, handle func(req Req) (resp Resp, err *RemoteErr), subroute ...string) error { + return m.RegisterSingleHandler(h.id, func(payload []byte) ([]byte, *RemoteErr) { + req := h.NewRequest() + _, err := req.UnmarshalMsg(payload) + if err != nil { + PutByteBuffer(payload) + r := RemoteErr(err.Error()) + return nil, &r + } + resp, rerr := handle(req) + h.putRequest(req) + if rerr != nil { + PutByteBuffer(payload) + return nil, rerr + } + payload, err = resp.MarshalMsg(payload[:0]) + if !h.sharedResponse { + h.PutResponse(resp) + } + if err != nil { + PutByteBuffer(payload) + r := RemoteErr(err.Error()) + return nil, &r + } + return payload, nil + }, subroute...) +} + +// Requester is able to send requests to a remote. +type Requester interface { + Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) +} + +// Call the remote with the request and return the response. +// The response should be returned with PutResponse when no error. +// If no deadline is set, a 1-minute deadline is added. +func (h *SingleHandler[Req, Resp]) Call(ctx context.Context, c Requester, req Req) (resp Resp, err error) { + payload, err := req.MarshalMsg(GetByteBuffer()[:0]) + if err != nil { + return resp, err + } + res, err := c.Request(ctx, h.id, payload) + PutByteBuffer(payload) + if err != nil { + return resp, err + } + r := h.NewResponse() + _, err = r.UnmarshalMsg(res) + if err != nil { + h.PutResponse(r) + return resp, err + } + PutByteBuffer(res) + return r, err +} + +// RemoteClient contains information about the caller. +type RemoteClient struct { + Name string +} + +type ( + ctxCallerKey = struct{} + ctxSubrouteKey = struct{} +) + +// GetCaller returns caller information from contexts provided to handlers. +func GetCaller(ctx context.Context) *RemoteClient { + val, _ := ctx.Value(ctxCallerKey{}).(*RemoteClient) + return val +} + +// GetSubroute returns caller information from contexts provided to handlers. +func GetSubroute(ctx context.Context) string { + val, _ := ctx.Value(ctxSubrouteKey{}).(string) + return val +} + +func setCaller(ctx context.Context, cl *RemoteClient) context.Context { + return context.WithValue(ctx, ctxCallerKey{}, cl) +} + +func setSubroute(ctx context.Context, s string) context.Context { + return context.WithValue(ctx, ctxSubrouteKey{}, s) +} + +// StreamTypeHandler is a type safe handler for streaming requests. +type StreamTypeHandler[Payload, Req, Resp RoundTripper] struct { + WithPayload bool + + // Override the default capacities (1) + OutCapacity int + + // Set to 0 if no input is expected. + // Will be 0 if newReq is nil. + InCapacity int + + reqPool sync.Pool + respPool sync.Pool + id HandlerID + newPayload func() Payload + nilReq Req + nilResp Resp + sharedResponse bool +} + +// NewStream creates a typed handler that can provide Marshal/Unmarshal. +// Use Register to register a server handler. +// Use Call to initiate a clientside call. +// newPayload can be nil. In that case payloads will always be nil. +// newReq can be nil. In that case no input stream is expected and the handler will be called with nil 'in' channel. +func NewStream[Payload, Req, Resp RoundTripper](h HandlerID, newPayload func() Payload, newReq func() Req, newResp func() Resp) *StreamTypeHandler[Payload, Req, Resp] { + if newResp == nil { + panic("newResp missing in NewStream") + } + + s := newStreamHandler[Payload, Req, Resp](h) + if newReq != nil { + s.reqPool.New = func() interface{} { + return newReq() + } + } else { + s.InCapacity = 0 + } + s.respPool.New = func() interface{} { + return newResp() + } + s.newPayload = newPayload + s.WithPayload = newPayload != nil + return s +} + +// WithSharedResponse indicates it is unsafe to reuse the response. +// Typically this is used when the response sharing part of its data structure. +func (h *StreamTypeHandler[Payload, Req, Resp]) WithSharedResponse() *StreamTypeHandler[Payload, Req, Resp] { + h.sharedResponse = true + return h +} + +// NewPayload creates a new payload. +func (h *StreamTypeHandler[Payload, Req, Resp]) NewPayload() Payload { + return h.newPayload() +} + +// NewRequest creates a new request. +// The struct may be reused, so caller should clear any fields. +func (h *StreamTypeHandler[Payload, Req, Resp]) NewRequest() Req { + return h.reqPool.Get().(Req) +} + +// PutRequest will accept a request for reuse. +// These should be returned by the handler. +func (h *StreamTypeHandler[Payload, Req, Resp]) PutRequest(r Req) { + if r != h.nilReq { + h.reqPool.Put(r) + } +} + +// PutResponse will accept a response for reuse. +// These should be returned by the caller. +func (h *StreamTypeHandler[Payload, Req, Resp]) PutResponse(r Resp) { + if r != h.nilResp { + h.respPool.Put(r) + } +} + +// NewResponse creates a new response. +// Handlers can use this to create a reusable response. +func (h *StreamTypeHandler[Payload, Req, Resp]) NewResponse() Resp { + return h.respPool.Get().(Resp) +} + +func newStreamHandler[Payload, Req, Resp RoundTripper](h HandlerID) *StreamTypeHandler[Payload, Req, Resp] { + return &StreamTypeHandler[Payload, Req, Resp]{id: h, InCapacity: 1, OutCapacity: 1} +} + +// Register a handler for two-way streaming with payload, input stream and output stream. +// An optional subroute can be given. Multiple entries are joined with '/'. +func (h *StreamTypeHandler[Payload, Req, Resp]) Register(m *Manager, handle func(ctx context.Context, p Payload, in <-chan Req, out chan<- Resp) *RemoteErr, subroute ...string) error { + return h.register(m, handle, subroute...) +} + +// RegisterNoInput a handler for one-way streaming with payload and output stream. +// An optional subroute can be given. Multiple entries are joined with '/'. +func (h *StreamTypeHandler[Payload, Req, Resp]) RegisterNoInput(m *Manager, handle func(ctx context.Context, p Payload, out chan<- Resp) *RemoteErr, subroute ...string) error { + h.InCapacity = 0 + return h.register(m, func(ctx context.Context, p Payload, in <-chan Req, out chan<- Resp) *RemoteErr { + return handle(ctx, p, out) + }, subroute...) +} + +// RegisterNoPayload a handler for one-way streaming with payload and output stream. +// An optional subroute can be given. Multiple entries are joined with '/'. +func (h *StreamTypeHandler[Payload, Req, Resp]) RegisterNoPayload(m *Manager, handle func(ctx context.Context, in <-chan Req, out chan<- Resp) *RemoteErr, subroute ...string) error { + h.WithPayload = false + return h.register(m, func(ctx context.Context, p Payload, in <-chan Req, out chan<- Resp) *RemoteErr { + return handle(ctx, in, out) + }, subroute...) +} + +// Register a handler for two-way streaming with optional payload and input stream. +func (h *StreamTypeHandler[Payload, Req, Resp]) register(m *Manager, handle func(ctx context.Context, p Payload, in <-chan Req, out chan<- Resp) *RemoteErr, subroute ...string) error { + return m.RegisterStreamingHandler(h.id, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr { + var plT Payload + if h.WithPayload { + plT = h.NewPayload() + _, err := plT.UnmarshalMsg(payload) + PutByteBuffer(payload) + if err != nil { + r := RemoteErr(err.Error()) + return &r + } + } + + var inT chan Req + if h.InCapacity > 0 { + // Don't add extra buffering + inT = make(chan Req) + go func() { + defer close(inT) + for { + select { + case <-ctx.Done(): + return + case v, ok := <-in: + if !ok { + return + } + input := h.NewRequest() + _, err := input.UnmarshalMsg(v) + if err != nil { + logger.LogOnceIf(ctx, err, err.Error()) + } + PutByteBuffer(v) + // Send input + select { + case <-ctx.Done(): + return + case inT <- input: + } + } + } + }() + } + outT := make(chan Resp) + outDone := make(chan struct{}) + go func() { + defer close(outDone) + dropOutput := false + for v := range outT { + if dropOutput { + continue + } + dst := GetByteBuffer() + dst, err := v.MarshalMsg(dst[:0]) + if err != nil { + logger.LogOnceIf(ctx, err, err.Error()) + } + if !h.sharedResponse { + h.PutResponse(v) + } + select { + case <-ctx.Done(): + dropOutput = true + case out <- dst: + } + } + }() + rErr := handle(ctx, plT, inT, outT) + close(outT) + <-outDone + return rErr + }, OutCapacity: h.OutCapacity, InCapacity: h.InCapacity, Subroute: strings.Join(subroute, "/"), + }) +} + +// TypedStream is a stream with specific types. +type TypedStream[Req, Resp RoundTripper] struct { + // responses from the remote server. + // Channel will be closed after error or when remote closes. + // responses *must* be read to either an error is returned or the channel is closed. + responses *Stream + newResp func() Resp + + // Requests sent to the server. + // If the handler is defined with 0 incoming capacity this will be nil. + // Channel *must* be closed to signal the end of the stream. + // If the request context is canceled, the stream will no longer process requests. + Requests chan<- Req +} + +// Results returns the results from the remote server one by one. +// If any error is returned by the callback, the stream will be canceled. +// If the context is canceled, the stream will be canceled. +func (s *TypedStream[Req, Resp]) Results(next func(resp Resp) error) (err error) { + return s.responses.Results(func(b []byte) error { + resp := s.newResp() + _, err := resp.UnmarshalMsg(b) + if err != nil { + return err + } + return next(resp) + }) +} + +// Streamer creates a stream. +type Streamer interface { + NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) +} + +// Call the remove with the request and +func (h *StreamTypeHandler[Payload, Req, Resp]) Call(ctx context.Context, c Streamer, payload Payload) (st *TypedStream[Req, Resp], err error) { + var payloadB []byte + if h.WithPayload { + var err error + payloadB, err = payload.MarshalMsg(GetByteBuffer()[:0]) + if err != nil { + return nil, err + } + } + stream, err := c.NewStream(ctx, h.id, payloadB) + PutByteBuffer(payloadB) + if err != nil { + return nil, err + } + + // respT := make(chan TypedResponse[Resp]) + var reqT chan Req + if h.InCapacity > 0 { + reqT = make(chan Req) + // Request handler + go func() { + defer close(stream.Requests) + for req := range reqT { + b, err := req.MarshalMsg(GetByteBuffer()[:0]) + if err != nil { + logger.LogOnceIf(ctx, err, err.Error()) + } + h.PutRequest(req) + stream.Requests <- b + } + }() + } else if stream.Requests != nil { + close(stream.Requests) + } + + return &TypedStream[Req, Resp]{responses: stream, newResp: h.NewResponse, Requests: reqT}, nil +} + +// NoPayload is a type that can be used for handlers that do not use a payload. +type NoPayload struct{} + +// Msgsize returns 0. +func (p NoPayload) Msgsize() int { + return 0 +} + +// UnmarshalMsg satisfies the interface, but is a no-op. +func (NoPayload) UnmarshalMsg(bytes []byte) ([]byte, error) { + return bytes, nil +} + +// MarshalMsg satisfies the interface, but is a no-op. +func (NoPayload) MarshalMsg(bytes []byte) ([]byte, error) { + return bytes, nil +} + +// NewNoPayload returns an empty NoPayload struct. +func NewNoPayload() NoPayload { + return NoPayload{} +} diff --git a/internal/grid/handlers_string.go b/internal/grid/handlers_string.go new file mode 100644 index 000000000..160920aea --- /dev/null +++ b/internal/grid/handlers_string.go @@ -0,0 +1,44 @@ +// Code generated by "stringer -type=HandlerID -output=handlers_string.go -trimprefix=Handler msg.go handlers.go"; DO NOT EDIT. + +package grid + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[handlerInvalid-0] + _ = x[HandlerLockLock-1] + _ = x[HandlerLockRLock-2] + _ = x[HandlerLockUnlock-3] + _ = x[HandlerLockRUnlock-4] + _ = x[HandlerLockRefresh-5] + _ = x[HandlerLockForceUnlock-6] + _ = x[HandlerWalkDir-7] + _ = x[HandlerStatVol-8] + _ = x[HandlerDiskInfo-9] + _ = x[HandlerNSScanner-10] + _ = x[HandlerReadXL-11] + _ = x[HandlerReadVersion-12] + _ = x[HandlerDeleteFile-13] + _ = x[HandlerDeleteVersion-14] + _ = x[HandlerUpdateMetadata-15] + _ = x[HandlerWriteMetadata-16] + _ = x[HandlerCheckParts-17] + _ = x[HandlerRenamedata-18] + _ = x[handlerTest-19] + _ = x[handlerTest2-20] + _ = x[handlerLast-21] +} + +const _HandlerID_name = "handlerInvalidLockLockLockRLockLockUnlockLockRUnlockLockRefreshLockForceUnlockWalkDirStatVolDiskInfoNSScannerReadXLReadVersionDeleteFileDeleteVersionUpdateMetadataWriteMetadataCheckPartsRenamedatahandlerTesthandlerTest2handlerLast" + +var _HandlerID_index = [...]uint8{0, 14, 22, 31, 41, 52, 63, 78, 85, 92, 100, 109, 115, 126, 136, 149, 163, 176, 186, 196, 207, 219, 230} + +func (i HandlerID) String() string { + if i >= HandlerID(len(_HandlerID_index)-1) { + return "HandlerID(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _HandlerID_name[_HandlerID_index[i]:_HandlerID_index[i+1]] +} diff --git a/internal/grid/manager.go b/internal/grid/manager.go new file mode 100644 index 000000000..4cbbe3f2c --- /dev/null +++ b/internal/grid/manager.go @@ -0,0 +1,321 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "runtime/debug" + "strings" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/google/uuid" + "github.com/minio/madmin-go/v3" + "github.com/minio/minio/internal/logger" + "github.com/minio/minio/internal/pubsub" + "github.com/minio/mux" +) + +const ( + // apiVersion is a major version of the entire api. + // Bumping this should only be done when overall, + // incompatible changes are made, not when adding a new handler + // or changing an existing handler. + apiVersion = "v1" + + // RoutePath is the remote path to connect to. + RoutePath = "/minio/grid/" + apiVersion +) + +// Manager will contain all the connections to the grid. +// It also handles incoming requests and routes them to the appropriate connection. +type Manager struct { + // ID is an instance ID, that will change whenever the server restarts. + // This allows remotes to keep track of whether state is preserved. + ID uuid.UUID + + // Immutable after creation, so no locks. + targets map[string]*Connection + + // serverside handlers. + handlers handlers + + // local host name. + local string + + // Validate incoming requests. + authRequest func(r *http.Request) error +} + +// ManagerOptions are options for creating a new grid manager. +type ManagerOptions struct { + Dialer ContextDialer // Outgoing dialer. + Local string // Local host name. + Hosts []string // All hosts, including local in the grid. + AddAuth AuthFn // Add authentication to the given audience. + AuthRequest func(r *http.Request) error // Validate incoming requests. + TLSConfig *tls.Config // TLS to apply to the connnections. + Incoming func(n int64) // Record incoming bytes. + Outgoing func(n int64) // Record outgoing bytes. + BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed. + TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] +} + +// NewManager creates a new grid manager +func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) { + found := false + if o.AuthRequest == nil { + return nil, fmt.Errorf("grid: AuthRequest must be set") + } + m := &Manager{ + ID: uuid.New(), + targets: make(map[string]*Connection, len(o.Hosts)), + local: o.Local, + authRequest: o.AuthRequest, + } + m.handlers.init() + if ctx == nil { + ctx = context.Background() + } + for _, host := range o.Hosts { + if host == o.Local { + if found { + return nil, fmt.Errorf("grid: local host found multiple times") + } + found = true + // No connection to local. + continue + } + m.targets[host] = newConnection(connectionParams{ + ctx: ctx, + id: m.ID, + local: o.Local, + remote: host, + dial: o.Dialer, + handlers: &m.handlers, + auth: o.AddAuth, + blockConnect: o.BlockConnect, + tlsConfig: o.TLSConfig, + publisher: o.TraceTo, + }) + } + if !found { + return nil, fmt.Errorf("grid: local host not found") + } + + return m, nil +} + +// AddToMux will add the grid manager to the given mux. +func (m *Manager) AddToMux(router *mux.Router) { + router.Handle(RoutePath, m.Handler()) +} + +// Handler returns a handler that can be used to serve grid requests. +// This should be connected on RoutePath to the main server. +func (m *Manager) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + defer func() { + if debugPrint { + fmt.Printf("grid: Handler returning from: %v %v\n", req.Method, req.URL) + } + if r := recover(); r != nil { + debug.PrintStack() + err := fmt.Errorf("grid: panic: %v\n", r) + logger.LogIf(context.Background(), err, err.Error()) + w.WriteHeader(http.StatusInternalServerError) + } + }() + if debugPrint { + fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL) + } + ctx := req.Context() + if err := m.authRequest(req); err != nil { + logger.LogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr+err.Error()) + w.WriteHeader(http.StatusForbidden) + return + } + conn, _, _, err := ws.UpgradeHTTP(req, w) + if err != nil { + if debugPrint { + fmt.Printf("grid: Unable to upgrade: %v. http.ResponseWriter is type %T\n", err, w) + } + w.WriteHeader(http.StatusUpgradeRequired) + return + } + defer conn.Close() + if debugPrint { + fmt.Printf("grid: Upgraded request: %v\n", req.URL) + } + + msg, _, err := wsutil.ReadClientData(conn) + if err != nil { + logger.LogIf(ctx, fmt.Errorf("grid: reading connect: %w", err)) + return + } + if debugPrint { + fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg)) + } + + var message message + _, _, err = message.parse(msg) + if err != nil { + if debugPrint { + fmt.Println("parse err:", err) + } + logger.LogIf(ctx, fmt.Errorf("handleMessages: parsing connect: %w", err)) + return + } + if message.Op != OpConnect { + if debugPrint { + fmt.Println("op err:", message.Op) + } + logger.LogIf(ctx, fmt.Errorf("handler: unexpected op: %v", message.Op)) + return + } + var cReq connectReq + _, err = cReq.UnmarshalMsg(message.Payload) + if err != nil { + if debugPrint { + fmt.Println("handler: creq err:", err) + } + logger.LogIf(ctx, fmt.Errorf("handleMessages: parsing ConnectReq: %w", err)) + return + } + remote := m.targets[cReq.Host] + if remote == nil { + if debugPrint { + fmt.Printf("%s: handler: unknown host: %v. Have %v\n", m.local, cReq.Host, m.targets) + } + logger.LogIf(ctx, fmt.Errorf("handler: unknown host: %v", cReq.Host)) + return + } + if debugPrint { + fmt.Printf("handler: Got Connect Req %+v\n", cReq) + } + + logger.LogIf(ctx, remote.handleIncoming(ctx, conn, cReq)) + } +} + +// AuthFn should provide an authentication string for the given aud. +type AuthFn func(aud string) string + +// Connection will return the connection for the specified host. +// If the host does not exist nil will be returned. +func (m *Manager) Connection(host string) *Connection { + return m.targets[host] +} + +// RegisterSingleHandler will register a stateless handler that serves +// []byte -> ([]byte, error) requests. +// subroutes are joined with "/" to a single subroute. +func (m *Manager) RegisterSingleHandler(id HandlerID, h SingleHandlerFn, subroute ...string) error { + if !id.valid() { + return ErrUnknownHandler + } + s := strings.Join(subroute, "/") + if debugPrint { + fmt.Println("RegisterSingleHandler: ", id.String(), "subroute:", s) + } + + if len(subroute) == 0 { + if m.handlers.hasAny(id) && !id.isTestHandler() { + return ErrHandlerAlreadyExists + } + + m.handlers.single[id] = h + return nil + } + subID := makeSubHandlerID(id, s) + if m.handlers.hasSubhandler(subID) && !id.isTestHandler() { + return ErrHandlerAlreadyExists + } + m.handlers.subSingle[subID] = h + // Copy so clients can also pick it up for other subpaths. + m.handlers.subSingle[makeZeroSubHandlerID(id)] = h + return nil +} + +/* +// RegisterStateless will register a stateless handler that serves +// []byte -> stream of ([]byte, error) requests. +func (m *Manager) RegisterStateless(id HandlerID, h StatelessHandler) error { + if !id.valid() { + return ErrUnknownHandler + } + if m.handlers.hasAny(id) && !id.isTestHandler() { + return ErrHandlerAlreadyExists + } + + m.handlers.stateless[id] = &h + return nil +} +*/ + +// RegisterStreamingHandler will register a stateless handler that serves +// two-way streaming requests. +func (m *Manager) RegisterStreamingHandler(id HandlerID, h StreamHandler) error { + if !id.valid() { + return ErrUnknownHandler + } + if debugPrint { + fmt.Println("RegisterStreamingHandler: subroute:", h.Subroute) + } + if h.Subroute == "" { + if m.handlers.hasAny(id) && !id.isTestHandler() { + return ErrHandlerAlreadyExists + } + m.handlers.streams[id] = &h + return nil + } + subID := makeSubHandlerID(id, h.Subroute) + if m.handlers.hasSubhandler(subID) && !id.isTestHandler() { + return ErrHandlerAlreadyExists + } + m.handlers.subStreams[subID] = &h + // Copy so clients can also pick it up for other subpaths. + m.handlers.subStreams[makeZeroSubHandlerID(id)] = &h + return nil +} + +// HostName returns the name of the local host. +func (m *Manager) HostName() string { + return m.local +} + +// Targets returns the names of all remote targets. +func (m *Manager) Targets() []string { + var res []string + for k := range m.targets { + res = append(res, k) + } + return res +} + +// debugMsg should *only* be used by tests. +// +//lint:ignore U1000 This is used by tests. +func (m *Manager) debugMsg(d debugMsg, args ...any) { + for _, c := range m.targets { + c.debugMsg(d, args...) + } +} diff --git a/internal/grid/msg.go b/internal/grid/msg.go new file mode 100644 index 000000000..86e20fb06 --- /dev/null +++ b/internal/grid/msg.go @@ -0,0 +1,281 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "encoding/binary" + "fmt" + "strings" + + "github.com/tinylib/msgp/msgp" + "github.com/zeebo/xxh3" +) + +// Op is operation type. +// +//go:generate msgp -unexported -file=$GOFILE +//go:generate stringer -type=Op -output=msg_string.go -trimprefix=Op $GOFILE + +// Op is operation type messages. +type Op uint8 + +// HandlerID is the ID for the handler of a specific type. +type HandlerID uint8 + +const ( + // OpConnect is a connect request. + OpConnect Op = iota + 1 + + // OpConnectResponse is a response to a connect request. + OpConnectResponse + + // OpPing is a ping request. + // If a mux id is specified that mux is pinged. + // Clients send ping requests. + OpPing + + // OpPong is a OpPing response returned by the server. + OpPong + + // OpConnectMux will connect a new mux with optional payload. + OpConnectMux + + // OpMuxConnectError is an error while connecting a mux. + OpMuxConnectError + + // OpDisconnectClientMux instructs a client to disconnect a mux + OpDisconnectClientMux + + // OpDisconnectServerMux instructs a server to disconnect (cancel) a server mux + OpDisconnectServerMux + + // OpMuxClientMsg contains a message to a client Mux + OpMuxClientMsg + + // OpMuxServerMsg contains a message to a server Mux + OpMuxServerMsg + + // OpUnblockSrvMux contains a message that a server mux is unblocked with one. + // Only Stateful streams has flow control. + OpUnblockSrvMux + + // OpUnblockClMux contains a message that a client mux is unblocked with one. + // Only Stateful streams has flow control. + OpUnblockClMux + + // OpAckMux acknowledges a mux was created. + OpAckMux + + // OpRequest is a single request + response. + // MuxID is returned in response. + OpRequest + + // OpResponse is a response to a single request. + // FlagPayloadIsErr is used to signify that the payload is a string error converted to byte slice. + // When a response is received, the mux is already removed from the remote. + OpResponse + + // OpDisconnect instructs that remote wants to disconnect + OpDisconnect + + // OpMerged is several operations merged into one. + OpMerged +) + +const ( + // FlagCRCxxh3 indicates that, the lower 32 bits of xxhash3 of the serialized + // message will be sent after the serialized message as little endian. + FlagCRCxxh3 Flags = 1 << iota + + // FlagEOF the stream (either direction) is at EOF. + FlagEOF + + // FlagStateless indicates the message is stateless. + // This will retain clients across reconnections or + // if sequence numbers are unexpected. + FlagStateless + + // FlagPayloadIsErr can be used by individual ops to signify that + // The payload is a string error converted to byte slice. + FlagPayloadIsErr + + // FlagPayloadIsZero means that payload is 0-length slice and not nil. + FlagPayloadIsZero + + // FlagSubroute indicates that the message has subroute. + // Subroute will be 32 bytes long and added before any CRC. + FlagSubroute +) + +// This struct cannot be changed and retain backwards compatibility. +// If changed, endpoint version must be bumped. +// +//msgp:tuple message +type message struct { + MuxID uint64 // Mux to receive message if any. + Seq uint32 // Sequence number. + DeadlineMS uint32 // If non-zero, milliseconds until deadline (max 1193h2m47.295s, ~49 days) + Handler HandlerID // ID of handler if invoking a remote handler. + Op Op // Operation. Other fields change based on this value. + Flags Flags // Optional flags. + Payload []byte // Optional payload. +} + +// Flags is a set of flags set on a message. +type Flags uint8 + +func (m message) String() string { + var res []string + if m.MuxID != 0 { + res = append(res, fmt.Sprintf("MuxID: %v", m.MuxID)) + } + if m.Seq != 0 { + res = append(res, fmt.Sprintf("Seq: %v", m.Seq)) + } + if m.DeadlineMS != 0 { + res = append(res, fmt.Sprintf("Deadline: %vms", m.DeadlineMS)) + } + if m.Handler != handlerInvalid { + res = append(res, fmt.Sprintf("Handler: %v", m.Handler)) + } + if m.Op != 0 { + res = append(res, fmt.Sprintf("Op: %v", m.Op)) + } + res = append(res, fmt.Sprintf("Flags: %s", m.Flags.String())) + if len(m.Payload) != 0 { + res = append(res, fmt.Sprintf("Payload: %v", bytesOrLength(m.Payload))) + } + return "{" + strings.Join(res, ", ") + "}" +} + +func (f Flags) String() string { + var res []string + if f&FlagCRCxxh3 != 0 { + res = append(res, "CRC") + } + if f&FlagEOF != 0 { + res = append(res, "EOF") + } + if f&FlagStateless != 0 { + res = append(res, "SL") + } + if f&FlagPayloadIsErr != 0 { + res = append(res, "ERR") + } + if f&FlagPayloadIsZero != 0 { + res = append(res, "ZERO") + } + if f&FlagSubroute != 0 { + res = append(res, "SUB") + } + return "[" + strings.Join(res, ",") + "]" +} + +// parse an incoming message. +func (m *message) parse(b []byte) (*subHandlerID, []byte, error) { + var sub *subHandlerID + if m.Payload == nil { + m.Payload = GetByteBuffer()[:0] + } + h, err := m.UnmarshalMsg(b) + if err != nil { + return nil, nil, fmt.Errorf("read write: %v", err) + } + if len(m.Payload) == 0 && m.Flags&FlagPayloadIsZero == 0 { + PutByteBuffer(m.Payload) + m.Payload = nil + } + if m.Flags&FlagCRCxxh3 != 0 { + const hashLen = 4 + if len(h) < hashLen { + return nil, nil, fmt.Errorf("want crc len 4, got %v", len(h)) + } + got := uint32(xxh3.Hash(b[:len(b)-hashLen])) + want := binary.LittleEndian.Uint32(h[len(h)-hashLen:]) + if got != want { + return nil, nil, fmt.Errorf("crc mismatch: 0x%08x (given) != 0x%08x (bytes)", want, got) + } + h = h[:len(h)-hashLen] + } + // Extract subroute if any. + if m.Flags&FlagSubroute != 0 { + if len(h) < 32 { + return nil, nil, fmt.Errorf("want subroute len 32, got %v", len(h)) + } + subID := (*[32]byte)(h[len(h)-32:]) + sub = (*subHandlerID)(subID) + // Add if more modifications to h is needed + h = h[:len(h)-32] + } + return sub, h, nil +} + +// setZeroPayloadFlag will clear or set the FlagPayloadIsZero if +// m.Payload is length 0, but not nil. +func (m *message) setZeroPayloadFlag() { + m.Flags &^= FlagPayloadIsZero + if len(m.Payload) == 0 && m.Payload != nil { + m.Flags |= FlagPayloadIsZero + } +} + +type receiver interface { + msgp.Unmarshaler + Op() Op +} + +type sender interface { + msgp.MarshalSizer + Op() Op +} + +type connectReq struct { + ID [16]byte + Host string +} + +func (connectReq) Op() Op { + return OpConnect +} + +type connectResp struct { + ID [16]byte + Accepted bool + RejectedReason string +} + +func (connectResp) Op() Op { + return OpConnectResponse +} + +type muxConnectError struct { + Error string +} + +func (muxConnectError) Op() Op { + return OpMuxConnectError +} + +type pongMsg struct { + NotFound bool `msg:"nf"` + Err *string `msg:"e,allownil"` +} + +func (pongMsg) Op() Op { + return OpPong +} diff --git a/internal/grid/msg_gen.go b/internal/grid/msg_gen.go new file mode 100644 index 000000000..15f2a58f9 --- /dev/null +++ b/internal/grid/msg_gen.go @@ -0,0 +1,905 @@ +package grid + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *Flags) DecodeMsg(dc *msgp.Reader) (err error) { + { + var zb0001 uint8 + zb0001, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = Flags(zb0001) + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z Flags) EncodeMsg(en *msgp.Writer) (err error) { + err = en.WriteUint8(uint8(z)) + if err != nil { + err = msgp.WrapError(err) + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z Flags) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + o = msgp.AppendUint8(o, uint8(z)) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *Flags) UnmarshalMsg(bts []byte) (o []byte, err error) { + { + var zb0001 uint8 + zb0001, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = Flags(zb0001) + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z Flags) Msgsize() (s int) { + s = msgp.Uint8Size + return +} + +// DecodeMsg implements msgp.Decodable +func (z *HandlerID) DecodeMsg(dc *msgp.Reader) (err error) { + { + var zb0001 uint8 + zb0001, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = HandlerID(zb0001) + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z HandlerID) EncodeMsg(en *msgp.Writer) (err error) { + err = en.WriteUint8(uint8(z)) + if err != nil { + err = msgp.WrapError(err) + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z HandlerID) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + o = msgp.AppendUint8(o, uint8(z)) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *HandlerID) UnmarshalMsg(bts []byte) (o []byte, err error) { + { + var zb0001 uint8 + zb0001, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = HandlerID(zb0001) + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z HandlerID) Msgsize() (s int) { + s = msgp.Uint8Size + return +} + +// DecodeMsg implements msgp.Decodable +func (z *Op) DecodeMsg(dc *msgp.Reader) (err error) { + { + var zb0001 uint8 + zb0001, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = Op(zb0001) + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z Op) EncodeMsg(en *msgp.Writer) (err error) { + err = en.WriteUint8(uint8(z)) + if err != nil { + err = msgp.WrapError(err) + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z Op) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + o = msgp.AppendUint8(o, uint8(z)) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *Op) UnmarshalMsg(bts []byte) (o []byte, err error) { + { + var zb0001 uint8 + zb0001, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + (*z) = Op(zb0001) + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z Op) Msgsize() (s int) { + s = msgp.Uint8Size + return +} + +// DecodeMsg implements msgp.Decodable +func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "ID": + err = dc.ReadExactBytes((z.ID)[:]) + if err != nil { + err = msgp.WrapError(err, "ID") + return + } + case "Host": + z.Host, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Host") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "ID" + err = en.Append(0x82, 0xa2, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteBytes((z.ID)[:]) + if err != nil { + err = msgp.WrapError(err, "ID") + return + } + // write "Host" + err = en.Append(0xa4, 0x48, 0x6f, 0x73, 0x74) + if err != nil { + return + } + err = en.WriteString(z.Host) + if err != nil { + err = msgp.WrapError(err, "Host") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "ID" + o = append(o, 0x82, 0xa2, 0x49, 0x44) + o = msgp.AppendBytes(o, (z.ID)[:]) + // string "Host" + o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74) + o = msgp.AppendString(o, z.Host) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "ID": + bts, err = msgp.ReadExactBytes(bts, (z.ID)[:]) + if err != nil { + err = msgp.WrapError(err, "ID") + return + } + case "Host": + z.Host, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Host") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *connectReq) Msgsize() (s int) { + s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) + return +} + +// DecodeMsg implements msgp.Decodable +func (z *connectResp) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "ID": + err = dc.ReadExactBytes((z.ID)[:]) + if err != nil { + err = msgp.WrapError(err, "ID") + return + } + case "Accepted": + z.Accepted, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "Accepted") + return + } + case "RejectedReason": + z.RejectedReason, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "RejectedReason") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *connectResp) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 3 + // write "ID" + err = en.Append(0x83, 0xa2, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteBytes((z.ID)[:]) + if err != nil { + err = msgp.WrapError(err, "ID") + return + } + // write "Accepted" + err = en.Append(0xa8, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x65, 0x64) + if err != nil { + return + } + err = en.WriteBool(z.Accepted) + if err != nil { + err = msgp.WrapError(err, "Accepted") + return + } + // write "RejectedReason" + err = en.Append(0xae, 0x52, 0x65, 0x6a, 0x65, 0x63, 0x74, 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e) + if err != nil { + return + } + err = en.WriteString(z.RejectedReason) + if err != nil { + err = msgp.WrapError(err, "RejectedReason") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *connectResp) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 3 + // string "ID" + o = append(o, 0x83, 0xa2, 0x49, 0x44) + o = msgp.AppendBytes(o, (z.ID)[:]) + // string "Accepted" + o = append(o, 0xa8, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x65, 0x64) + o = msgp.AppendBool(o, z.Accepted) + // string "RejectedReason" + o = append(o, 0xae, 0x52, 0x65, 0x6a, 0x65, 0x63, 0x74, 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e) + o = msgp.AppendString(o, z.RejectedReason) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *connectResp) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "ID": + bts, err = msgp.ReadExactBytes(bts, (z.ID)[:]) + if err != nil { + err = msgp.WrapError(err, "ID") + return + } + case "Accepted": + z.Accepted, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Accepted") + return + } + case "RejectedReason": + z.RejectedReason, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "RejectedReason") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *connectResp) Msgsize() (s int) { + s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 9 + msgp.BoolSize + 15 + msgp.StringPrefixSize + len(z.RejectedReason) + return +} + +// DecodeMsg implements msgp.Decodable +func (z *message) DecodeMsg(dc *msgp.Reader) (err error) { + var zb0001 uint32 + zb0001, err = dc.ReadArrayHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + if zb0001 != 7 { + err = msgp.ArrayError{Wanted: 7, Got: zb0001} + return + } + z.MuxID, err = dc.ReadUint64() + if err != nil { + err = msgp.WrapError(err, "MuxID") + return + } + z.Seq, err = dc.ReadUint32() + if err != nil { + err = msgp.WrapError(err, "Seq") + return + } + z.DeadlineMS, err = dc.ReadUint32() + if err != nil { + err = msgp.WrapError(err, "DeadlineMS") + return + } + { + var zb0002 uint8 + zb0002, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err, "Handler") + return + } + z.Handler = HandlerID(zb0002) + } + { + var zb0003 uint8 + zb0003, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err, "Op") + return + } + z.Op = Op(zb0003) + } + { + var zb0004 uint8 + zb0004, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err, "Flags") + return + } + z.Flags = Flags(zb0004) + } + z.Payload, err = dc.ReadBytes(z.Payload) + if err != nil { + err = msgp.WrapError(err, "Payload") + return + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *message) EncodeMsg(en *msgp.Writer) (err error) { + // array header, size 7 + err = en.Append(0x97) + if err != nil { + return + } + err = en.WriteUint64(z.MuxID) + if err != nil { + err = msgp.WrapError(err, "MuxID") + return + } + err = en.WriteUint32(z.Seq) + if err != nil { + err = msgp.WrapError(err, "Seq") + return + } + err = en.WriteUint32(z.DeadlineMS) + if err != nil { + err = msgp.WrapError(err, "DeadlineMS") + return + } + err = en.WriteUint8(uint8(z.Handler)) + if err != nil { + err = msgp.WrapError(err, "Handler") + return + } + err = en.WriteUint8(uint8(z.Op)) + if err != nil { + err = msgp.WrapError(err, "Op") + return + } + err = en.WriteUint8(uint8(z.Flags)) + if err != nil { + err = msgp.WrapError(err, "Flags") + return + } + err = en.WriteBytes(z.Payload) + if err != nil { + err = msgp.WrapError(err, "Payload") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *message) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // array header, size 7 + o = append(o, 0x97) + o = msgp.AppendUint64(o, z.MuxID) + o = msgp.AppendUint32(o, z.Seq) + o = msgp.AppendUint32(o, z.DeadlineMS) + o = msgp.AppendUint8(o, uint8(z.Handler)) + o = msgp.AppendUint8(o, uint8(z.Op)) + o = msgp.AppendUint8(o, uint8(z.Flags)) + o = msgp.AppendBytes(o, z.Payload) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *message) UnmarshalMsg(bts []byte) (o []byte, err error) { + var zb0001 uint32 + zb0001, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + if zb0001 != 7 { + err = msgp.ArrayError{Wanted: 7, Got: zb0001} + return + } + z.MuxID, bts, err = msgp.ReadUint64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "MuxID") + return + } + z.Seq, bts, err = msgp.ReadUint32Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Seq") + return + } + z.DeadlineMS, bts, err = msgp.ReadUint32Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "DeadlineMS") + return + } + { + var zb0002 uint8 + zb0002, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Handler") + return + } + z.Handler = HandlerID(zb0002) + } + { + var zb0003 uint8 + zb0003, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Op") + return + } + z.Op = Op(zb0003) + } + { + var zb0004 uint8 + zb0004, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Flags") + return + } + z.Flags = Flags(zb0004) + } + z.Payload, bts, err = msgp.ReadBytesBytes(bts, z.Payload) + if err != nil { + err = msgp.WrapError(err, "Payload") + return + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *message) Msgsize() (s int) { + s = 1 + msgp.Uint64Size + msgp.Uint32Size + msgp.Uint32Size + msgp.Uint8Size + msgp.Uint8Size + msgp.Uint8Size + msgp.BytesPrefixSize + len(z.Payload) + return +} + +// DecodeMsg implements msgp.Decodable +func (z *muxConnectError) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Error": + z.Error, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z muxConnectError) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 1 + // write "Error" + err = en.Append(0x81, 0xa5, 0x45, 0x72, 0x72, 0x6f, 0x72) + if err != nil { + return + } + err = en.WriteString(z.Error) + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z muxConnectError) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 1 + // string "Error" + o = append(o, 0x81, 0xa5, 0x45, 0x72, 0x72, 0x6f, 0x72) + o = msgp.AppendString(o, z.Error) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *muxConnectError) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Error": + z.Error, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z muxConnectError) Msgsize() (s int) { + s = 1 + 6 + msgp.StringPrefixSize + len(z.Error) + return +} + +// DecodeMsg implements msgp.Decodable +func (z *pongMsg) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "nf": + z.NotFound, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "NotFound") + return + } + case "e": + if dc.IsNil() { + err = dc.ReadNil() + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + z.Err = nil + } else { + if z.Err == nil { + z.Err = new(string) + } + *z.Err, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *pongMsg) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "nf" + err = en.Append(0x82, 0xa2, 0x6e, 0x66) + if err != nil { + return + } + err = en.WriteBool(z.NotFound) + if err != nil { + err = msgp.WrapError(err, "NotFound") + return + } + // write "e" + err = en.Append(0xa1, 0x65) + if err != nil { + return + } + if z.Err == nil { + err = en.WriteNil() + if err != nil { + return + } + } else { + err = en.WriteString(*z.Err) + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *pongMsg) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "nf" + o = append(o, 0x82, 0xa2, 0x6e, 0x66) + o = msgp.AppendBool(o, z.NotFound) + // string "e" + o = append(o, 0xa1, 0x65) + if z.Err == nil { + o = msgp.AppendNil(o) + } else { + o = msgp.AppendString(o, *z.Err) + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *pongMsg) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "nf": + z.NotFound, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "NotFound") + return + } + case "e": + if msgp.IsNil(bts) { + bts, err = msgp.ReadNilBytes(bts) + if err != nil { + return + } + z.Err = nil + } else { + if z.Err == nil { + z.Err = new(string) + } + *z.Err, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Err") + return + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *pongMsg) Msgsize() (s int) { + s = 1 + 3 + msgp.BoolSize + 2 + if z.Err == nil { + s += msgp.NilSize + } else { + s += msgp.StringPrefixSize + len(*z.Err) + } + return +} diff --git a/internal/grid/msg_gen_test.go b/internal/grid/msg_gen_test.go new file mode 100644 index 000000000..a3170c811 --- /dev/null +++ b/internal/grid/msg_gen_test.go @@ -0,0 +1,575 @@ +package grid + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalconnectReq(t *testing.T) { + v := connectReq{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgconnectReq(b *testing.B) { + v := connectReq{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgconnectReq(b *testing.B) { + v := connectReq{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalconnectReq(b *testing.B) { + v := connectReq{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeconnectReq(t *testing.T) { + v := connectReq{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeconnectReq Msgsize() is inaccurate") + } + + vn := connectReq{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeconnectReq(b *testing.B) { + v := connectReq{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeconnectReq(b *testing.B) { + v := connectReq{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalconnectResp(t *testing.T) { + v := connectResp{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgconnectResp(b *testing.B) { + v := connectResp{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgconnectResp(b *testing.B) { + v := connectResp{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalconnectResp(b *testing.B) { + v := connectResp{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeconnectResp(t *testing.T) { + v := connectResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeconnectResp Msgsize() is inaccurate") + } + + vn := connectResp{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeconnectResp(b *testing.B) { + v := connectResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeconnectResp(b *testing.B) { + v := connectResp{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalmessage(t *testing.T) { + v := message{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgmessage(b *testing.B) { + v := message{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgmessage(b *testing.B) { + v := message{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalmessage(b *testing.B) { + v := message{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodemessage(t *testing.T) { + v := message{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodemessage Msgsize() is inaccurate") + } + + vn := message{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodemessage(b *testing.B) { + v := message{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodemessage(b *testing.B) { + v := message{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalmuxConnectError(t *testing.T) { + v := muxConnectError{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgmuxConnectError(b *testing.B) { + v := muxConnectError{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgmuxConnectError(b *testing.B) { + v := muxConnectError{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalmuxConnectError(b *testing.B) { + v := muxConnectError{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodemuxConnectError(t *testing.T) { + v := muxConnectError{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodemuxConnectError Msgsize() is inaccurate") + } + + vn := muxConnectError{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodemuxConnectError(b *testing.B) { + v := muxConnectError{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodemuxConnectError(b *testing.B) { + v := muxConnectError{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalpongMsg(t *testing.T) { + v := pongMsg{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgpongMsg(b *testing.B) { + v := pongMsg{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgpongMsg(b *testing.B) { + v := pongMsg{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalpongMsg(b *testing.B) { + v := pongMsg{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodepongMsg(t *testing.T) { + v := pongMsg{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodepongMsg Msgsize() is inaccurate") + } + + vn := pongMsg{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodepongMsg(b *testing.B) { + v := pongMsg{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodepongMsg(b *testing.B) { + v := pongMsg{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/internal/grid/msg_string.go b/internal/grid/msg_string.go new file mode 100644 index 000000000..f89986b31 --- /dev/null +++ b/internal/grid/msg_string.go @@ -0,0 +1,40 @@ +// Code generated by "stringer -type=Op -output=msg_string.go -trimprefix=Op msg.go"; DO NOT EDIT. + +package grid + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[OpConnect-1] + _ = x[OpConnectResponse-2] + _ = x[OpPing-3] + _ = x[OpPong-4] + _ = x[OpConnectMux-5] + _ = x[OpMuxConnectError-6] + _ = x[OpDisconnectClientMux-7] + _ = x[OpDisconnectServerMux-8] + _ = x[OpMuxClientMsg-9] + _ = x[OpMuxServerMsg-10] + _ = x[OpUnblockSrvMux-11] + _ = x[OpUnblockClMux-12] + _ = x[OpAckMux-13] + _ = x[OpRequest-14] + _ = x[OpResponse-15] + _ = x[OpDisconnect-16] + _ = x[OpMerged-17] +} + +const _Op_name = "ConnectConnectResponsePingPongConnectMuxMuxConnectErrorDisconnectClientMuxDisconnectServerMuxMuxClientMsgMuxServerMsgUnblockSrvMuxUnblockClMuxAckMuxRequestResponseDisconnectMerged" + +var _Op_index = [...]uint8{0, 7, 22, 26, 30, 40, 55, 74, 93, 105, 117, 130, 142, 148, 155, 163, 173, 179} + +func (i Op) String() string { + i -= 1 + if i >= Op(len(_Op_index)-1) { + return "Op(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _Op_name[_Op_index[i]:_Op_index[i+1]] +} diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go new file mode 100644 index 000000000..a04122c30 --- /dev/null +++ b/internal/grid/muxclient.go @@ -0,0 +1,539 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/minio/minio/internal/logger" + "github.com/zeebo/xxh3" +) + +// muxClient is a stateful connection to a remote. +type muxClient struct { + MuxID uint64 + SendSeq, RecvSeq uint32 + LastPong int64 + BaseFlags Flags + ctx context.Context + cancelFn context.CancelCauseFunc + parent *Connection + respWait chan<- Response + respMu sync.Mutex + singleResp bool + closed bool + stateless bool + acked bool + init bool + deadline time.Duration + outBlock chan struct{} + subroute *subHandlerID +} + +// Response is a response from the server. +type Response struct { + Msg []byte + Err error +} + +func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient { + ctx, cancelFn := context.WithCancelCause(ctx) + return &muxClient{ + MuxID: muxID, + ctx: ctx, + cancelFn: cancelFn, + parent: parent, + LastPong: time.Now().Unix(), + BaseFlags: parent.baseFlags, + } +} + +// roundtrip performs a roundtrip, returning the first response. +// This cannot be used concurrently. +func (m *muxClient) roundtrip(h HandlerID, req []byte) ([]byte, error) { + if m.init { + return nil, errors.New("mux client already used") + } + m.init = true + m.singleResp = true + msg := message{ + Op: OpRequest, + MuxID: m.MuxID, + Handler: h, + Flags: m.BaseFlags | FlagEOF, + Payload: req, + DeadlineMS: uint32(m.deadline.Milliseconds()), + } + if m.subroute != nil { + msg.Flags |= FlagSubroute + } + ch := make(chan Response, 1) + m.respWait = ch + ctx := m.ctx + + // Add deadline if none. + if msg.DeadlineMS == 0 { + msg.DeadlineMS = uint32(defaultSingleRequestTimeout / time.Millisecond) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, defaultSingleRequestTimeout) + defer cancel() + } + // Send... (no need for lock yet) + if err := m.sendLocked(msg); err != nil { + return nil, err + } + if debugReqs { + fmt.Println(m.MuxID, m.parent.String(), "SEND") + } + // Wait for response or context. + select { + case v, ok := <-ch: + if !ok { + return nil, ErrDisconnected + } + if debugReqs && v.Err != nil { + v.Err = fmt.Errorf("%d %s RESP ERR: %w", m.MuxID, m.parent.String(), v.Err) + } + return v.Msg, v.Err + case <-ctx.Done(): + if debugReqs { + return nil, fmt.Errorf("%d %s ERR: %w", m.MuxID, m.parent.String(), context.Cause(ctx)) + } + return nil, context.Cause(ctx) + } +} + +// send the message. msg.Seq and msg.MuxID will be set +func (m *muxClient) send(msg message) error { + m.respMu.Lock() + defer m.respMu.Unlock() + if m.closed { + return errors.New("mux client closed") + } + return m.sendLocked(msg) +} + +// sendLocked the message. msg.Seq and msg.MuxID will be set. +// m.respMu must be held. +func (m *muxClient) sendLocked(msg message) error { + dst := GetByteBuffer()[:0] + msg.Seq = m.SendSeq + msg.MuxID = m.MuxID + msg.Flags |= m.BaseFlags + if debugPrint { + fmt.Println("Client sending", &msg, "to", m.parent.Remote) + } + m.SendSeq++ + + dst, err := msg.MarshalMsg(dst) + if err != nil { + return err + } + if msg.Flags&FlagSubroute != 0 { + if m.subroute == nil { + return fmt.Errorf("internal error: subroute not defined on client") + } + hid := m.subroute.withHandler(msg.Handler) + before := len(dst) + dst = append(dst, hid[:]...) + if debugPrint { + fmt.Println("Added subroute", hid.String(), "to message", msg, "len", len(dst)-before) + } + } + if msg.Flags&FlagCRCxxh3 != 0 { + h := xxh3.Hash(dst) + dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) + } + return m.parent.send(dst) +} + +// RequestStateless will send a single payload request and stream back results. +// req may not be read/written to after calling. +// TODO: Not implemented +func (m *muxClient) RequestStateless(h HandlerID, req []byte, out chan<- Response) { + if m.init { + out <- Response{Err: errors.New("mux client already used")} + } + m.init = true + + // Try to grab an initial block. + m.singleResp = false + msg := message{ + Op: OpConnectMux, + Handler: h, + Flags: FlagEOF, + Payload: req, + DeadlineMS: uint32(m.deadline.Milliseconds()), + } + msg.setZeroPayloadFlag() + if m.subroute != nil { + msg.Flags |= FlagSubroute + } + + // Send... + err := m.send(msg) + if err != nil { + out <- Response{Err: err} + return + } + + // Route directly to output. + m.respWait = out +} + +// RequestStream will send a single payload request and stream back results. +// 'requests' can be nil, in which case only req is sent as input. +// It will however take less resources. +func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []byte, responses chan Response) (*Stream, error) { + if m.init { + return nil, errors.New("mux client already used") + } + if responses == nil { + return nil, errors.New("RequestStream: responses channel is nil") + } + m.init = true + m.respWait = responses // Route directly to output. + + // Try to grab an initial block. + m.singleResp = false + m.RecvSeq = m.SendSeq // Sync + if cap(requests) > 0 { + m.outBlock = make(chan struct{}, cap(requests)) + } + msg := message{ + Op: OpConnectMux, + Handler: h, + Payload: payload, + DeadlineMS: uint32(m.deadline.Milliseconds()), + } + msg.setZeroPayloadFlag() + if requests == nil { + msg.Flags |= FlagEOF + } + if m.subroute != nil { + msg.Flags |= FlagSubroute + } + + // Send... + err := m.send(msg) + if err != nil { + return nil, err + } + if debugPrint { + fmt.Println("Connecting Mux", m.MuxID, ",to", m.parent.Remote) + } + + // Space for one message and an error. + responseCh := make(chan Response, 1) + + // Spawn simple disconnect + if requests == nil { + start := time.Now() + go m.handleOneWayStream(start, responseCh, responses) + return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn}, nil + } + + // Deliver responses and send unblocks back to the server. + go m.handleTwowayResponses(responseCh, responses) + go m.handleTwowayRequests(responses, requests) + + return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn}, nil +} + +func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Response, respServer <-chan Response) { + if debugPrint { + defer func() { + fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) + }() + } + defer close(respHandler) + var pingTimer <-chan time.Time + if m.deadline == 0 || m.deadline > clientPingInterval { + ticker := time.NewTicker(clientPingInterval) + defer ticker.Stop() + pingTimer = ticker.C + atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + } + defer m.parent.deleteMux(false, m.MuxID) + for { + select { + case <-m.ctx.Done(): + if debugPrint { + fmt.Println("Client sending disconnect to mux", m.MuxID) + } + m.respMu.Lock() + defer m.respMu.Unlock() // We always return in this path. + if !m.closed { + respHandler <- Response{Err: context.Cause(m.ctx)} + logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) + m.closeLocked() + } + return + case resp, ok := <-respServer: + if !ok { + return + } + select { + case respHandler <- resp: + m.respMu.Lock() + if !m.closed { + logger.LogIf(m.ctx, m.sendLocked(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})) + } + m.respMu.Unlock() + case <-m.ctx.Done(): + // Client canceled. Don't block. + // Next loop will catch it. + } + case <-pingTimer: + if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 { + m.respMu.Lock() + defer m.respMu.Unlock() // We always return in this path. + if !m.closed { + respHandler <- Response{Err: ErrDisconnected} + logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) + m.closeLocked() + } + return + } + // Send new ping. + logger.LogIf(m.ctx, m.send(message{Op: OpPing, MuxID: m.MuxID})) + } + } +} + +func (m *muxClient) handleTwowayResponses(responseCh chan Response, responses chan Response) { + defer m.parent.deleteMux(false, m.MuxID) + defer close(responseCh) + for resp := range responses { + responseCh <- resp + m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) + } +} + +func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests chan []byte) { + var errState bool + start := time.Now() + if debugPrint { + defer func() { + fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) + }() + } + + // Listen for client messages. + for { + select { + case <-m.ctx.Done(): + if debugPrint { + fmt.Println("Client sending disconnect to mux", m.MuxID) + } + m.respMu.Lock() + defer m.respMu.Unlock() + logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) + if !m.closed { + responses <- Response{Err: context.Cause(m.ctx)} + m.closeLocked() + } + return + case req, ok := <-requests: + if !ok { + // Done send EOF + if debugPrint { + fmt.Println("Client done, sending EOF to mux", m.MuxID) + } + msg := message{ + Op: OpMuxClientMsg, + MuxID: m.MuxID, + Seq: 1, + Flags: FlagEOF, + } + msg.setZeroPayloadFlag() + err := m.send(msg) + if err != nil { + m.respMu.Lock() + responses <- Response{Err: err} + m.closeLocked() + m.respMu.Unlock() + } + return + } + if errState { + continue + } + // Grab a send token. + select { + case <-m.ctx.Done(): + errState = true + continue + case <-m.outBlock: + } + msg := message{ + Op: OpMuxClientMsg, + MuxID: m.MuxID, + Seq: 1, + Payload: req, + } + msg.setZeroPayloadFlag() + err := m.send(msg) + PutByteBuffer(req) + if err != nil { + responses <- Response{Err: err} + m.close() + errState = true + continue + } + msg.Seq++ + } + } +} + +// checkSeq will check if sequence number is correct and increment it by 1. +func (m *muxClient) checkSeq(seq uint32) (ok bool) { + if seq != m.RecvSeq { + if debugPrint { + fmt.Printf("MuxID: %d client, expected sequence %d, got %d\n", m.MuxID, m.RecvSeq, seq) + } + m.addResponse(Response{Err: ErrIncorrectSequence}) + return false + } + m.RecvSeq++ + return true +} + +// response will send handleIncoming response to client. +// may never block. +// Should return whether the next call would block. +func (m *muxClient) response(seq uint32, r Response) { + if debugReqs { + fmt.Println(m.MuxID, m.parent.String(), "RESP") + } + if debugPrint { + fmt.Printf("mux %d: got msg seqid %d, payload length: %d, err:%v\n", m.MuxID, seq, len(r.Msg), r.Err) + } + if !m.checkSeq(seq) { + if debugReqs { + fmt.Println(m.MuxID, m.parent.String(), "CHECKSEQ FAIL", m.RecvSeq, seq) + } + PutByteBuffer(r.Msg) + r.Err = ErrIncorrectSequence + m.addResponse(r) + return + } + atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + ok := m.addResponse(r) + if !ok { + PutByteBuffer(r.Msg) + } +} + +// error is a message from the server to disconnect. +func (m *muxClient) error(err RemoteErr) { + if debugPrint { + fmt.Printf("mux %d: got remote err:%v\n", m.MuxID, string(err)) + } + m.addResponse(Response{Err: &err}) +} + +func (m *muxClient) ack(seq uint32) { + if !m.checkSeq(seq) { + return + } + if m.acked || m.outBlock == nil { + return + } + available := cap(m.outBlock) + for i := 0; i < available; i++ { + m.outBlock <- struct{}{} + } + m.acked = true +} + +func (m *muxClient) unblockSend(seq uint32) { + if !m.checkSeq(seq) { + return + } + select { + case m.outBlock <- struct{}{}: + default: + logger.LogIf(m.ctx, errors.New("output unblocked overflow")) + } +} + +func (m *muxClient) pong(msg pongMsg) { + if msg.NotFound || msg.Err != nil { + err := errors.New("remote terminated call") + if msg.Err != nil { + err = fmt.Errorf("remove pong failed: %v", &msg.Err) + } + m.addResponse(Response{Err: err}) + return + } + atomic.StoreInt64(&m.LastPong, time.Now().Unix()) +} + +// addResponse will add a response to the response channel. +// This function will never block +func (m *muxClient) addResponse(r Response) (ok bool) { + m.respMu.Lock() + defer m.respMu.Unlock() + if m.closed { + return false + } + select { + case m.respWait <- r: + if r.Err != nil { + if debugPrint { + fmt.Println("Closing mux", m.MuxID, "due to error:", r.Err) + } + m.closeLocked() + } + return true + default: + if m.stateless { + // Drop message if not stateful. + return + } + err := errors.New("INTERNAL ERROR: Response was blocked") + logger.LogIf(m.ctx, err) + m.closeLocked() + return false + } +} + +func (m *muxClient) close() { + if debugPrint { + fmt.Println("closing outgoing mux", m.MuxID) + } + m.respMu.Lock() + defer m.respMu.Unlock() + m.closeLocked() +} + +func (m *muxClient) closeLocked() { + if m.closed { + return + } + close(m.respWait) + m.respWait = nil + m.closed = true +} diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go new file mode 100644 index 000000000..6b36394c8 --- /dev/null +++ b/internal/grid/muxserver.go @@ -0,0 +1,331 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/minio/minio/internal/logger" +) + +const lastPingThreshold = 4 * clientPingInterval + +type muxServer struct { + ID uint64 + LastPing int64 + SendSeq, RecvSeq uint32 + Resp chan []byte + BaseFlags Flags + ctx context.Context + cancel context.CancelFunc + inbound chan []byte + parent *Connection + sendMu sync.Mutex + recvMu sync.Mutex + outBlock chan struct{} +} + +func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { + var cancel context.CancelFunc + ctx = setCaller(ctx, c.remote) + if msg.DeadlineMS > 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond) + } else { + ctx, cancel = context.WithCancel(ctx) + } + m := muxServer{ + ID: msg.MuxID, + RecvSeq: msg.Seq + 1, + SendSeq: msg.Seq, + ctx: ctx, + cancel: cancel, + parent: c, + LastPing: time.Now().Unix(), + BaseFlags: c.baseFlags, + } + go func() { + // TODO: Handle + }() + + return &m +} + +func newMuxStream(ctx context.Context, msg message, c *Connection, handler StreamHandler) *muxServer { + var cancel context.CancelFunc + ctx = setCaller(ctx, c.remote) + if len(handler.Subroute) > 0 { + ctx = setSubroute(ctx, handler.Subroute) + } + if msg.DeadlineMS > 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond+c.addDeadline) + } else { + ctx, cancel = context.WithCancel(ctx) + } + + send := make(chan []byte) + inboundCap, outboundCap := handler.InCapacity, handler.OutCapacity + if outboundCap <= 0 { + outboundCap = 1 + } + + m := muxServer{ + ID: msg.MuxID, + RecvSeq: msg.Seq + 1, + SendSeq: msg.Seq, + ctx: ctx, + cancel: cancel, + parent: c, + inbound: nil, + outBlock: make(chan struct{}, outboundCap), + LastPing: time.Now().Unix(), + BaseFlags: c.baseFlags, + } + // Acknowledge Mux created. + var ack message + ack.Op = OpAckMux + ack.Flags = m.BaseFlags + ack.MuxID = m.ID + m.send(ack) + if debugPrint { + fmt.Println("connected stream mux:", ack.MuxID) + } + + // Data inbound to the handler + var handlerIn chan []byte + if inboundCap > 0 { + m.inbound = make(chan []byte, inboundCap) + handlerIn = make(chan []byte, 1) + go func(inbound <-chan []byte) { + defer close(handlerIn) + // Send unblocks when we have delivered the message to the handler. + for in := range inbound { + handlerIn <- in + m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) + } + }(m.inbound) + } + for i := 0; i < outboundCap; i++ { + m.outBlock <- struct{}{} + } + + // Handler goroutine. + var handlerErr *RemoteErr + go func() { + start := time.Now() + defer func() { + if debugPrint { + fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond)) + } + if r := recover(); r != nil { + logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r)) + err := RemoteErr(fmt.Sprintf("panic: %v", r)) + handlerErr = &err + } + if debugPrint { + fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr) + } + close(send) + }() + // handlerErr is guarded by 'send' channel. + handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send) + }() + // Response sender gorutine... + go func(outBlock <-chan struct{}) { + defer m.parent.deleteMux(true, m.ID) + for { + // Process outgoing message. + var payload []byte + var ok bool + select { + case payload, ok = <-send: + case <-ctx.Done(): + return + } + select { + case <-ctx.Done(): + return + case <-outBlock: + } + msg := message{ + MuxID: m.ID, + Op: OpMuxServerMsg, + Flags: c.baseFlags, + } + if !ok { + if debugPrint { + fmt.Println("muxServer: Mux", m.ID, "send EOF", handlerErr) + } + msg.Flags |= FlagEOF + if handlerErr != nil { + msg.Flags |= FlagPayloadIsErr + msg.Payload = []byte(*handlerErr) + } + msg.setZeroPayloadFlag() + m.send(msg) + return + } + msg.Payload = payload + msg.setZeroPayloadFlag() + m.send(msg) + } + }(m.outBlock) + + // Remote aliveness check. + if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { + go func() { + t := time.NewTicker(lastPingThreshold / 4) + defer t.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-t.C: + last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) + if last > lastPingThreshold { + logger.LogIf(m.ctx, fmt.Errorf("canceling remote mux %d not seen for %v", m.ID, last)) + m.close() + return + } + } + } + }() + } + return &m +} + +// checkSeq will check if sequence number is correct and increment it by 1. +func (m *muxServer) checkSeq(seq uint32) (ok bool) { + if seq != m.RecvSeq { + if debugPrint { + fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq) + } + m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq)) + return false + } + m.RecvSeq++ + return true +} + +func (m *muxServer) message(msg message) { + if debugPrint { + fmt.Printf("muxServer: recevied message %d, length %d\n", msg.Seq, len(msg.Payload)) + } + m.recvMu.Lock() + defer m.recvMu.Unlock() + if cap(m.inbound) == 0 { + m.disconnect("did not expect inbound message") + return + } + if !m.checkSeq(msg.Seq) { + return + } + // Note, on EOF no value can be sent. + if msg.Flags&FlagEOF != 0 { + if len(msg.Payload) > 0 { + logger.LogIf(m.ctx, fmt.Errorf("muxServer: EOF message with payload")) + } + close(m.inbound) + m.inbound = nil + return + } + + select { + case <-m.ctx.Done(): + case m.inbound <- msg.Payload: + if debugPrint { + fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq) + } + default: + m.disconnect("handler blocked") + } +} + +func (m *muxServer) unblockSend(seq uint32) { + if !m.checkSeq(seq) { + return + } + m.recvMu.Lock() + defer m.recvMu.Unlock() + if m.outBlock == nil { + // Closed + return + } + select { + case m.outBlock <- struct{}{}: + default: + logger.LogIf(m.ctx, errors.New("output unblocked overflow")) + } +} + +func (m *muxServer) ping(seq uint32) pongMsg { + if !m.checkSeq(seq) { + msg := fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq) + return pongMsg{Err: &msg} + } + select { + case <-m.ctx.Done(): + err := context.Cause(m.ctx).Error() + return pongMsg{Err: &err} + default: + atomic.StoreInt64(&m.LastPing, time.Now().Unix()) + return pongMsg{} + } +} + +func (m *muxServer) disconnect(msg string) { + if debugPrint { + fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg) + } + if msg != "" { + m.send(message{Op: OpMuxServerMsg, MuxID: m.ID, Flags: FlagPayloadIsErr | FlagEOF, Payload: []byte(msg)}) + } else { + m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID}) + } + m.parent.deleteMux(true, m.ID) +} + +func (m *muxServer) send(msg message) { + m.sendMu.Lock() + defer m.sendMu.Unlock() + msg.MuxID = m.ID + msg.Seq = m.SendSeq + m.SendSeq++ + if debugPrint { + fmt.Printf("Mux %d, Sending %+v\n", m.ID, msg) + } + logger.LogIf(m.ctx, m.parent.queueMsg(msg, nil)) +} + +func (m *muxServer) close() { + m.cancel() + m.recvMu.Lock() + defer m.recvMu.Unlock() + if m.inbound != nil { + close(m.inbound) + m.inbound = nil + } + if m.outBlock != nil { + close(m.outBlock) + m.outBlock = nil + } +} diff --git a/internal/grid/stats.go b/internal/grid/stats.go new file mode 100644 index 000000000..2d21d4bdc --- /dev/null +++ b/internal/grid/stats.go @@ -0,0 +1,24 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +// ConnectionStats contains connection statistics. +type ConnectionStats struct { + OutgoingStreams int + IncomingStreams int +} diff --git a/internal/grid/stream.go b/internal/grid/stream.go new file mode 100644 index 000000000..abebd2229 --- /dev/null +++ b/internal/grid/stream.go @@ -0,0 +1,93 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "context" + "errors" +) + +// A Stream is a two-way stream. +// All responses *must* be read by the caller. +// If the call is canceled through the context, +// the appropriate error will be returned. +type Stream struct { + // responses from the remote server. + // Channel will be closed after error or when remote closes. + // All responses *must* be read by the caller until either an error is returned or the channel is closed. + // Canceling the context will cause the context cancellation error to be returned. + responses <-chan Response + cancel context.CancelCauseFunc + + // Requests sent to the server. + // If the handler is defined with 0 incoming capacity this will be nil. + // Channel *must* be closed to signal the end of the stream. + // If the request context is canceled, the stream will no longer process requests. + // Requests sent cannot be used any further by the called. + Requests chan<- []byte + + ctx context.Context +} + +// Send a payload to the remote server. +func (s *Stream) Send(b []byte) error { + if s.Requests == nil { + return errors.New("stream does not accept requests") + } + select { + case s.Requests <- b: + return nil + case <-s.ctx.Done(): + return context.Cause(s.ctx) + } +} + +// Results returns the results from the remote server one by one. +// If any error is returned by the callback, the stream will be canceled. +// If the context is canceled, the stream will be canceled. +func (s *Stream) Results(next func(b []byte) error) (err error) { + done := false + defer func() { + if !done { + if s.cancel != nil { + s.cancel(err) + } + // Drain channel. + for range s.responses { + } + } + }() + for { + select { + case <-s.ctx.Done(): + return context.Cause(s.ctx) + case resp, ok := <-s.responses: + if !ok { + done = true + return nil + } + if resp.Err != nil { + return resp.Err + } + err = next(resp.Msg) + if err != nil { + return err + } + } + } +} diff --git a/internal/grid/trace.go b/internal/grid/trace.go new file mode 100644 index 000000000..58c7ef116 --- /dev/null +++ b/internal/grid/trace.go @@ -0,0 +1,110 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "net/http" + "time" + + "github.com/minio/madmin-go/v3" + "github.com/minio/minio/internal/pubsub" +) + +// traceRequests adds request tracing to the connection. +func (c *Connection) traceRequests(p *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]) { + c.trace = &tracer{ + Publisher: p, + TraceType: madmin.TraceInternal, + Prefix: "grid.", + Local: c.Local, + Remote: c.Remote, + Subroute: "", + } +} + +// subroute adds a specific subroute to the request. +func (c *tracer) subroute(subroute string) *tracer { + if c == nil { + return nil + } + c2 := *c + c2.Subroute = subroute + return &c2 +} + +type tracer struct { + Publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] + TraceType madmin.TraceType + Prefix string + Local string + Remote string + Subroute string +} + +func (c *muxClient) traceRoundtrip(t *tracer, h HandlerID, req []byte) ([]byte, error) { + if t == nil || t.Publisher.NumSubscribers(t.TraceType) == 0 { + return c.roundtrip(h, req) + } + start := time.Now() + body := bytesOrLength(req) + resp, err := c.roundtrip(h, req) + end := time.Now() + status := http.StatusOK + errString := "" + if err != nil { + errString = err.Error() + if IsRemoteErr(err) == nil { + status = http.StatusInternalServerError + } else { + status = http.StatusBadRequest + } + } + trace := madmin.TraceInfo{ + TraceType: t.TraceType, + FuncName: t.Prefix + h.String(), + NodeName: t.Local, + Time: start, + Duration: end.Sub(start), + Path: t.Subroute, + Error: errString, + HTTP: &madmin.TraceHTTPStats{ + ReqInfo: madmin.TraceRequestInfo{ + Time: start, + Proto: "grid", + Method: "REQ", + Client: t.Remote, + Headers: nil, + Path: t.Subroute, + Body: []byte(body), + }, + RespInfo: madmin.TraceResponseInfo{ + Time: end, + Headers: nil, + StatusCode: status, + Body: []byte(bytesOrLength(resp)), + }, + CallStats: madmin.TraceCallStats{ + InputBytes: len(req), + OutputBytes: len(resp), + TimeToFirstByte: end.Sub(start), + }, + }, + } + t.Publisher.Publish(trace) + return resp, err +} diff --git a/internal/grid/types.go b/internal/grid/types.go new file mode 100644 index 000000000..c9dee4abe --- /dev/null +++ b/internal/grid/types.go @@ -0,0 +1,172 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "errors" + + "github.com/tinylib/msgp/msgp" +) + +// MSS is a map[string]string that can be serialized. +// It is not very efficient, but it is only used for easy parameter passing. +type MSS map[string]string + +// Get returns the value for the given key. +func (m *MSS) Get(key string) string { + if m == nil { + return "" + } + return (*m)[key] +} + +// UnmarshalMsg deserializes m from the provided byte slice and returns the +// remainder of bytes. +func (m *MSS) UnmarshalMsg(bts []byte) (o []byte, err error) { + if m == nil { + return bts, errors.New("MSS: UnmarshalMsg on nil pointer") + } + if msgp.IsNil(bts) { + bts = bts[1:] + *m = nil + return bts, nil + } + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Values") + return + } + dst := *m + if dst == nil { + dst = make(map[string]string, zb0002) + } else if len(dst) > 0 { + for key := range dst { + delete(dst, key) + } + } + for zb0002 > 0 { + var za0001 string + var za0002 string + zb0002-- + za0001, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Values") + return + } + za0002, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Values", za0001) + return + } + dst[za0001] = za0002 + } + *m = dst + return bts, nil +} + +// MarshalMsg appends the bytes representation of b to the provided byte slice. +func (m *MSS) MarshalMsg(bytes []byte) (o []byte, err error) { + if m == nil || *m == nil { + return msgp.AppendNil(bytes), nil + } + o = msgp.AppendMapHeader(bytes, uint32(len(*m))) + for za0001, za0002 := range *m { + o = msgp.AppendString(o, za0001) + o = msgp.AppendString(o, za0002) + } + return o, nil +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message. +func (m *MSS) Msgsize() int { + if m == nil || *m == nil { + return msgp.NilSize + } + s := msgp.MapHeaderSize + for za0001, za0002 := range *m { + s += msgp.StringPrefixSize + len(za0001) + msgp.StringPrefixSize + len(za0002) + } + return s +} + +// NewMSS returns a new MSS. +func NewMSS() *MSS { + m := MSS(make(map[string]string)) + return &m +} + +// NewMSSWith returns a new MSS with the given map. +func NewMSSWith(m map[string]string) *MSS { + m2 := MSS(m) + return &m2 +} + +// NewBytes returns a new Bytes. +func NewBytes() *Bytes { + b := Bytes(GetByteBuffer()[:0]) + return &b +} + +// NewBytesWith returns a new Bytes with the provided content. +func NewBytesWith(b []byte) *Bytes { + bb := Bytes(b) + return &bb +} + +// Bytes provides a byte slice that can be serialized. +type Bytes []byte + +// UnmarshalMsg deserializes b from the provided byte slice and returns the +// remainder of bytes. +func (b *Bytes) UnmarshalMsg(bytes []byte) ([]byte, error) { + if b == nil { + return bytes, errors.New("Bytes: UnmarshalMsg on nil pointer") + } + if bytes, err := msgp.ReadNilBytes(bytes); err == nil { + *b = nil + return bytes, nil + } + val, bytes, err := msgp.ReadBytesZC(bytes) + if err != nil { + return bytes, err + } + if cap(*b) >= len(val) { + *b = (*b)[:len(val)] + copy(*b, val) + } else { + *b = append(make([]byte, 0, len(val)), val...) + } + return bytes, nil +} + +// MarshalMsg appends the bytes representation of b to the provided byte slice. +func (b *Bytes) MarshalMsg(bytes []byte) ([]byte, error) { + if b == nil || *b == nil { + return msgp.AppendNil(bytes), nil + } + return msgp.AppendBytes(bytes, *b), nil +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message. +func (b *Bytes) Msgsize() int { + if b == nil || *b == nil { + return msgp.NilSize + } + return msgp.ArrayHeaderSize + len(*b) +} diff --git a/internal/grid/types_test.go b/internal/grid/types_test.go new file mode 100644 index 000000000..43899de28 --- /dev/null +++ b/internal/grid/types_test.go @@ -0,0 +1,168 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package grid + +import ( + "reflect" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalMSS(t *testing.T) { + v := MSS{"abc": "def", "ghi": "jkl"} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + var v2 MSS + left, err := v2.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) != 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } + if !reflect.DeepEqual(v, v2) { + t.Errorf("MSS: %v != %v", v, v2) + } +} + +func TestMarshalUnmarshalMSSNil(t *testing.T) { + v := MSS(nil) + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + v2 := MSS(make(map[string]string, 1)) + left, err := v2.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) != 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } + if !reflect.DeepEqual(v, v2) { + t.Errorf("MSS: %v != %v", v, v2) + } +} + +func BenchmarkMarshalMsgMSS(b *testing.B) { + v := MSS{"abc": "def", "ghi": "jkl"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgMSS(b *testing.B) { + v := MSS{"abc": "def", "ghi": "jkl"} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalMSS(b *testing.B) { + v := MSS{"abc": "def", "ghi": "jkl"} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalBytes(t *testing.T) { + v := Bytes([]byte("abc123123123")) + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + var v2 Bytes + left, err := v2.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) != 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } + if !reflect.DeepEqual(v, v2) { + t.Errorf("MSS: %v != %v", v, v2) + } +} + +func TestMarshalUnmarshalBytesNil(t *testing.T) { + v := Bytes(nil) + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + v2 := Bytes(make([]byte, 1)) + left, err := v2.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) != 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } + if !reflect.DeepEqual(v, v2) { + t.Errorf("MSS: %v != %v", v, v2) + } +} diff --git a/internal/http/response-recorder.go b/internal/http/response-recorder.go index 9ab47d5e8..72622601a 100644 --- a/internal/http/response-recorder.go +++ b/internal/http/response-recorder.go @@ -18,10 +18,12 @@ package http import ( + "bufio" "bytes" "errors" "fmt" "io" + "net" "net/http" "time" ) @@ -50,6 +52,15 @@ type ResponseRecorder struct { headersLogged bool } +// Hijack - hijacks the underlying connection +func (lrw *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := lrw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("response writer does not support hijacking. Type is %T", lrw.ResponseWriter) + } + return hj.Hijack() +} + // NewResponseRecorder - returns a wrapped response writer to trap // http status codes for auditing purposes. func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 7727fb830..e5e3abe02 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -20,6 +20,7 @@ package logger import ( "context" "encoding/hex" + "errors" "fmt" "go/build" "path/filepath" @@ -258,6 +259,20 @@ func LogIf(ctx context.Context, err error, errKind ...interface{}) { logIf(ctx, err, errKind...) } +// LogIfNot prints a detailed error message during +// the execution of the server, if it is not an ignored error (either internal or given). +func LogIfNot(ctx context.Context, err error, ignored ...error) { + if logIgnoreError(err) { + return + } + for _, ignore := range ignored { + if errors.Is(err, ignore) { + return + } + } + logIf(ctx, err) +} + func errToEntry(ctx context.Context, err error, errKind ...interface{}) log.Entry { logKind := madmin.LogKindAll if len(errKind) > 0 { diff --git a/internal/logger/target/testlogger/testlogger.go b/internal/logger/target/testlogger/testlogger.go new file mode 100644 index 000000000..35f5b3da6 --- /dev/null +++ b/internal/logger/target/testlogger/testlogger.go @@ -0,0 +1,170 @@ +// Copyright (c) 2015-2023 MinIO, Inc. +// +// This file is part of MinIO Object Storage stack +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package testlogger contains an autoregistering logger that can be used to capture logging events +// for individual tests. +// This package should only be included by test files. +// To enable logging for a test, use: +// +// func TestSomething(t *testing.T) { +// defer testlogger.T.SetLogTB(t)() +// +// This cannot be used for parallel tests. +package testlogger + +import ( + "context" + "fmt" + "os" + "strings" + "sync/atomic" + "testing" + + "github.com/minio/minio/internal/logger" + "github.com/minio/minio/internal/logger/target/types" + "github.com/minio/pkg/v2/logger/message/log" +) + +const ( + logMessage = iota + errorMessage + fatalMessage +) + +// T is the test logger. +var T = &testLogger{} + +func init() { + logger.AddSystemTarget(context.Background(), T) +} + +type testLogger struct { + current atomic.Pointer[testing.TB] + action atomic.Int32 +} + +// SetLogTB will set the logger to output to tb. +// Call the returned function to disable logging. +func (t *testLogger) SetLogTB(tb testing.TB) func() { + return t.setTB(tb, logMessage) +} + +// SetErrorTB will set the logger to output to tb.Error. +// Call the returned function to disable logging. +func (t *testLogger) SetErrorTB(tb testing.TB) func() { + return t.setTB(tb, errorMessage) +} + +// SetFatalTB will set the logger to output to tb.Panic. +// Call the returned function to disable logging. +func (t *testLogger) SetFatalTB(tb testing.TB) func() { + return t.setTB(tb, fatalMessage) +} + +func (t *testLogger) setTB(tb testing.TB, action int32) func() { + old := t.action.Swap(action) + t.current.Store(&tb) + return func() { + t.current.Store(nil) + t.action.Store(old) + } +} + +func (t *testLogger) String() string { + tb := t.current.Load() + if tb != nil { + tbb := *tb + return tbb.Name() + } + return "" +} + +func (t *testLogger) Endpoint() string { + return "" +} + +func (t *testLogger) Stats() types.TargetStats { + return types.TargetStats{} +} + +func (t *testLogger) Init(ctx context.Context) error { + return nil +} + +func (t *testLogger) IsOnline(ctx context.Context) bool { + return t.current.Load() != nil +} + +func (t *testLogger) Cancel() { + t.current.Store(nil) +} + +func (t *testLogger) Send(ctx context.Context, entry interface{}) error { + tb := t.current.Load() + var logf func(format string, args ...any) + if tb != nil { + tbb := *tb + tbb.Helper() + switch t.action.Load() { + case errorMessage: + logf = tbb.Errorf + case fatalMessage: + logf = tbb.Fatalf + default: + logf = tbb.Logf + } + } else { + switch t.action.Load() { + case errorMessage: + logf = func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format+"\n", args...) + } + case fatalMessage: + logf = func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format+"\n", args...) + } + defer os.Exit(1) + default: + logf = func(format string, args ...any) { + fmt.Fprintf(os.Stdout, format+"\n", args...) + } + } + } + + switch v := entry.(type) { + case log.Entry: + if v.Trace == nil { + logf("%s: %s", v.Level, v.Message) + } else { + msg := fmt.Sprintf("%s: %+v", v.Level, v.Trace.Message) + for i, m := range v.Trace.Source { + if i == 0 && strings.Contains(m, "logger.go:") { + continue + } + msg += fmt.Sprintf("\n%s", m) + } + logf("%s", msg) + } + default: + logf("%+v (%T)", v, v) + } + return nil +} + +func (t *testLogger) Type() types.TargetType { + return types.TargetConsole +} diff --git a/internal/rest/client.go b/internal/rest/client.go index b71dd8823..5a9d9baf9 100644 --- a/internal/rest/client.go +++ b/internal/rest/client.go @@ -448,6 +448,7 @@ func (c *Client) MarkOffline(err error) bool { c.Lock() c.lastErr = err c.lastErrTime = time.Now() + atomic.StoreInt64(&c.lastConn, time.Now().UnixNano()) c.Unlock() // Start goroutine that will attempt to reconnect. // If server is already trying to reconnect this will have no effect.