From 40fb3371fa8235b12dc4172afbf0dffc9c9e0515 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 28 Feb 2024 10:05:18 -0800 Subject: [PATCH] Mux: Send async mux ack and fix stream error responses (#19149) Streams can return errors if the cancelation is picked up before the response stream close is picked up. Under extreme load, this could lead to missing responses. Send server mux ack async so a blocked send cannot block newMuxStream call. Stream will not progress until mux has been acked. --- internal/grid/connection.go | 3 +++ internal/grid/muxclient.go | 8 +++++++- internal/grid/muxserver.go | 26 ++++++++++++++++++-------- internal/grid/stream.go | 9 +++++++-- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/internal/grid/connection.go b/internal/grid/connection.go index d5bbdcaa0..6a6979c74 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -1504,6 +1504,9 @@ func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) { }) } if m.Flags&FlagEOF != 0 { + if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 { + v.cancelFn(errStreamEOF) + } v.close() if debugReqs { fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX") diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index bcc17cba9..7fa4ce29a 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -323,7 +323,10 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer < if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.addErrorNonBlockingClose(respHandler, context.Cause(m.ctx)) + err := context.Cause(m.ctx) + if !errors.Is(err, errStreamEOF) { + m.addErrorNonBlockingClose(respHandler, err) + } return case resp, ok := <-respServer: if !ok { @@ -463,6 +466,7 @@ func (m *muxClient) response(seq uint32, r Response) { fmt.Println(m.MuxID, m.parent.String(), "CHECKSEQ FAIL", m.RecvSeq, seq) } PutByteBuffer(r.Msg) + r.Msg = nil r.Err = ErrIncorrectSequence m.addResponse(r) return @@ -474,6 +478,8 @@ func (m *muxClient) response(seq uint32, r Response) { } } +var errStreamEOF = errors.New("stream EOF") + // error is a message from the server to disconnect. func (m *muxClient) error(err RemoteErr) { if debugPrint { diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index 183a463b3..f4917bafa 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -103,14 +103,20 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea BaseFlags: c.baseFlags, } // Acknowledge Mux created. - var ack message - ack.Op = OpAckMux - ack.Flags = m.BaseFlags - ack.MuxID = m.ID - m.send(ack) - if debugPrint { - fmt.Println("connected stream mux:", ack.MuxID) - } + // Send async. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var ack message + ack.Op = OpAckMux + ack.Flags = m.BaseFlags + ack.MuxID = m.ID + m.send(ack) + if debugPrint { + fmt.Println("connected stream mux:", ack.MuxID) + } + }() // Data inbound to the handler var handlerIn chan []byte @@ -118,6 +124,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea m.inbound = make(chan []byte, inboundCap) handlerIn = make(chan []byte, 1) go func(inbound <-chan []byte) { + wg.Wait() defer xioutil.SafeClose(handlerIn) // Send unblocks when we have delivered the message to the handler. for in := range inbound { @@ -133,6 +140,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea // Handler goroutine. var handlerErr *RemoteErr go func() { + wg.Wait() start := time.Now() defer func() { if debugPrint { @@ -154,6 +162,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea }() // Response sender gorutine... go func(outBlock <-chan struct{}) { + wg.Wait() defer m.parent.deleteMux(true, m.ID) for { // Process outgoing message. @@ -196,6 +205,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea // Remote aliveness check. if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { go func() { + wg.Wait() t := time.NewTicker(lastPingThreshold / 4) defer t.Stop() for { diff --git a/internal/grid/stream.go b/internal/grid/stream.go index a99b66643..29e0cebf4 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -74,10 +74,15 @@ func (s *Stream) Results(next func(b []byte) error) (err error) { } } }() + doneCh := s.ctx.Done() for { select { - case <-s.ctx.Done(): - return context.Cause(s.ctx) + case <-doneCh: + if err := context.Cause(s.ctx); !errors.Is(err, errStreamEOF) { + return err + } + // Fall through to be sure we have returned all responses. + doneCh = nil case resp, ok := <-s.responses: if !ok { done = true