minio/internal/grid/connection.go
Klaus Post feeeef71f1
Add extra protection for grid reconnects (#18840)
Race checks would occasionally show race on handleMsgWg WaitGroup by debug messages (used in test only).

Use the `connMu` mutex to protect this against concurrent Wait/Add.

Fixes #18827
2024-01-22 09:39:06 -08:00

1625 lines
40 KiB
Go

// Copyright (c) 2015-2023 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package grid
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/google/uuid"
"github.com/minio/madmin-go/v3"
"github.com/minio/minio/internal/logger"
"github.com/minio/minio/internal/pubsub"
"github.com/tinylib/msgp/msgp"
"github.com/zeebo/xxh3"
)
// A Connection is a remote connection.
// There is no distinction externally whether the connection was initiated from
// this server or from the remote.
type Connection struct {
// NextID is the next ID that can be used (atomic).
NextID uint64
// LastPong is last pong time (atomic)
// Only valid when StateConnected.
LastPong int64
// State of the connection (atomic)
state State
// Non-atomic
Remote string
Local string
// ID of this connection instance.
id uuid.UUID
// Remote uuid, if we have been connected.
remoteID *uuid.UUID
// Context for the server.
ctx context.Context
// Active mux connections.
outgoing *lockedClientMap
// Incoming streams
inStream *lockedServerMap
// outQueue is the output queue
outQueue chan []byte
// Client or serverside.
side ws.State
// Transport for outgoing connections.
dialer ContextDialer
header http.Header
handleMsgWg sync.WaitGroup
// connChange will be signaled whenever State has been updated, or at regular intervals.
// Holding the lock allows safe reads of State, and guarantees that changes will be detected.
connChange *sync.Cond
handlers *handlers
remote *RemoteClient
auth AuthFn
clientPingInterval time.Duration
connPingInterval time.Duration
tlsConfig *tls.Config
blockConnect chan struct{}
incomingBytes func(n int64) // Record incoming bytes.
outgoingBytes func(n int64) // Record outgoing bytes.
trace *tracer // tracer for this connection.
baseFlags Flags
// For testing only
debugInConn net.Conn
debugOutConn net.Conn
addDeadline time.Duration
connMu sync.Mutex
}
// Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID.
type Subroute struct {
*Connection
trace *tracer
route string
subID subHandlerID
}
// String returns a string representation of the connection.
func (c *Connection) String() string {
return fmt.Sprintf("%s->%s", c.Local, c.Remote)
}
// StringReverse returns a string representation of the reverse connection.
func (c *Connection) StringReverse() string {
return fmt.Sprintf("%s->%s", c.Remote, c.Local)
}
// State is a connection state.
type State uint32
// MANUAL go:generate stringer -type=State -output=state_string.go -trimprefix=State $GOFILE
const (
// StateUnconnected is the initial state of a connection.
// When the first message is sent it will attempt to connect.
StateUnconnected = iota
// StateConnecting is the state from StateUnconnected while the connection is attempted to be established.
// After this connection will be StateConnected or StateConnectionError.
StateConnecting
// StateConnected is the state when the connection has been established and is considered stable.
// If the connection is lost, state will switch to StateConnecting.
StateConnected
// StateConnectionError is the state once a connection attempt has been made, and it failed.
// The connection will remain in this stat until the connection has been successfully re-established.
StateConnectionError
// StateShutdown is the state when the server has been shut down.
// This will not be used under normal operation.
StateShutdown
// MaxDeadline is the maximum deadline allowed,
// Approx 49 days.
MaxDeadline = time.Duration(math.MaxUint32) * time.Millisecond
)
// ContextDialer is a dialer that can be used to dial a remote.
type ContextDialer func(ctx context.Context, network, address string) (net.Conn, error)
// DialContext implements the Dialer interface.
func (c ContextDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return c(ctx, network, address)
}
const (
defaultOutQueue = 10000
readBufferSize = 16 << 10
writeBufferSize = 16 << 10
defaultDialTimeout = 2 * time.Second
connPingInterval = 10 * time.Second
)
type connectionParams struct {
ctx context.Context
id uuid.UUID
local, remote string
dial ContextDialer
handlers *handlers
auth AuthFn
tlsConfig *tls.Config
incomingBytes func(n int64) // Record incoming bytes.
outgoingBytes func(n int64) // Record outgoing bytes.
publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
blockConnect chan struct{}
}
// newConnection will create an unconnected connection to a remote.
func newConnection(o connectionParams) *Connection {
c := &Connection{
state: StateUnconnected,
Remote: o.remote,
Local: o.local,
id: o.id,
ctx: o.ctx,
outgoing: &lockedClientMap{m: make(map[uint64]*muxClient, 1000)},
inStream: &lockedServerMap{m: make(map[uint64]*muxServer, 1000)},
outQueue: make(chan []byte, defaultOutQueue),
dialer: o.dial,
side: ws.StateServerSide,
connChange: &sync.Cond{L: &sync.Mutex{}},
handlers: o.handlers,
auth: o.auth,
header: make(http.Header, 1),
remote: &RemoteClient{Name: o.remote},
clientPingInterval: clientPingInterval,
connPingInterval: connPingInterval,
tlsConfig: o.tlsConfig,
incomingBytes: o.incomingBytes,
outgoingBytes: o.outgoingBytes,
}
if debugPrint {
// Random Mux ID
c.NextID = rand.Uint64()
}
if !strings.HasPrefix(o.remote, "https://") && !strings.HasPrefix(o.remote, "wss://") {
c.baseFlags |= FlagCRCxxh3
}
if !strings.HasPrefix(o.local, "https://") && !strings.HasPrefix(o.local, "wss://") {
c.baseFlags |= FlagCRCxxh3
}
if o.publisher != nil {
c.traceRequests(o.publisher)
}
if o.local == o.remote {
panic("equal hosts")
}
if c.shouldConnect() {
c.side = ws.StateClientSide
go func() {
if o.blockConnect != nil {
<-o.blockConnect
}
c.connect()
}()
}
if debugPrint {
fmt.Println(c.Local, "->", c.Remote, "Should local connect:", c.shouldConnect(), "side:", c.side)
}
if debugReqs {
fmt.Println("Created connection", c.String())
}
return c
}
// Subroute returns a static subroute for the connection.
func (c *Connection) Subroute(s string) *Subroute {
if c == nil {
return nil
}
return &Subroute{
Connection: c,
route: s,
subID: makeSubHandlerID(0, s),
trace: c.trace.subroute(s),
}
}
// Subroute adds a subroute to the subroute.
// The subroutes are combined with '/'.
func (c *Subroute) Subroute(s string) *Subroute {
route := strings.Join([]string{c.route, s}, "/")
return &Subroute{
Connection: c.Connection,
route: route,
subID: makeSubHandlerID(0, route),
trace: c.trace.subroute(route),
}
}
// newMuxClient returns a mux client for manual use.
func (c *Connection) newMuxClient(ctx context.Context) (*muxClient, error) {
client := newMuxClient(ctx, atomic.AddUint64(&c.NextID, 1), c)
if dl, ok := ctx.Deadline(); ok {
client.deadline = getDeadline(time.Until(dl))
if client.deadline == 0 {
return nil, context.DeadlineExceeded
}
}
for {
// Handle the extremely unlikely scenario that we wrapped.
if _, loaded := c.outgoing.LoadOrStore(client.MuxID, client); client.MuxID != 0 && !loaded {
if debugReqs {
_, found := c.outgoing.Load(client.MuxID)
fmt.Println(client.MuxID, c.String(), "Connection.newMuxClient: RELOADED MUX. loaded:", loaded, "found:", found)
}
return client, nil
}
client.MuxID = atomic.AddUint64(&c.NextID, 1)
}
}
// newMuxClient returns a mux client for manual use.
func (c *Subroute) newMuxClient(ctx context.Context) (*muxClient, error) {
cl, err := c.Connection.newMuxClient(ctx)
if err != nil {
return nil, err
}
cl.subroute = &c.subID
return cl, nil
}
// Request allows to do a single remote request.
// 'req' will not be used after the call and caller can reuse.
// If no deadline is set on ctx, a 1-minute deadline will be added.
func (c *Connection) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) {
if !h.valid() {
return nil, ErrUnknownHandler
}
if c.State() != StateConnected {
return nil, ErrDisconnected
}
// Create mux client and call.
client, err := c.newMuxClient(ctx)
if err != nil {
return nil, err
}
defer func() {
if debugReqs {
_, ok := c.outgoing.Load(client.MuxID)
fmt.Println(client.MuxID, c.String(), "Connection.Request: DELETING MUX. Exists:", ok)
}
c.outgoing.Delete(client.MuxID)
}()
return client.traceRoundtrip(ctx, c.trace, h, req)
}
// Request allows to do a single remote request.
// 'req' will not be used after the call and caller can reuse.
// If no deadline is set on ctx, a 1-minute deadline will be added.
func (c *Subroute) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) {
if !h.valid() {
return nil, ErrUnknownHandler
}
if c.State() != StateConnected {
return nil, ErrDisconnected
}
// Create mux client and call.
client, err := c.newMuxClient(ctx)
if err != nil {
return nil, err
}
client.subroute = &c.subID
defer func() {
if debugReqs {
fmt.Println(client.MuxID, c.String(), "Subroute.Request: DELETING MUX")
}
c.outgoing.Delete(client.MuxID)
}()
return client.traceRoundtrip(ctx, c.trace, h, req)
}
// NewStream creates a new stream.
// Initial payload can be reused by the caller.
func (c *Connection) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) {
if !h.valid() {
return nil, ErrUnknownHandler
}
if c.State() != StateConnected {
return nil, ErrDisconnected
}
handler := c.handlers.streams[h]
if handler == nil {
return nil, ErrUnknownHandler
}
var requests chan []byte
var responses chan Response
if handler.InCapacity > 0 {
requests = make(chan []byte, handler.InCapacity)
}
if handler.OutCapacity > 0 {
responses = make(chan Response, handler.OutCapacity)
} else {
responses = make(chan Response, 1)
}
cl, err := c.newMuxClient(ctx)
if err != nil {
return nil, err
}
return cl.RequestStream(h, payload, requests, responses)
}
// NewStream creates a new stream.
// Initial payload can be reused by the caller.
func (c *Subroute) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) {
if !h.valid() {
return nil, ErrUnknownHandler
}
if c.State() != StateConnected {
return nil, ErrDisconnected
}
handler := c.handlers.subStreams[makeZeroSubHandlerID(h)]
if handler == nil {
if debugPrint {
fmt.Println("want", makeZeroSubHandlerID(h), c.route, "got", c.handlers.subStreams)
}
return nil, ErrUnknownHandler
}
var requests chan []byte
var responses chan Response
if handler.InCapacity > 0 {
requests = make(chan []byte, handler.InCapacity)
}
if handler.OutCapacity > 0 {
responses = make(chan Response, handler.OutCapacity)
} else {
responses = make(chan Response, 1)
}
cl, err := c.newMuxClient(ctx)
if err != nil {
return nil, err
}
cl.subroute = &c.subID
return cl.RequestStream(h, payload, requests, responses)
}
// WaitForConnect will block until a connection has been established or
// the context is canceled, in which case the context error is returned.
func (c *Connection) WaitForConnect(ctx context.Context) error {
if debugPrint {
fmt.Println(c.Local, "->", c.Remote, "WaitForConnect")
defer fmt.Println(c.Local, "->", c.Remote, "WaitForConnect done")
}
c.connChange.L.Lock()
if atomic.LoadUint32((*uint32)(&c.state)) == StateConnected {
c.connChange.L.Unlock()
// Happy path.
return nil
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
changed := make(chan State, 1)
go func() {
defer close(changed)
for {
c.connChange.Wait()
newState := c.State()
select {
case changed <- newState:
if newState == StateConnected || newState == StateShutdown {
c.connChange.L.Unlock()
return
}
case <-ctx.Done():
c.connChange.L.Unlock()
return
}
}
}()
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case newState := <-changed:
if newState == StateConnected {
return nil
}
}
}
}
/*
var ErrDone = errors.New("done for now")
var ErrRemoteRestart = errors.New("remote restarted")
// Stateless connects to the remote handler and return all packets sent back.
// If the remote is restarted will return ErrRemoteRestart.
// If nil will be returned remote call sent EOF or ErrDone is returned by the callback.
// If ErrDone is returned on cb nil will be returned.
func (c *Connection) Stateless(ctx context.Context, h HandlerID, req []byte, cb func([]byte) error) error {
client, err := c.newMuxClient(ctx)
if err != nil {
return err
}
defer c.outgoing.Delete(client.MuxID)
resp := make(chan Response, 10)
client.RequestStateless(h, req, resp)
for r := range resp {
if r.Err != nil {
return r.Err
}
if len(r.Msg) > 0 {
err := cb(r.Msg)
if err != nil {
if errors.Is(err, ErrDone) {
break
}
return err
}
}
}
return nil
}
*/
// shouldConnect returns a deterministic bool whether the local should initiate the connection.
// It should be 50% chance of any host initiating the connection.
func (c *Connection) shouldConnect() bool {
// The remote should have the opposite result.
h0 := xxh3.HashString(c.Local + c.Remote)
h1 := xxh3.HashString(c.Remote + c.Local)
if h0 == h1 {
return c.Local < c.Remote
}
return h0 < h1
}
func (c *Connection) send(msg []byte) error {
select {
case <-c.ctx.Done():
return context.Cause(c.ctx)
case c.outQueue <- msg:
return nil
}
}
// queueMsg queues a message, with an optional payload.
// sender should not reference msg.Payload
func (c *Connection) queueMsg(msg message, payload sender) error {
// Add baseflags.
msg.Flags.Set(c.baseFlags)
// This cannot encode subroute.
msg.Flags.Clear(FlagSubroute)
if payload != nil {
if cap(msg.Payload) < payload.Msgsize() {
old := msg.Payload
msg.Payload = GetByteBuffer()[:0]
PutByteBuffer(old)
}
var err error
msg.Payload, err = payload.MarshalMsg(msg.Payload[:0])
msg.Op = payload.Op()
if err != nil {
return err
}
}
defer PutByteBuffer(msg.Payload)
dst := GetByteBuffer()[:0]
dst, err := msg.MarshalMsg(dst)
if err != nil {
return err
}
if msg.Flags&FlagCRCxxh3 != 0 {
h := xxh3.Hash(dst)
dst = binary.LittleEndian.AppendUint32(dst, uint32(h))
}
return c.send(dst)
}
// sendMsg will send
func (c *Connection) sendMsg(conn net.Conn, msg message, payload msgp.MarshalSizer) error {
if payload != nil {
if cap(msg.Payload) < payload.Msgsize() {
PutByteBuffer(msg.Payload)
msg.Payload = GetByteBuffer()[:0]
}
var err error
msg.Payload, err = payload.MarshalMsg(msg.Payload)
if err != nil {
return err
}
defer PutByteBuffer(msg.Payload)
}
dst := GetByteBuffer()[:0]
dst, err := msg.MarshalMsg(dst)
if err != nil {
return err
}
if msg.Flags&FlagCRCxxh3 != 0 {
h := xxh3.Hash(dst)
dst = binary.LittleEndian.AppendUint32(dst, uint32(h))
}
if debugPrint {
fmt.Println(c.Local, "sendMsg: Sending", msg.Op, "as", len(dst), "bytes")
}
if c.outgoingBytes != nil {
c.outgoingBytes(int64(len(dst)))
}
return wsutil.WriteMessage(conn, c.side, ws.OpBinary, dst)
}
func (c *Connection) connect() {
c.updateState(StateConnecting)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
// Runs until the server is shut down.
for {
if c.State() == StateShutdown {
return
}
toDial := strings.Replace(c.Remote, "http://", "ws://", 1)
toDial = strings.Replace(toDial, "https://", "wss://", 1)
toDial += RoutePath
dialer := ws.DefaultDialer
dialer.ReadBufferSize = readBufferSize
dialer.WriteBufferSize = writeBufferSize
dialer.Timeout = defaultDialTimeout
if c.dialer != nil {
dialer.NetDial = c.dialer.DialContext
}
if c.header == nil {
c.header = make(http.Header, 2)
}
c.header.Set("Authorization", "Bearer "+c.auth(""))
c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
if len(c.header) > 0 {
dialer.Header = ws.HandshakeHeaderHTTP(c.header)
}
dialer.TLSConfig = c.tlsConfig
dialStarted := time.Now()
if debugPrint {
fmt.Println(c.Local, "Connecting to ", toDial)
}
conn, br, _, err := dialer.Dial(c.ctx, toDial)
if br != nil {
ws.PutReader(br)
}
c.connMu.Lock()
c.debugOutConn = conn
c.connMu.Unlock()
retry := func(err error) {
if debugPrint {
fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err)
}
sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout)))
next := dialStarted.Add(sleep)
sleep = time.Until(next).Round(time.Millisecond)
if sleep < 0 {
sleep = 0
}
gotState := c.State()
if gotState == StateShutdown {
return
}
if gotState != StateConnecting {
// Don't print error on first attempt,
// and after that only once per hour.
cHour := strconv.FormatInt(time.Now().Unix()/60/60, 10)
logger.LogOnceIf(c.ctx, fmt.Errorf("grid: %s connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), c.Local+toDial+cHour+err.Error())
}
c.updateState(StateConnectionError)
time.Sleep(sleep)
}
if err != nil {
retry(err)
continue
}
// Send connect message.
m := message{
Op: OpConnect,
}
req := connectReq{
Host: c.Local,
ID: c.id,
}
err = c.sendMsg(conn, m, &req)
if err != nil {
retry(err)
continue
}
// Wait for response
var r connectResp
err = c.receive(conn, &r)
if err != nil {
if debugPrint {
fmt.Println(c.Local, "receive err:", err, "side:", c.side)
}
retry(err)
continue
}
if debugPrint {
fmt.Println(c.Local, "Got connectResp:", r)
}
if !r.Accepted {
retry(fmt.Errorf("connection rejected: %s", r.RejectedReason))
continue
}
remoteUUID := uuid.UUID(r.ID)
if c.remoteID != nil {
c.reconnected()
}
c.remoteID = &remoteUUID
if debugPrint {
fmt.Println(c.Local, "Connected Waiting for Messages")
}
c.updateState(StateConnected)
go c.handleMessages(c.ctx, conn)
// Monitor state changes and reconnect if needed.
c.connChange.L.Lock()
for {
newState := c.State()
if newState != StateConnected {
c.connChange.L.Unlock()
if newState == StateShutdown {
conn.Close()
return
}
if debugPrint {
fmt.Println(c.Local, "Disconnected")
}
// Reconnect
break
}
// Unlock and wait for state change.
c.connChange.Wait()
}
}
}
func (c *Connection) disconnected() {
c.outgoing.Range(func(key uint64, client *muxClient) bool {
if !client.stateless {
client.cancelFn(ErrDisconnected)
}
return true
})
if debugReqs {
fmt.Println(c.String(), "Disconnected. Clearing outgoing.")
}
c.outgoing.Clear()
c.inStream.Range(func(key uint64, client *muxServer) bool {
client.cancel()
return true
})
c.inStream.Clear()
}
func (c *Connection) receive(conn net.Conn, r receiver) error {
b, op, err := wsutil.ReadData(conn, c.side)
if err != nil {
return err
}
if op != ws.OpBinary {
return fmt.Errorf("unexpected connect response type %v", op)
}
var m message
_, _, err = m.parse(b)
if err != nil {
return err
}
if m.Op != r.Op() {
return fmt.Errorf("unexpected response OP, want %v, got %v", r.Op(), m.Op)
}
_, err = r.UnmarshalMsg(m.Payload)
return err
}
func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req connectReq) error {
c.connMu.Lock()
c.debugInConn = conn
c.connMu.Unlock()
if c.blockConnect != nil {
// Block until we are allowed to connect.
<-c.blockConnect
}
if req.Host != c.Remote {
err := fmt.Errorf("expected remote '%s', got '%s'", c.Remote, req.Host)
if debugPrint {
fmt.Println(err)
}
return err
}
if c.shouldConnect() {
if debugPrint {
fmt.Println("expected to be client side, not server side")
}
return errors.New("grid: expected to be client side, not server side")
}
msg := message{
Op: OpConnectResponse,
}
if c.remoteID != nil {
c.reconnected()
}
rid := uuid.UUID(req.ID)
c.remoteID = &rid
resp := connectResp{
ID: c.id,
Accepted: true,
}
err := c.sendMsg(conn, msg, &resp)
if debugPrint {
fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side)
}
if err == nil {
c.updateState(StateConnected)
c.handleMessages(ctx, conn)
}
return err
}
func (c *Connection) reconnected() {
c.updateState(StateConnectionError)
// Close all active requests.
if debugReqs {
fmt.Println(c.String(), "Reconnected. Clearing outgoing.")
}
c.outgoing.Range(func(key uint64, client *muxClient) bool {
client.close()
return true
})
c.inStream.Range(func(key uint64, value *muxServer) bool {
value.close()
return true
})
c.inStream.Clear()
c.outgoing.Clear()
// Wait for existing to exit
c.connMu.Lock()
c.handleMsgWg.Wait()
c.connMu.Unlock()
}
func (c *Connection) updateState(s State) {
c.connChange.L.Lock()
defer c.connChange.L.Unlock()
// We may have reads that aren't locked, so update atomically.
gotState := atomic.LoadUint32((*uint32)(&c.state))
if gotState == StateShutdown || State(gotState) == s {
return
}
if s == StateConnected {
atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
}
atomic.StoreUint32((*uint32)(&c.state), uint32(s))
if debugPrint {
fmt.Println(c.Local, "updateState:", gotState, "->", s)
}
c.connChange.Broadcast()
}
func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
// Read goroutine
c.connMu.Lock()
c.handleMsgWg.Add(2)
c.connMu.Unlock()
ctx, cancel := context.WithCancelCause(ctx)
go func() {
defer func() {
if rec := recover(); rec != nil {
logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
debug.PrintStack()
}
c.connChange.L.Lock()
if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
c.connChange.Broadcast()
}
c.connChange.L.Unlock()
conn.Close()
c.handleMsgWg.Done()
}()
controlHandler := wsutil.ControlFrameHandler(conn, c.side)
wsReader := wsutil.Reader{
Source: conn,
State: c.side,
CheckUTF8: true,
SkipHeaderCheck: false,
OnIntermediate: controlHandler,
}
readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) {
dst = dst[:0]
for {
hdr, err := wsReader.NextFrame()
if err != nil {
return nil, err
}
if hdr.OpCode.IsControl() {
if err := controlHandler(hdr, &wsReader); err != nil {
return nil, err
}
continue
}
if hdr.OpCode&want == 0 {
if err := wsReader.Discard(); err != nil {
return nil, err
}
continue
}
if int64(cap(dst)) < hdr.Length+1 {
dst = make([]byte, 0, hdr.Length+hdr.Length>>3)
}
return readAllInto(dst[:0], &wsReader)
}
}
// Keep reusing the same buffer.
var msg []byte
for {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
cancel(ErrDisconnected)
return
}
if cap(msg) > readBufferSize*8 {
// Don't keep too much memory around.
msg = nil
}
var err error
msg, err = readDataInto(msg, conn, c.side, ws.OpBinary)
if err != nil {
cancel(ErrDisconnected)
logger.LogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
return
}
if c.incomingBytes != nil {
c.incomingBytes(int64(len(msg)))
}
// Parse the received message
var m message
subID, remain, err := m.parse(msg)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws parse package: %w", err))
cancel(ErrDisconnected)
return
}
if debugPrint {
fmt.Printf("%s Got msg: %v\n", c.Local, m)
}
if m.Op != OpMerged {
c.handleMsg(ctx, m, subID)
continue
}
// Handle merged messages.
messages := int(m.Seq)
for i := 0; i < messages; i++ {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
cancel(ErrDisconnected)
return
}
var next []byte
next, remain, err = msgp.ReadBytesZC(remain)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws read merged: %w", err))
cancel(ErrDisconnected)
return
}
m.Payload = nil
subID, _, err = m.parse(next)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws parse merged: %w", err))
cancel(ErrDisconnected)
return
}
c.handleMsg(ctx, m, subID)
}
}
}()
// Write function.
defer func() {
if rec := recover(); rec != nil {
logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
debug.PrintStack()
}
if debugPrint {
fmt.Println("handleMessages: write goroutine exited")
}
cancel(ErrDisconnected)
c.connChange.L.Lock()
if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
c.connChange.Broadcast()
}
c.disconnected()
c.connChange.L.Unlock()
conn.Close()
c.handleMsgWg.Done()
}()
c.connMu.Lock()
connPingInterval := c.connPingInterval
c.connMu.Unlock()
ping := time.NewTicker(connPingInterval)
pingFrame := message{
Op: OpPing,
DeadlineMS: 5000,
}
defer ping.Stop()
queue := make([][]byte, 0, maxMergeMessages)
merged := make([]byte, 0, writeBufferSize)
var queueSize int
var buf bytes.Buffer
var wsw wsWriter
for {
var toSend []byte
select {
case <-ctx.Done():
return
case <-ping.C:
if c.State() != StateConnected {
continue
}
lastPong := atomic.LoadInt64(&c.LastPong)
if lastPong > 0 {
lastPongTime := time.Unix(lastPong, 0)
if d := time.Since(lastPongTime); d > connPingInterval*2 {
logger.LogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond)))
cancel(ErrDisconnected)
return
}
}
var err error
toSend, err = pingFrame.MarshalMsg(GetByteBuffer()[:0])
if err != nil {
logger.LogIf(ctx, err)
// Fake it...
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
continue
}
case toSend = <-c.outQueue:
if len(toSend) == 0 {
continue
}
}
if len(queue) < maxMergeMessages && queueSize+len(toSend) < writeBufferSize-1024 && len(c.outQueue) > 0 {
queue = append(queue, toSend)
queueSize += len(toSend)
continue
}
c.connChange.L.Lock()
for {
state := c.State()
if state == StateConnected {
break
}
if debugPrint {
fmt.Println(c.Local, "Waiting for connection ->", c.Remote, "state: ", state)
}
if state == StateShutdown || state == StateConnectionError {
c.connChange.L.Unlock()
return
}
c.connChange.Wait()
select {
case <-ctx.Done():
c.connChange.L.Unlock()
return
default:
}
}
c.connChange.L.Unlock()
if len(queue) == 0 {
// Combine writes.
buf.Reset()
err := wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err))
cancel(ErrDisconnected)
return
}
PutByteBuffer(toSend)
_, err = buf.WriteTo(conn)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws write: %w", err))
cancel(ErrDisconnected)
return
}
continue
}
// Merge entries and send
queue = append(queue, toSend)
if debugPrint {
fmt.Println("Merging", len(queue), "messages")
}
toSend = merged[:0]
m := message{Op: OpMerged, Seq: uint32(len(queue))}
var err error
toSend, err = m.MarshalMsg(toSend)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("msg.MarshalMsg: %w", err))
cancel(ErrDisconnected)
return
}
// Append as byte slices.
for _, q := range queue {
toSend = msgp.AppendBytes(toSend, q)
PutByteBuffer(q)
}
queue = queue[:0]
queueSize = 0
// Combine writes.
// Consider avoiding buffer copy.
buf.Reset()
err = wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err))
cancel(ErrDisconnected)
return
}
// Tosend is our local buffer, so we can reuse it.
_, err = buf.WriteTo(conn)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("ws write: %w", err))
cancel(ErrDisconnected)
return
}
if buf.Cap() > writeBufferSize*4 {
// Reset buffer if it gets too big, so we don't keep it around.
buf = bytes.Buffer{}
}
}
}
func (c *Connection) handleMsg(ctx context.Context, m message, subID *subHandlerID) {
switch m.Op {
case OpMuxServerMsg:
c.handleMuxServerMsg(ctx, m)
case OpResponse:
c.handleResponse(m)
case OpMuxClientMsg:
c.handleMuxClientMsg(ctx, m)
case OpUnblockSrvMux:
c.handleUnblockSrvMux(m)
case OpUnblockClMux:
c.handleUnblockClMux(m)
case OpDisconnectServerMux:
c.handleDisconnectServerMux(m)
case OpDisconnectClientMux:
c.handleDisconnectClientMux(m)
case OpPing:
c.handlePing(ctx, m)
case OpPong:
c.handlePong(ctx, m)
case OpRequest:
c.handleRequest(ctx, m, subID)
case OpAckMux:
c.handleAckMux(ctx, m)
case OpConnectMux:
c.handleConnectMux(ctx, m, subID)
case OpMuxConnectError:
c.handleConnectMuxError(ctx, m)
default:
logger.LogIf(ctx, fmt.Errorf("unknown message type: %v", m.Op))
}
}
func (c *Connection) handleConnectMux(ctx context.Context, m message, subID *subHandlerID) {
// Stateless stream:
if m.Flags&FlagStateless != 0 {
// Reject for now, so we can safely add it later.
if true {
logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Stateless streams not supported"}))
return
}
var handler *StatelessHandler
if subID == nil {
handler = c.handlers.stateless[m.Handler]
} else {
handler = c.handlers.subStateless[*subID]
}
if handler == nil {
logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"}))
return
}
_, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer {
return newMuxStateless(ctx, m, c, *handler)
})
} else {
// Stream:
var handler *StreamHandler
if subID == nil {
if !m.Handler.valid() {
logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"}))
return
}
handler = c.handlers.streams[m.Handler]
} else {
handler = c.handlers.subStreams[*subID]
}
if handler == nil {
logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"}))
return
}
// Start a new server handler if none exists.
_, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer {
return newMuxStream(ctx, m, c, *handler)
})
}
}
// handleConnectMuxError when mux connect was rejected.
func (c *Connection) handleConnectMuxError(ctx context.Context, m message) {
if v, ok := c.outgoing.Load(m.MuxID); ok {
var cErr muxConnectError
_, err := cErr.UnmarshalMsg(m.Payload)
logger.LogIf(ctx, err)
v.error(RemoteErr(cErr.Error))
return
}
PutByteBuffer(m.Payload)
}
func (c *Connection) handleAckMux(ctx context.Context, m message) {
PutByteBuffer(m.Payload)
v, ok := c.outgoing.Load(m.MuxID)
if !ok {
if m.Flags&FlagEOF == 0 {
logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
}
return
}
if debugPrint {
fmt.Println(c.Local, "Mux", m.MuxID, "Acknowledged")
}
v.ack(m.Seq)
}
func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHandlerID) {
if !m.Handler.valid() {
logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"}))
return
}
if debugReqs {
fmt.Println(m.MuxID, c.StringReverse(), "INCOMING")
}
// Singleshot message
var handler SingleHandlerFn
if subID == nil {
handler = c.handlers.single[m.Handler]
} else {
handler = c.handlers.subSingle[*subID]
}
if handler == nil {
logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"}))
return
}
// TODO: This causes allocations, but escape analysis doesn't really show the cause.
// If another faithful engineer wants to take a stab, feel free.
go func(m message) {
var start time.Time
if m.DeadlineMS > 0 {
start = time.Now()
}
var b []byte
var err *RemoteErr
func() {
defer func() {
if rec := recover(); rec != nil {
err = NewRemoteErrString(fmt.Sprintf("handleMessages: panic recovered: %v", rec))
debug.PrintStack()
logger.LogIf(ctx, err)
}
}()
b, err = handler(m.Payload)
if debugPrint {
fmt.Println(c.Local, "Handler returned payload:", bytesOrLength(b), "err:", err)
}
}()
if m.DeadlineMS > 0 && time.Since(start).Milliseconds()+c.addDeadline.Milliseconds() > int64(m.DeadlineMS) {
if debugReqs {
fmt.Println(m.MuxID, c.StringReverse(), "DEADLINE EXCEEDED")
}
// No need to return result
PutByteBuffer(b)
return
}
if debugReqs {
fmt.Println(m.MuxID, c.StringReverse(), "RESPONDING")
}
m = message{
MuxID: m.MuxID,
Seq: m.Seq,
Op: OpResponse,
Flags: FlagEOF,
}
if err != nil {
m.Flags |= FlagPayloadIsErr
m.Payload = []byte(*err)
} else {
m.Payload = b
m.setZeroPayloadFlag()
}
logger.LogIf(ctx, c.queueMsg(m, nil))
}(m)
}
func (c *Connection) handlePong(ctx context.Context, m message) {
var pong pongMsg
_, err := pong.UnmarshalMsg(m.Payload)
PutByteBuffer(m.Payload)
logger.LogIf(ctx, err)
if m.MuxID == 0 {
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
return
}
if v, ok := c.outgoing.Load(m.MuxID); ok {
v.pong(pong)
} else {
// We don't care if the client was removed in the meantime,
// but we send a disconnect message to the server just in case.
logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
}
}
func (c *Connection) handlePing(ctx context.Context, m message) {
if m.MuxID == 0 {
logger.LogIf(ctx, c.queueMsg(m, &pongMsg{}))
return
}
// Single calls do not support pinging.
if v, ok := c.inStream.Load(m.MuxID); ok {
pong := v.ping(m.Seq)
logger.LogIf(ctx, c.queueMsg(m, &pong))
} else {
pong := pongMsg{NotFound: true}
logger.LogIf(ctx, c.queueMsg(m, &pong))
}
return
}
func (c *Connection) handleDisconnectClientMux(m message) {
if v, ok := c.outgoing.Load(m.MuxID); ok {
if m.Flags&FlagPayloadIsErr != 0 {
v.error(RemoteErr(m.Payload))
} else {
v.error("remote disconnected")
}
return
}
PutByteBuffer(m.Payload)
}
func (c *Connection) handleDisconnectServerMux(m message) {
if debugPrint {
fmt.Println(c.Local, "Disconnect server mux:", m.MuxID)
}
PutByteBuffer(m.Payload)
m.Payload = nil
if v, ok := c.inStream.Load(m.MuxID); ok {
v.close()
}
}
func (c *Connection) handleUnblockClMux(m message) {
PutByteBuffer(m.Payload)
m.Payload = nil
v, ok := c.outgoing.Load(m.MuxID)
if !ok {
if debugPrint {
fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID)
}
// We can expect to receive unblocks for closed muxes
return
}
v.unblockSend(m.Seq)
}
func (c *Connection) handleUnblockSrvMux(m message) {
if m.Payload != nil {
PutByteBuffer(m.Payload)
}
m.Payload = nil
if v, ok := c.inStream.Load(m.MuxID); ok {
v.unblockSend(m.Seq)
return
}
// We can expect to receive unblocks for closed muxes
if debugPrint {
fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID)
}
}
func (c *Connection) handleMuxClientMsg(ctx context.Context, m message) {
v, ok := c.inStream.Load(m.MuxID)
if !ok {
if debugPrint {
fmt.Println(c.Local, "OpMuxClientMsg: Unknown Mux:", m.MuxID)
}
logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
PutByteBuffer(m.Payload)
return
}
v.message(m)
}
func (c *Connection) handleResponse(m message) {
if debugPrint {
fmt.Printf("%s Got mux response: %v\n", c.Local, m)
}
v, ok := c.outgoing.Load(m.MuxID)
if !ok {
if debugReqs {
fmt.Println(m.MuxID, c.String(), "Got response for unknown mux")
}
PutByteBuffer(m.Payload)
return
}
if m.Flags&FlagPayloadIsErr != 0 {
v.response(m.Seq, Response{
Msg: nil,
Err: RemoteErr(m.Payload),
})
PutByteBuffer(m.Payload)
} else {
v.response(m.Seq, Response{
Msg: m.Payload,
Err: nil,
})
}
v.close()
if debugReqs {
fmt.Println(m.MuxID, c.String(), "handleResponse: closing mux")
}
}
func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) {
if debugPrint {
fmt.Printf("%s Got mux msg: %v\n", c.Local, m)
}
v, ok := c.outgoing.Load(m.MuxID)
if !ok {
if m.Flags&FlagEOF == 0 {
logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
}
PutByteBuffer(m.Payload)
return
}
if m.Flags&FlagPayloadIsErr != 0 {
v.response(m.Seq, Response{
Msg: nil,
Err: RemoteErr(m.Payload),
})
PutByteBuffer(m.Payload)
} else if m.Payload != nil {
v.response(m.Seq, Response{
Msg: m.Payload,
Err: nil,
})
}
if m.Flags&FlagEOF != 0 {
v.close()
if debugReqs {
fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX")
}
c.outgoing.Delete(m.MuxID)
}
}
func (c *Connection) deleteMux(incoming bool, muxID uint64) {
if incoming {
if debugPrint {
fmt.Println("deleteMux: disconnect incoming mux", muxID)
}
v, loaded := c.inStream.LoadAndDelete(muxID)
if loaded && v != nil {
logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: muxID}, nil))
v.close()
}
} else {
if debugPrint {
fmt.Println("deleteMux: disconnect outgoing mux", muxID)
}
v, loaded := c.outgoing.LoadAndDelete(muxID)
if loaded && v != nil {
if debugReqs {
fmt.Println(muxID, c.String(), "deleteMux: DELETING MUX")
}
v.close()
logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectServerMux, MuxID: muxID}, nil))
}
}
}
// State returns the current connection status.
func (c *Connection) State() State {
return State(atomic.LoadUint32((*uint32)(&c.state)))
}
// Stats returns the current connection stats.
func (c *Connection) Stats() ConnectionStats {
return ConnectionStats{
IncomingStreams: c.inStream.Size(),
OutgoingStreams: c.outgoing.Size(),
}
}
func (c *Connection) debugMsg(d debugMsg, args ...any) {
if debugPrint {
fmt.Println("debug: sending message", d, args)
}
switch d {
case debugShutdown:
c.updateState(StateShutdown)
case debugKillInbound:
c.connMu.Lock()
defer c.connMu.Unlock()
if c.debugInConn != nil {
if debugPrint {
fmt.Println("debug: closing inbound connection")
}
c.debugInConn.Close()
}
case debugKillOutbound:
c.connMu.Lock()
defer c.connMu.Unlock()
if c.debugInConn != nil {
if debugPrint {
fmt.Println("debug: closing outgoing connection")
}
c.debugInConn.Close()
}
case debugWaitForExit:
c.connMu.Lock()
c.handleMsgWg.Wait()
c.connMu.Unlock()
case debugSetConnPingDuration:
c.connMu.Lock()
defer c.connMu.Unlock()
c.connPingInterval = args[0].(time.Duration)
case debugSetClientPingDuration:
c.clientPingInterval = args[0].(time.Duration)
case debugAddToDeadline:
c.addDeadline = args[0].(time.Duration)
}
}
// wsWriter writes websocket messages.
type wsWriter struct {
tmp [ws.MaxHeaderSize]byte
}
// writeMessage writes a message to w without allocations.
func (ww *wsWriter) writeMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error {
const fin = true
var frame ws.Frame
if s.ClientSide() {
// We do not need to copy the payload, since we own it.
payload := p
frame = ws.NewFrame(op, fin, payload)
frame = ws.MaskFrameInPlace(frame)
} else {
frame = ws.NewFrame(op, fin, p)
}
return ww.writeFrame(w, frame)
}
// writeFrame writes frame binary representation into w.
func (ww *wsWriter) writeFrame(w io.Writer, f ws.Frame) error {
const (
bit0 = 0x80
len7 = int64(125)
len16 = int64(^(uint16(0)))
len64 = int64(^(uint64(0)) >> 1)
)
bts := ww.tmp[:]
if f.Header.Fin {
bts[0] |= bit0
}
bts[0] |= f.Header.Rsv << 4
bts[0] |= byte(f.Header.OpCode)
var n int
switch {
case f.Header.Length <= len7:
bts[1] = byte(f.Header.Length)
n = 2
case f.Header.Length <= len16:
bts[1] = 126
binary.BigEndian.PutUint16(bts[2:4], uint16(f.Header.Length))
n = 4
case f.Header.Length <= len64:
bts[1] = 127
binary.BigEndian.PutUint64(bts[2:10], uint64(f.Header.Length))
n = 10
default:
return ws.ErrHeaderLengthUnexpected
}
if f.Header.Masked {
bts[1] |= bit0
n += copy(bts[n:], f.Header.Mask[:])
}
if _, err := w.Write(bts[:n]); err != nil {
return err
}
_, err := w.Write(f.Payload)
return err
}