From f00187033d7977ee2da95e3b8542bd720b52628f Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 7 Jun 2024 08:51:52 -0700 Subject: [PATCH] Two way streams for upcoming locking enhancements (#19796) --- internal/grid/benchmark_test.go | 132 +++++++++++++++++ internal/grid/connection.go | 50 +++++-- internal/grid/debug.go | 1 + internal/grid/debugmsg_string.go | 5 +- internal/grid/grid_test.go | 237 ++++++++++++++++++++++++++++++- internal/grid/muxclient.go | 145 ++++++++++++------- internal/grid/muxserver.go | 93 +++++++----- internal/grid/stream.go | 14 ++ 8 files changed, 573 insertions(+), 104 deletions(-) diff --git a/internal/grid/benchmark_test.go b/internal/grid/benchmark_test.go index 54feb9aa2..e5cfd0c68 100644 --- a/internal/grid/benchmark_test.go +++ b/internal/grid/benchmark_test.go @@ -200,6 +200,7 @@ func BenchmarkStream(b *testing.B) { }{ {name: "request", fn: benchmarkGridStreamReqOnly}, {name: "responses", fn: benchmarkGridStreamRespOnly}, + {name: "twoway", fn: benchmarkGridStreamTwoway}, } for _, test := range tests { b.Run(test.name, func(b *testing.B) { @@ -438,3 +439,134 @@ func benchmarkGridStreamReqOnly(b *testing.B, n int) { }) } } + +func benchmarkGridStreamTwoway(b *testing.B, n int) { + defer testlogger.T.SetErrorTB(b)() + + errFatal := func(err error) { + b.Helper() + if err != nil { + b.Fatal(err) + } + } + grid, err := SetupTestGrid(n) + errFatal(err) + b.Cleanup(grid.Cleanup) + const messages = 10 + // Create n managers. + const payloadSize = 512 + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + payload := make([]byte, payloadSize) + _, err = rng.Read(payload) + errFatal(err) + + for _, remote := range grid.Managers { + // Register a single handler which echos the payload. + errFatal(remote.RegisterStreamingHandler(handlerTest, StreamHandler{ + // Send 10x requests. + Handle: func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr { + got := 0 + for { + select { + case b, ok := <-in: + if !ok { + if got != messages { + return NewRemoteErrf("wrong number of requests. want %d, got %d", messages, got) + } + return nil + } + out <- b + got++ + } + } + }, + + Subroute: "some-subroute", + OutCapacity: 1, + InCapacity: 1, // Only one message buffered. + })) + errFatal(err) + } + + // Wait for all to connect + // Parallel writes per server. + for par := 1; par <= 32; par *= 2 { + b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) { + defer timeout(30 * time.Second)() + b.ReportAllocs() + b.SetBytes(int64(len(payload) * (2*messages + 1))) + b.ResetTimer() + t := time.Now() + var ops int64 + var lat int64 + b.SetParallelism(par) + b.RunParallel(func(pb *testing.PB) { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + n := 0 + var latency int64 + managers := grid.Managers + hosts := grid.Hosts + for pb.Next() { + // Pick a random manager. + src, dst := rng.Intn(len(managers)), rng.Intn(len(managers)) + if src == dst { + dst = (dst + 1) % len(managers) + } + local := managers[src] + conn := local.Connection(hosts[dst]).Subroute("some-subroute") + if conn == nil { + b.Fatal("No connection") + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // Send the payload. + t := time.Now() + st, err := conn.NewStream(ctx, handlerTest, payload) + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + got := 0 + sent := 0 + go func() { + for i := 0; i < messages; i++ { + st.Requests <- append(GetByteBuffer()[:0], payload...) + if sent++; sent == messages { + close(st.Requests) + return + } + } + }() + err = st.Results(func(b []byte) error { + got++ + PutByteBuffer(b) + return nil + }) + if err != nil { + if debugReqs { + fmt.Println(err.Error()) + } + b.Fatal(err.Error()) + } + if got != messages { + b.Fatalf("wrong number of responses. want %d, got %d", messages, got) + } + latency += time.Since(t).Nanoseconds() + cancel() + n += got + } + atomic.AddInt64(&ops, int64(n*2)) + atomic.AddInt64(&lat, latency) + }) + spent := time.Since(t) + if spent > 0 && n > 0 { + // Since we are benchmarking n parallel servers we need to multiply by n. + // This will give an estimate of the total ops/s. + latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond) + b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s") + b.ReportMetric(latency/float64(ops), "ms/op") + } + }) + } +} diff --git a/internal/grid/connection.go b/internal/grid/connection.go index a6b867e1b..3048e84a9 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -128,10 +128,11 @@ type Connection struct { outMessages atomic.Int64 // For testing only - debugInConn net.Conn - debugOutConn net.Conn - addDeadline time.Duration - connMu sync.Mutex + debugInConn net.Conn + debugOutConn net.Conn + blockMessages atomic.Pointer[<-chan struct{}] + addDeadline time.Duration + connMu sync.Mutex } // Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID. @@ -883,7 +884,7 @@ func (c *Connection) updateState(s State) { return } if s == StateConnected { - atomic.StoreInt64(&c.LastPong, time.Now().Unix()) + atomic.StoreInt64(&c.LastPong, time.Now().UnixNano()) } atomic.StoreUint32((*uint32)(&c.state), uint32(s)) if debugPrint { @@ -993,6 +994,11 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { } return } + block := c.blockMessages.Load() + if block != nil && *block != nil { + <-*block + } + if c.incomingBytes != nil { c.incomingBytes(int64(len(msg))) } @@ -1094,7 +1100,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { } lastPong := atomic.LoadInt64(&c.LastPong) if lastPong > 0 { - lastPongTime := time.Unix(lastPong, 0) + lastPongTime := time.Unix(0, lastPong) if d := time.Since(lastPongTime); d > connPingInterval*2 { gridLogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond))) return @@ -1105,7 +1111,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { if err != nil { gridLogIf(ctx, err) // Fake it... - atomic.StoreInt64(&c.LastPong, time.Now().Unix()) + atomic.StoreInt64(&c.LastPong, time.Now().UnixNano()) continue } case toSend = <-c.outQueue: @@ -1406,12 +1412,16 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan } func (c *Connection) handlePong(ctx context.Context, m message) { + if m.MuxID == 0 && m.Payload == nil { + atomic.StoreInt64(&c.LastPong, time.Now().UnixNano()) + return + } var pong pongMsg _, err := pong.UnmarshalMsg(m.Payload) PutByteBuffer(m.Payload) gridLogIf(ctx, err) if m.MuxID == 0 { - atomic.StoreInt64(&c.LastPong, time.Now().Unix()) + atomic.StoreInt64(&c.LastPong, time.Now().UnixNano()) return } if v, ok := c.outgoing.Load(m.MuxID); ok { @@ -1425,7 +1435,9 @@ func (c *Connection) handlePong(ctx context.Context, m message) { func (c *Connection) handlePing(ctx context.Context, m message) { if m.MuxID == 0 { - gridLogIf(ctx, c.queueMsg(m, &pongMsg{})) + m.Flags.Clear(FlagPayloadIsZero) + m.Op = OpPong + gridLogIf(ctx, c.queueMsg(m, nil)) return } // Single calls do not support pinging. @@ -1560,9 +1572,15 @@ func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) { } if m.Flags&FlagEOF != 0 { if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 { + // We must obtain the lock before calling cancelFn + // Otherwise others may pick up the error before close is called. + v.respMu.Lock() v.cancelFn(errStreamEOF) + v.closeLocked() + v.respMu.Unlock() + } else { + v.close() } - v.close() if debugReqs { fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX") } @@ -1617,7 +1635,7 @@ func (c *Connection) Stats() madmin.RPCMetrics { IncomingMessages: c.inMessages.Load(), OutgoingMessages: c.outMessages.Load(), OutQueue: len(c.outQueue), - LastPongTime: time.Unix(c.LastPong, 0).UTC(), + LastPongTime: time.Unix(0, c.LastPong).UTC(), } m.ByDestination = map[string]madmin.RPCMetrics{ c.Remote: m, @@ -1659,7 +1677,12 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { c.connMu.Lock() defer c.connMu.Unlock() c.connPingInterval = args[0].(time.Duration) + if c.connPingInterval < time.Second { + panic("CONN ping interval too low") + } case debugSetClientPingDuration: + c.connMu.Lock() + defer c.connMu.Unlock() c.clientPingInterval = args[0].(time.Duration) case debugAddToDeadline: c.addDeadline = args[0].(time.Duration) @@ -1675,6 +1698,11 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { mid.respMu.Lock() resp(mid.closed) mid.respMu.Unlock() + case debugBlockInboundMessages: + c.connMu.Lock() + block := (<-chan struct{})(args[0].(chan struct{})) + c.blockMessages.Store(&block) + c.connMu.Unlock() } } diff --git a/internal/grid/debug.go b/internal/grid/debug.go index 0172f87e2..8110acb65 100644 --- a/internal/grid/debug.go +++ b/internal/grid/debug.go @@ -50,6 +50,7 @@ const ( debugSetClientPingDuration debugAddToDeadline debugIsOutgoingClosed + debugBlockInboundMessages ) // TestGrid contains a grid of servers for testing purposes. diff --git a/internal/grid/debugmsg_string.go b/internal/grid/debugmsg_string.go index a84f811b5..52c92cb4b 100644 --- a/internal/grid/debugmsg_string.go +++ b/internal/grid/debugmsg_string.go @@ -16,11 +16,12 @@ func _() { _ = x[debugSetClientPingDuration-5] _ = x[debugAddToDeadline-6] _ = x[debugIsOutgoingClosed-7] + _ = x[debugBlockInboundMessages-8] } -const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed" +const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages" -var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151} +var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176} func (i debugMsg) String() string { if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) { diff --git a/internal/grid/grid_test.go b/internal/grid/grid_test.go index fcf325168..35850c039 100644 --- a/internal/grid/grid_test.go +++ b/internal/grid/grid_test.go @@ -26,6 +26,7 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" "time" @@ -378,6 +379,54 @@ func TestStreamSuite(t *testing.T) { assertNoActive(t, connRemoteLocal) assertNoActive(t, connLocalToRemote) }) + t.Run("testServerStreamOnewayNoPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamNoPing(t, local, remote, 0) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayNoPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamNoPing(t, local, remote, 1) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, false, false) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPingReq", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, false, true) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPingResp", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, true, false) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 1, true, true) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamOnewayPing", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 0, false, true) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) + t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamPingRunning(t, local, remote, 0, false, false) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) } func testStreamRoundtrip(t *testing.T, local, remote *Manager) { @@ -491,12 +540,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { register(remote) // local to remote - testHandler := func(t *testing.T, handler HandlerID) { + testHandler := func(t *testing.T, handler HandlerID, sendReq bool) { remoteConn := local.Connection(remoteHost) const testPayload = "Hello Grid World!" ctx, cancel := context.WithCancel(context.Background()) - st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload)) errFatal(err) clientCanceled := make(chan time.Time, 1) err = nil @@ -513,6 +562,18 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { clientCanceled <- time.Now() }(t) start := time.Now() + if st.Requests != nil { + defer close(st.Requests) + } + // Fill up queue. + for sendReq { + select { + case st.Requests <- []byte("Hello"): + time.Sleep(10 * time.Millisecond) + default: + sendReq = false + } + } cancel() <-serverCanceled t.Log("server cancel time:", time.Since(start)) @@ -524,11 +585,13 @@ func testStreamCancel(t *testing.T, local, remote *Manager) { } // local to remote, unbuffered t.Run("unbuffered", func(t *testing.T) { - testHandler(t, handlerTest) + testHandler(t, handlerTest, false) }) - t.Run("buffered", func(t *testing.T) { - testHandler(t, handlerTest2) + testHandler(t, handlerTest2, false) + }) + t.Run("buffered", func(t *testing.T) { + testHandler(t, handlerTest2, true) }) } @@ -1025,6 +1088,170 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) { } } +// testServerStreamNoPing will test if server and client handle no pings. +func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + reqStarted := make(chan struct{}) + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr { + close(reqStarted) + // Just wait for it to cancel. + <-ctx.Done() + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + }, + OutCapacity: 1, + InCapacity: inCap, + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond) + defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server start the request. + <-reqStarted + + // Stop processing requests + nowBlocking := make(chan struct{}) + remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking) + + // Check that local returned. + err = st.Results(func(b []byte) error { + return nil + }) + if err == nil { + t.Fatal("expected error, got nil") + } + t.Logf("response: %v", err) + + // Check that remote is canceled. + <-serverCanceled + close(nowBlocking) +} + +// testServerStreamPingRunning will test if server and client handle ping even when blocked. +func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + reqStarted := make(chan struct{}) + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr { + close(reqStarted) + // Just wait for it to cancel. + for blockResp { + select { + case <-ctx.Done(): + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + case resp <- []byte{1}: + time.Sleep(10 * time.Millisecond) + } + } + // Just wait for it to cancel. + <-ctx.Done() + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + }, + OutCapacity: 1, + InCapacity: inCap, + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond) + defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server start the request. + <-reqStarted + + // Block until we have exceeded the deadline several times over. + nowBlocking := make(chan struct{}) + var mu sync.Mutex + time.AfterFunc(time.Second, func() { + mu.Lock() + cancel() + close(nowBlocking) + mu.Unlock() + }) + if inCap > 0 { + go func() { + defer close(st.Requests) + if !blockReq { + <-nowBlocking + return + } + for { + select { + case <-nowBlocking: + return + case <-st.Done(): + case st.Requests <- []byte{1}: + time.Sleep(10 * time.Millisecond) + } + } + }() + } + // Check that local returned. + err = st.Results(func(b []byte) error { + <-st.Done() + return ctx.Err() + }) + mu.Lock() + select { + case <-nowBlocking: + default: + t.Fatal("expected to be blocked. got err", err) + } + if err == nil { + t.Fatal("expected error, got nil") + } + t.Logf("response: %v", err) + // Check that remote is canceled. + <-serverCanceled +} + func timeout(after time.Duration) (cancel func()) { c := time.After(after) cc := make(chan struct{}) diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index dd5331a31..5ec8fb347 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -32,24 +32,25 @@ import ( // muxClient is a stateful connection to a remote. type muxClient struct { - MuxID uint64 - SendSeq, RecvSeq uint32 - LastPong int64 - BaseFlags Flags - ctx context.Context - cancelFn context.CancelCauseFunc - parent *Connection - respWait chan<- Response - respMu sync.Mutex - singleResp bool - closed bool - stateless bool - acked bool - init bool - deadline time.Duration - outBlock chan struct{} - subroute *subHandlerID - respErr atomic.Pointer[error] + MuxID uint64 + SendSeq, RecvSeq uint32 + LastPong int64 + BaseFlags Flags + ctx context.Context + cancelFn context.CancelCauseFunc + parent *Connection + respWait chan<- Response + respMu sync.Mutex + singleResp bool + closed bool + stateless bool + acked bool + init bool + deadline time.Duration + outBlock chan struct{} + subroute *subHandlerID + respErr atomic.Pointer[error] + clientPingInterval time.Duration } // Response is a response from the server. @@ -61,12 +62,13 @@ type Response struct { func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient { ctx, cancelFn := context.WithCancelCause(ctx) return &muxClient{ - MuxID: muxID, - ctx: ctx, - cancelFn: cancelFn, - parent: parent, - LastPong: time.Now().Unix(), - BaseFlags: parent.baseFlags, + MuxID: muxID, + ctx: ctx, + cancelFn: cancelFn, + parent: parent, + LastPong: time.Now().UnixNano(), + BaseFlags: parent.baseFlags, + clientPingInterval: parent.clientPingInterval, } } @@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer < } }() var pingTimer <-chan time.Time - if m.deadline == 0 || m.deadline > clientPingInterval { - ticker := time.NewTicker(clientPingInterval) + if m.deadline == 0 || m.deadline > m.clientPingInterval { + ticker := time.NewTicker(m.clientPingInterval) defer ticker.Stop() pingTimer = ticker.C - atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } defer m.parent.deleteMux(false, m.MuxID) for { @@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { } // Only check ping when not closed. - if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 { + if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 { m.respMu.Unlock() if debugPrint { fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got) @@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) { // responseCh is the channel to that goes to the requester. // internalResp is the channel that comes from the server. func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { - defer m.parent.deleteMux(false, m.MuxID) - defer xioutil.SafeClose(responseCh) + defer func() { + m.parent.deleteMux(false, m.MuxID) + // addErrorNonBlockingClose will close the response channel. + xioutil.SafeClose(responseCh) + }() + + // Cancelation and errors are handled by handleTwowayRequests below. for resp := range internalResp { - responseCh <- resp m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) + responseCh <- resp } } -func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) { +func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) { var errState bool if debugPrint { start := time.Now() @@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests }() } + var pingTimer <-chan time.Time + if m.deadline == 0 || m.deadline > m.clientPingInterval { + ticker := time.NewTicker(m.clientPingInterval) + defer ticker.Stop() + pingTimer = ticker.C + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) + } + // Listen for client messages. - for { - if errState { - go func() { - // Drain requests. - for range requests { - } - }() - return - } +reqLoop: + for !errState { select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) + m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) errState = true continue + case <-pingTimer: + if !m.doPing(errResp) { + errState = true + continue + } case req, ok := <-requests: if !ok { // Done send EOF @@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests msg := message{ Op: OpMuxClientMsg, MuxID: m.MuxID, - Seq: 1, Flags: FlagEOF, } msg.setZeroPayloadFlag() err := m.send(msg) if err != nil { - m.addErrorNonBlockingClose(internalResp, err) + m.addErrorNonBlockingClose(errResp, err) } - return + break reqLoop } // Grab a send token. + sendReq: select { case <-m.ctx.Done(): - m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) + m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) errState = true continue + case <-pingTimer: + if !m.doPing(errResp) { + errState = true + continue + } + goto sendReq case <-m.outBlock: } msg := message{ @@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests err := m.send(msg) PutByteBuffer(req) if err != nil { - m.addErrorNonBlockingClose(internalResp, err) + m.addErrorNonBlockingClose(errResp, err) errState = true continue } msg.Seq++ } } + + if errState { + // Drain requests. + for { + select { + case r, ok := <-requests: + if !ok { + return + } + PutByteBuffer(r) + default: + return + } + } + } + + for !errState { + select { + case <-m.ctx.Done(): + if debugPrint { + fmt.Println("Client sending disconnect to mux", m.MuxID) + } + m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) + return + case <-pingTimer: + errState = !m.doPing(errResp) + } + } } // checkSeq will check if sequence number is correct and increment it by 1. @@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) { m.addResponse(r) return } - atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) ok := m.addResponse(r) if !ok { PutByteBuffer(r.Msg) @@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) { m.addResponse(Response{Err: err}) return } - atomic.StoreInt64(&m.LastPong, time.Now().Unix()) + atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) } // addResponse will add a response to the response channel. diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index e9d1db659..87e81acda 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -28,21 +28,20 @@ import ( xioutil "github.com/minio/minio/internal/ioutil" ) -const lastPingThreshold = 4 * clientPingInterval - type muxServer struct { - ID uint64 - LastPing int64 - SendSeq, RecvSeq uint32 - Resp chan []byte - BaseFlags Flags - ctx context.Context - cancel context.CancelFunc - inbound chan []byte - parent *Connection - sendMu sync.Mutex - recvMu sync.Mutex - outBlock chan struct{} + ID uint64 + LastPing int64 + SendSeq, RecvSeq uint32 + Resp chan []byte + BaseFlags Flags + ctx context.Context + cancel context.CancelFunc + inbound chan []byte + parent *Connection + sendMu sync.Mutex + recvMu sync.Mutex + outBlock chan struct{} + clientPingInterval time.Duration } func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { @@ -89,16 +88,17 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea } m := muxServer{ - ID: msg.MuxID, - RecvSeq: msg.Seq + 1, - SendSeq: msg.Seq, - ctx: ctx, - cancel: cancel, - parent: c, - inbound: nil, - outBlock: make(chan struct{}, outboundCap), - LastPing: time.Now().Unix(), - BaseFlags: c.baseFlags, + ID: msg.MuxID, + RecvSeq: msg.Seq + 1, + SendSeq: msg.Seq, + ctx: ctx, + cancel: cancel, + parent: c, + inbound: nil, + outBlock: make(chan struct{}, outboundCap), + LastPing: time.Now().Unix(), + BaseFlags: c.baseFlags, + clientPingInterval: c.clientPingInterval, } // Acknowledge Mux created. // Send async. @@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea }(m.outBlock) // Remote aliveness check if needed. - if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { + if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) { go func() { wg.Wait() m.checkRemoteAlive() @@ -164,9 +164,21 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea // handleInbound sends unblocks when we have delivered the message to the handler. func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) { - for in := range inbound { - handlerIn <- in - m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) + for { + select { + case <-m.ctx.Done(): + return + case in, ok := <-inbound: + if !ok { + return + } + select { + case <-m.ctx.Done(): + return + case handlerIn <- in: + m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) + } + } } } @@ -234,7 +246,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- // checkRemoteAlive will check if the remote is alive. func (m *muxServer) checkRemoteAlive() { - t := time.NewTicker(lastPingThreshold / 4) + t := time.NewTicker(m.clientPingInterval) defer t.Stop() for { select { @@ -242,7 +254,7 @@ func (m *muxServer) checkRemoteAlive() { return case <-t.C: last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) - if last > lastPingThreshold { + if last > 4*m.clientPingInterval { gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) m.close() return @@ -257,7 +269,7 @@ func (m *muxServer) checkSeq(seq uint32) (ok bool) { if debugPrint { fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq) } - m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq)) + m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq), false) return false } m.RecvSeq++ @@ -268,13 +280,13 @@ func (m *muxServer) message(msg message) { if debugPrint { fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload)) } + if !m.checkSeq(msg.Seq) { + return + } m.recvMu.Lock() defer m.recvMu.Unlock() if cap(m.inbound) == 0 { - m.disconnect("did not expect inbound message") - return - } - if !m.checkSeq(msg.Seq) { + m.disconnect("did not expect inbound message", true) return } // Note, on EOF no value can be sent. @@ -296,7 +308,7 @@ func (m *muxServer) message(msg message) { fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq) } default: - m.disconnect("handler blocked") + m.disconnect("handler blocked", true) } } @@ -332,7 +344,9 @@ func (m *muxServer) ping(seq uint32) pongMsg { } } -func (m *muxServer) disconnect(msg string) { +// disconnect will disconnect the mux. +// m.recvMu must be locked when calling this function. +func (m *muxServer) disconnect(msg string, locked bool) { if debugPrint { fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg) } @@ -341,6 +355,11 @@ func (m *muxServer) disconnect(msg string) { } else { m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID}) } + // Unlock, since we are calling deleteMux, which will call close - which will lock recvMu. + if locked { + m.recvMu.Unlock() + defer m.recvMu.Lock() + } m.parent.deleteMux(true, m.ID) } diff --git a/internal/grid/stream.go b/internal/grid/stream.go index 29e0cebf4..cbc69c1ba 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -89,12 +89,26 @@ func (s *Stream) Results(next func(b []byte) error) (err error) { return nil } if resp.Err != nil { + s.cancel(resp.Err) return resp.Err } err = next(resp.Msg) if err != nil { + s.cancel(err) return err } } } } + +// Done will return a channel that will be closed when the stream is done. +// This mirrors context.Done(). +func (s *Stream) Done() <-chan struct{} { + return s.ctx.Done() +} + +// Err will return the error that caused the stream to end. +// This mirrors context.Err(). +func (s *Stream) Err() error { + return s.ctx.Err() +}