// Copyright (c) 2015-2023 MinIO, Inc. // // This file is part of MinIO Object Storage stack // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package grid import ( "context" "encoding/binary" "errors" "fmt" "sync" "sync/atomic" "time" xioutil "github.com/minio/minio/internal/ioutil" "github.com/zeebo/xxh3" ) // muxClient is a stateful connection to a remote. type muxClient struct { MuxID uint64 SendSeq, RecvSeq uint32 LastPong int64 BaseFlags Flags ctx context.Context cancelFn context.CancelCauseFunc parent *Connection respWait chan<- Response respMu sync.Mutex singleResp bool closed bool stateless bool acked bool init bool deadline time.Duration outBlock chan struct{} subroute *subHandlerID respErr atomic.Pointer[error] clientPingInterval time.Duration } // Response is a response from the server. type Response struct { Msg []byte Err error } func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient { ctx, cancelFn := context.WithCancelCause(ctx) return &muxClient{ MuxID: muxID, ctx: ctx, cancelFn: cancelFn, parent: parent, LastPong: time.Now().UnixNano(), BaseFlags: parent.baseFlags, clientPingInterval: parent.clientPingInterval, } } // roundtrip performs a roundtrip, returning the first response. // This cannot be used concurrently. func (m *muxClient) roundtrip(h HandlerID, req []byte) ([]byte, error) { if m.init { return nil, errors.New("mux client already used") } m.init = true m.singleResp = true msg := message{ Op: OpRequest, MuxID: m.MuxID, Handler: h, Flags: m.BaseFlags | FlagEOF, Payload: req, DeadlineMS: uint32(m.deadline.Milliseconds()), } if m.subroute != nil { msg.Flags |= FlagSubroute } ch := make(chan Response, 1) m.respMu.Lock() if m.closed { m.respMu.Unlock() return nil, ErrDisconnected } m.respWait = ch m.respMu.Unlock() ctx := m.ctx // Add deadline if none. if msg.DeadlineMS == 0 { msg.DeadlineMS = uint32(defaultSingleRequestTimeout / time.Millisecond) var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, defaultSingleRequestTimeout) defer cancel() } // Send request if err := m.send(msg); err != nil { return nil, err } if debugReqs { fmt.Println(m.MuxID, m.parent.String(), "SEND") } // Wait for response or context. select { case v, ok := <-ch: if !ok { return nil, ErrDisconnected } if debugReqs && v.Err != nil { v.Err = fmt.Errorf("%d %s RESP ERR: %w", m.MuxID, m.parent.String(), v.Err) } return v.Msg, v.Err case <-ctx.Done(): if debugReqs { return nil, fmt.Errorf("%d %s ERR: %w", m.MuxID, m.parent.String(), context.Cause(ctx)) } return nil, context.Cause(ctx) } } // send the message. msg.Seq and msg.MuxID will be set func (m *muxClient) send(msg message) error { m.respMu.Lock() defer m.respMu.Unlock() if m.closed { return errors.New("mux client closed") } return m.sendLocked(msg) } // sendLocked the message. msg.Seq and msg.MuxID will be set. // m.respMu must be held. func (m *muxClient) sendLocked(msg message) error { dst := GetByteBufferCap(msg.Msgsize()) msg.Seq = m.SendSeq msg.MuxID = m.MuxID msg.Flags |= m.BaseFlags if debugPrint { fmt.Println("Client sending", &msg, "to", m.parent.Remote) } m.SendSeq++ dst, err := msg.MarshalMsg(dst) if err != nil { return err } if msg.Flags&FlagSubroute != 0 { if m.subroute == nil { return fmt.Errorf("internal error: subroute not defined on client") } hid := m.subroute.withHandler(msg.Handler) before := len(dst) dst = append(dst, hid[:]...) if debugPrint { fmt.Println("Added subroute", hid.String(), "to message", msg, "len", len(dst)-before) } } if msg.Flags&FlagCRCxxh3 != 0 { h := xxh3.Hash(dst) dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) } return m.parent.send(m.ctx, dst) } // RequestStateless will send a single payload request and stream back results. // req may not be read/written to after calling. // TODO: Not implemented func (m *muxClient) RequestStateless(h HandlerID, req []byte, out chan<- Response) { if m.init { out <- Response{Err: errors.New("mux client already used")} } m.init = true // Try to grab an initial block. m.singleResp = false msg := message{ Op: OpConnectMux, Handler: h, Flags: FlagEOF, Payload: req, DeadlineMS: uint32(m.deadline.Milliseconds()), } msg.setZeroPayloadFlag() if m.subroute != nil { msg.Flags |= FlagSubroute } // Send... err := m.send(msg) if err != nil { out <- Response{Err: err} return } // Route directly to output. m.respWait = out } // RequestStream will send a single payload request and stream back results. // 'requests' can be nil, in which case only req is sent as input. // It will however take less resources. func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []byte, responses chan Response) (*Stream, error) { if m.init { return nil, errors.New("mux client already used") } if responses == nil { return nil, errors.New("RequestStream: responses channel is nil") } m.init = true m.respMu.Lock() if m.closed { m.respMu.Unlock() return nil, ErrDisconnected } m.respWait = responses // Route directly to output. m.respMu.Unlock() // Try to grab an initial block. m.singleResp = false m.RecvSeq = m.SendSeq // Sync if cap(requests) > 0 { m.outBlock = make(chan struct{}, cap(requests)) } msg := message{ Op: OpConnectMux, Handler: h, Payload: payload, DeadlineMS: uint32(m.deadline.Milliseconds()), } msg.setZeroPayloadFlag() if requests == nil { msg.Flags |= FlagEOF } if m.subroute != nil { msg.Flags |= FlagSubroute } // Send... err := m.send(msg) if err != nil { return nil, err } if debugPrint { fmt.Println("Connecting Mux", m.MuxID, ",to", m.parent.Remote) } // Space for one message and an error. responseCh := make(chan Response, 1) // Spawn simple disconnect if requests == nil { go m.handleOneWayStream(responseCh, responses) return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil } // Deliver responses and send unblocks back to the server. go m.handleTwowayResponses(responseCh, responses) go m.handleTwowayRequests(responses, requests) return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil } func (m *muxClient) addErrorNonBlockingClose(respHandler chan<- Response, err error) { m.respMu.Lock() defer m.respMu.Unlock() if !m.closed { m.respErr.Store(&err) // Do not block. select { case respHandler <- Response{Err: err}: xioutil.SafeClose(respHandler) default: go func() { respHandler <- Response{Err: err} xioutil.SafeClose(respHandler) }() } gridLogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) m.closed = true } } // respHandler func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <-chan Response) { if debugPrint { start := time.Now() defer func() { fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) }() } defer func() { // addErrorNonBlockingClose will close the response channel // - maybe async, so we shouldn't do it here. if m.respErr.Load() == nil { xioutil.SafeClose(respHandler) } }() var pingTimer <-chan time.Time if m.deadline == 0 || m.deadline > m.clientPingInterval { ticker := time.NewTicker(m.clientPingInterval) defer ticker.Stop() pingTimer = ticker.C atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } defer m.parent.deleteMux(false, m.MuxID) for { select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } err := context.Cause(m.ctx) if !errors.Is(err, errStreamEOF) { m.addErrorNonBlockingClose(respHandler, err) } return case resp, ok := <-respServer: if !ok { return } sendResp: select { case respHandler <- resp: m.respMu.Lock() if !m.closed { gridLogIf(m.ctx, m.sendLocked(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})) } m.respMu.Unlock() case <-m.ctx.Done(): // Client canceled. Don't block. // Next loop will catch it. case <-pingTimer: if !m.doPing(respHandler) { return } goto sendResp } case <-pingTimer: if !m.doPing(respHandler) { return } } } } // doPing checks last ping time and sends another ping. func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { m.respMu.Lock() if m.closed { m.respMu.Unlock() // Already closed. This is not an error state; // we may just be delivering the last responses. return true } // Only check ping when not closed. if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 { m.respMu.Unlock() if debugPrint { fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got) } m.addErrorNonBlockingClose(respHandler, ErrDisconnected) return false } // Send new ping err := m.sendLocked(message{Op: OpPing, MuxID: m.MuxID}) m.respMu.Unlock() if err != nil { m.addErrorNonBlockingClose(respHandler, err) } return err == nil } // responseCh is the channel to that goes to the requester. // internalResp is the channel that comes from the server. func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { defer func() { m.parent.deleteMux(false, m.MuxID) // addErrorNonBlockingClose will close the response channel. xioutil.SafeClose(responseCh) }() // Cancelation and errors are handled by handleTwowayRequests below. for resp := range internalResp { m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) responseCh <- resp } } func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) { var errState bool if debugPrint { start := time.Now() defer func() { fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) }() } var pingTimer <-chan time.Time if m.deadline == 0 || m.deadline > m.clientPingInterval { ticker := time.NewTicker(m.clientPingInterval) defer ticker.Stop() pingTimer = ticker.C atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } // Listen for client messages. reqLoop: for !errState { select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) errState = true continue case <-pingTimer: if !m.doPing(errResp) { errState = true continue } case req, ok := <-requests: if !ok { // Done send EOF if debugPrint { fmt.Println("Client done, sending EOF to mux", m.MuxID) } msg := message{ Op: OpMuxClientMsg, MuxID: m.MuxID, Flags: FlagEOF, } msg.setZeroPayloadFlag() err := m.send(msg) if err != nil { m.addErrorNonBlockingClose(errResp, err) } break reqLoop } // Grab a send token. sendReq: select { case <-m.ctx.Done(): m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) errState = true continue case <-pingTimer: if !m.doPing(errResp) { errState = true continue } goto sendReq case <-m.outBlock: } msg := message{ Op: OpMuxClientMsg, MuxID: m.MuxID, Seq: 1, Payload: req, } msg.setZeroPayloadFlag() err := m.send(msg) PutByteBuffer(req) if err != nil { m.addErrorNonBlockingClose(errResp, err) errState = true continue } msg.Seq++ } } if errState { // Drain requests. for { select { case r, ok := <-requests: if !ok { return } PutByteBuffer(r) default: return } } } for !errState { select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) return case <-pingTimer: errState = !m.doPing(errResp) } } } // checkSeq will check if sequence number is correct and increment it by 1. func (m *muxClient) checkSeq(seq uint32) (ok bool) { if seq != m.RecvSeq { if debugPrint { fmt.Printf("MuxID: %d client, expected sequence %d, got %d\n", m.MuxID, m.RecvSeq, seq) } m.addResponse(Response{Err: ErrIncorrectSequence}) return false } m.RecvSeq++ return true } // response will send handleIncoming response to client. // may never block. // Should return whether the next call would block. func (m *muxClient) response(seq uint32, r Response) { if debugReqs { fmt.Println(m.MuxID, m.parent.String(), "RESP") } if debugPrint { fmt.Printf("mux %d: got msg seqid %d, payload length: %d, err:%v\n", m.MuxID, seq, len(r.Msg), r.Err) } if !m.checkSeq(seq) { if debugReqs { fmt.Println(m.MuxID, m.parent.String(), "CHECKSEQ FAIL", m.RecvSeq, seq) } PutByteBuffer(r.Msg) r.Msg = nil r.Err = ErrIncorrectSequence m.addResponse(r) return } atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) ok := m.addResponse(r) if !ok { PutByteBuffer(r.Msg) } } var errStreamEOF = errors.New("stream EOF") // error is a message from the server to disconnect. func (m *muxClient) error(err RemoteErr) { if debugPrint { fmt.Printf("mux %d: got remote err:%v\n", m.MuxID, string(err)) } m.addResponse(Response{Err: &err}) } func (m *muxClient) ack(seq uint32) { if !m.checkSeq(seq) { return } if m.acked || m.outBlock == nil { return } available := cap(m.outBlock) for i := 0; i < available; i++ { m.outBlock <- struct{}{} } m.acked = true } func (m *muxClient) unblockSend(seq uint32) { if !m.checkSeq(seq) { return } select { case m.outBlock <- struct{}{}: default: gridLogIf(m.ctx, errors.New("output unblocked overflow")) } } func (m *muxClient) pong(msg pongMsg) { if msg.NotFound || msg.Err != nil { err := errors.New("remote terminated call") if msg.Err != nil { err = fmt.Errorf("remove pong failed: %v", &msg.Err) } m.addResponse(Response{Err: err}) return } atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } // addResponse will add a response to the response channel. // This function will never block func (m *muxClient) addResponse(r Response) (ok bool) { m.respMu.Lock() defer m.respMu.Unlock() if m.closed { return false } select { case m.respWait <- r: if r.Err != nil { if debugPrint { fmt.Println("Closing mux", m.MuxID, "due to error:", r.Err) } m.closeLocked() } return true default: if m.stateless { // Drop message if not stateful. return } err := errors.New("INTERNAL ERROR: Response was blocked") gridLogIf(m.ctx, err) m.closeLocked() return false } } func (m *muxClient) close() { if debugPrint { fmt.Println("closing outgoing mux", m.MuxID) } if !m.respMu.TryLock() { // Cancel before locking - will unblock any pending sends. if m.cancelFn != nil { m.cancelFn(context.Canceled) } // Wait for senders to release. m.respMu.Lock() } defer m.respMu.Unlock() m.closeLocked() } func (m *muxClient) closeLocked() { if m.closed { return } // We hold the lock, so nobody can modify m.respWait while we're closing. if m.respWait != nil { xioutil.SafeClose(m.respWait) m.respWait = nil } m.closed = true }