diff --git a/internal/grid/connection.go b/internal/grid/connection.go index 3d312b1ae..91e5006d8 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -531,10 +531,11 @@ func (c *Connection) shouldConnect() bool { return h0 < h1 } -func (c *Connection) send(msg []byte) error { +func (c *Connection) send(ctx context.Context, msg []byte) error { select { - case <-c.ctx.Done(): - return context.Cause(c.ctx) + case <-ctx.Done(): + // Returning error here is too noisy. + return nil case c.outQueue <- msg: return nil } @@ -570,7 +571,7 @@ func (c *Connection) queueMsg(msg message, payload sender) error { h := xxh3.Hash(dst) dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) } - return c.send(dst) + return c.send(c.ctx, dst) } // sendMsg will send diff --git a/internal/grid/grid.go b/internal/grid/grid.go index a0813a7f8..5034e1a8e 100644 --- a/internal/grid/grid.go +++ b/internal/grid/grid.go @@ -184,12 +184,12 @@ func (m *lockedClientMap) Delete(id uint64) { func (m *lockedClientMap) Range(fn func(key uint64, value *muxClient) bool) { m.mu.Lock() + defer m.mu.Unlock() for k, v := range m.m { if !fn(k, v) { break } } - m.mu.Unlock() } func (m *lockedClientMap) Clear() { diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index 2078dc19a..bcc17cba9 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -173,7 +173,7 @@ func (m *muxClient) sendLocked(msg message) error { h := xxh3.Hash(dst) dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) } - return m.parent.send(dst) + return m.parent.send(m.ctx, dst) } // RequestStateless will send a single payload request and stream back results. @@ -552,7 +552,15 @@ func (m *muxClient) close() { if debugPrint { fmt.Println("closing outgoing mux", m.MuxID) } - m.respMu.Lock() + if !m.respMu.TryLock() { + // Cancel before locking - will unblock any pending sends. + if m.cancelFn != nil { + m.cancelFn(context.Canceled) + } + // Wait for senders to release. + m.respMu.Lock() + } + defer m.respMu.Unlock() m.closeLocked() }