mirror of https://github.com/minio/minio.git
Abstract grid connections (#20038)
Add `ConnDialer` to abstract connection creation. - `IncomingConn(ctx context.Context, conn net.Conn)` is provided as an entry point for incoming custom connections. - `ConnectWS` is provided to create web socket connections.
This commit is contained in:
parent
b433bf14ba
commit
0d0b0aa599
18
cmd/grid.go
18
cmd/grid.go
|
@ -41,17 +41,19 @@ func initGlobalGrid(ctx context.Context, eps EndpointServerPools) error {
|
|||
// Pass Dialer for websocket grid, make sure we do not
|
||||
// provide any DriveOPTimeout() function, as that is not
|
||||
// useful over persistent connections.
|
||||
Dialer: grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))),
|
||||
Dialer: grid.ConnectWS(
|
||||
grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))),
|
||||
newCachedAuthToken(),
|
||||
&tls.Config{
|
||||
RootCAs: globalRootCAs,
|
||||
CipherSuites: fips.TLSCiphers(),
|
||||
CurvePreferences: fips.TLSCurveIDs(),
|
||||
}),
|
||||
Local: local,
|
||||
Hosts: hosts,
|
||||
AddAuth: newCachedAuthToken(),
|
||||
AuthRequest: storageServerRequestValidate,
|
||||
AuthToken: validateStorageRequestToken,
|
||||
AuthFn: newCachedAuthToken(),
|
||||
BlockConnect: globalGridStart,
|
||||
TLSConfig: &tls.Config{
|
||||
RootCAs: globalRootCAs,
|
||||
CipherSuites: fips.TLSCiphers(),
|
||||
CurvePreferences: fips.TLSCurveIDs(),
|
||||
},
|
||||
// Record incoming and outgoing bytes.
|
||||
Incoming: globalConnStats.incInternodeInputBytes,
|
||||
Outgoing: globalConnStats.incInternodeOutputBytes,
|
||||
|
|
|
@ -39,7 +39,7 @@ func registerDistErasureRouters(router *mux.Router, endpointServerPools Endpoint
|
|||
registerLockRESTHandlers()
|
||||
|
||||
// Add grid to router
|
||||
router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(), noGZFlag, noObjLayerFlag))
|
||||
router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(storageServerRequestValidate), noGZFlag, noObjLayerFlag))
|
||||
}
|
||||
|
||||
// List of some generic middlewares which are applied for all incoming requests.
|
||||
|
|
|
@ -109,6 +109,24 @@ func (s *storageRESTServer) writeErrorResponse(w http.ResponseWriter, err error)
|
|||
// DefaultSkewTime - skew time is 15 minutes between minio peers.
|
||||
const DefaultSkewTime = 15 * time.Minute
|
||||
|
||||
// validateStorageRequestToken will validate the token against the provided audience.
|
||||
func validateStorageRequestToken(token, audience string) error {
|
||||
claims := xjwt.NewStandardClaims()
|
||||
if err := xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil {
|
||||
return errAuthentication
|
||||
}
|
||||
|
||||
owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey
|
||||
if !owner {
|
||||
return errAuthentication
|
||||
}
|
||||
|
||||
if claims.Audience != audience {
|
||||
return errAuthentication
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticates storage client's requests and validates for skewed time.
|
||||
func storageServerRequestValidate(r *http.Request) error {
|
||||
token, err := jwtreq.AuthorizationHeaderExtractor.ExtractToken(r)
|
||||
|
@ -118,19 +136,8 @@ func storageServerRequestValidate(r *http.Request) error {
|
|||
}
|
||||
return errMalformedAuth
|
||||
}
|
||||
|
||||
claims := xjwt.NewStandardClaims()
|
||||
if err = xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil {
|
||||
return errAuthentication
|
||||
}
|
||||
|
||||
owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey
|
||||
if !owner {
|
||||
return errAuthentication
|
||||
}
|
||||
|
||||
if claims.Audience != r.URL.RawQuery {
|
||||
return errAuthentication
|
||||
if err = validateStorageRequestToken(token, r.URL.RawQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestTimeStr := r.Header.Get("X-Minio-Time")
|
||||
|
|
|
@ -20,7 +20,6 @@ package grid
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -28,7 +27,6 @@ import (
|
|||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -100,9 +98,9 @@ type Connection struct {
|
|||
// Client or serverside.
|
||||
side ws.State
|
||||
|
||||
// Transport for outgoing connections.
|
||||
dialer ContextDialer
|
||||
header http.Header
|
||||
// Dialer for outgoing connections.
|
||||
dial ConnDialer
|
||||
authFn AuthFn
|
||||
|
||||
handleMsgWg sync.WaitGroup
|
||||
|
||||
|
@ -112,10 +110,8 @@ type Connection struct {
|
|||
handlers *handlers
|
||||
|
||||
remote *RemoteClient
|
||||
auth AuthFn
|
||||
clientPingInterval time.Duration
|
||||
connPingInterval time.Duration
|
||||
tlsConfig *tls.Config
|
||||
blockConnect chan struct{}
|
||||
|
||||
incomingBytes func(n int64) // Record incoming bytes.
|
||||
|
@ -205,13 +201,12 @@ type connectionParams struct {
|
|||
ctx context.Context
|
||||
id uuid.UUID
|
||||
local, remote string
|
||||
dial ContextDialer
|
||||
handlers *handlers
|
||||
auth AuthFn
|
||||
tlsConfig *tls.Config
|
||||
incomingBytes func(n int64) // Record incoming bytes.
|
||||
outgoingBytes func(n int64) // Record outgoing bytes.
|
||||
publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
||||
dialer ConnDialer
|
||||
authFn AuthFn
|
||||
|
||||
blockConnect chan struct{}
|
||||
}
|
||||
|
@ -227,16 +222,14 @@ func newConnection(o connectionParams) *Connection {
|
|||
outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000),
|
||||
inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000),
|
||||
outQueue: make(chan []byte, defaultOutQueue),
|
||||
dialer: o.dial,
|
||||
side: ws.StateServerSide,
|
||||
connChange: &sync.Cond{L: &sync.Mutex{}},
|
||||
handlers: o.handlers,
|
||||
auth: o.auth,
|
||||
header: make(http.Header, 1),
|
||||
remote: &RemoteClient{Name: o.remote},
|
||||
clientPingInterval: clientPingInterval,
|
||||
connPingInterval: connPingInterval,
|
||||
tlsConfig: o.tlsConfig,
|
||||
dial: o.dialer,
|
||||
authFn: o.authFn,
|
||||
}
|
||||
if debugPrint {
|
||||
// Random Mux ID
|
||||
|
@ -648,41 +641,17 @@ func (c *Connection) connect() {
|
|||
if c.State() == StateShutdown {
|
||||
return
|
||||
}
|
||||
toDial := strings.Replace(c.Remote, "http://", "ws://", 1)
|
||||
toDial = strings.Replace(toDial, "https://", "wss://", 1)
|
||||
toDial += RoutePath
|
||||
|
||||
dialer := ws.DefaultDialer
|
||||
dialer.ReadBufferSize = readBufferSize
|
||||
dialer.WriteBufferSize = writeBufferSize
|
||||
dialer.Timeout = defaultDialTimeout
|
||||
if c.dialer != nil {
|
||||
dialer.NetDial = c.dialer.DialContext
|
||||
}
|
||||
if c.header == nil {
|
||||
c.header = make(http.Header, 2)
|
||||
}
|
||||
c.header.Set("Authorization", "Bearer "+c.auth(""))
|
||||
c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
|
||||
|
||||
if len(c.header) > 0 {
|
||||
dialer.Header = ws.HandshakeHeaderHTTP(c.header)
|
||||
}
|
||||
dialer.TLSConfig = c.tlsConfig
|
||||
dialStarted := time.Now()
|
||||
if debugPrint {
|
||||
fmt.Println(c.Local, "Connecting to ", toDial)
|
||||
}
|
||||
conn, br, _, err := dialer.Dial(c.ctx, toDial)
|
||||
if br != nil {
|
||||
ws.PutReader(br)
|
||||
fmt.Println(c.Local, "Connecting to ", c.Remote)
|
||||
}
|
||||
conn, err := c.dial(c.ctx, c.Remote)
|
||||
c.connMu.Lock()
|
||||
c.debugOutConn = conn
|
||||
c.connMu.Unlock()
|
||||
retry := func(err error) {
|
||||
if debugPrint {
|
||||
fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err)
|
||||
fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, c.Remote, err)
|
||||
}
|
||||
sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout)))
|
||||
next := dialStarted.Add(sleep / 2)
|
||||
|
@ -696,7 +665,7 @@ func (c *Connection) connect() {
|
|||
}
|
||||
if gotState != StateConnecting {
|
||||
// Don't print error on first attempt, and after that only once per hour.
|
||||
gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), toDial)
|
||||
gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, c.Remote, err, err, sleep, gotState), c.Remote)
|
||||
}
|
||||
c.updateState(StateConnectionError)
|
||||
time.Sleep(sleep)
|
||||
|
@ -712,7 +681,9 @@ func (c *Connection) connect() {
|
|||
req := connectReq{
|
||||
Host: c.Local,
|
||||
ID: c.id,
|
||||
Time: time.Now(),
|
||||
}
|
||||
req.addToken(c.authFn)
|
||||
err = c.sendMsg(conn, m, &req)
|
||||
if err != nil {
|
||||
retry(err)
|
||||
|
|
|
@ -52,11 +52,13 @@ func TestDisconnect(t *testing.T) {
|
|||
localHost := hosts[0]
|
||||
remoteHost := hosts[1]
|
||||
local, err := NewManager(context.Background(), ManagerOptions{
|
||||
Dialer: dialer.DialContext,
|
||||
Dialer: ConnectWS(dialer.DialContext,
|
||||
dummyNewToken,
|
||||
nil),
|
||||
Local: localHost,
|
||||
Hosts: hosts,
|
||||
AddAuth: func(aud string) string { return aud },
|
||||
AuthRequest: dummyRequestValidate,
|
||||
AuthFn: dummyNewToken,
|
||||
AuthToken: dummyTokenValidate,
|
||||
BlockConnect: connReady,
|
||||
})
|
||||
errFatal(err)
|
||||
|
@ -74,17 +76,19 @@ func TestDisconnect(t *testing.T) {
|
|||
}))
|
||||
|
||||
remote, err := NewManager(context.Background(), ManagerOptions{
|
||||
Dialer: dialer.DialContext,
|
||||
Dialer: ConnectWS(dialer.DialContext,
|
||||
dummyNewToken,
|
||||
nil),
|
||||
Local: remoteHost,
|
||||
Hosts: hosts,
|
||||
AddAuth: func(aud string) string { return aud },
|
||||
AuthRequest: dummyRequestValidate,
|
||||
AuthFn: dummyNewToken,
|
||||
AuthToken: dummyTokenValidate,
|
||||
BlockConnect: connReady,
|
||||
})
|
||||
errFatal(err)
|
||||
|
||||
localServer := startServer(t, listeners[0], wrapServer(local.Handler()))
|
||||
remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler()))
|
||||
localServer := startServer(t, listeners[0], wrapServer(local.Handler(dummyRequestValidate)))
|
||||
remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler(dummyRequestValidate)))
|
||||
close(connReady)
|
||||
|
||||
defer func() {
|
||||
|
@ -165,10 +169,6 @@ func TestDisconnect(t *testing.T) {
|
|||
<-gotCall
|
||||
}
|
||||
|
||||
func dummyRequestValidate(r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestShouldConnect(t *testing.T) {
|
||||
var c Connection
|
||||
var cReverse Connection
|
||||
|
|
|
@ -82,20 +82,20 @@ func SetupTestGrid(n int) (*TestGrid, error) {
|
|||
res.cancel = cancel
|
||||
for i, host := range hosts {
|
||||
manager, err := NewManager(ctx, ManagerOptions{
|
||||
Dialer: dialer.DialContext,
|
||||
Local: host,
|
||||
Hosts: hosts,
|
||||
AuthRequest: func(r *http.Request) error {
|
||||
return nil
|
||||
},
|
||||
AddAuth: func(aud string) string { return aud },
|
||||
Dialer: ConnectWS(dialer.DialContext,
|
||||
dummyNewToken,
|
||||
nil),
|
||||
Local: host,
|
||||
Hosts: hosts,
|
||||
AuthFn: dummyNewToken,
|
||||
AuthToken: dummyTokenValidate,
|
||||
BlockConnect: ready,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := mux.NewRouter()
|
||||
m.Handle(RoutePath, manager.Handler())
|
||||
m.Handle(RoutePath, manager.Handler(dummyRequestValidate))
|
||||
res.Managers = append(res.Managers, manager)
|
||||
res.Servers = append(res.Servers, startHTTPServer(listeners[i], m))
|
||||
res.Listeners = append(res.Listeners, listeners[i])
|
||||
|
@ -164,3 +164,18 @@ func startHTTPServer(listener net.Listener, handler http.Handler) (server *httpt
|
|||
server.Start()
|
||||
return server
|
||||
}
|
||||
|
||||
func dummyRequestValidate(r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func dummyTokenValidate(token, audience string) error {
|
||||
if token == audience {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid token. want %s, got %s", audience, token)
|
||||
}
|
||||
|
||||
func dummyNewToken(audience string) string {
|
||||
return audience
|
||||
}
|
||||
|
|
|
@ -20,12 +20,17 @@ package grid
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
)
|
||||
|
||||
|
@ -179,3 +184,45 @@ func bytesOrLength(b []byte) string {
|
|||
}
|
||||
return fmt.Sprint(b)
|
||||
}
|
||||
|
||||
// ConnDialer is a function that dials a connection to the given address.
|
||||
// There should be no retries in this function,
|
||||
// and should have a timeout of something like 2 seconds.
|
||||
// The returned net.Conn should also have quick disconnect on errors.
|
||||
// The net.Conn must support all features as described by the net.Conn interface.
|
||||
type ConnDialer func(ctx context.Context, address string) (net.Conn, error)
|
||||
|
||||
// ConnectWS returns a function that dials a websocket connection to the given address.
|
||||
// Route and auth are added to the connection.
|
||||
func ConnectWS(dial ContextDialer, auth AuthFn, tls *tls.Config) func(ctx context.Context, remote string) (net.Conn, error) {
|
||||
return func(ctx context.Context, remote string) (net.Conn, error) {
|
||||
toDial := strings.Replace(remote, "http://", "ws://", 1)
|
||||
toDial = strings.Replace(toDial, "https://", "wss://", 1)
|
||||
toDial += RoutePath
|
||||
|
||||
dialer := ws.DefaultDialer
|
||||
dialer.ReadBufferSize = readBufferSize
|
||||
dialer.WriteBufferSize = writeBufferSize
|
||||
dialer.Timeout = defaultDialTimeout
|
||||
if dial != nil {
|
||||
dialer.NetDial = dial
|
||||
}
|
||||
header := make(http.Header, 2)
|
||||
header.Set("Authorization", "Bearer "+auth(""))
|
||||
header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
|
||||
|
||||
if len(header) > 0 {
|
||||
dialer.Header = ws.HandshakeHeaderHTTP(header)
|
||||
}
|
||||
dialer.TLSConfig = tls
|
||||
|
||||
conn, br, _, err := dialer.Dial(ctx, toDial)
|
||||
if br != nil {
|
||||
ws.PutReader(br)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateTokenFn must validate the token and return an error if it is invalid.
|
||||
type ValidateTokenFn func(token, audience string) error
|
||||
|
|
|
@ -19,13 +19,14 @@ package grid
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
|
@ -62,40 +63,48 @@ type Manager struct {
|
|||
// local host name.
|
||||
local string
|
||||
|
||||
// Validate incoming requests.
|
||||
authRequest func(r *http.Request) error
|
||||
// authToken is a function that will validate a token.
|
||||
authToken ValidateTokenFn
|
||||
}
|
||||
|
||||
// ManagerOptions are options for creating a new grid manager.
|
||||
type ManagerOptions struct {
|
||||
Dialer ContextDialer // Outgoing dialer.
|
||||
Local string // Local host name.
|
||||
Hosts []string // All hosts, including local in the grid.
|
||||
AddAuth AuthFn // Add authentication to the given audience.
|
||||
AuthRequest func(r *http.Request) error // Validate incoming requests.
|
||||
TLSConfig *tls.Config // TLS to apply to the connections.
|
||||
Incoming func(n int64) // Record incoming bytes.
|
||||
Outgoing func(n int64) // Record outgoing bytes.
|
||||
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
|
||||
Local string // Local host name.
|
||||
Hosts []string // All hosts, including local in the grid.
|
||||
Incoming func(n int64) // Record incoming bytes.
|
||||
Outgoing func(n int64) // Record outgoing bytes.
|
||||
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
|
||||
TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
||||
Dialer ConnDialer
|
||||
// Sign a token for the given audience.
|
||||
AuthFn AuthFn
|
||||
// Callbacks to validate incoming connections.
|
||||
AuthToken ValidateTokenFn
|
||||
}
|
||||
|
||||
// NewManager creates a new grid manager
|
||||
func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||
found := false
|
||||
if o.AuthRequest == nil {
|
||||
return nil, fmt.Errorf("grid: AuthRequest must be set")
|
||||
if o.AuthToken == nil {
|
||||
return nil, fmt.Errorf("grid: AuthToken not set")
|
||||
}
|
||||
if o.Dialer == nil {
|
||||
return nil, fmt.Errorf("grid: Dialer not set")
|
||||
}
|
||||
if o.AuthFn == nil {
|
||||
return nil, fmt.Errorf("grid: AuthFn not set")
|
||||
}
|
||||
m := &Manager{
|
||||
ID: uuid.New(),
|
||||
targets: make(map[string]*Connection, len(o.Hosts)),
|
||||
local: o.Local,
|
||||
authRequest: o.AuthRequest,
|
||||
ID: uuid.New(),
|
||||
targets: make(map[string]*Connection, len(o.Hosts)),
|
||||
local: o.Local,
|
||||
authToken: o.AuthToken,
|
||||
}
|
||||
m.handlers.init()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
for _, host := range o.Hosts {
|
||||
if host == o.Local {
|
||||
if found {
|
||||
|
@ -110,14 +119,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
|||
id: m.ID,
|
||||
local: o.Local,
|
||||
remote: host,
|
||||
dial: o.Dialer,
|
||||
handlers: &m.handlers,
|
||||
auth: o.AddAuth,
|
||||
blockConnect: o.BlockConnect,
|
||||
tlsConfig: o.TLSConfig,
|
||||
publisher: o.TraceTo,
|
||||
incomingBytes: o.Incoming,
|
||||
outgoingBytes: o.Outgoing,
|
||||
dialer: o.Dialer,
|
||||
authFn: o.AuthFn,
|
||||
})
|
||||
}
|
||||
if !found {
|
||||
|
@ -128,13 +136,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
|||
}
|
||||
|
||||
// AddToMux will add the grid manager to the given mux.
|
||||
func (m *Manager) AddToMux(router *mux.Router) {
|
||||
router.Handle(RoutePath, m.Handler())
|
||||
func (m *Manager) AddToMux(router *mux.Router, authReq func(r *http.Request) error) {
|
||||
router.Handle(RoutePath, m.Handler(authReq))
|
||||
}
|
||||
|
||||
// Handler returns a handler that can be used to serve grid requests.
|
||||
// This should be connected on RoutePath to the main server.
|
||||
func (m *Manager) Handler() http.HandlerFunc {
|
||||
func (m *Manager) Handler(authReq func(r *http.Request) error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
defer func() {
|
||||
if debugPrint {
|
||||
|
@ -151,7 +159,7 @@ func (m *Manager) Handler() http.HandlerFunc {
|
|||
fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL)
|
||||
}
|
||||
ctx := req.Context()
|
||||
if err := m.authRequest(req); err != nil {
|
||||
if err := authReq(req); err != nil {
|
||||
gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
|
@ -164,76 +172,96 @@ func (m *Manager) Handler() http.HandlerFunc {
|
|||
w.WriteHeader(http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
// will write an OpConnectResponse message to the remote and log it once locally.
|
||||
writeErr := func(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
gridLogOnceIf(ctx, err, req.RemoteAddr)
|
||||
resp := connectResp{
|
||||
ID: m.ID,
|
||||
Accepted: false,
|
||||
RejectedReason: err.Error(),
|
||||
}
|
||||
if b, err := resp.MarshalMsg(nil); err == nil {
|
||||
msg := message{
|
||||
Op: OpConnectResponse,
|
||||
Payload: b,
|
||||
}
|
||||
if b, err := msg.MarshalMsg(nil); err == nil {
|
||||
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
if debugPrint {
|
||||
fmt.Printf("grid: Upgraded request: %v\n", req.URL)
|
||||
}
|
||||
|
||||
msg, _, err := wsutil.ReadClientData(conn)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("reading connect: %w", err))
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if debugPrint {
|
||||
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
|
||||
}
|
||||
|
||||
var message message
|
||||
_, _, err = message.parse(msg)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
|
||||
return
|
||||
}
|
||||
if message.Op != OpConnect {
|
||||
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
|
||||
return
|
||||
}
|
||||
var cReq connectReq
|
||||
_, err = cReq.UnmarshalMsg(message.Payload)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
|
||||
return
|
||||
}
|
||||
remote := m.targets[cReq.Host]
|
||||
if remote == nil {
|
||||
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
|
||||
return
|
||||
}
|
||||
if debugPrint {
|
||||
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
|
||||
}
|
||||
writeErr(remote.handleIncoming(ctx, conn, cReq))
|
||||
m.IncomingConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// IncomingConn will handle an incoming connection.
|
||||
// This should be called with the incoming connection after accept.
|
||||
// Auth is handled internally, as well as disconnecting any connections from the same host.
|
||||
func (m *Manager) IncomingConn(ctx context.Context, conn net.Conn) {
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
// will write an OpConnectResponse message to the remote and log it once locally.
|
||||
defer conn.Close()
|
||||
writeErr := func(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
gridLogOnceIf(ctx, err, remoteAddr)
|
||||
resp := connectResp{
|
||||
ID: m.ID,
|
||||
Accepted: false,
|
||||
RejectedReason: err.Error(),
|
||||
}
|
||||
if b, err := resp.MarshalMsg(nil); err == nil {
|
||||
msg := message{
|
||||
Op: OpConnectResponse,
|
||||
Payload: b,
|
||||
}
|
||||
if b, err := msg.MarshalMsg(nil); err == nil {
|
||||
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
if debugPrint {
|
||||
fmt.Printf("grid: Upgraded request: %v\n", remoteAddr)
|
||||
}
|
||||
|
||||
msg, _, err := wsutil.ReadClientData(conn)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("reading connect: %w", err))
|
||||
return
|
||||
}
|
||||
if debugPrint {
|
||||
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
|
||||
}
|
||||
|
||||
var message message
|
||||
_, _, err = message.parse(msg)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
|
||||
return
|
||||
}
|
||||
if message.Op != OpConnect {
|
||||
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
|
||||
return
|
||||
}
|
||||
var cReq connectReq
|
||||
_, err = cReq.UnmarshalMsg(message.Payload)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
|
||||
return
|
||||
}
|
||||
remote := m.targets[cReq.Host]
|
||||
if remote == nil {
|
||||
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
|
||||
return
|
||||
}
|
||||
if time.Since(cReq.Time).Abs() > 5*time.Minute {
|
||||
writeErr(fmt.Errorf("time difference too large between servers: %v", time.Since(cReq.Time).Abs()))
|
||||
return
|
||||
}
|
||||
if err := m.authToken(cReq.Token, cReq.audience()); err != nil {
|
||||
writeErr(fmt.Errorf("auth token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if debugPrint {
|
||||
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
|
||||
}
|
||||
writeErr(remote.handleIncoming(ctx, conn, cReq))
|
||||
}
|
||||
|
||||
// AuthFn should provide an authentication string for the given aud.
|
||||
type AuthFn func(aud string) string
|
||||
|
||||
// ValidateAuthFn should check authentication for the given aud.
|
||||
type ValidateAuthFn func(auth, aud string) string
|
||||
|
||||
// Connection will return the connection for the specified host.
|
||||
// If the host does not exist nil will be returned.
|
||||
func (m *Manager) Connection(host string) *Connection {
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
"github.com/zeebo/xxh3"
|
||||
|
@ -255,8 +256,20 @@ type sender interface {
|
|||
}
|
||||
|
||||
type connectReq struct {
|
||||
ID [16]byte
|
||||
Host string
|
||||
ID [16]byte
|
||||
Host string
|
||||
Time time.Time
|
||||
Token string
|
||||
}
|
||||
|
||||
// audience returns the audience for the connect call.
|
||||
func (c *connectReq) audience() string {
|
||||
return fmt.Sprintf("%s-%d", c.Host, c.Time.Unix())
|
||||
}
|
||||
|
||||
// addToken will add the token to the connect request.
|
||||
func (c *connectReq) addToken(fn AuthFn) {
|
||||
c.Token = fn(c.audience())
|
||||
}
|
||||
|
||||
func (connectReq) Op() Op {
|
||||
|
|
|
@ -192,6 +192,18 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||
err = msgp.WrapError(err, "Host")
|
||||
return
|
||||
}
|
||||
case "Time":
|
||||
z.Time, err = dc.ReadTime()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Time")
|
||||
return
|
||||
}
|
||||
case "Token":
|
||||
z.Token, err = dc.ReadString()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Token")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
|
@ -205,9 +217,9 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 2
|
||||
// map header, size 4
|
||||
// write "ID"
|
||||
err = en.Append(0x82, 0xa2, 0x49, 0x44)
|
||||
err = en.Append(0x84, 0xa2, 0x49, 0x44)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -226,19 +238,45 @@ func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
|
|||
err = msgp.WrapError(err, "Host")
|
||||
return
|
||||
}
|
||||
// write "Time"
|
||||
err = en.Append(0xa4, 0x54, 0x69, 0x6d, 0x65)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteTime(z.Time)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Time")
|
||||
return
|
||||
}
|
||||
// write "Token"
|
||||
err = en.Append(0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteString(z.Token)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Token")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 2
|
||||
// map header, size 4
|
||||
// string "ID"
|
||||
o = append(o, 0x82, 0xa2, 0x49, 0x44)
|
||||
o = append(o, 0x84, 0xa2, 0x49, 0x44)
|
||||
o = msgp.AppendBytes(o, (z.ID)[:])
|
||||
// string "Host"
|
||||
o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74)
|
||||
o = msgp.AppendString(o, z.Host)
|
||||
// string "Time"
|
||||
o = append(o, 0xa4, 0x54, 0x69, 0x6d, 0x65)
|
||||
o = msgp.AppendTime(o, z.Time)
|
||||
// string "Token"
|
||||
o = append(o, 0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e)
|
||||
o = msgp.AppendString(o, z.Token)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -272,6 +310,18 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||
err = msgp.WrapError(err, "Host")
|
||||
return
|
||||
}
|
||||
case "Time":
|
||||
z.Time, bts, err = msgp.ReadTimeBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Time")
|
||||
return
|
||||
}
|
||||
case "Token":
|
||||
z.Token, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Token")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
|
@ -286,7 +336,7 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *connectReq) Msgsize() (s int) {
|
||||
s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host)
|
||||
s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) + 5 + msgp.TimeSize + 6 + msgp.StringPrefixSize + len(z.Token)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue