mirror of
https://github.com/minio/minio.git
synced 2025-11-07 21:02:58 -05:00
Abstract grid connections (#20038)
Add `ConnDialer` to abstract connection creation. - `IncomingConn(ctx context.Context, conn net.Conn)` is provided as an entry point for incoming custom connections. - `ConnectWS` is provided to create web socket connections.
This commit is contained in:
@@ -19,13 +19,14 @@ package grid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
@@ -62,40 +63,48 @@ type Manager struct {
|
||||
// local host name.
|
||||
local string
|
||||
|
||||
// Validate incoming requests.
|
||||
authRequest func(r *http.Request) error
|
||||
// authToken is a function that will validate a token.
|
||||
authToken ValidateTokenFn
|
||||
}
|
||||
|
||||
// ManagerOptions are options for creating a new grid manager.
|
||||
type ManagerOptions struct {
|
||||
Dialer ContextDialer // Outgoing dialer.
|
||||
Local string // Local host name.
|
||||
Hosts []string // All hosts, including local in the grid.
|
||||
AddAuth AuthFn // Add authentication to the given audience.
|
||||
AuthRequest func(r *http.Request) error // Validate incoming requests.
|
||||
TLSConfig *tls.Config // TLS to apply to the connections.
|
||||
Incoming func(n int64) // Record incoming bytes.
|
||||
Outgoing func(n int64) // Record outgoing bytes.
|
||||
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
|
||||
Local string // Local host name.
|
||||
Hosts []string // All hosts, including local in the grid.
|
||||
Incoming func(n int64) // Record incoming bytes.
|
||||
Outgoing func(n int64) // Record outgoing bytes.
|
||||
BlockConnect chan struct{} // If set, incoming and outgoing connections will be blocked until closed.
|
||||
TraceTo *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
|
||||
Dialer ConnDialer
|
||||
// Sign a token for the given audience.
|
||||
AuthFn AuthFn
|
||||
// Callbacks to validate incoming connections.
|
||||
AuthToken ValidateTokenFn
|
||||
}
|
||||
|
||||
// NewManager creates a new grid manager
|
||||
func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||
found := false
|
||||
if o.AuthRequest == nil {
|
||||
return nil, fmt.Errorf("grid: AuthRequest must be set")
|
||||
if o.AuthToken == nil {
|
||||
return nil, fmt.Errorf("grid: AuthToken not set")
|
||||
}
|
||||
if o.Dialer == nil {
|
||||
return nil, fmt.Errorf("grid: Dialer not set")
|
||||
}
|
||||
if o.AuthFn == nil {
|
||||
return nil, fmt.Errorf("grid: AuthFn not set")
|
||||
}
|
||||
m := &Manager{
|
||||
ID: uuid.New(),
|
||||
targets: make(map[string]*Connection, len(o.Hosts)),
|
||||
local: o.Local,
|
||||
authRequest: o.AuthRequest,
|
||||
ID: uuid.New(),
|
||||
targets: make(map[string]*Connection, len(o.Hosts)),
|
||||
local: o.Local,
|
||||
authToken: o.AuthToken,
|
||||
}
|
||||
m.handlers.init()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
for _, host := range o.Hosts {
|
||||
if host == o.Local {
|
||||
if found {
|
||||
@@ -110,14 +119,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||
id: m.ID,
|
||||
local: o.Local,
|
||||
remote: host,
|
||||
dial: o.Dialer,
|
||||
handlers: &m.handlers,
|
||||
auth: o.AddAuth,
|
||||
blockConnect: o.BlockConnect,
|
||||
tlsConfig: o.TLSConfig,
|
||||
publisher: o.TraceTo,
|
||||
incomingBytes: o.Incoming,
|
||||
outgoingBytes: o.Outgoing,
|
||||
dialer: o.Dialer,
|
||||
authFn: o.AuthFn,
|
||||
})
|
||||
}
|
||||
if !found {
|
||||
@@ -128,13 +136,13 @@ func NewManager(ctx context.Context, o ManagerOptions) (*Manager, error) {
|
||||
}
|
||||
|
||||
// AddToMux will add the grid manager to the given mux.
|
||||
func (m *Manager) AddToMux(router *mux.Router) {
|
||||
router.Handle(RoutePath, m.Handler())
|
||||
func (m *Manager) AddToMux(router *mux.Router, authReq func(r *http.Request) error) {
|
||||
router.Handle(RoutePath, m.Handler(authReq))
|
||||
}
|
||||
|
||||
// Handler returns a handler that can be used to serve grid requests.
|
||||
// This should be connected on RoutePath to the main server.
|
||||
func (m *Manager) Handler() http.HandlerFunc {
|
||||
func (m *Manager) Handler(authReq func(r *http.Request) error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
defer func() {
|
||||
if debugPrint {
|
||||
@@ -151,7 +159,7 @@ func (m *Manager) Handler() http.HandlerFunc {
|
||||
fmt.Printf("grid: Got a %s request for: %v\n", req.Method, req.URL)
|
||||
}
|
||||
ctx := req.Context()
|
||||
if err := m.authRequest(req); err != nil {
|
||||
if err := authReq(req); err != nil {
|
||||
gridLogOnceIf(ctx, fmt.Errorf("auth %s: %w", req.RemoteAddr, err), req.RemoteAddr)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
@@ -164,76 +172,96 @@ func (m *Manager) Handler() http.HandlerFunc {
|
||||
w.WriteHeader(http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
// will write an OpConnectResponse message to the remote and log it once locally.
|
||||
writeErr := func(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
gridLogOnceIf(ctx, err, req.RemoteAddr)
|
||||
resp := connectResp{
|
||||
ID: m.ID,
|
||||
Accepted: false,
|
||||
RejectedReason: err.Error(),
|
||||
}
|
||||
if b, err := resp.MarshalMsg(nil); err == nil {
|
||||
msg := message{
|
||||
Op: OpConnectResponse,
|
||||
Payload: b,
|
||||
}
|
||||
if b, err := msg.MarshalMsg(nil); err == nil {
|
||||
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
if debugPrint {
|
||||
fmt.Printf("grid: Upgraded request: %v\n", req.URL)
|
||||
}
|
||||
|
||||
msg, _, err := wsutil.ReadClientData(conn)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("reading connect: %w", err))
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if debugPrint {
|
||||
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
|
||||
}
|
||||
|
||||
var message message
|
||||
_, _, err = message.parse(msg)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
|
||||
return
|
||||
}
|
||||
if message.Op != OpConnect {
|
||||
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
|
||||
return
|
||||
}
|
||||
var cReq connectReq
|
||||
_, err = cReq.UnmarshalMsg(message.Payload)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
|
||||
return
|
||||
}
|
||||
remote := m.targets[cReq.Host]
|
||||
if remote == nil {
|
||||
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
|
||||
return
|
||||
}
|
||||
if debugPrint {
|
||||
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
|
||||
}
|
||||
writeErr(remote.handleIncoming(ctx, conn, cReq))
|
||||
m.IncomingConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// IncomingConn will handle an incoming connection.
|
||||
// This should be called with the incoming connection after accept.
|
||||
// Auth is handled internally, as well as disconnecting any connections from the same host.
|
||||
func (m *Manager) IncomingConn(ctx context.Context, conn net.Conn) {
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
// will write an OpConnectResponse message to the remote and log it once locally.
|
||||
defer conn.Close()
|
||||
writeErr := func(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
gridLogOnceIf(ctx, err, remoteAddr)
|
||||
resp := connectResp{
|
||||
ID: m.ID,
|
||||
Accepted: false,
|
||||
RejectedReason: err.Error(),
|
||||
}
|
||||
if b, err := resp.MarshalMsg(nil); err == nil {
|
||||
msg := message{
|
||||
Op: OpConnectResponse,
|
||||
Payload: b,
|
||||
}
|
||||
if b, err := msg.MarshalMsg(nil); err == nil {
|
||||
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
if debugPrint {
|
||||
fmt.Printf("grid: Upgraded request: %v\n", remoteAddr)
|
||||
}
|
||||
|
||||
msg, _, err := wsutil.ReadClientData(conn)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("reading connect: %w", err))
|
||||
return
|
||||
}
|
||||
if debugPrint {
|
||||
fmt.Printf("%s handler: Got message, length %v\n", m.local, len(msg))
|
||||
}
|
||||
|
||||
var message message
|
||||
_, _, err = message.parse(msg)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
|
||||
return
|
||||
}
|
||||
if message.Op != OpConnect {
|
||||
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
|
||||
return
|
||||
}
|
||||
var cReq connectReq
|
||||
_, err = cReq.UnmarshalMsg(message.Payload)
|
||||
if err != nil {
|
||||
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
|
||||
return
|
||||
}
|
||||
remote := m.targets[cReq.Host]
|
||||
if remote == nil {
|
||||
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
|
||||
return
|
||||
}
|
||||
if time.Since(cReq.Time).Abs() > 5*time.Minute {
|
||||
writeErr(fmt.Errorf("time difference too large between servers: %v", time.Since(cReq.Time).Abs()))
|
||||
return
|
||||
}
|
||||
if err := m.authToken(cReq.Token, cReq.audience()); err != nil {
|
||||
writeErr(fmt.Errorf("auth token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if debugPrint {
|
||||
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
|
||||
}
|
||||
writeErr(remote.handleIncoming(ctx, conn, cReq))
|
||||
}
|
||||
|
||||
// AuthFn should provide an authentication string for the given aud.
|
||||
type AuthFn func(aud string) string
|
||||
|
||||
// ValidateAuthFn should check authentication for the given aud.
|
||||
type ValidateAuthFn func(auth, aud string) string
|
||||
|
||||
// Connection will return the connection for the specified host.
|
||||
// If the host does not exist nil will be returned.
|
||||
func (m *Manager) Connection(host string) *Connection {
|
||||
|
||||
Reference in New Issue
Block a user