diff --git a/internal/grid/connection.go b/internal/grid/connection.go index d5bbdcaa0..6a6979c74 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -1504,6 +1504,9 @@ func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) { }) } if m.Flags&FlagEOF != 0 { + if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 { + v.cancelFn(errStreamEOF) + } v.close() if debugReqs { fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX") diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index bcc17cba9..7fa4ce29a 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -323,7 +323,10 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer < if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.addErrorNonBlockingClose(respHandler, context.Cause(m.ctx)) + err := context.Cause(m.ctx) + if !errors.Is(err, errStreamEOF) { + m.addErrorNonBlockingClose(respHandler, err) + } return case resp, ok := <-respServer: if !ok { @@ -463,6 +466,7 @@ func (m *muxClient) response(seq uint32, r Response) { fmt.Println(m.MuxID, m.parent.String(), "CHECKSEQ FAIL", m.RecvSeq, seq) } PutByteBuffer(r.Msg) + r.Msg = nil r.Err = ErrIncorrectSequence m.addResponse(r) return @@ -474,6 +478,8 @@ func (m *muxClient) response(seq uint32, r Response) { } } +var errStreamEOF = errors.New("stream EOF") + // error is a message from the server to disconnect. func (m *muxClient) error(err RemoteErr) { if debugPrint { diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index 183a463b3..f4917bafa 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -103,14 +103,20 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea BaseFlags: c.baseFlags, } // Acknowledge Mux created. - var ack message - ack.Op = OpAckMux - ack.Flags = m.BaseFlags - ack.MuxID = m.ID - m.send(ack) - if debugPrint { - fmt.Println("connected stream mux:", ack.MuxID) - } + // Send async. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var ack message + ack.Op = OpAckMux + ack.Flags = m.BaseFlags + ack.MuxID = m.ID + m.send(ack) + if debugPrint { + fmt.Println("connected stream mux:", ack.MuxID) + } + }() // Data inbound to the handler var handlerIn chan []byte @@ -118,6 +124,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea m.inbound = make(chan []byte, inboundCap) handlerIn = make(chan []byte, 1) go func(inbound <-chan []byte) { + wg.Wait() defer xioutil.SafeClose(handlerIn) // Send unblocks when we have delivered the message to the handler. for in := range inbound { @@ -133,6 +140,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea // Handler goroutine. var handlerErr *RemoteErr go func() { + wg.Wait() start := time.Now() defer func() { if debugPrint { @@ -154,6 +162,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea }() // Response sender gorutine... go func(outBlock <-chan struct{}) { + wg.Wait() defer m.parent.deleteMux(true, m.ID) for { // Process outgoing message. @@ -196,6 +205,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea // Remote aliveness check. if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { go func() { + wg.Wait() t := time.NewTicker(lastPingThreshold / 4) defer t.Stop() for { diff --git a/internal/grid/stream.go b/internal/grid/stream.go index a99b66643..29e0cebf4 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -74,10 +74,15 @@ func (s *Stream) Results(next func(b []byte) error) (err error) { } } }() + doneCh := s.ctx.Done() for { select { - case <-s.ctx.Done(): - return context.Cause(s.ctx) + case <-doneCh: + if err := context.Cause(s.ctx); !errors.Is(err, errStreamEOF) { + return err + } + // Fall through to be sure we have returned all responses. + doneCh = nil case resp, ok := <-s.responses: if !ok { done = true