Fix blocked streams blocking reconnects (#19017)

We have observed cases where a blocked stream will block for cancellations.

This happens when response channel is blocked and we want to push an error.
This will have the response mutex locked, which will prevent all other operations until upstream is unblocked.

Make this behavior non-blocking and if blocked spawn a goroutine that will send the response and close the output.

Still a lot of "dancing". Added a test for this and reviewed.
This commit is contained in:
Klaus Post 2024-02-08 10:15:27 -08:00 committed by GitHub
parent a29c66ed74
commit 7ec43bd177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 42 deletions

View File

@ -1587,6 +1587,18 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.clientPingInterval = args[0].(time.Duration) c.clientPingInterval = args[0].(time.Duration)
case debugAddToDeadline: case debugAddToDeadline:
c.addDeadline = args[0].(time.Duration) c.addDeadline = args[0].(time.Duration)
case debugIsOutgoingClosed:
// params: muxID uint64, isClosed func(bool)
muxID := args[0].(uint64)
resp := args[1].(func(b bool))
mid, ok := c.outgoing.Load(muxID)
if !ok || mid == nil {
resp(true)
return
}
mid.respMu.Lock()
resp(mid.closed)
mid.respMu.Unlock()
} }
} }

View File

@ -49,6 +49,7 @@ const (
debugSetConnPingDuration debugSetConnPingDuration
debugSetClientPingDuration debugSetClientPingDuration
debugAddToDeadline debugAddToDeadline
debugIsOutgoingClosed
) )
// TestGrid contains a grid of servers for testing purposes. // TestGrid contains a grid of servers for testing purposes.

View File

@ -372,6 +372,12 @@ func TestStreamSuite(t *testing.T) {
assertNoActive(t, connRemoteLocal) assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote) assertNoActive(t, connLocalToRemote)
}) })
t.Run("testServerStreamResponseBlocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamResponseBlocked(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
} }
func testStreamRoundtrip(t *testing.T, local, remote *Manager) { func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
@ -929,6 +935,96 @@ func testGenericsStreamRoundtripSubroute(t *testing.T, local, remote *Manager) {
t.Log("EOF.", payloads, " Roundtrips:", time.Since(start)) t.Log("EOF.", payloads, " Roundtrips:", time.Since(start))
} }
// testServerStreamResponseBlocked will test if server can handle a blocked response stream
func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
serverSent := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
// Send many responses.
// Test that this doesn't block.
for i := byte(0); i < 100; i++ {
select {
case resp <- []byte{i}:
// ok
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
}
if i == 1 {
close(serverSent)
}
}
return nil
},
OutCapacity: 1,
InCapacity: 0,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server to send the first response.
<-serverSent
// Read back from the stream and block.
nowBlocking := make(chan struct{})
stopBlocking := make(chan struct{})
defer close(stopBlocking)
go func() {
st.Results(func(b []byte) error {
close(nowBlocking)
// Block until test is done.
<-stopBlocking
return nil
})
}()
<-nowBlocking
// Wait for the receiver channel to fill.
for len(st.responses) != cap(st.responses) {
time.Sleep(time.Millisecond)
}
cancel()
<-serverCanceled
local.debugMsg(debugIsOutgoingClosed, st.muxID, func(closed bool) {
if !closed {
t.Error("expected outgoing closed")
} else {
t.Log("outgoing was closed")
}
})
// Drain responses and check if error propagated.
err = st.Results(func(b []byte) error {
return nil
})
if !errors.Is(err, context.Canceled) {
t.Error("expected context.Canceled, got", err)
}
}
func timeout(after time.Duration) (cancel func()) { func timeout(after time.Duration) (cancel func()) {
c := time.After(after) c := time.After(after)
cc := make(chan struct{}) cc := make(chan struct{})

View File

@ -50,6 +50,7 @@ type muxClient struct {
deadline time.Duration deadline time.Duration
outBlock chan struct{} outBlock chan struct{}
subroute *subHandlerID subroute *subHandlerID
respErr atomic.Pointer[error]
} }
// Response is a response from the server. // Response is a response from the server.
@ -250,25 +251,52 @@ func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []b
// Spawn simple disconnect // Spawn simple disconnect
if requests == nil { if requests == nil {
start := time.Now() go m.handleOneWayStream(responseCh, responses)
go m.handleOneWayStream(start, responseCh, responses) return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil
return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn}, nil
} }
// Deliver responses and send unblocks back to the server. // Deliver responses and send unblocks back to the server.
go m.handleTwowayResponses(responseCh, responses) go m.handleTwowayResponses(responseCh, responses)
go m.handleTwowayRequests(responses, requests) go m.handleTwowayRequests(responses, requests)
return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn}, nil return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil
} }
func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Response, respServer <-chan Response) { func (m *muxClient) addErrorNonBlockingClose(respHandler chan<- Response, err error) {
m.respMu.Lock()
defer m.respMu.Unlock()
if !m.closed {
m.respErr.Store(&err)
// Do not block.
select {
case respHandler <- Response{Err: err}:
xioutil.SafeClose(respHandler)
default:
go func() {
respHandler <- Response{Err: err}
xioutil.SafeClose(respHandler)
}()
}
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
m.closed = true
}
}
// respHandler
func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <-chan Response) {
if debugPrint { if debugPrint {
start := time.Now()
defer func() { defer func() {
fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond))
}() }()
} }
defer xioutil.SafeClose(respHandler) defer func() {
// addErrorNonBlockingClose will close the response channel
// - maybe async, so we shouldn't do it here.
if m.respErr.Load() == nil {
xioutil.SafeClose(respHandler)
}
}()
var pingTimer <-chan time.Time var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > clientPingInterval { if m.deadline == 0 || m.deadline > clientPingInterval {
ticker := time.NewTicker(clientPingInterval) ticker := time.NewTicker(clientPingInterval)
@ -283,13 +311,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo
if debugPrint { if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID) fmt.Println("Client sending disconnect to mux", m.MuxID)
} }
m.respMu.Lock() m.addErrorNonBlockingClose(respHandler, context.Cause(m.ctx))
defer m.respMu.Unlock() // We always return in this path.
if !m.closed {
respHandler <- Response{Err: context.Cause(m.ctx)}
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
m.closeLocked()
}
return return
case resp, ok := <-respServer: case resp, ok := <-respServer:
if !ok { if !ok {
@ -308,13 +330,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo
} }
case <-pingTimer: case <-pingTimer:
if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 { if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 {
m.respMu.Lock() m.addErrorNonBlockingClose(respHandler, ErrDisconnected)
defer m.respMu.Unlock() // We always return in this path.
if !m.closed {
respHandler <- Response{Err: ErrDisconnected}
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
m.closeLocked()
}
return return
} }
// Send new ping. // Send new ping.
@ -323,19 +339,21 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo
} }
} }
func (m *muxClient) handleTwowayResponses(responseCh chan Response, responses chan Response) { // responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer m.parent.deleteMux(false, m.MuxID) defer m.parent.deleteMux(false, m.MuxID)
defer xioutil.SafeClose(responseCh) defer xioutil.SafeClose(responseCh)
for resp := range responses { for resp := range internalResp {
responseCh <- resp responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
} }
} }
func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests chan []byte) { func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
var errState bool var errState bool
start := time.Now()
if debugPrint { if debugPrint {
start := time.Now()
defer func() { defer func() {
fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond))
}() }()
@ -343,19 +361,22 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha
// Listen for client messages. // Listen for client messages.
for { for {
if errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
if debugPrint { if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID) fmt.Println("Client sending disconnect to mux", m.MuxID)
} }
m.respMu.Lock() m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
defer m.respMu.Unlock() errState = true
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) continue
if !m.closed {
responses <- Response{Err: context.Cause(m.ctx)}
m.closeLocked()
}
return
case req, ok := <-requests: case req, ok := <-requests:
if !ok { if !ok {
// Done send EOF // Done send EOF
@ -371,19 +392,14 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha
msg.setZeroPayloadFlag() msg.setZeroPayloadFlag()
err := m.send(msg) err := m.send(msg)
if err != nil { if err != nil {
m.respMu.Lock() m.addErrorNonBlockingClose(internalResp, err)
responses <- Response{Err: err}
m.closeLocked()
m.respMu.Unlock()
} }
return return
} }
if errState {
continue
}
// Grab a send token. // Grab a send token.
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true errState = true
continue continue
case <-m.outBlock: case <-m.outBlock:
@ -398,8 +414,7 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha
err := m.send(msg) err := m.send(msg)
PutByteBuffer(req) PutByteBuffer(req)
if err != nil { if err != nil {
responses <- Response{Err: err} m.addErrorNonBlockingClose(internalResp, err)
m.close()
errState = true errState = true
continue continue
} }
@ -534,6 +549,7 @@ func (m *muxClient) closeLocked() {
if m.closed { if m.closed {
return return
} }
// We hold the lock, so nobody can modify m.respWait while we're closing.
if m.respWait != nil { if m.respWait != nil {
xioutil.SafeClose(m.respWait) xioutil.SafeClose(m.respWait)
m.respWait = nil m.respWait = nil

View File

@ -41,7 +41,8 @@ type Stream struct {
// Requests sent cannot be used any further by the called. // Requests sent cannot be used any further by the called.
Requests chan<- []byte Requests chan<- []byte
ctx context.Context muxID uint64
ctx context.Context
} }
// Send a payload to the remote server. // Send a payload to the remote server.