Abstract grid connections (#20038)

Add `ConnDialer` to abstract connection creation.

- `IncomingConn(ctx context.Context, conn net.Conn)` is provided as an entry point for 
   incoming custom connections.

- `ConnectWS` is provided to create web socket connections.
This commit is contained in:
Klaus Post 2024-07-08 14:44:00 -07:00 committed by GitHub
parent b433bf14ba
commit 0d0b0aa599
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 313 additions and 180 deletions

View File

@ -41,17 +41,19 @@ func initGlobalGrid(ctx context.Context, eps EndpointServerPools) error {
// Pass Dialer for websocket grid, make sure we do not // Pass Dialer for websocket grid, make sure we do not
// provide any DriveOPTimeout() function, as that is not // provide any DriveOPTimeout() function, as that is not
// useful over persistent connections. // useful over persistent connections.
Dialer: grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))), Dialer: grid.ConnectWS(
grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))),
newCachedAuthToken(),
&tls.Config{
RootCAs: globalRootCAs,
CipherSuites: fips.TLSCiphers(),
CurvePreferences: fips.TLSCurveIDs(),
}),
Local: local, Local: local,
Hosts: hosts, Hosts: hosts,
AddAuth: newCachedAuthToken(), AuthToken: validateStorageRequestToken,
AuthRequest: storageServerRequestValidate, AuthFn: newCachedAuthToken(),
BlockConnect: globalGridStart, BlockConnect: globalGridStart,
TLSConfig: &tls.Config{
RootCAs: globalRootCAs,
CipherSuites: fips.TLSCiphers(),
CurvePreferences: fips.TLSCurveIDs(),
},
// Record incoming and outgoing bytes. // Record incoming and outgoing bytes.
Incoming: globalConnStats.incInternodeInputBytes, Incoming: globalConnStats.incInternodeInputBytes,
Outgoing: globalConnStats.incInternodeOutputBytes, Outgoing: globalConnStats.incInternodeOutputBytes,

View File

@ -39,7 +39,7 @@ func registerDistErasureRouters(router *mux.Router, endpointServerPools Endpoint
registerLockRESTHandlers() registerLockRESTHandlers()
// Add grid to router // Add grid to router
router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(), noGZFlag, noObjLayerFlag)) router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(storageServerRequestValidate), noGZFlag, noObjLayerFlag))
} }
// List of some generic middlewares which are applied for all incoming requests. // List of some generic middlewares which are applied for all incoming requests.

View File

@ -109,6 +109,24 @@ func (s *storageRESTServer) writeErrorResponse(w http.ResponseWriter, err error)
// DefaultSkewTime - skew time is 15 minutes between minio peers. // DefaultSkewTime - skew time is 15 minutes between minio peers.
const DefaultSkewTime = 15 * time.Minute const DefaultSkewTime = 15 * time.Minute
// validateStorageRequestToken will validate the token against the provided audience.
func validateStorageRequestToken(token, audience string) error {
claims := xjwt.NewStandardClaims()
if err := xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil {
return errAuthentication
}
owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey
if !owner {
return errAuthentication
}
if claims.Audience != audience {
return errAuthentication
}
return nil
}
// Authenticates storage client's requests and validates for skewed time. // Authenticates storage client's requests and validates for skewed time.
func storageServerRequestValidate(r *http.Request) error { func storageServerRequestValidate(r *http.Request) error {
token, err := jwtreq.AuthorizationHeaderExtractor.ExtractToken(r) token, err := jwtreq.AuthorizationHeaderExtractor.ExtractToken(r)
@ -118,19 +136,8 @@ func storageServerRequestValidate(r *http.Request) error {
} }
return errMalformedAuth return errMalformedAuth
} }
if err = validateStorageRequestToken(token, r.URL.RawQuery); err != nil {
claims := xjwt.NewStandardClaims() return err
if err = xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil {
return errAuthentication
}
owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey
if !owner {
return errAuthentication
}
if claims.Audience != r.URL.RawQuery {
return errAuthentication
} }
requestTimeStr := r.Header.Get("X-Minio-Time") requestTimeStr := r.Header.Get("X-Minio-Time")

View File

@ -20,7 +20,6 @@ package grid
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -28,7 +27,6 @@ import (
"math" "math"
"math/rand" "math/rand"
"net" "net"
"net/http"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync" "sync"
@ -100,9 +98,9 @@ type Connection struct {
// Client or serverside. // Client or serverside.
side ws.State side ws.State
// Transport for outgoing connections. // Dialer for outgoing connections.
dialer ContextDialer dial ConnDialer
header http.Header authFn AuthFn
handleMsgWg sync.WaitGroup handleMsgWg sync.WaitGroup
@ -112,10 +110,8 @@ type Connection struct {
handlers *handlers handlers *handlers
remote *RemoteClient remote *RemoteClient
auth AuthFn
clientPingInterval time.Duration clientPingInterval time.Duration
connPingInterval time.Duration connPingInterval time.Duration
tlsConfig *tls.Config
blockConnect chan struct{} blockConnect chan struct{}
incomingBytes func(n int64) // Record incoming bytes. incomingBytes func(n int64) // Record incoming bytes.
@ -205,13 +201,12 @@ type connectionParams struct {
ctx context.Context ctx context.Context
id uuid.UUID id uuid.UUID
local, remote string local, remote string
dial ContextDialer
handlers *handlers handlers *handlers
auth AuthFn
tlsConfig *tls.Config
incomingBytes func(n int64) // Record incoming bytes. incomingBytes func(n int64) // Record incoming bytes.
outgoingBytes func(n int64) // Record outgoing bytes. outgoingBytes func(n int64) // Record outgoing bytes.
publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
dialer ConnDialer
authFn AuthFn
blockConnect chan struct{} blockConnect chan struct{}
} }
@ -227,16 +222,14 @@ func newConnection(o connectionParams) *Connection {
outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000), outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000),
inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000), inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000),
outQueue: make(chan []byte, defaultOutQueue), outQueue: make(chan []byte, defaultOutQueue),
dialer: o.dial,
side: ws.StateServerSide, side: ws.StateServerSide,
connChange: &sync.Cond{L: &sync.Mutex{}}, connChange: &sync.Cond{L: &sync.Mutex{}},
handlers: o.handlers, handlers: o.handlers,
auth: o.auth,
header: make(http.Header, 1),
remote: &RemoteClient{Name: o.remote}, remote: &RemoteClient{Name: o.remote},
clientPingInterval: clientPingInterval, clientPingInterval: clientPingInterval,
connPingInterval: connPingInterval, connPingInterval: connPingInterval,
tlsConfig: o.tlsConfig, dial: o.dialer,
authFn: o.authFn,
} }
if debugPrint { if debugPrint {
// Random Mux ID // Random Mux ID
@ -648,41 +641,17 @@ func (c *Connection) connect() {
if c.State() == StateShutdown { if c.State() == StateShutdown {
return return
} }
toDial := strings.Replace(c.Remote, "http://", "ws://", 1)
toDial = strings.Replace(toDial, "https://", "wss://", 1)
toDial += RoutePath
dialer := ws.DefaultDialer
dialer.ReadBufferSize = readBufferSize
dialer.WriteBufferSize = writeBufferSize
dialer.Timeout = defaultDialTimeout
if c.dialer != nil {
dialer.NetDial = c.dialer.DialContext
}
if c.header == nil {
c.header = make(http.Header, 2)
}
c.header.Set("Authorization", "Bearer "+c.auth(""))
c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
if len(c.header) > 0 {
dialer.Header = ws.HandshakeHeaderHTTP(c.header)
}
dialer.TLSConfig = c.tlsConfig
dialStarted := time.Now() dialStarted := time.Now()
if debugPrint { if debugPrint {
fmt.Println(c.Local, "Connecting to ", toDial) fmt.Println(c.Local, "Connecting to ", c.Remote)
}
conn, br, _, err := dialer.Dial(c.ctx, toDial)
if br != nil {
ws.PutReader(br)
} }
conn, err := c.dial(c.ctx, c.Remote)
c.connMu.Lock() c.connMu.Lock()
c.debugOutConn = conn c.debugOutConn = conn
c.connMu.Unlock() c.connMu.Unlock()
retry := func(err error) { retry := func(err error) {
if debugPrint { if debugPrint {
fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err) fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, c.Remote, err)
} }
sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout))) sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout)))
next := dialStarted.Add(sleep / 2) next := dialStarted.Add(sleep / 2)
@ -696,7 +665,7 @@ func (c *Connection) connect() {
} }
if gotState != StateConnecting { if gotState != StateConnecting {
// Don't print error on first attempt, and after that only once per hour. // Don't print error on first attempt, and after that only once per hour.
gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), toDial) gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, c.Remote, err, err, sleep, gotState), c.Remote)
} }
c.updateState(StateConnectionError) c.updateState(StateConnectionError)
time.Sleep(sleep) time.Sleep(sleep)
@ -712,7 +681,9 @@ func (c *Connection) connect() {
req := connectReq{ req := connectReq{
Host: c.Local, Host: c.Local,
ID: c.id, ID: c.id,
Time: time.Now(),
} }
req.addToken(c.authFn)
err = c.sendMsg(conn, m, &req) err = c.sendMsg(conn, m, &req)
if err != nil { if err != nil {
retry(err) retry(err)

View File

@ -52,11 +52,13 @@ func TestDisconnect(t *testing.T) {
localHost := hosts[0] localHost := hosts[0]
remoteHost := hosts[1] remoteHost := hosts[1]
local, err := NewManager(context.Background(), ManagerOptions{ local, err := NewManager(context.Background(), ManagerOptions{
Dialer: dialer.DialContext, Dialer: ConnectWS(dialer.DialContext,
dummyNewToken,
nil),
Local: localHost, Local: localHost,
Hosts: hosts, Hosts: hosts,
AddAuth: func(aud string) string { return aud }, AuthFn: dummyNewToken,
AuthRequest: dummyRequestValidate, AuthToken: dummyTokenValidate,
BlockConnect: connReady, BlockConnect: connReady,
}) })
errFatal(err) errFatal(err)
@ -74,17 +76,19 @@ func TestDisconnect(t *testing.T) {
})) }))
remote, err := NewManager(context.Background(), ManagerOptions{ remote, err := NewManager(context.Background(), ManagerOptions{
Dialer: dialer.DialContext, Dialer: ConnectWS(dialer.DialContext,
dummyNewToken,
nil),
Local: remoteHost, Local: remoteHost,
Hosts: hosts, Hosts: hosts,
AddAuth: func(aud string) string { return aud }, AuthFn: dummyNewToken,
AuthRequest: dummyRequestValidate, AuthToken: dummyTokenValidate,
BlockConnect: connReady, BlockConnect: connReady,
}) })
errFatal(err) errFatal(err)
localServer := startServer(t, listeners[0], wrapServer(local.Handler())) localServer := startServer(t, listeners[0], wrapServer(local.Handler(dummyRequestValidate)))
remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler())) remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler(dummyRequestValidate)))
close(connReady) close(connReady)
defer func() { defer func() {
@ -165,10 +169,6 @@ func TestDisconnect(t *testing.T) {
<-gotCall <-gotCall
} }
func dummyRequestValidate(r *http.Request) error {
return nil
}
func TestShouldConnect(t *testing.T) { func TestShouldConnect(t *testing.T) {
var c Connection var c Connection
var cReverse Connection var cReverse Connection

View File

@ -82,20 +82,20 @@ func SetupTestGrid(n int) (*TestGrid, error) {
res.cancel = cancel res.cancel = cancel
for i, host := range hosts { for i, host := range hosts {
manager, err := NewManager(ctx, ManagerOptions{ manager, err := NewManager(ctx, ManagerOptions{
Dialer: dialer.DialContext, Dialer: ConnectWS(dialer.DialContext,
Local: host, dummyNewToken,
Hosts: hosts, nil),
AuthRequest: func(r *http.Request) error { Local: host,
return nil Hosts: hosts,
}, AuthFn: dummyNewToken,
AddAuth: func(aud string) string { return aud }, AuthToken: dummyTokenValidate,
BlockConnect: ready, BlockConnect: ready,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := mux.NewRouter() m := mux.NewRouter()
m.Handle(RoutePath, manager.Handler()) m.Handle(RoutePath, manager.Handler(dummyRequestValidate))
res.Managers = append(res.Managers, manager) res.Managers = append(res.Managers, manager)
res.Servers = append(res.Servers, startHTTPServer(listeners[i], m)) res.Servers = append(res.Servers, startHTTPServer(listeners[i], m))
res.Listeners = append(res.Listeners, listeners[i]) res.Listeners = append(res.Listeners, listeners[i])
@ -164,3 +164,18 @@ func startHTTPServer(listener net.Listener, handler http.Handler) (server *httpt
server.Start() server.Start()
return server return server
} }
func dummyRequestValidate(r *http.Request) error {
return nil
}
func dummyTokenValidate(token, audience string) error {
if token == audience {
return nil
}
return fmt.Errorf("invalid token. want %s, got %s", audience, token)
}
func dummyNewToken(audience string) string {
return audience
}

View File

@ -20,12 +20,17 @@ package grid
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http"
"strings"
"sync" "sync"
"time" "time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
) )
@ -179,3 +184,45 @@ func bytesOrLength(b []byte) string {
} }
return fmt.Sprint(b) return fmt.Sprint(b)
} }
// ConnDialer is a function that dials a connection to the given address.
// There should be no retries in this function,
// and should have a timeout of something like 2 seconds.
// The returned net.Conn should also have quick disconnect on errors.
// The net.Conn must support all features as described by the net.Conn interface.
type ConnDialer func(ctx context.Context, address string) (net.Conn, error)
// ConnectWS returns a function that dials a websocket connection to the given address.
// Route and auth are added to the connection.
func ConnectWS(dial ContextDialer, auth AuthFn, tls *tls.Config) func(ctx context.Context, remote string) (net.Conn, error) {
return func(ctx context.Context, remote string) (net.Conn, error) {
toDial := strings.Replace(remote, "http://", "ws://", 1)
toDial = strings.Replace(toDial, "https://", "wss://", 1)
toDial += RoutePath
dialer := ws.DefaultDialer
dialer.ReadBufferSize = readBufferSize
dialer.WriteBufferSize = writeBufferSize
dialer.Timeout = defaultDialTimeout
if dial != nil {
dialer.NetDial = dial
}
header := make(http.Header, 2)
header.Set("Authorization", "Bearer "+auth(""))
header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
if len(header) > 0 {
dialer.Header = ws.HandshakeHeaderHTTP(header)
}
dialer.TLSConfig = tls
conn, br, _, err := dialer.Dial(ctx, toDial)
if br != nil {
ws.PutReader(br)
}
return conn, err
}
}
// ValidateTokenFn must validate the token and return an error if it is invalid.
type ValidateTokenFn func(token, audience string) error

View File

@ -19,13 +19,14 @@ package grid
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"strings" "strings"
"time"
"github.com/gobwas/ws" "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
@ -62,40 +63,48 @@ type Manager struct {
// local host name. // local host name.
local string local string
// Validate incoming requests. // authToken is a function that will validate a token.
authRequest func(r *http.Request) error authToken ValidateTokenFn
} }
// ManagerOptions are options for creating a new grid manager. // ManagerOptions are options for creating a new grid manager.
type ManagerOptions struct { type ManagerOptions struct {
Dialer ContextDialer // Outgoing dialer. Local string // Local host name.
Local string // Local host name. Hosts []string // All hosts, including local in the grid.
Hosts []string // All hosts, including local in the grid. Incoming func(n int64) // Record incoming bytes.
AddAuth AuthFn // Add authentication to the given audience. Outgoing func(n int64) // Record outgoing bytes.
AuthRequest func(r *http.Request) error // Validate incoming requests. BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
TLSConfig *tls.Config // TLS to apply to the connections.
Incoming func(n int64) // Record incoming bytes.
Outgoing func(n int64) // Record outgoing bytes.
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
Dialer ConnDialer
// Sign a token for the given audience.
AuthFn AuthFn
// Callbacks to validate incoming connections.
AuthToken ValidateTokenFn
} }
// NewManager creates a new grid manager // NewManager creates a new grid manager
func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) { func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
found := false found := false
if o.AuthRequest == nil { if o.AuthToken == nil {
return nil, fmt.Errorf("grid: AuthRequest must be set") return nil, fmt.Errorf("grid: AuthToken not set")
}
if o.Dialer == nil {
return nil, fmt.Errorf("grid: Dialer not set")
}
if o.AuthFn == nil {
return nil, fmt.Errorf("grid: AuthFn not set")
} }
m := &Manager{ m := &Manager{
ID: uuid.New(), ID: uuid.New(),
targets: make(map[string]*Connection, len(o.Hosts)), targets: make(map[string]*Connection, len(o.Hosts)),
local: o.Local, local: o.Local,
authRequest: o.AuthRequest, authToken: o.AuthToken,
} }
m.handlers.init() m.handlers.init()
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
for _, host := range o.Hosts { for _, host := range o.Hosts {
if host == o.Local { if host == o.Local {
if found { if found {
@ -110,14 +119,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
id: m.ID, id: m.ID,
local: o.Local, local: o.Local,
remote: host, remote: host,
dial: o.Dialer,
handlers: &m.handlers, handlers: &m.handlers,
auth: o.AddAuth,
blockConnect: o.BlockConnect, blockConnect: o.BlockConnect,
tlsConfig: o.TLSConfig,
publisher: o.TraceTo, publisher: o.TraceTo,
incomingBytes: o.Incoming, incomingBytes: o.Incoming,
outgoingBytes: o.Outgoing, outgoingBytes: o.Outgoing,
dialer: o.Dialer,
authFn: o.AuthFn,
}) })
} }
if !found { if !found {
@ -128,13 +136,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
} }
// AddToMux will add the grid manager to the given mux. // AddToMux will add the grid manager to the given mux.
func (m *Manager) AddToMux(router *mux.Router) { func (m *Manager) AddToMux(router *mux.Router, authReq func(r *http.Request) error) {
router.Handle(RoutePath, m.Handler()) router.Handle(RoutePath, m.Handler(authReq))
} }
// Handler returns a handler that can be used to serve grid requests. // Handler returns a handler that can be used to serve grid requests.
// This should be connected on RoutePath to the main server. // This should be connected on RoutePath to the main server.
func (m *Manager) Handler() http.HandlerFunc { func (m *Manager) Handler(authReq func(r *http.Request) error) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) {
defer func() { defer func() {
if debugPrint { if debugPrint {
@ -151,7 +159,7 @@ func (m *Manager) Handler() http.HandlerFunc {
fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL) fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL)
} }
ctx := req.Context() ctx := req.Context()
if err := m.authRequest(req); err != nil { if err := authReq(req); err != nil {
gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr) gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr)
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return return
@ -164,76 +172,96 @@ func (m *Manager) Handler() http.HandlerFunc {
w.WriteHeader(http.StatusUpgradeRequired) w.WriteHeader(http.StatusUpgradeRequired)
return return
} }
// will write an OpConnectResponse message to the remote and log it once locally. m.IncomingConn(ctx, conn)
writeErr := func(err error) {
if err == nil {
return
}
if errors.Is(err, io.EOF) {
return
}
gridLogOnceIf(ctx, err, req.RemoteAddr)
resp := connectResp{
ID: m.ID,
Accepted: false,
RejectedReason: err.Error(),
}
if b, err := resp.MarshalMsg(nil); err == nil {
msg := message{
Op: OpConnectResponse,
Payload: b,
}
if b, err := msg.MarshalMsg(nil); err == nil {
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
}
}
}
defer conn.Close()
if debugPrint {
fmt.Printf("grid: Upgraded request: %v\n", req.URL)
}
msg, _, err := wsutil.ReadClientData(conn)
if err != nil {
writeErr(fmt.Errorf("reading connect: %w", err))
w.WriteHeader(http.StatusForbidden)
return
}
if debugPrint {
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
}
var message message
_, _, err = message.parse(msg)
if err != nil {
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
return
}
if message.Op != OpConnect {
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
return
}
var cReq connectReq
_, err = cReq.UnmarshalMsg(message.Payload)
if err != nil {
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
return
}
remote := m.targets[cReq.Host]
if remote == nil {
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
return
}
if debugPrint {
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
}
writeErr(remote.handleIncoming(ctx, conn, cReq))
} }
} }
// IncomingConn will handle an incoming connection.
// This should be called with the incoming connection after accept.
// Auth is handled internally, as well as disconnecting any connections from the same host.
func (m *Manager) IncomingConn(ctx context.Context, conn net.Conn) {
remoteAddr := conn.RemoteAddr().String()
// will write an OpConnectResponse message to the remote and log it once locally.
defer conn.Close()
writeErr := func(err error) {
if err == nil {
return
}
if errors.Is(err, io.EOF) {
return
}
gridLogOnceIf(ctx, err, remoteAddr)
resp := connectResp{
ID: m.ID,
Accepted: false,
RejectedReason: err.Error(),
}
if b, err := resp.MarshalMsg(nil); err == nil {
msg := message{
Op: OpConnectResponse,
Payload: b,
}
if b, err := msg.MarshalMsg(nil); err == nil {
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
}
}
}
defer conn.Close()
if debugPrint {
fmt.Printf("grid: Upgraded request: %v\n", remoteAddr)
}
msg, _, err := wsutil.ReadClientData(conn)
if err != nil {
writeErr(fmt.Errorf("reading connect: %w", err))
return
}
if debugPrint {
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
}
var message message
_, _, err = message.parse(msg)
if err != nil {
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
return
}
if message.Op != OpConnect {
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
return
}
var cReq connectReq
_, err = cReq.UnmarshalMsg(message.Payload)
if err != nil {
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
return
}
remote := m.targets[cReq.Host]
if remote == nil {
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
return
}
if time.Since(cReq.Time).Abs() > 5*time.Minute {
writeErr(fmt.Errorf("time difference too large between servers: %v", time.Since(cReq.Time).Abs()))
return
}
if err := m.authToken(cReq.Token, cReq.audience()); err != nil {
writeErr(fmt.Errorf("auth token: %w", err))
return
}
if debugPrint {
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
}
writeErr(remote.handleIncoming(ctx, conn, cReq))
}
// AuthFn should provide an authentication string for the given aud. // AuthFn should provide an authentication string for the given aud.
type AuthFn func(aud string) string type AuthFn func(aud string) string
// ValidateAuthFn should check authentication for the given aud.
type ValidateAuthFn func(auth, aud string) string
// Connection will return the connection for the specified host. // Connection will return the connection for the specified host.
// If the host does not exist nil will be returned. // If the host does not exist nil will be returned.
func (m *Manager) Connection(host string) *Connection { func (m *Manager) Connection(host string) *Connection {

View File

@ -21,6 +21,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
"github.com/zeebo/xxh3" "github.com/zeebo/xxh3"
@ -255,8 +256,20 @@ type sender interface {
} }
type connectReq struct { type connectReq struct {
ID [16]byte ID [16]byte
Host string Host string
Time time.Time
Token string
}
// audience returns the audience for the connect call.
func (c *connectReq) audience() string {
return fmt.Sprintf("%s-%d", c.Host, c.Time.Unix())
}
// addToken will add the token to the connect request.
func (c *connectReq) addToken(fn AuthFn) {
c.Token = fn(c.audience())
} }
func (connectReq) Op() Op { func (connectReq) Op() Op {

View File

@ -192,6 +192,18 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) {
err = msgp.WrapError(err, "Host") err = msgp.WrapError(err, "Host")
return return
} }
case "Time":
z.Time, err = dc.ReadTime()
if err != nil {
err = msgp.WrapError(err, "Time")
return
}
case "Token":
z.Token, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "Token")
return
}
default: default:
err = dc.Skip() err = dc.Skip()
if err != nil { if err != nil {
@ -205,9 +217,9 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) {
// EncodeMsg implements msgp.Encodable // EncodeMsg implements msgp.Encodable
func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) { func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2 // map header, size 4
// write "ID" // write "ID"
err = en.Append(0x82, 0xa2, 0x49, 0x44) err = en.Append(0x84, 0xa2, 0x49, 0x44)
if err != nil { if err != nil {
return return
} }
@ -226,19 +238,45 @@ func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
err = msgp.WrapError(err, "Host") err = msgp.WrapError(err, "Host")
return return
} }
// write "Time"
err = en.Append(0xa4, 0x54, 0x69, 0x6d, 0x65)
if err != nil {
return
}
err = en.WriteTime(z.Time)
if err != nil {
err = msgp.WrapError(err, "Time")
return
}
// write "Token"
err = en.Append(0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e)
if err != nil {
return
}
err = en.WriteString(z.Token)
if err != nil {
err = msgp.WrapError(err, "Token")
return
}
return return
} }
// MarshalMsg implements msgp.Marshaler // MarshalMsg implements msgp.Marshaler
func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) { func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize()) o = msgp.Require(b, z.Msgsize())
// map header, size 2 // map header, size 4
// string "ID" // string "ID"
o = append(o, 0x82, 0xa2, 0x49, 0x44) o = append(o, 0x84, 0xa2, 0x49, 0x44)
o = msgp.AppendBytes(o, (z.ID)[:]) o = msgp.AppendBytes(o, (z.ID)[:])
// string "Host" // string "Host"
o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74) o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74)
o = msgp.AppendString(o, z.Host) o = msgp.AppendString(o, z.Host)
// string "Time"
o = append(o, 0xa4, 0x54, 0x69, 0x6d, 0x65)
o = msgp.AppendTime(o, z.Time)
// string "Token"
o = append(o, 0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e)
o = msgp.AppendString(o, z.Token)
return return
} }
@ -272,6 +310,18 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) {
err = msgp.WrapError(err, "Host") err = msgp.WrapError(err, "Host")
return return
} }
case "Time":
z.Time, bts, err = msgp.ReadTimeBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Time")
return
}
case "Token":
z.Token, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Token")
return
}
default: default:
bts, err = msgp.Skip(bts) bts, err = msgp.Skip(bts)
if err != nil { if err != nil {
@ -286,7 +336,7 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) {
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *connectReq) Msgsize() (s int) { func (z *connectReq) Msgsize() (s int) {
s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) + 5 + msgp.TimeSize + 6 + msgp.StringPrefixSize + len(z.Token)
return return
} }