From 0d0b0aa599dc1022d9eb8ebbe15432f688467047 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 8 Jul 2024 14:44:00 -0700 Subject: [PATCH] Abstract grid connections (#20038) Add `ConnDialer` to abstract connection creation. - `IncomingConn(ctx context.Context, conn net.Conn)` is provided as an entry point for incoming custom connections. - `ConnectWS` is provided to create web socket connections. --- cmd/grid.go | 18 +-- cmd/routers.go | 2 +- cmd/storage-rest-server.go | 33 +++-- internal/grid/connection.go | 55 ++------- internal/grid/connection_test.go | 24 ++-- internal/grid/debug.go | 31 +++-- internal/grid/grid.go | 47 +++++++ internal/grid/manager.go | 206 ++++++++++++++++++------------- internal/grid/msg.go | 17 ++- internal/grid/msg_gen.go | 60 ++++++++- 10 files changed, 313 insertions(+), 180 deletions(-) diff --git a/cmd/grid.go b/cmd/grid.go index 81c9d07fe..125dda952 100644 --- a/cmd/grid.go +++ b/cmd/grid.go @@ -41,17 +41,19 @@ func initGlobalGrid(ctx context.Context, eps EndpointServerPools) error { // Pass Dialer for websocket grid, make sure we do not // provide any DriveOPTimeout() function, as that is not // useful over persistent connections. - Dialer: grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))), + Dialer: grid.ConnectWS( + grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))), + newCachedAuthToken(), + &tls.Config{ + RootCAs: globalRootCAs, + CipherSuites: fips.TLSCiphers(), + CurvePreferences: fips.TLSCurveIDs(), + }), Local: local, Hosts: hosts, - AddAuth: newCachedAuthToken(), - AuthRequest: storageServerRequestValidate, + AuthToken: validateStorageRequestToken, + AuthFn: newCachedAuthToken(), BlockConnect: globalGridStart, - TLSConfig: &tls.Config{ - RootCAs: globalRootCAs, - CipherSuites: fips.TLSCiphers(), - CurvePreferences: fips.TLSCurveIDs(), - }, // Record incoming and outgoing bytes. Incoming: globalConnStats.incInternodeInputBytes, Outgoing: globalConnStats.incInternodeOutputBytes, diff --git a/cmd/routers.go b/cmd/routers.go index d5e77ddf2..e0cafdbea 100644 --- a/cmd/routers.go +++ b/cmd/routers.go @@ -39,7 +39,7 @@ func registerDistErasureRouters(router *mux.Router, endpointServerPools Endpoint registerLockRESTHandlers() // Add grid to router - router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(), noGZFlag, noObjLayerFlag)) + router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(storageServerRequestValidate), noGZFlag, noObjLayerFlag)) } // List of some generic middlewares which are applied for all incoming requests. diff --git a/cmd/storage-rest-server.go b/cmd/storage-rest-server.go index 99eeea4d8..201ffc5dd 100644 --- a/cmd/storage-rest-server.go +++ b/cmd/storage-rest-server.go @@ -109,6 +109,24 @@ func (s *storageRESTServer) writeErrorResponse(w http.ResponseWriter, err error) // DefaultSkewTime - skew time is 15 minutes between minio peers. const DefaultSkewTime = 15 * time.Minute +// validateStorageRequestToken will validate the token against the provided audience. +func validateStorageRequestToken(token, audience string) error { + claims := xjwt.NewStandardClaims() + if err := xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil { + return errAuthentication + } + + owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey + if !owner { + return errAuthentication + } + + if claims.Audience != audience { + return errAuthentication + } + return nil +} + // Authenticates storage client's requests and validates for skewed time. func storageServerRequestValidate(r *http.Request) error { token, err := jwtreq.AuthorizationHeaderExtractor.ExtractToken(r) @@ -118,19 +136,8 @@ func storageServerRequestValidate(r *http.Request) error { } return errMalformedAuth } - - claims := xjwt.NewStandardClaims() - if err = xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil { - return errAuthentication - } - - owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey - if !owner { - return errAuthentication - } - - if claims.Audience != r.URL.RawQuery { - return errAuthentication + if err = validateStorageRequestToken(token, r.URL.RawQuery); err != nil { + return err } requestTimeStr := r.Header.Get("X-Minio-Time") diff --git a/internal/grid/connection.go b/internal/grid/connection.go index f3ccf222e..eae23c40a 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -20,7 +20,6 @@ package grid import ( "bytes" "context" - "crypto/tls" "encoding/binary" "errors" "fmt" @@ -28,7 +27,6 @@ import ( "math" "math/rand" "net" - "net/http" "runtime/debug" "strings" "sync" @@ -100,9 +98,9 @@ type Connection struct { // Client or serverside. side ws.State - // Transport for outgoing connections. - dialer ContextDialer - header http.Header + // Dialer for outgoing connections. + dial ConnDialer + authFn AuthFn handleMsgWg sync.WaitGroup @@ -112,10 +110,8 @@ type Connection struct { handlers *handlers remote *RemoteClient - auth AuthFn clientPingInterval time.Duration connPingInterval time.Duration - tlsConfig *tls.Config blockConnect chan struct{} incomingBytes func(n int64) // Record incoming bytes. @@ -205,13 +201,12 @@ type connectionParams struct { ctx context.Context id uuid.UUID local, remote string - dial ContextDialer handlers *handlers - auth AuthFn - tlsConfig *tls.Config incomingBytes func(n int64) // Record incoming bytes. outgoingBytes func(n int64) // Record outgoing bytes. publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] + dialer ConnDialer + authFn AuthFn blockConnect chan struct{} } @@ -227,16 +222,14 @@ func newConnection(o connectionParams) *Connection { outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000), inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000), outQueue: make(chan []byte, defaultOutQueue), - dialer: o.dial, side: ws.StateServerSide, connChange: &sync.Cond{L: &sync.Mutex{}}, handlers: o.handlers, - auth: o.auth, - header: make(http.Header, 1), remote: &RemoteClient{Name: o.remote}, clientPingInterval: clientPingInterval, connPingInterval: connPingInterval, - tlsConfig: o.tlsConfig, + dial: o.dialer, + authFn: o.authFn, } if debugPrint { // Random Mux ID @@ -648,41 +641,17 @@ func (c *Connection) connect() { if c.State() == StateShutdown { return } - toDial := strings.Replace(c.Remote, "http://", "ws://", 1) - toDial = strings.Replace(toDial, "https://", "wss://", 1) - toDial += RoutePath - - dialer := ws.DefaultDialer - dialer.ReadBufferSize = readBufferSize - dialer.WriteBufferSize = writeBufferSize - dialer.Timeout = defaultDialTimeout - if c.dialer != nil { - dialer.NetDial = c.dialer.DialContext - } - if c.header == nil { - c.header = make(http.Header, 2) - } - c.header.Set("Authorization", "Bearer "+c.auth("")) - c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339)) - - if len(c.header) > 0 { - dialer.Header = ws.HandshakeHeaderHTTP(c.header) - } - dialer.TLSConfig = c.tlsConfig dialStarted := time.Now() if debugPrint { - fmt.Println(c.Local, "Connecting to ", toDial) - } - conn, br, _, err := dialer.Dial(c.ctx, toDial) - if br != nil { - ws.PutReader(br) + fmt.Println(c.Local, "Connecting to ", c.Remote) } + conn, err := c.dial(c.ctx, c.Remote) c.connMu.Lock() c.debugOutConn = conn c.connMu.Unlock() retry := func(err error) { if debugPrint { - fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err) + fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, c.Remote, err) } sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout))) next := dialStarted.Add(sleep / 2) @@ -696,7 +665,7 @@ func (c *Connection) connect() { } if gotState != StateConnecting { // Don't print error on first attempt, and after that only once per hour. - gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), toDial) + gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, c.Remote, err, err, sleep, gotState), c.Remote) } c.updateState(StateConnectionError) time.Sleep(sleep) @@ -712,7 +681,9 @@ func (c *Connection) connect() { req := connectReq{ Host: c.Local, ID: c.id, + Time: time.Now(), } + req.addToken(c.authFn) err = c.sendMsg(conn, m, &req) if err != nil { retry(err) diff --git a/internal/grid/connection_test.go b/internal/grid/connection_test.go index f95b122e1..aae0d8b7c 100644 --- a/internal/grid/connection_test.go +++ b/internal/grid/connection_test.go @@ -52,11 +52,13 @@ func TestDisconnect(t *testing.T) { localHost := hosts[0] remoteHost := hosts[1] local, err := NewManager(context.Background(), ManagerOptions{ - Dialer: dialer.DialContext, + Dialer: ConnectWS(dialer.DialContext, + dummyNewToken, + nil), Local: localHost, Hosts: hosts, - AddAuth: func(aud string) string { return aud }, - AuthRequest: dummyRequestValidate, + AuthFn: dummyNewToken, + AuthToken: dummyTokenValidate, BlockConnect: connReady, }) errFatal(err) @@ -74,17 +76,19 @@ func TestDisconnect(t *testing.T) { })) remote, err := NewManager(context.Background(), ManagerOptions{ - Dialer: dialer.DialContext, + Dialer: ConnectWS(dialer.DialContext, + dummyNewToken, + nil), Local: remoteHost, Hosts: hosts, - AddAuth: func(aud string) string { return aud }, - AuthRequest: dummyRequestValidate, + AuthFn: dummyNewToken, + AuthToken: dummyTokenValidate, BlockConnect: connReady, }) errFatal(err) - localServer := startServer(t, listeners[0], wrapServer(local.Handler())) - remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler())) + localServer := startServer(t, listeners[0], wrapServer(local.Handler(dummyRequestValidate))) + remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler(dummyRequestValidate))) close(connReady) defer func() { @@ -165,10 +169,6 @@ func TestDisconnect(t *testing.T) { <-gotCall } -func dummyRequestValidate(r *http.Request) error { - return nil -} - func TestShouldConnect(t *testing.T) { var c Connection var cReverse Connection diff --git a/internal/grid/debug.go b/internal/grid/debug.go index 8110acb65..8d02bb7fe 100644 --- a/internal/grid/debug.go +++ b/internal/grid/debug.go @@ -82,20 +82,20 @@ func SetupTestGrid(n int) (*TestGrid, error) { res.cancel = cancel for i, host := range hosts { manager, err := NewManager(ctx, ManagerOptions{ - Dialer: dialer.DialContext, - Local: host, - Hosts: hosts, - AuthRequest: func(r *http.Request) error { - return nil - }, - AddAuth: func(aud string) string { return aud }, + Dialer: ConnectWS(dialer.DialContext, + dummyNewToken, + nil), + Local: host, + Hosts: hosts, + AuthFn: dummyNewToken, + AuthToken: dummyTokenValidate, BlockConnect: ready, }) if err != nil { return nil, err } m := mux.NewRouter() - m.Handle(RoutePath, manager.Handler()) + m.Handle(RoutePath, manager.Handler(dummyRequestValidate)) res.Managers = append(res.Managers, manager) res.Servers = append(res.Servers, startHTTPServer(listeners[i], m)) res.Listeners = append(res.Listeners, listeners[i]) @@ -164,3 +164,18 @@ func startHTTPServer(listener net.Listener, handler http.Handler) (server *httpt server.Start() return server } + +func dummyRequestValidate(r *http.Request) error { + return nil +} + +func dummyTokenValidate(token, audience string) error { + if token == audience { + return nil + } + return fmt.Errorf("invalid token. want %s, got %s", audience, token) +} + +func dummyNewToken(audience string) string { + return audience +} diff --git a/internal/grid/grid.go b/internal/grid/grid.go index 447dae25a..8ff7aaa82 100644 --- a/internal/grid/grid.go +++ b/internal/grid/grid.go @@ -20,12 +20,17 @@ package grid import ( "context" + "crypto/tls" "errors" "fmt" "io" + "net" + "net/http" + "strings" "sync" "time" + "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" ) @@ -179,3 +184,45 @@ func bytesOrLength(b []byte) string { } return fmt.Sprint(b) } + +// ConnDialer is a function that dials a connection to the given address. +// There should be no retries in this function, +// and should have a timeout of something like 2 seconds. +// The returned net.Conn should also have quick disconnect on errors. +// The net.Conn must support all features as described by the net.Conn interface. +type ConnDialer func(ctx context.Context, address string) (net.Conn, error) + +// ConnectWS returns a function that dials a websocket connection to the given address. +// Route and auth are added to the connection. +func ConnectWS(dial ContextDialer, auth AuthFn, tls *tls.Config) func(ctx context.Context, remote string) (net.Conn, error) { + return func(ctx context.Context, remote string) (net.Conn, error) { + toDial := strings.Replace(remote, "http://", "ws://", 1) + toDial = strings.Replace(toDial, "https://", "wss://", 1) + toDial += RoutePath + + dialer := ws.DefaultDialer + dialer.ReadBufferSize = readBufferSize + dialer.WriteBufferSize = writeBufferSize + dialer.Timeout = defaultDialTimeout + if dial != nil { + dialer.NetDial = dial + } + header := make(http.Header, 2) + header.Set("Authorization", "Bearer "+auth("")) + header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339)) + + if len(header) > 0 { + dialer.Header = ws.HandshakeHeaderHTTP(header) + } + dialer.TLSConfig = tls + + conn, br, _, err := dialer.Dial(ctx, toDial) + if br != nil { + ws.PutReader(br) + } + return conn, err + } +} + +// ValidateTokenFn must validate the token and return an error if it is invalid. +type ValidateTokenFn func(token, audience string) error diff --git a/internal/grid/manager.go b/internal/grid/manager.go index a90f9c402..b9e199e4d 100644 --- a/internal/grid/manager.go +++ b/internal/grid/manager.go @@ -19,13 +19,14 @@ package grid import ( "context" - "crypto/tls" "errors" "fmt" "io" + "net" "net/http" "runtime/debug" "strings" + "time" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" @@ -62,40 +63,48 @@ type Manager struct { // local host name. local string - // Validate incoming requests. - authRequest func(r *http.Request) error + // authToken is a function that will validate a token. + authToken ValidateTokenFn } // ManagerOptions are options for creating a new grid manager. type ManagerOptions struct { - Dialer ContextDialer // Outgoing dialer. - Local string // Local host name. - Hosts []string // All hosts, including local in the grid. - AddAuth AuthFn // Add authentication to the given audience. - AuthRequest func(r *http.Request) error // Validate incoming requests. - TLSConfig *tls.Config // TLS to apply to the connections. - Incoming func(n int64) // Record incoming bytes. - Outgoing func(n int64) // Record outgoing bytes. - BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed. + Local string // Local host name. + Hosts []string // All hosts, including local in the grid. + Incoming func(n int64) // Record incoming bytes. + Outgoing func(n int64) // Record outgoing bytes. + BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed. TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] + Dialer ConnDialer + // Sign a token for the given audience. + AuthFn AuthFn + // Callbacks to validate incoming connections. + AuthToken ValidateTokenFn } // NewManager creates a new grid manager func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) { found := false - if o.AuthRequest == nil { - return nil, fmt.Errorf("grid: AuthRequest must be set") + if o.AuthToken == nil { + return nil, fmt.Errorf("grid: AuthToken not set") + } + if o.Dialer == nil { + return nil, fmt.Errorf("grid: Dialer not set") + } + if o.AuthFn == nil { + return nil, fmt.Errorf("grid: AuthFn not set") } m := &Manager{ - ID: uuid.New(), - targets: make(map[string]*Connection, len(o.Hosts)), - local: o.Local, - authRequest: o.AuthRequest, + ID: uuid.New(), + targets: make(map[string]*Connection, len(o.Hosts)), + local: o.Local, + authToken: o.AuthToken, } m.handlers.init() if ctx == nil { ctx = context.Background() } + for _, host := range o.Hosts { if host == o.Local { if found { @@ -110,14 +119,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) { id: m.ID, local: o.Local, remote: host, - dial: o.Dialer, handlers: &m.handlers, - auth: o.AddAuth, blockConnect: o.BlockConnect, - tlsConfig: o.TLSConfig, publisher: o.TraceTo, incomingBytes: o.Incoming, outgoingBytes: o.Outgoing, + dialer: o.Dialer, + authFn: o.AuthFn, }) } if !found { @@ -128,13 +136,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) { } // AddToMux will add the grid manager to the given mux. -func (m *Manager) AddToMux(router *mux.Router) { - router.Handle(RoutePath, m.Handler()) +func (m *Manager) AddToMux(router *mux.Router, authReq func(r *http.Request) error) { + router.Handle(RoutePath, m.Handler(authReq)) } // Handler returns a handler that can be used to serve grid requests. // This should be connected on RoutePath to the main server. -func (m *Manager) Handler() http.HandlerFunc { +func (m *Manager) Handler(authReq func(r *http.Request) error) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { defer func() { if debugPrint { @@ -151,7 +159,7 @@ func (m *Manager) Handler() http.HandlerFunc { fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL) } ctx := req.Context() - if err := m.authRequest(req); err != nil { + if err := authReq(req); err != nil { gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr) w.WriteHeader(http.StatusForbidden) return @@ -164,76 +172,96 @@ func (m *Manager) Handler() http.HandlerFunc { w.WriteHeader(http.StatusUpgradeRequired) return } - // will write an OpConnectResponse message to the remote and log it once locally. - writeErr := func(err error) { - if err == nil { - return - } - if errors.Is(err, io.EOF) { - return - } - gridLogOnceIf(ctx, err, req.RemoteAddr) - resp := connectResp{ - ID: m.ID, - Accepted: false, - RejectedReason: err.Error(), - } - if b, err := resp.MarshalMsg(nil); err == nil { - msg := message{ - Op: OpConnectResponse, - Payload: b, - } - if b, err := msg.MarshalMsg(nil); err == nil { - wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b) - } - } - } - defer conn.Close() - if debugPrint { - fmt.Printf("grid: Upgraded request: %v\n", req.URL) - } - - msg, _, err := wsutil.ReadClientData(conn) - if err != nil { - writeErr(fmt.Errorf("reading connect: %w", err)) - w.WriteHeader(http.StatusForbidden) - return - } - if debugPrint { - fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg)) - } - - var message message - _, _, err = message.parse(msg) - if err != nil { - writeErr(fmt.Errorf("error parsing grid connect: %w", err)) - return - } - if message.Op != OpConnect { - writeErr(fmt.Errorf("unexpected connect op: %v", message.Op)) - return - } - var cReq connectReq - _, err = cReq.UnmarshalMsg(message.Payload) - if err != nil { - writeErr(fmt.Errorf("error parsing connectReq: %w", err)) - return - } - remote := m.targets[cReq.Host] - if remote == nil { - writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host)) - return - } - if debugPrint { - fmt.Printf("handler: Got Connect Req %+v\n", cReq) - } - writeErr(remote.handleIncoming(ctx, conn, cReq)) + m.IncomingConn(ctx, conn) } } +// IncomingConn will handle an incoming connection. +// This should be called with the incoming connection after accept. +// Auth is handled internally, as well as disconnecting any connections from the same host. +func (m *Manager) IncomingConn(ctx context.Context, conn net.Conn) { + remoteAddr := conn.RemoteAddr().String() + // will write an OpConnectResponse message to the remote and log it once locally. + defer conn.Close() + writeErr := func(err error) { + if err == nil { + return + } + if errors.Is(err, io.EOF) { + return + } + gridLogOnceIf(ctx, err, remoteAddr) + resp := connectResp{ + ID: m.ID, + Accepted: false, + RejectedReason: err.Error(), + } + if b, err := resp.MarshalMsg(nil); err == nil { + msg := message{ + Op: OpConnectResponse, + Payload: b, + } + if b, err := msg.MarshalMsg(nil); err == nil { + wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b) + } + } + } + defer conn.Close() + if debugPrint { + fmt.Printf("grid: Upgraded request: %v\n", remoteAddr) + } + + msg, _, err := wsutil.ReadClientData(conn) + if err != nil { + writeErr(fmt.Errorf("reading connect: %w", err)) + return + } + if debugPrint { + fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg)) + } + + var message message + _, _, err = message.parse(msg) + if err != nil { + writeErr(fmt.Errorf("error parsing grid connect: %w", err)) + return + } + if message.Op != OpConnect { + writeErr(fmt.Errorf("unexpected connect op: %v", message.Op)) + return + } + var cReq connectReq + _, err = cReq.UnmarshalMsg(message.Payload) + if err != nil { + writeErr(fmt.Errorf("error parsing connectReq: %w", err)) + return + } + remote := m.targets[cReq.Host] + if remote == nil { + writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host)) + return + } + if time.Since(cReq.Time).Abs() > 5*time.Minute { + writeErr(fmt.Errorf("time difference too large between servers: %v", time.Since(cReq.Time).Abs())) + return + } + if err := m.authToken(cReq.Token, cReq.audience()); err != nil { + writeErr(fmt.Errorf("auth token: %w", err)) + return + } + + if debugPrint { + fmt.Printf("handler: Got Connect Req %+v\n", cReq) + } + writeErr(remote.handleIncoming(ctx, conn, cReq)) +} + // AuthFn should provide an authentication string for the given aud. type AuthFn func(aud string) string +// ValidateAuthFn should check authentication for the given aud. +type ValidateAuthFn func(auth, aud string) string + // Connection will return the connection for the specified host. // If the host does not exist nil will be returned. func (m *Manager) Connection(host string) *Connection { diff --git a/internal/grid/msg.go b/internal/grid/msg.go index f55230f40..5fa8dc49d 100644 --- a/internal/grid/msg.go +++ b/internal/grid/msg.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "fmt" "strings" + "time" "github.com/tinylib/msgp/msgp" "github.com/zeebo/xxh3" @@ -255,8 +256,20 @@ type sender interface { } type connectReq struct { - ID [16]byte - Host string + ID [16]byte + Host string + Time time.Time + Token string +} + +// audience returns the audience for the connect call. +func (c *connectReq) audience() string { + return fmt.Sprintf("%s-%d", c.Host, c.Time.Unix()) +} + +// addToken will add the token to the connect request. +func (c *connectReq) addToken(fn AuthFn) { + c.Token = fn(c.audience()) } func (connectReq) Op() Op { diff --git a/internal/grid/msg_gen.go b/internal/grid/msg_gen.go index 15f2a58f9..14e88c740 100644 --- a/internal/grid/msg_gen.go +++ b/internal/grid/msg_gen.go @@ -192,6 +192,18 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "Host") return } + case "Time": + z.Time, err = dc.ReadTime() + if err != nil { + err = msgp.WrapError(err, "Time") + return + } + case "Token": + z.Token, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Token") + return + } default: err = dc.Skip() if err != nil { @@ -205,9 +217,9 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 2 + // map header, size 4 // write "ID" - err = en.Append(0x82, 0xa2, 0x49, 0x44) + err = en.Append(0x84, 0xa2, 0x49, 0x44) if err != nil { return } @@ -226,19 +238,45 @@ func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "Host") return } + // write "Time" + err = en.Append(0xa4, 0x54, 0x69, 0x6d, 0x65) + if err != nil { + return + } + err = en.WriteTime(z.Time) + if err != nil { + err = msgp.WrapError(err, "Time") + return + } + // write "Token" + err = en.Append(0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e) + if err != nil { + return + } + err = en.WriteString(z.Token) + if err != nil { + err = msgp.WrapError(err, "Token") + return + } return } // MarshalMsg implements msgp.Marshaler func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 2 + // map header, size 4 // string "ID" - o = append(o, 0x82, 0xa2, 0x49, 0x44) + o = append(o, 0x84, 0xa2, 0x49, 0x44) o = msgp.AppendBytes(o, (z.ID)[:]) // string "Host" o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74) o = msgp.AppendString(o, z.Host) + // string "Time" + o = append(o, 0xa4, 0x54, 0x69, 0x6d, 0x65) + o = msgp.AppendTime(o, z.Time) + // string "Token" + o = append(o, 0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e) + o = msgp.AppendString(o, z.Token) return } @@ -272,6 +310,18 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "Host") return } + case "Time": + z.Time, bts, err = msgp.ReadTimeBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Time") + return + } + case "Token": + z.Token, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Token") + return + } default: bts, err = msgp.Skip(bts) if err != nil { @@ -286,7 +336,7 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *connectReq) Msgsize() (s int) { - s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) + s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) + 5 + msgp.TimeSize + 6 + msgp.StringPrefixSize + len(z.Token) return }