diff --git a/internal/grid/connection.go b/internal/grid/connection.go index b5e43ad71..72925c5d2 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -123,11 +123,10 @@ type Connection struct { baseFlags Flags // For testing only - debugInConn net.Conn - debugOutConn net.Conn - blockMessages atomic.Pointer[<-chan struct{}] - addDeadline time.Duration - connMu sync.Mutex + debugInConn net.Conn + debugOutConn net.Conn + addDeadline time.Duration + connMu sync.Mutex } // Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID. @@ -976,11 +975,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) return } - block := c.blockMessages.Load() - if block != nil && *block != nil { - <-*block - } - if c.incomingBytes != nil { c.incomingBytes(int64(len(msg))) } @@ -1369,10 +1363,6 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan } func (c *Connection) handlePong(ctx context.Context, m message) { - if m.MuxID == 0 && m.Payload == nil { - atomic.StoreInt64(&c.LastPong, time.Now().Unix()) - return - } var pong pongMsg _, err := pong.UnmarshalMsg(m.Payload) PutByteBuffer(m.Payload) @@ -1392,9 +1382,7 @@ func (c *Connection) handlePong(ctx context.Context, m message) { func (c *Connection) handlePing(ctx context.Context, m message) { if m.MuxID == 0 { - m.Flags.Clear(FlagPayloadIsZero) - m.Op = OpPong - gridLogIf(ctx, c.queueMsg(m, nil)) + gridLogIf(ctx, c.queueMsg(m, &pongMsg{})) return } // Single calls do not support pinging. @@ -1611,12 +1599,7 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { c.connMu.Lock() defer c.connMu.Unlock() c.connPingInterval = args[0].(time.Duration) - if c.connPingInterval < time.Second { - panic("CONN ping interval too low") - } case debugSetClientPingDuration: - c.connMu.Lock() - defer c.connMu.Unlock() c.clientPingInterval = args[0].(time.Duration) case debugAddToDeadline: c.addDeadline = args[0].(time.Duration) @@ -1632,11 +1615,6 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { mid.respMu.Lock() resp(mid.closed) mid.respMu.Unlock() - case debugBlockInboundMessages: - c.connMu.Lock() - block := (<-chan struct{})(args[0].(chan struct{})) - c.blockMessages.Store(&block) - c.connMu.Unlock() } } diff --git a/internal/grid/debug.go b/internal/grid/debug.go index 8110acb65..0172f87e2 100644 --- a/internal/grid/debug.go +++ b/internal/grid/debug.go @@ -50,7 +50,6 @@ const ( debugSetClientPingDuration debugAddToDeadline debugIsOutgoingClosed - debugBlockInboundMessages ) // TestGrid contains a grid of servers for testing purposes. diff --git a/internal/grid/debugmsg_string.go b/internal/grid/debugmsg_string.go index 52c92cb4b..a84f811b5 100644 --- a/internal/grid/debugmsg_string.go +++ b/internal/grid/debugmsg_string.go @@ -16,12 +16,11 @@ func _() { _ = x[debugSetClientPingDuration-5] _ = x[debugAddToDeadline-6] _ = x[debugIsOutgoingClosed-7] - _ = x[debugBlockInboundMessages-8] } -const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages" +const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed" -var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176} +var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151} func (i debugMsg) String() string { if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) { diff --git a/internal/grid/grid_test.go b/internal/grid/grid_test.go index f488b5947..fcf325168 100644 --- a/internal/grid/grid_test.go +++ b/internal/grid/grid_test.go @@ -378,54 +378,6 @@ func TestStreamSuite(t *testing.T) { assertNoActive(t, connRemoteLocal) assertNoActive(t, connLocalToRemote) }) - t.Run("testServerStreamOnewayNoPing", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamNoPing(t, local, remote, 0) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamTwowayNoPing", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamNoPing(t, local, remote, 1) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamTwowayPing", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamPingRunning(t, local, remote, 1, false, false) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamTwowayPingReq", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamPingRunning(t, local, remote, 1, false, true) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamTwowayPingResp", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamPingRunning(t, local, remote, 1, true, false) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamPingRunning(t, local, remote, 1, true, true) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamOnewayPing", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamPingRunning(t, local, remote, 0, false, true) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) - t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) { - defer timeout(1 * time.Minute)() - testServerStreamPingRunning(t, local, remote, 0, false, false) - assertNoActive(t, connRemoteLocal) - assertNoActive(t, connLocalToRemote) - }) } func testStreamRoundtrip(t *testing.T, local, remote *Manager) { @@ -539,12 +491,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { register(remote) // local to remote - testHandler := func(t *testing.T, handler HandlerID, sendReq bool) { + testHandler := func(t *testing.T, handler HandlerID) { remoteConn := local.Connection(remoteHost) const testPayload = "Hello Grid World!" ctx, cancel := context.WithCancel(context.Background()) - st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload)) + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) errFatal(err) clientCanceled := make(chan time.Time, 1) err = nil @@ -561,18 +513,6 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { clientCanceled <- time.Now() }(t) start := time.Now() - if st.Requests != nil { - defer close(st.Requests) - } - // Fill up queue. - for sendReq { - select { - case st.Requests <- []byte("Hello"): - time.Sleep(10 * time.Millisecond) - default: - sendReq = false - } - } cancel() <-serverCanceled t.Log("server cancel time:", time.Since(start)) @@ -584,13 +524,11 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { } // local to remote, unbuffered t.Run("unbuffered", func(t *testing.T) { - testHandler(t, handlerTest, false) + testHandler(t, handlerTest) }) + t.Run("buffered", func(t *testing.T) { - testHandler(t, handlerTest2, false) - }) - t.Run("buffered", func(t *testing.T) { - testHandler(t, handlerTest2, true) + testHandler(t, handlerTest2) }) } @@ -1087,167 +1025,6 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) { } } -// testServerStreamNoPing will test if server and client handle no pings. -func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) { - defer testlogger.T.SetErrorTB(t)() - errFatal := func(err error) { - t.Helper() - if err != nil { - t.Fatal(err) - } - } - - // We fake a local and remote server. - remoteHost := remote.HostName() - - // 1: Echo - reqStarted := make(chan struct{}) - serverCanceled := make(chan struct{}) - register := func(manager *Manager) { - errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ - Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr { - close(reqStarted) - // Just wait for it to cancel. - <-ctx.Done() - close(serverCanceled) - return NewRemoteErr(ctx.Err()) - }, - OutCapacity: 1, - InCapacity: inCap, - })) - } - register(local) - register(remote) - - remoteConn := local.Connection(remoteHost) - const testPayload = "Hello Grid World!" - remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) - errFatal(err) - - // Wait for the server start the request. - <-reqStarted - - // Stop processing requests - nowBlocking := make(chan struct{}) - remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking) - - // Check that local returned. - err = st.Results(func(b []byte) error { - return nil - }) - if err == nil { - t.Fatal("expected error, got nil") - } - t.Logf("error: %v", err) - // Check that remote is canceled. - <-serverCanceled - close(nowBlocking) -} - -// testServerStreamNoPing will test if server and client handle ping even when blocked. -func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) { - defer testlogger.T.SetErrorTB(t)() - errFatal := func(err error) { - t.Helper() - if err != nil { - t.Fatal(err) - } - } - - // We fake a local and remote server. - remoteHost := remote.HostName() - - // 1: Echo - reqStarted := make(chan struct{}) - serverCanceled := make(chan struct{}) - register := func(manager *Manager) { - errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ - Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr { - close(reqStarted) - // Just wait for it to cancel. - for blockResp { - select { - case <-ctx.Done(): - close(serverCanceled) - return NewRemoteErr(ctx.Err()) - case resp <- []byte{1}: - time.Sleep(10 * time.Millisecond) - } - } - // Just wait for it to cancel. - <-ctx.Done() - close(serverCanceled) - return NewRemoteErr(ctx.Err()) - }, - OutCapacity: 1, - InCapacity: inCap, - })) - } - register(local) - register(remote) - - remoteConn := local.Connection(remoteHost) - const testPayload = "Hello Grid World!" - remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) - errFatal(err) - - // Wait for the server start the request. - <-reqStarted - - // Block until we have exceeded the deadline several times over. - nowBlocking := make(chan struct{}) - time.AfterFunc(time.Second, func() { - cancel() - close(nowBlocking) - }) - if inCap > 0 { - go func() { - defer close(st.Requests) - if !blockReq { - <-nowBlocking - return - } - for { - select { - case <-nowBlocking: - return - case <-st.Done(): - case st.Requests <- []byte{1}: - time.Sleep(10 * time.Millisecond) - } - } - }() - } - // Check that local returned. - err = st.Results(func(b []byte) error { - select { - case <-nowBlocking: - case <-st.Done(): - return ctx.Err() - } - return nil - }) - select { - case <-nowBlocking: - default: - t.Fatal("expected to be blocked. got err", err) - } - if err == nil { - t.Fatal("expected error, got nil") - } - t.Logf("error: %v", err) - // Check that remote is canceled. - <-serverCanceled -} - func timeout(after time.Duration) (cancel func()) { c := time.After(after) cc := make(chan struct{}) diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index 5ec8fb347..dd5331a31 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -32,25 +32,24 @@ import ( // muxClient is a stateful connection to a remote. type muxClient struct { - MuxID uint64 - SendSeq, RecvSeq uint32 - LastPong int64 - BaseFlags Flags - ctx context.Context - cancelFn context.CancelCauseFunc - parent *Connection - respWait chan<- Response - respMu sync.Mutex - singleResp bool - closed bool - stateless bool - acked bool - init bool - deadline time.Duration - outBlock chan struct{} - subroute *subHandlerID - respErr atomic.Pointer[error] - clientPingInterval time.Duration + MuxID uint64 + SendSeq, RecvSeq uint32 + LastPong int64 + BaseFlags Flags + ctx context.Context + cancelFn context.CancelCauseFunc + parent *Connection + respWait chan<- Response + respMu sync.Mutex + singleResp bool + closed bool + stateless bool + acked bool + init bool + deadline time.Duration + outBlock chan struct{} + subroute *subHandlerID + respErr atomic.Pointer[error] } // Response is a response from the server. @@ -62,13 +61,12 @@ type Response struct { func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient { ctx, cancelFn := context.WithCancelCause(ctx) return &muxClient{ - MuxID: muxID, - ctx: ctx, - cancelFn: cancelFn, - parent: parent, - LastPong: time.Now().UnixNano(), - BaseFlags: parent.baseFlags, - clientPingInterval: parent.clientPingInterval, + MuxID: muxID, + ctx: ctx, + cancelFn: cancelFn, + parent: parent, + LastPong: time.Now().Unix(), + BaseFlags: parent.baseFlags, } } @@ -311,11 +309,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer < } }() var pingTimer <-chan time.Time - if m.deadline == 0 || m.deadline > m.clientPingInterval { - ticker := time.NewTicker(m.clientPingInterval) + if m.deadline == 0 || m.deadline > clientPingInterval { + ticker := time.NewTicker(clientPingInterval) defer ticker.Stop() pingTimer = ticker.C - atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) + atomic.StoreInt64(&m.LastPong, time.Now().Unix()) } defer m.parent.deleteMux(false, m.MuxID) for { @@ -369,7 +367,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { } // Only check ping when not closed. - if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 { + if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 { m.respMu.Unlock() if debugPrint { fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got) @@ -390,20 +388,15 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { // responseCh is the channel to that goes to the requester. // internalResp is the channel that comes from the server. func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { - defer func() { - m.parent.deleteMux(false, m.MuxID) - // addErrorNonBlockingClose will close the response channel. - xioutil.SafeClose(responseCh) - }() - - // Cancelation and errors are handled by handleTwowayRequests below. + defer m.parent.deleteMux(false, m.MuxID) + defer xioutil.SafeClose(responseCh) for resp := range internalResp { - m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) responseCh <- resp + m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) } } -func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) { +func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) { var errState bool if debugPrint { start := time.Now() @@ -412,30 +405,24 @@ func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-cha }() } - var pingTimer <-chan time.Time - if m.deadline == 0 || m.deadline > m.clientPingInterval { - ticker := time.NewTicker(m.clientPingInterval) - defer ticker.Stop() - pingTimer = ticker.C - atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) - } - // Listen for client messages. -reqLoop: - for !errState { + for { + if errState { + go func() { + // Drain requests. + for range requests { + } + }() + return + } select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) + m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) errState = true continue - case <-pingTimer: - if !m.doPing(errResp) { - errState = true - continue - } case req, ok := <-requests: if !ok { // Done send EOF @@ -445,28 +432,22 @@ reqLoop: msg := message{ Op: OpMuxClientMsg, MuxID: m.MuxID, + Seq: 1, Flags: FlagEOF, } msg.setZeroPayloadFlag() err := m.send(msg) if err != nil { - m.addErrorNonBlockingClose(errResp, err) + m.addErrorNonBlockingClose(internalResp, err) } - break reqLoop + return } // Grab a send token. - sendReq: select { case <-m.ctx.Done(): - m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) + m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) errState = true continue - case <-pingTimer: - if !m.doPing(errResp) { - errState = true - continue - } - goto sendReq case <-m.outBlock: } msg := message{ @@ -479,41 +460,13 @@ reqLoop: err := m.send(msg) PutByteBuffer(req) if err != nil { - m.addErrorNonBlockingClose(errResp, err) + m.addErrorNonBlockingClose(internalResp, err) errState = true continue } msg.Seq++ } } - - if errState { - // Drain requests. - for { - select { - case r, ok := <-requests: - if !ok { - return - } - PutByteBuffer(r) - default: - return - } - } - } - - for !errState { - select { - case <-m.ctx.Done(): - if debugPrint { - fmt.Println("Client sending disconnect to mux", m.MuxID) - } - m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) - return - case <-pingTimer: - errState = !m.doPing(errResp) - } - } } // checkSeq will check if sequence number is correct and increment it by 1. @@ -549,7 +502,7 @@ func (m *muxClient) response(seq uint32, r Response) { m.addResponse(r) return } - atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) + atomic.StoreInt64(&m.LastPong, time.Now().Unix()) ok := m.addResponse(r) if !ok { PutByteBuffer(r.Msg) @@ -600,7 +553,7 @@ func (m *muxClient) pong(msg pongMsg) { m.addResponse(Response{Err: err}) return } - atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) + atomic.StoreInt64(&m.LastPong, time.Now().Unix()) } // addResponse will add a response to the response channel. diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index 8a835a16f..e9d1db659 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -28,20 +28,21 @@ import ( xioutil "github.com/minio/minio/internal/ioutil" ) +const lastPingThreshold = 4 * clientPingInterval + type muxServer struct { - ID uint64 - LastPing int64 - SendSeq, RecvSeq uint32 - Resp chan []byte - BaseFlags Flags - ctx context.Context - cancel context.CancelFunc - inbound chan []byte - parent *Connection - sendMu sync.Mutex - recvMu sync.Mutex - outBlock chan struct{} - clientPingInterval time.Duration + ID uint64 + LastPing int64 + SendSeq, RecvSeq uint32 + Resp chan []byte + BaseFlags Flags + ctx context.Context + cancel context.CancelFunc + inbound chan []byte + parent *Connection + sendMu sync.Mutex + recvMu sync.Mutex + outBlock chan struct{} } func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { @@ -88,17 +89,16 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea } m := muxServer{ - ID: msg.MuxID, - RecvSeq: msg.Seq + 1, - SendSeq: msg.Seq, - ctx: ctx, - cancel: cancel, - parent: c, - inbound: nil, - outBlock: make(chan struct{}, outboundCap), - LastPing: time.Now().Unix(), - BaseFlags: c.baseFlags, - clientPingInterval: c.clientPingInterval, + ID: msg.MuxID, + RecvSeq: msg.Seq + 1, + SendSeq: msg.Seq, + ctx: ctx, + cancel: cancel, + parent: c, + inbound: nil, + outBlock: make(chan struct{}, outboundCap), + LastPing: time.Now().Unix(), + BaseFlags: c.baseFlags, } // Acknowledge Mux created. // Send async. @@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea }(m.outBlock) // Remote aliveness check if needed. - if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) { + if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { go func() { wg.Wait() m.checkRemoteAlive() @@ -234,7 +234,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- // checkRemoteAlive will check if the remote is alive. func (m *muxServer) checkRemoteAlive() { - t := time.NewTicker(m.clientPingInterval) + t := time.NewTicker(lastPingThreshold / 4) defer t.Stop() for { select { @@ -242,7 +242,7 @@ func (m *muxServer) checkRemoteAlive() { return case <-t.C: last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) - if last > 4*m.clientPingInterval { + if last > lastPingThreshold { gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) m.close() return diff --git a/internal/grid/stream.go b/internal/grid/stream.go index cbc69c1ba..29e0cebf4 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -89,26 +89,12 @@ func (s *Stream) Results(next func(b []byte) error) (err error) { return nil } if resp.Err != nil { - s.cancel(resp.Err) return resp.Err } err = next(resp.Msg) if err != nil { - s.cancel(err) return err } } } } - -// Done will return a channel that will be closed when the stream is done. -// This mirrors context.Done(). -func (s *Stream) Done() <-chan struct{} { - return s.ctx.Done() -} - -// Err will return the error that caused the stream to end. -// This mirrors context.Err(). -func (s *Stream) Err() error { - return s.ctx.Err() -}