diff --git a/internal/grid/connection.go b/internal/grid/connection.go index 72925c5d2..b5e43ad71 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -123,10 +123,11 @@ type Connection struct { baseFlags Flags // For testing only - debugInConn net.Conn - debugOutConn net.Conn - addDeadline time.Duration - connMu sync.Mutex + debugInConn net.Conn + debugOutConn net.Conn + blockMessages atomic.Pointer[<-chan struct{}] + addDeadline time.Duration + connMu sync.Mutex } // Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID. @@ -975,6 +976,11 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) return } + block := c.blockMessages.Load() + if block != nil && *block != nil { + <-*block + } + if c.incomingBytes != nil { c.incomingBytes(int64(len(msg))) } @@ -1363,6 +1369,10 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan } func (c *Connection) handlePong(ctx context.Context, m message) { + if m.MuxID == 0 && m.Payload == nil { + atomic.StoreInt64(&c.LastPong, time.Now().Unix()) + return + } var pong pongMsg _, err := pong.UnmarshalMsg(m.Payload) PutByteBuffer(m.Payload) @@ -1382,7 +1392,9 @@ func (c *Connection) handlePong(ctx context.Context, m message) { func (c *Connection) handlePing(ctx context.Context, m message) { if m.MuxID == 0 { - gridLogIf(ctx, c.queueMsg(m, &pongMsg{})) + m.Flags.Clear(FlagPayloadIsZero) + m.Op = OpPong + gridLogIf(ctx, c.queueMsg(m, nil)) return } // Single calls do not support pinging. @@ -1599,7 +1611,12 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { c.connMu.Lock() defer c.connMu.Unlock() c.connPingInterval = args[0].(time.Duration) + if c.connPingInterval < time.Second { + panic("CONN ping interval too low") + } case debugSetClientPingDuration: + c.connMu.Lock() + defer c.connMu.Unlock() c.clientPingInterval = args[0].(time.Duration) case debugAddToDeadline: c.addDeadline = args[0].(time.Duration) @@ -1615,6 +1632,11 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { mid.respMu.Lock() resp(mid.closed) mid.respMu.Unlock() + case debugBlockInboundMessages: + c.connMu.Lock() + block := (<-chan struct{})(args[0].(chan struct{})) + c.blockMessages.Store(&block) + c.connMu.Unlock() } } diff --git a/internal/grid/debug.go b/internal/grid/debug.go index 0172f87e2..8110acb65 100644 --- a/internal/grid/debug.go +++ b/internal/grid/debug.go @@ -50,6 +50,7 @@ const ( debugSetClientPingDuration debugAddToDeadline debugIsOutgoingClosed + debugBlockInboundMessages ) // TestGrid contains a grid of servers for testing purposes. diff --git a/internal/grid/debugmsg_string.go b/internal/grid/debugmsg_string.go index a84f811b5..52c92cb4b 100644 --- a/internal/grid/debugmsg_string.go +++ b/internal/grid/debugmsg_string.go @@ -16,11 +16,12 @@ func _() { _ = x[debugSetClientPingDuration-5] _ = x[debugAddToDeadline-6] _ = x[debugIsOutgoingClosed-7] + _ = x[debugBlockInboundMessages-8] } -const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed" +const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages" -var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151} +var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176} func (i debugMsg) String() string { if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) { diff --git a/internal/grid/grid_test.go b/internal/grid/grid_test.go index fcf325168..f488b5947 100644 --- a/internal/grid/grid_test.go +++ b/internal/grid/grid_test.go @@ -378,6 +378,54 @@ func TestStreamSuite(t *testing.T) { assertNoActive(t, connRemoteLocal) assertNoActive(t, connLocalToRemote) }) + t.Run("testServerStreamOnewayNoPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamNoPing(t, local, remote, 0) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayNoPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamNoPing(t, local, remote, 1) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, false, false) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPingReq", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, false, true) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPingResp", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, true, false) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, true, true) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamOnewayPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 0, false, true) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 0, false, false) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) } func testStreamRoundtrip(t *testing.T, local, remote *Manager) { @@ -491,12 +539,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { register(remote) // local to remote - testHandler := func(t *testing.T, handler HandlerID) { + testHandler := func(t *testing.T, handler HandlerID, sendReq bool) { remoteConn := local.Connection(remoteHost) const testPayload = "Hello Grid World!" ctx, cancel := context.WithCancel(context.Background()) - st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload)) errFatal(err) clientCanceled := make(chan time.Time, 1) err = nil @@ -513,6 +561,18 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { clientCanceled <- time.Now() }(t) start := time.Now() + if st.Requests != nil { + defer close(st.Requests) + } + // Fill up queue. + for sendReq { + select { + case st.Requests <- []byte("Hello"): + time.Sleep(10 * time.Millisecond) + default: + sendReq = false + } + } cancel() <-serverCanceled t.Log("server cancel time:", time.Since(start)) @@ -524,11 +584,13 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { } // local to remote, unbuffered t.Run("unbuffered", func(t *testing.T) { - testHandler(t, handlerTest) + testHandler(t, handlerTest, false) }) - t.Run("buffered", func(t *testing.T) { - testHandler(t, handlerTest2) + testHandler(t, handlerTest2, false) + }) + t.Run("buffered", func(t *testing.T) { + testHandler(t, handlerTest2, true) }) } @@ -1025,6 +1087,167 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) { } } +// testServerStreamNoPing will test if server and client handle no pings. +func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + reqStarted := make(chan struct{}) + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr { + close(reqStarted) + // Just wait for it to cancel. + <-ctx.Done() + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + }, + OutCapacity: 1, + InCapacity: inCap, + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server start the request. + <-reqStarted + + // Stop processing requests + nowBlocking := make(chan struct{}) + remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking) + + // Check that local returned. + err = st.Results(func(b []byte) error { + return nil + }) + if err == nil { + t.Fatal("expected error, got nil") + } + t.Logf("error: %v", err) + // Check that remote is canceled. + <-serverCanceled + close(nowBlocking) +} + +// testServerStreamNoPing will test if server and client handle ping even when blocked. +func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + reqStarted := make(chan struct{}) + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr { + close(reqStarted) + // Just wait for it to cancel. + for blockResp { + select { + case <-ctx.Done(): + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + case resp <- []byte{1}: + time.Sleep(10 * time.Millisecond) + } + } + // Just wait for it to cancel. + <-ctx.Done() + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + }, + OutCapacity: 1, + InCapacity: inCap, + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server start the request. + <-reqStarted + + // Block until we have exceeded the deadline several times over. + nowBlocking := make(chan struct{}) + time.AfterFunc(time.Second, func() { + cancel() + close(nowBlocking) + }) + if inCap > 0 { + go func() { + defer close(st.Requests) + if !blockReq { + <-nowBlocking + return + } + for { + select { + case <-nowBlocking: + return + case <-st.Done(): + case st.Requests <- []byte{1}: + time.Sleep(10 * time.Millisecond) + } + } + }() + } + // Check that local returned. + err = st.Results(func(b []byte) error { + select { + case <-nowBlocking: + case <-st.Done(): + return ctx.Err() + } + return nil + }) + select { + case <-nowBlocking: + default: + t.Fatal("expected to be blocked. got err", err) + } + if err == nil { + t.Fatal("expected error, got nil") + } + t.Logf("error: %v", err) + // Check that remote is canceled. + <-serverCanceled +} + func timeout(after time.Duration) (cancel func()) { c := time.After(after) cc := make(chan struct{}) diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index dd5331a31..5ec8fb347 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -32,24 +32,25 @@ import ( // muxClient is a stateful connection to a remote. type muxClient struct { - MuxID uint64 - SendSeq, RecvSeq uint32 - LastPong int64 - BaseFlags Flags - ctx context.Context - cancelFn context.CancelCauseFunc - parent *Connection - respWait chan<- Response - respMu sync.Mutex - singleResp bool - closed bool - stateless bool - acked bool - init bool - deadline time.Duration - outBlock chan struct{} - subroute *subHandlerID - respErr atomic.Pointer[error] + MuxID uint64 + SendSeq, RecvSeq uint32 + LastPong int64 + BaseFlags Flags + ctx context.Context + cancelFn context.CancelCauseFunc + parent *Connection + respWait chan<- Response + respMu sync.Mutex + singleResp bool + closed bool + stateless bool + acked bool + init bool + deadline time.Duration + outBlock chan struct{} + subroute *subHandlerID + respErr atomic.Pointer[error] + clientPingInterval time.Duration } // Response is a response from the server. @@ -61,12 +62,13 @@ type Response struct { func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient { ctx, cancelFn := context.WithCancelCause(ctx) return &muxClient{ - MuxID: muxID, - ctx: ctx, - cancelFn: cancelFn, - parent: parent, - LastPong: time.Now().Unix(), - BaseFlags: parent.baseFlags, + MuxID: muxID, + ctx: ctx, + cancelFn: cancelFn, + parent: parent, + LastPong: time.Now().UnixNano(), + BaseFlags: parent.baseFlags, + clientPingInterval: parent.clientPingInterval, } } @@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer < } }() var pingTimer <-chan time.Time - if m.deadline == 0 || m.deadline > clientPingInterval { - ticker := time.NewTicker(clientPingInterval) + if m.deadline == 0 || m.deadline > m.clientPingInterval { + ticker := time.NewTicker(m.clientPingInterval) defer ticker.Stop() pingTimer = ticker.C - atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } defer m.parent.deleteMux(false, m.MuxID) for { @@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { } // Only check ping when not closed. - if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 { + if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 { m.respMu.Unlock() if debugPrint { fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got) @@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { // responseCh is the channel to that goes to the requester. // internalResp is the channel that comes from the server. func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { - defer m.parent.deleteMux(false, m.MuxID) - defer xioutil.SafeClose(responseCh) + defer func() { + m.parent.deleteMux(false, m.MuxID) + // addErrorNonBlockingClose will close the response channel. + xioutil.SafeClose(responseCh) + }() + + // Cancelation and errors are handled by handleTwowayRequests below. for resp := range internalResp { - responseCh <- resp m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) + responseCh <- resp } } -func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) { +func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) { var errState bool if debugPrint { start := time.Now() @@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests }() } + var pingTimer <-chan time.Time + if m.deadline == 0 || m.deadline > m.clientPingInterval { + ticker := time.NewTicker(m.clientPingInterval) + defer ticker.Stop() + pingTimer = ticker.C + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) + } + // Listen for client messages. - for { - if errState { - go func() { - // Drain requests. - for range requests { - } - }() - return - } +reqLoop: + for !errState { select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) + m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) errState = true continue + case <-pingTimer: + if !m.doPing(errResp) { + errState = true + continue + } case req, ok := <-requests: if !ok { // Done send EOF @@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests msg := message{ Op: OpMuxClientMsg, MuxID: m.MuxID, - Seq: 1, Flags: FlagEOF, } msg.setZeroPayloadFlag() err := m.send(msg) if err != nil { - m.addErrorNonBlockingClose(internalResp, err) + m.addErrorNonBlockingClose(errResp, err) } - return + break reqLoop } // Grab a send token. + sendReq: select { case <-m.ctx.Done(): - m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) + m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) errState = true continue + case <-pingTimer: + if !m.doPing(errResp) { + errState = true + continue + } + goto sendReq case <-m.outBlock: } msg := message{ @@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests err := m.send(msg) PutByteBuffer(req) if err != nil { - m.addErrorNonBlockingClose(internalResp, err) + m.addErrorNonBlockingClose(errResp, err) errState = true continue } msg.Seq++ } } + + if errState { + // Drain requests. + for { + select { + case r, ok := <-requests: + if !ok { + return + } + PutByteBuffer(r) + default: + return + } + } + } + + for !errState { + select { + case <-m.ctx.Done(): + if debugPrint { + fmt.Println("Client sending disconnect to mux", m.MuxID) + } + m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) + return + case <-pingTimer: + errState = !m.doPing(errResp) + } + } } // checkSeq will check if sequence number is correct and increment it by 1. @@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) { m.addResponse(r) return } - atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) ok := m.addResponse(r) if !ok { PutByteBuffer(r.Msg) @@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) { m.addResponse(Response{Err: err}) return } - atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } // addResponse will add a response to the response channel. diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index e9d1db659..8a835a16f 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -28,21 +28,20 @@ import ( xioutil "github.com/minio/minio/internal/ioutil" ) -const lastPingThreshold = 4 * clientPingInterval - type muxServer struct { - ID uint64 - LastPing int64 - SendSeq, RecvSeq uint32 - Resp chan []byte - BaseFlags Flags - ctx context.Context - cancel context.CancelFunc - inbound chan []byte - parent *Connection - sendMu sync.Mutex - recvMu sync.Mutex - outBlock chan struct{} + ID uint64 + LastPing int64 + SendSeq, RecvSeq uint32 + Resp chan []byte + BaseFlags Flags + ctx context.Context + cancel context.CancelFunc + inbound chan []byte + parent *Connection + sendMu sync.Mutex + recvMu sync.Mutex + outBlock chan struct{} + clientPingInterval time.Duration } func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { @@ -89,16 +88,17 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea } m := muxServer{ - ID: msg.MuxID, - RecvSeq: msg.Seq + 1, - SendSeq: msg.Seq, - ctx: ctx, - cancel: cancel, - parent: c, - inbound: nil, - outBlock: make(chan struct{}, outboundCap), - LastPing: time.Now().Unix(), - BaseFlags: c.baseFlags, + ID: msg.MuxID, + RecvSeq: msg.Seq + 1, + SendSeq: msg.Seq, + ctx: ctx, + cancel: cancel, + parent: c, + inbound: nil, + outBlock: make(chan struct{}, outboundCap), + LastPing: time.Now().Unix(), + BaseFlags: c.baseFlags, + clientPingInterval: c.clientPingInterval, } // Acknowledge Mux created. // Send async. @@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea }(m.outBlock) // Remote aliveness check if needed. - if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { + if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) { go func() { wg.Wait() m.checkRemoteAlive() @@ -234,7 +234,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- // checkRemoteAlive will check if the remote is alive. func (m *muxServer) checkRemoteAlive() { - t := time.NewTicker(lastPingThreshold / 4) + t := time.NewTicker(m.clientPingInterval) defer t.Stop() for { select { @@ -242,7 +242,7 @@ func (m *muxServer) checkRemoteAlive() { return case <-t.C: last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) - if last > lastPingThreshold { + if last > 4*m.clientPingInterval { gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) m.close() return diff --git a/internal/grid/stream.go b/internal/grid/stream.go index 29e0cebf4..cbc69c1ba 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -89,12 +89,26 @@ func (s *Stream) Results(next func(b []byte) error) (err error) { return nil } if resp.Err != nil { + s.cancel(resp.Err) return resp.Err } err = next(resp.Msg) if err != nil { + s.cancel(err) return err } } } } + +// Done will return a channel that will be closed when the stream is done. +// This mirrors context.Done(). +func (s *Stream) Done() <-chan struct{} { + return s.ctx.Done() +} + +// Err will return the error that caused the stream to end. +// This mirrors context.Err(). +func (s *Stream) Err() error { + return s.ctx.Err() +}