mirror of https://github.com/minio/minio.git
Abstract grid connections (#20038)
Add `ConnDialer` to abstract connection creation. - `IncomingConn(ctx context.Context, conn net.Conn)` is provided as an entry point for incoming custom connections. - `ConnectWS` is provided to create web socket connections.
This commit is contained in:
parent
b433bf14ba
commit
0d0b0aa599
18
cmd/grid.go
18
cmd/grid.go
|
@ -41,17 +41,19 @@ func initGlobalGrid(ctx context.Context, eps EndpointServerPools) error {
|
||||||
// Pass Dialer for websocket grid, make sure we do not
|
// Pass Dialer for websocket grid, make sure we do not
|
||||||
// provide any DriveOPTimeout() function, as that is not
|
// provide any DriveOPTimeout() function, as that is not
|
||||||
// useful over persistent connections.
|
// useful over persistent connections.
|
||||||
Dialer: grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))),
|
Dialer: grid.ConnectWS(
|
||||||
|
grid.ContextDialer(xhttp.DialContextWithLookupHost(lookupHost, xhttp.NewInternodeDialContext(rest.DefaultTimeout, globalTCPOptions.ForWebsocket()))),
|
||||||
|
newCachedAuthToken(),
|
||||||
|
&tls.Config{
|
||||||
|
RootCAs: globalRootCAs,
|
||||||
|
CipherSuites: fips.TLSCiphers(),
|
||||||
|
CurvePreferences: fips.TLSCurveIDs(),
|
||||||
|
}),
|
||||||
Local: local,
|
Local: local,
|
||||||
Hosts: hosts,
|
Hosts: hosts,
|
||||||
AddAuth: newCachedAuthToken(),
|
AuthToken: validateStorageRequestToken,
|
||||||
AuthRequest: storageServerRequestValidate,
|
AuthFn: newCachedAuthToken(),
|
||||||
BlockConnect: globalGridStart,
|
BlockConnect: globalGridStart,
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
RootCAs: globalRootCAs,
|
|
||||||
CipherSuites: fips.TLSCiphers(),
|
|
||||||
CurvePreferences: fips.TLSCurveIDs(),
|
|
||||||
},
|
|
||||||
// Record incoming and outgoing bytes.
|
// Record incoming and outgoing bytes.
|
||||||
Incoming: globalConnStats.incInternodeInputBytes,
|
Incoming: globalConnStats.incInternodeInputBytes,
|
||||||
Outgoing: globalConnStats.incInternodeOutputBytes,
|
Outgoing: globalConnStats.incInternodeOutputBytes,
|
||||||
|
|
|
@ -39,7 +39,7 @@ func registerDistErasureRouters(router *mux.Router, endpointServerPools Endpoint
|
||||||
registerLockRESTHandlers()
|
registerLockRESTHandlers()
|
||||||
|
|
||||||
// Add grid to router
|
// Add grid to router
|
||||||
router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(), noGZFlag, noObjLayerFlag))
|
router.Handle(grid.RoutePath, adminMiddleware(globalGrid.Load().Handler(storageServerRequestValidate), noGZFlag, noObjLayerFlag))
|
||||||
}
|
}
|
||||||
|
|
||||||
// List of some generic middlewares which are applied for all incoming requests.
|
// List of some generic middlewares which are applied for all incoming requests.
|
||||||
|
|
|
@ -109,6 +109,24 @@ func (s *storageRESTServer) writeErrorResponse(w http.ResponseWriter, err error)
|
||||||
// DefaultSkewTime - skew time is 15 minutes between minio peers.
|
// DefaultSkewTime - skew time is 15 minutes between minio peers.
|
||||||
const DefaultSkewTime = 15 * time.Minute
|
const DefaultSkewTime = 15 * time.Minute
|
||||||
|
|
||||||
|
// validateStorageRequestToken will validate the token against the provided audience.
|
||||||
|
func validateStorageRequestToken(token, audience string) error {
|
||||||
|
claims := xjwt.NewStandardClaims()
|
||||||
|
if err := xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil {
|
||||||
|
return errAuthentication
|
||||||
|
}
|
||||||
|
|
||||||
|
owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey
|
||||||
|
if !owner {
|
||||||
|
return errAuthentication
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Audience != audience {
|
||||||
|
return errAuthentication
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticates storage client's requests and validates for skewed time.
|
// Authenticates storage client's requests and validates for skewed time.
|
||||||
func storageServerRequestValidate(r *http.Request) error {
|
func storageServerRequestValidate(r *http.Request) error {
|
||||||
token, err := jwtreq.AuthorizationHeaderExtractor.ExtractToken(r)
|
token, err := jwtreq.AuthorizationHeaderExtractor.ExtractToken(r)
|
||||||
|
@ -118,19 +136,8 @@ func storageServerRequestValidate(r *http.Request) error {
|
||||||
}
|
}
|
||||||
return errMalformedAuth
|
return errMalformedAuth
|
||||||
}
|
}
|
||||||
|
if err = validateStorageRequestToken(token, r.URL.RawQuery); err != nil {
|
||||||
claims := xjwt.NewStandardClaims()
|
return err
|
||||||
if err = xjwt.ParseWithStandardClaims(token, claims, []byte(globalActiveCred.SecretKey)); err != nil {
|
|
||||||
return errAuthentication
|
|
||||||
}
|
|
||||||
|
|
||||||
owner := claims.AccessKey == globalActiveCred.AccessKey || claims.Subject == globalActiveCred.AccessKey
|
|
||||||
if !owner {
|
|
||||||
return errAuthentication
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Audience != r.URL.RawQuery {
|
|
||||||
return errAuthentication
|
|
||||||
}
|
}
|
||||||
|
|
||||||
requestTimeStr := r.Header.Get("X-Minio-Time")
|
requestTimeStr := r.Header.Get("X-Minio-Time")
|
||||||
|
|
|
@ -20,7 +20,6 @@ package grid
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -28,7 +27,6 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -100,9 +98,9 @@ type Connection struct {
|
||||||
// Client or serverside.
|
// Client or serverside.
|
||||||
side ws.State
|
side ws.State
|
||||||
|
|
||||||
// Transport for outgoing connections.
|
// Dialer for outgoing connections.
|
||||||
dialer ContextDialer
|
dial ConnDialer
|
||||||
header http.Header
|
authFn AuthFn
|
||||||
|
|
||||||
handleMsgWg sync.WaitGroup
|
handleMsgWg sync.WaitGroup
|
||||||
|
|
||||||
|
@ -112,10 +110,8 @@ type Connection struct {
|
||||||
handlers *handlers
|
handlers *handlers
|
||||||
|
|
||||||
remote *RemoteClient
|
remote *RemoteClient
|
||||||
auth AuthFn
|
|
||||||
clientPingInterval time.Duration
|
clientPingInterval time.Duration
|
||||||
connPingInterval time.Duration
|
connPingInterval time.Duration
|
||||||
tlsConfig *tls.Config
|
|
||||||
blockConnect chan struct{}
|
blockConnect chan struct{}
|
||||||
|
|
||||||
incomingBytes func(n int64) // Record incoming bytes.
|
incomingBytes func(n int64) // Record incoming bytes.
|
||||||
|
@ -205,13 +201,12 @@ type connectionParams struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
local, remote string
|
local, remote string
|
||||||
dial ContextDialer
|
|
||||||
handlers *handlers
|
handlers *handlers
|
||||||
auth AuthFn
|
|
||||||
tlsConfig *tls.Config
|
|
||||||
incomingBytes func(n int64) // Record incoming bytes.
|
incomingBytes func(n int64) // Record incoming bytes.
|
||||||
outgoingBytes func(n int64) // Record outgoing bytes.
|
outgoingBytes func(n int64) // Record outgoing bytes.
|
||||||
publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
||||||
|
dialer ConnDialer
|
||||||
|
authFn AuthFn
|
||||||
|
|
||||||
blockConnect chan struct{}
|
blockConnect chan struct{}
|
||||||
}
|
}
|
||||||
|
@ -227,16 +222,14 @@ func newConnection(o connectionParams) *Connection {
|
||||||
outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000),
|
outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000),
|
||||||
inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000),
|
inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000),
|
||||||
outQueue: make(chan []byte, defaultOutQueue),
|
outQueue: make(chan []byte, defaultOutQueue),
|
||||||
dialer: o.dial,
|
|
||||||
side: ws.StateServerSide,
|
side: ws.StateServerSide,
|
||||||
connChange: &sync.Cond{L: &sync.Mutex{}},
|
connChange: &sync.Cond{L: &sync.Mutex{}},
|
||||||
handlers: o.handlers,
|
handlers: o.handlers,
|
||||||
auth: o.auth,
|
|
||||||
header: make(http.Header, 1),
|
|
||||||
remote: &RemoteClient{Name: o.remote},
|
remote: &RemoteClient{Name: o.remote},
|
||||||
clientPingInterval: clientPingInterval,
|
clientPingInterval: clientPingInterval,
|
||||||
connPingInterval: connPingInterval,
|
connPingInterval: connPingInterval,
|
||||||
tlsConfig: o.tlsConfig,
|
dial: o.dialer,
|
||||||
|
authFn: o.authFn,
|
||||||
}
|
}
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
// Random Mux ID
|
// Random Mux ID
|
||||||
|
@ -648,41 +641,17 @@ func (c *Connection) connect() {
|
||||||
if c.State() == StateShutdown {
|
if c.State() == StateShutdown {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
toDial := strings.Replace(c.Remote, "http://", "ws://", 1)
|
|
||||||
toDial = strings.Replace(toDial, "https://", "wss://", 1)
|
|
||||||
toDial += RoutePath
|
|
||||||
|
|
||||||
dialer := ws.DefaultDialer
|
|
||||||
dialer.ReadBufferSize = readBufferSize
|
|
||||||
dialer.WriteBufferSize = writeBufferSize
|
|
||||||
dialer.Timeout = defaultDialTimeout
|
|
||||||
if c.dialer != nil {
|
|
||||||
dialer.NetDial = c.dialer.DialContext
|
|
||||||
}
|
|
||||||
if c.header == nil {
|
|
||||||
c.header = make(http.Header, 2)
|
|
||||||
}
|
|
||||||
c.header.Set("Authorization", "Bearer "+c.auth(""))
|
|
||||||
c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
|
|
||||||
|
|
||||||
if len(c.header) > 0 {
|
|
||||||
dialer.Header = ws.HandshakeHeaderHTTP(c.header)
|
|
||||||
}
|
|
||||||
dialer.TLSConfig = c.tlsConfig
|
|
||||||
dialStarted := time.Now()
|
dialStarted := time.Now()
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Println(c.Local, "Connecting to ", toDial)
|
fmt.Println(c.Local, "Connecting to ", c.Remote)
|
||||||
}
|
|
||||||
conn, br, _, err := dialer.Dial(c.ctx, toDial)
|
|
||||||
if br != nil {
|
|
||||||
ws.PutReader(br)
|
|
||||||
}
|
}
|
||||||
|
conn, err := c.dial(c.ctx, c.Remote)
|
||||||
c.connMu.Lock()
|
c.connMu.Lock()
|
||||||
c.debugOutConn = conn
|
c.debugOutConn = conn
|
||||||
c.connMu.Unlock()
|
c.connMu.Unlock()
|
||||||
retry := func(err error) {
|
retry := func(err error) {
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err)
|
fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, c.Remote, err)
|
||||||
}
|
}
|
||||||
sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout)))
|
sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout)))
|
||||||
next := dialStarted.Add(sleep / 2)
|
next := dialStarted.Add(sleep / 2)
|
||||||
|
@ -696,7 +665,7 @@ func (c *Connection) connect() {
|
||||||
}
|
}
|
||||||
if gotState != StateConnecting {
|
if gotState != StateConnecting {
|
||||||
// Don't print error on first attempt, and after that only once per hour.
|
// Don't print error on first attempt, and after that only once per hour.
|
||||||
gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), toDial)
|
gridLogOnceIf(c.ctx, fmt.Errorf("grid: %s re-connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, c.Remote, err, err, sleep, gotState), c.Remote)
|
||||||
}
|
}
|
||||||
c.updateState(StateConnectionError)
|
c.updateState(StateConnectionError)
|
||||||
time.Sleep(sleep)
|
time.Sleep(sleep)
|
||||||
|
@ -712,7 +681,9 @@ func (c *Connection) connect() {
|
||||||
req := connectReq{
|
req := connectReq{
|
||||||
Host: c.Local,
|
Host: c.Local,
|
||||||
ID: c.id,
|
ID: c.id,
|
||||||
|
Time: time.Now(),
|
||||||
}
|
}
|
||||||
|
req.addToken(c.authFn)
|
||||||
err = c.sendMsg(conn, m, &req)
|
err = c.sendMsg(conn, m, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retry(err)
|
retry(err)
|
||||||
|
|
|
@ -52,11 +52,13 @@ func TestDisconnect(t *testing.T) {
|
||||||
localHost := hosts[0]
|
localHost := hosts[0]
|
||||||
remoteHost := hosts[1]
|
remoteHost := hosts[1]
|
||||||
local, err := NewManager(context.Background(), ManagerOptions{
|
local, err := NewManager(context.Background(), ManagerOptions{
|
||||||
Dialer: dialer.DialContext,
|
Dialer: ConnectWS(dialer.DialContext,
|
||||||
|
dummyNewToken,
|
||||||
|
nil),
|
||||||
Local: localHost,
|
Local: localHost,
|
||||||
Hosts: hosts,
|
Hosts: hosts,
|
||||||
AddAuth: func(aud string) string { return aud },
|
AuthFn: dummyNewToken,
|
||||||
AuthRequest: dummyRequestValidate,
|
AuthToken: dummyTokenValidate,
|
||||||
BlockConnect: connReady,
|
BlockConnect: connReady,
|
||||||
})
|
})
|
||||||
errFatal(err)
|
errFatal(err)
|
||||||
|
@ -74,17 +76,19 @@ func TestDisconnect(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
remote, err := NewManager(context.Background(), ManagerOptions{
|
remote, err := NewManager(context.Background(), ManagerOptions{
|
||||||
Dialer: dialer.DialContext,
|
Dialer: ConnectWS(dialer.DialContext,
|
||||||
|
dummyNewToken,
|
||||||
|
nil),
|
||||||
Local: remoteHost,
|
Local: remoteHost,
|
||||||
Hosts: hosts,
|
Hosts: hosts,
|
||||||
AddAuth: func(aud string) string { return aud },
|
AuthFn: dummyNewToken,
|
||||||
AuthRequest: dummyRequestValidate,
|
AuthToken: dummyTokenValidate,
|
||||||
BlockConnect: connReady,
|
BlockConnect: connReady,
|
||||||
})
|
})
|
||||||
errFatal(err)
|
errFatal(err)
|
||||||
|
|
||||||
localServer := startServer(t, listeners[0], wrapServer(local.Handler()))
|
localServer := startServer(t, listeners[0], wrapServer(local.Handler(dummyRequestValidate)))
|
||||||
remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler()))
|
remoteServer := startServer(t, listeners[1], wrapServer(remote.Handler(dummyRequestValidate)))
|
||||||
close(connReady)
|
close(connReady)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -165,10 +169,6 @@ func TestDisconnect(t *testing.T) {
|
||||||
<-gotCall
|
<-gotCall
|
||||||
}
|
}
|
||||||
|
|
||||||
func dummyRequestValidate(r *http.Request) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldConnect(t *testing.T) {
|
func TestShouldConnect(t *testing.T) {
|
||||||
var c Connection
|
var c Connection
|
||||||
var cReverse Connection
|
var cReverse Connection
|
||||||
|
|
|
@ -82,20 +82,20 @@ func SetupTestGrid(n int) (*TestGrid, error) {
|
||||||
res.cancel = cancel
|
res.cancel = cancel
|
||||||
for i, host := range hosts {
|
for i, host := range hosts {
|
||||||
manager, err := NewManager(ctx, ManagerOptions{
|
manager, err := NewManager(ctx, ManagerOptions{
|
||||||
Dialer: dialer.DialContext,
|
Dialer: ConnectWS(dialer.DialContext,
|
||||||
Local: host,
|
dummyNewToken,
|
||||||
Hosts: hosts,
|
nil),
|
||||||
AuthRequest: func(r *http.Request) error {
|
Local: host,
|
||||||
return nil
|
Hosts: hosts,
|
||||||
},
|
AuthFn: dummyNewToken,
|
||||||
AddAuth: func(aud string) string { return aud },
|
AuthToken: dummyTokenValidate,
|
||||||
BlockConnect: ready,
|
BlockConnect: ready,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := mux.NewRouter()
|
m := mux.NewRouter()
|
||||||
m.Handle(RoutePath, manager.Handler())
|
m.Handle(RoutePath, manager.Handler(dummyRequestValidate))
|
||||||
res.Managers = append(res.Managers, manager)
|
res.Managers = append(res.Managers, manager)
|
||||||
res.Servers = append(res.Servers, startHTTPServer(listeners[i], m))
|
res.Servers = append(res.Servers, startHTTPServer(listeners[i], m))
|
||||||
res.Listeners = append(res.Listeners, listeners[i])
|
res.Listeners = append(res.Listeners, listeners[i])
|
||||||
|
@ -164,3 +164,18 @@ func startHTTPServer(listener net.Listener, handler http.Handler) (server *httpt
|
||||||
server.Start()
|
server.Start()
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dummyRequestValidate(r *http.Request) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dummyTokenValidate(token, audience string) error {
|
||||||
|
if token == audience {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid token. want %s, got %s", audience, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dummyNewToken(audience string) string {
|
||||||
|
return audience
|
||||||
|
}
|
||||||
|
|
|
@ -20,12 +20,17 @@ package grid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -179,3 +184,45 @@ func bytesOrLength(b []byte) string {
|
||||||
}
|
}
|
||||||
return fmt.Sprint(b)
|
return fmt.Sprint(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnDialer is a function that dials a connection to the given address.
|
||||||
|
// There should be no retries in this function,
|
||||||
|
// and should have a timeout of something like 2 seconds.
|
||||||
|
// The returned net.Conn should also have quick disconnect on errors.
|
||||||
|
// The net.Conn must support all features as described by the net.Conn interface.
|
||||||
|
type ConnDialer func(ctx context.Context, address string) (net.Conn, error)
|
||||||
|
|
||||||
|
// ConnectWS returns a function that dials a websocket connection to the given address.
|
||||||
|
// Route and auth are added to the connection.
|
||||||
|
func ConnectWS(dial ContextDialer, auth AuthFn, tls *tls.Config) func(ctx context.Context, remote string) (net.Conn, error) {
|
||||||
|
return func(ctx context.Context, remote string) (net.Conn, error) {
|
||||||
|
toDial := strings.Replace(remote, "http://", "ws://", 1)
|
||||||
|
toDial = strings.Replace(toDial, "https://", "wss://", 1)
|
||||||
|
toDial += RoutePath
|
||||||
|
|
||||||
|
dialer := ws.DefaultDialer
|
||||||
|
dialer.ReadBufferSize = readBufferSize
|
||||||
|
dialer.WriteBufferSize = writeBufferSize
|
||||||
|
dialer.Timeout = defaultDialTimeout
|
||||||
|
if dial != nil {
|
||||||
|
dialer.NetDial = dial
|
||||||
|
}
|
||||||
|
header := make(http.Header, 2)
|
||||||
|
header.Set("Authorization", "Bearer "+auth(""))
|
||||||
|
header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
|
||||||
|
|
||||||
|
if len(header) > 0 {
|
||||||
|
dialer.Header = ws.HandshakeHeaderHTTP(header)
|
||||||
|
}
|
||||||
|
dialer.TLSConfig = tls
|
||||||
|
|
||||||
|
conn, br, _, err := dialer.Dial(ctx, toDial)
|
||||||
|
if br != nil {
|
||||||
|
ws.PutReader(br)
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTokenFn must validate the token and return an error if it is invalid.
|
||||||
|
type ValidateTokenFn func(token, audience string) error
|
||||||
|
|
|
@ -19,13 +19,14 @@ package grid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gobwas/ws"
|
"github.com/gobwas/ws"
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
@ -62,40 +63,48 @@ type Manager struct {
|
||||||
// local host name.
|
// local host name.
|
||||||
local string
|
local string
|
||||||
|
|
||||||
// Validate incoming requests.
|
// authToken is a function that will validate a token.
|
||||||
authRequest func(r *http.Request) error
|
authToken ValidateTokenFn
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerOptions are options for creating a new grid manager.
|
// ManagerOptions are options for creating a new grid manager.
|
||||||
type ManagerOptions struct {
|
type ManagerOptions struct {
|
||||||
Dialer ContextDialer // Outgoing dialer.
|
Local string // Local host name.
|
||||||
Local string // Local host name.
|
Hosts []string // All hosts, including local in the grid.
|
||||||
Hosts []string // All hosts, including local in the grid.
|
Incoming func(n int64) // Record incoming bytes.
|
||||||
AddAuth AuthFn // Add authentication to the given audience.
|
Outgoing func(n int64) // Record outgoing bytes.
|
||||||
AuthRequest func(r *http.Request) error // Validate incoming requests.
|
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
|
||||||
TLSConfig *tls.Config // TLS to apply to the connections.
|
|
||||||
Incoming func(n int64) // Record incoming bytes.
|
|
||||||
Outgoing func(n int64) // Record outgoing bytes.
|
|
||||||
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
|
|
||||||
TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
||||||
|
Dialer ConnDialer
|
||||||
|
// Sign a token for the given audience.
|
||||||
|
AuthFn AuthFn
|
||||||
|
// Callbacks to validate incoming connections.
|
||||||
|
AuthToken ValidateTokenFn
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new grid manager
|
// NewManager creates a new grid manager
|
||||||
func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||||
found := false
|
found := false
|
||||||
if o.AuthRequest == nil {
|
if o.AuthToken == nil {
|
||||||
return nil, fmt.Errorf("grid: AuthRequest must be set")
|
return nil, fmt.Errorf("grid: AuthToken not set")
|
||||||
|
}
|
||||||
|
if o.Dialer == nil {
|
||||||
|
return nil, fmt.Errorf("grid: Dialer not set")
|
||||||
|
}
|
||||||
|
if o.AuthFn == nil {
|
||||||
|
return nil, fmt.Errorf("grid: AuthFn not set")
|
||||||
}
|
}
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
targets: make(map[string]*Connection, len(o.Hosts)),
|
targets: make(map[string]*Connection, len(o.Hosts)),
|
||||||
local: o.Local,
|
local: o.Local,
|
||||||
authRequest: o.AuthRequest,
|
authToken: o.AuthToken,
|
||||||
}
|
}
|
||||||
m.handlers.init()
|
m.handlers.init()
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, host := range o.Hosts {
|
for _, host := range o.Hosts {
|
||||||
if host == o.Local {
|
if host == o.Local {
|
||||||
if found {
|
if found {
|
||||||
|
@ -110,14 +119,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||||
id: m.ID,
|
id: m.ID,
|
||||||
local: o.Local,
|
local: o.Local,
|
||||||
remote: host,
|
remote: host,
|
||||||
dial: o.Dialer,
|
|
||||||
handlers: &m.handlers,
|
handlers: &m.handlers,
|
||||||
auth: o.AddAuth,
|
|
||||||
blockConnect: o.BlockConnect,
|
blockConnect: o.BlockConnect,
|
||||||
tlsConfig: o.TLSConfig,
|
|
||||||
publisher: o.TraceTo,
|
publisher: o.TraceTo,
|
||||||
incomingBytes: o.Incoming,
|
incomingBytes: o.Incoming,
|
||||||
outgoingBytes: o.Outgoing,
|
outgoingBytes: o.Outgoing,
|
||||||
|
dialer: o.Dialer,
|
||||||
|
authFn: o.AuthFn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
|
@ -128,13 +136,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddToMux will add the grid manager to the given mux.
|
// AddToMux will add the grid manager to the given mux.
|
||||||
func (m *Manager) AddToMux(router *mux.Router) {
|
func (m *Manager) AddToMux(router *mux.Router, authReq func(r *http.Request) error) {
|
||||||
router.Handle(RoutePath, m.Handler())
|
router.Handle(RoutePath, m.Handler(authReq))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler returns a handler that can be used to serve grid requests.
|
// Handler returns a handler that can be used to serve grid requests.
|
||||||
// This should be connected on RoutePath to the main server.
|
// This should be connected on RoutePath to the main server.
|
||||||
func (m *Manager) Handler() http.HandlerFunc {
|
func (m *Manager) Handler(authReq func(r *http.Request) error) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, req *http.Request) {
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
|
@ -151,7 +159,7 @@ func (m *Manager) Handler() http.HandlerFunc {
|
||||||
fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL)
|
fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL)
|
||||||
}
|
}
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
if err := m.authRequest(req); err != nil {
|
if err := authReq(req); err != nil {
|
||||||
gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr)
|
gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr)
|
||||||
w.WriteHeader(http.StatusForbidden)
|
w.WriteHeader(http.StatusForbidden)
|
||||||
return
|
return
|
||||||
|
@ -164,76 +172,96 @@ func (m *Manager) Handler() http.HandlerFunc {
|
||||||
w.WriteHeader(http.StatusUpgradeRequired)
|
w.WriteHeader(http.StatusUpgradeRequired)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// will write an OpConnectResponse message to the remote and log it once locally.
|
m.IncomingConn(ctx, conn)
|
||||||
writeErr := func(err error) {
|
|
||||||
if err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
gridLogOnceIf(ctx, err, req.RemoteAddr)
|
|
||||||
resp := connectResp{
|
|
||||||
ID: m.ID,
|
|
||||||
Accepted: false,
|
|
||||||
RejectedReason: err.Error(),
|
|
||||||
}
|
|
||||||
if b, err := resp.MarshalMsg(nil); err == nil {
|
|
||||||
msg := message{
|
|
||||||
Op: OpConnectResponse,
|
|
||||||
Payload: b,
|
|
||||||
}
|
|
||||||
if b, err := msg.MarshalMsg(nil); err == nil {
|
|
||||||
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
if debugPrint {
|
|
||||||
fmt.Printf("grid: Upgraded request: %v\n", req.URL)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, _, err := wsutil.ReadClientData(conn)
|
|
||||||
if err != nil {
|
|
||||||
writeErr(fmt.Errorf("reading connect: %w", err))
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if debugPrint {
|
|
||||||
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
var message message
|
|
||||||
_, _, err = message.parse(msg)
|
|
||||||
if err != nil {
|
|
||||||
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if message.Op != OpConnect {
|
|
||||||
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var cReq connectReq
|
|
||||||
_, err = cReq.UnmarshalMsg(message.Payload)
|
|
||||||
if err != nil {
|
|
||||||
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
remote := m.targets[cReq.Host]
|
|
||||||
if remote == nil {
|
|
||||||
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if debugPrint {
|
|
||||||
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
|
|
||||||
}
|
|
||||||
writeErr(remote.handleIncoming(ctx, conn, cReq))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncomingConn will handle an incoming connection.
|
||||||
|
// This should be called with the incoming connection after accept.
|
||||||
|
// Auth is handled internally, as well as disconnecting any connections from the same host.
|
||||||
|
func (m *Manager) IncomingConn(ctx context.Context, conn net.Conn) {
|
||||||
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
|
// will write an OpConnectResponse message to the remote and log it once locally.
|
||||||
|
defer conn.Close()
|
||||||
|
writeErr := func(err error) {
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gridLogOnceIf(ctx, err, remoteAddr)
|
||||||
|
resp := connectResp{
|
||||||
|
ID: m.ID,
|
||||||
|
Accepted: false,
|
||||||
|
RejectedReason: err.Error(),
|
||||||
|
}
|
||||||
|
if b, err := resp.MarshalMsg(nil); err == nil {
|
||||||
|
msg := message{
|
||||||
|
Op: OpConnectResponse,
|
||||||
|
Payload: b,
|
||||||
|
}
|
||||||
|
if b, err := msg.MarshalMsg(nil); err == nil {
|
||||||
|
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
if debugPrint {
|
||||||
|
fmt.Printf("grid: Upgraded request: %v\n", remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _, err := wsutil.ReadClientData(conn)
|
||||||
|
if err != nil {
|
||||||
|
writeErr(fmt.Errorf("reading connect: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if debugPrint {
|
||||||
|
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
var message message
|
||||||
|
_, _, err = message.parse(msg)
|
||||||
|
if err != nil {
|
||||||
|
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if message.Op != OpConnect {
|
||||||
|
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var cReq connectReq
|
||||||
|
_, err = cReq.UnmarshalMsg(message.Payload)
|
||||||
|
if err != nil {
|
||||||
|
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
remote := m.targets[cReq.Host]
|
||||||
|
if remote == nil {
|
||||||
|
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Since(cReq.Time).Abs() > 5*time.Minute {
|
||||||
|
writeErr(fmt.Errorf("time difference too large between servers: %v", time.Since(cReq.Time).Abs()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := m.authToken(cReq.Token, cReq.audience()); err != nil {
|
||||||
|
writeErr(fmt.Errorf("auth token: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if debugPrint {
|
||||||
|
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
|
||||||
|
}
|
||||||
|
writeErr(remote.handleIncoming(ctx, conn, cReq))
|
||||||
|
}
|
||||||
|
|
||||||
// AuthFn should provide an authentication string for the given aud.
|
// AuthFn should provide an authentication string for the given aud.
|
||||||
type AuthFn func(aud string) string
|
type AuthFn func(aud string) string
|
||||||
|
|
||||||
|
// ValidateAuthFn should check authentication for the given aud.
|
||||||
|
type ValidateAuthFn func(auth, aud string) string
|
||||||
|
|
||||||
// Connection will return the connection for the specified host.
|
// Connection will return the connection for the specified host.
|
||||||
// If the host does not exist nil will be returned.
|
// If the host does not exist nil will be returned.
|
||||||
func (m *Manager) Connection(host string) *Connection {
|
func (m *Manager) Connection(host string) *Connection {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/tinylib/msgp/msgp"
|
"github.com/tinylib/msgp/msgp"
|
||||||
"github.com/zeebo/xxh3"
|
"github.com/zeebo/xxh3"
|
||||||
|
@ -255,8 +256,20 @@ type sender interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type connectReq struct {
|
type connectReq struct {
|
||||||
ID [16]byte
|
ID [16]byte
|
||||||
Host string
|
Host string
|
||||||
|
Time time.Time
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
// audience returns the audience for the connect call.
|
||||||
|
func (c *connectReq) audience() string {
|
||||||
|
return fmt.Sprintf("%s-%d", c.Host, c.Time.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
// addToken will add the token to the connect request.
|
||||||
|
func (c *connectReq) addToken(fn AuthFn) {
|
||||||
|
c.Token = fn(c.audience())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (connectReq) Op() Op {
|
func (connectReq) Op() Op {
|
||||||
|
|
|
@ -192,6 +192,18 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||||
err = msgp.WrapError(err, "Host")
|
err = msgp.WrapError(err, "Host")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "Time":
|
||||||
|
z.Time, err = dc.ReadTime()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "Token":
|
||||||
|
z.Token, err = dc.ReadString()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Token")
|
||||||
|
return
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
err = dc.Skip()
|
err = dc.Skip()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -205,9 +217,9 @@ func (z *connectReq) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||||
|
|
||||||
// EncodeMsg implements msgp.Encodable
|
// EncodeMsg implements msgp.Encodable
|
||||||
func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
|
func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
|
||||||
// map header, size 2
|
// map header, size 4
|
||||||
// write "ID"
|
// write "ID"
|
||||||
err = en.Append(0x82, 0xa2, 0x49, 0x44)
|
err = en.Append(0x84, 0xa2, 0x49, 0x44)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -226,19 +238,45 @@ func (z *connectReq) EncodeMsg(en *msgp.Writer) (err error) {
|
||||||
err = msgp.WrapError(err, "Host")
|
err = msgp.WrapError(err, "Host")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// write "Time"
|
||||||
|
err = en.Append(0xa4, 0x54, 0x69, 0x6d, 0x65)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = en.WriteTime(z.Time)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// write "Token"
|
||||||
|
err = en.Append(0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = en.WriteString(z.Token)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Token")
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalMsg implements msgp.Marshaler
|
// MarshalMsg implements msgp.Marshaler
|
||||||
func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) {
|
func (z *connectReq) MarshalMsg(b []byte) (o []byte, err error) {
|
||||||
o = msgp.Require(b, z.Msgsize())
|
o = msgp.Require(b, z.Msgsize())
|
||||||
// map header, size 2
|
// map header, size 4
|
||||||
// string "ID"
|
// string "ID"
|
||||||
o = append(o, 0x82, 0xa2, 0x49, 0x44)
|
o = append(o, 0x84, 0xa2, 0x49, 0x44)
|
||||||
o = msgp.AppendBytes(o, (z.ID)[:])
|
o = msgp.AppendBytes(o, (z.ID)[:])
|
||||||
// string "Host"
|
// string "Host"
|
||||||
o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74)
|
o = append(o, 0xa4, 0x48, 0x6f, 0x73, 0x74)
|
||||||
o = msgp.AppendString(o, z.Host)
|
o = msgp.AppendString(o, z.Host)
|
||||||
|
// string "Time"
|
||||||
|
o = append(o, 0xa4, 0x54, 0x69, 0x6d, 0x65)
|
||||||
|
o = msgp.AppendTime(o, z.Time)
|
||||||
|
// string "Token"
|
||||||
|
o = append(o, 0xa5, 0x54, 0x6f, 0x6b, 0x65, 0x6e)
|
||||||
|
o = msgp.AppendString(o, z.Token)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,6 +310,18 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||||
err = msgp.WrapError(err, "Host")
|
err = msgp.WrapError(err, "Host")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "Time":
|
||||||
|
z.Time, bts, err = msgp.ReadTimeBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "Token":
|
||||||
|
z.Token, bts, err = msgp.ReadStringBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Token")
|
||||||
|
return
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
bts, err = msgp.Skip(bts)
|
bts, err = msgp.Skip(bts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -286,7 +336,7 @@ func (z *connectReq) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||||
|
|
||||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||||
func (z *connectReq) Msgsize() (s int) {
|
func (z *connectReq) Msgsize() (s int) {
|
||||||
s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host)
|
s = 1 + 3 + msgp.ArrayHeaderSize + (16 * (msgp.ByteSize)) + 5 + msgp.StringPrefixSize + len(z.Host) + 5 + msgp.TimeSize + 6 + msgp.StringPrefixSize + len(z.Token)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue