Fix blocked streams blocking reconnects (#19017)

We have observed cases where a blocked stream will block for cancellations.

This happens when response channel is blocked and we want to push an error.
This will have the response mutex locked, which will prevent all other operations until upstream is unblocked.

Make this behavior non-blocking and if blocked spawn a goroutine that will send the response and close the output.

Still a lot of "dancing". Added a test for this and reviewed.
This commit is contained in:
Klaus Post 2024-02-08 10:15:27 -08:00 committed by GitHub
parent a29c66ed74
commit 7ec43bd177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 42 deletions

View File

@ -1587,6 +1587,18 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.clientPingInterval = args[0].(time.Duration)
case debugAddToDeadline:
c.addDeadline = args[0].(time.Duration)
case debugIsOutgoingClosed:
// params: muxID uint64, isClosed func(bool)
muxID := args[0].(uint64)
resp := args[1].(func(b bool))
mid, ok := c.outgoing.Load(muxID)
if !ok || mid == nil {
resp(true)
return
}
mid.respMu.Lock()
resp(mid.closed)
mid.respMu.Unlock()
}
}

View File

@ -49,6 +49,7 @@ const (
debugSetConnPingDuration
debugSetClientPingDuration
debugAddToDeadline
debugIsOutgoingClosed
)
// TestGrid contains a grid of servers for testing purposes.

View File

@ -372,6 +372,12 @@ func TestStreamSuite(t *testing.T) {
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamResponseBlocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamResponseBlocked(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
}
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
@ -929,6 +935,96 @@ func testGenericsStreamRoundtripSubroute(t *testing.T, local, remote *Manager) {
t.Log("EOF.", payloads, " Roundtrips:", time.Since(start))
}
// testServerStreamResponseBlocked will test if server can handle a blocked response stream
func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
serverSent := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
// Send many responses.
// Test that this doesn't block.
for i := byte(0); i < 100; i++ {
select {
case resp <- []byte{i}:
// ok
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
}
if i == 1 {
close(serverSent)
}
}
return nil
},
OutCapacity: 1,
InCapacity: 0,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server to send the first response.
<-serverSent
// Read back from the stream and block.
nowBlocking := make(chan struct{})
stopBlocking := make(chan struct{})
defer close(stopBlocking)
go func() {
st.Results(func(b []byte) error {
close(nowBlocking)
// Block until test is done.
<-stopBlocking
return nil
})
}()
<-nowBlocking
// Wait for the receiver channel to fill.
for len(st.responses) != cap(st.responses) {
time.Sleep(time.Millisecond)
}
cancel()
<-serverCanceled
local.debugMsg(debugIsOutgoingClosed, st.muxID, func(closed bool) {
if !closed {
t.Error("expected outgoing closed")
} else {
t.Log("outgoing was closed")
}
})
// Drain responses and check if error propagated.
err = st.Results(func(b []byte) error {
return nil
})
if !errors.Is(err, context.Canceled) {
t.Error("expected context.Canceled, got", err)
}
}
func timeout(after time.Duration) (cancel func()) {
c := time.After(after)
cc := make(chan struct{})

View File

@ -50,6 +50,7 @@ type muxClient struct {
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
}
// Response is a response from the server.
@ -250,25 +251,52 @@ func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []b
// Spawn simple disconnect
if requests == nil {
start := time.Now()
go m.handleOneWayStream(start, responseCh, responses)
return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn}, nil
go m.handleOneWayStream(responseCh, responses)
return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil
}
// Deliver responses and send unblocks back to the server.
go m.handleTwowayResponses(responseCh, responses)
go m.handleTwowayRequests(responses, requests)
return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn}, nil
return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil
}
func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Response, respServer <-chan Response) {
func (m *muxClient) addErrorNonBlockingClose(respHandler chan<- Response, err error) {
m.respMu.Lock()
defer m.respMu.Unlock()
if !m.closed {
m.respErr.Store(&err)
// Do not block.
select {
case respHandler <- Response{Err: err}:
xioutil.SafeClose(respHandler)
default:
go func() {
respHandler <- Response{Err: err}
xioutil.SafeClose(respHandler)
}()
}
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
m.closed = true
}
}
// respHandler
func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <-chan Response) {
if debugPrint {
start := time.Now()
defer func() {
fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond))
}()
}
defer xioutil.SafeClose(respHandler)
defer func() {
// addErrorNonBlockingClose will close the response channel
// - maybe async, so we shouldn't do it here.
if m.respErr.Load() == nil {
xioutil.SafeClose(respHandler)
}
}()
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > clientPingInterval {
ticker := time.NewTicker(clientPingInterval)
@ -283,13 +311,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.respMu.Lock()
defer m.respMu.Unlock() // We always return in this path.
if !m.closed {
respHandler <- Response{Err: context.Cause(m.ctx)}
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
m.closeLocked()
}
m.addErrorNonBlockingClose(respHandler, context.Cause(m.ctx))
return
case resp, ok := <-respServer:
if !ok {
@ -308,13 +330,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo
}
case <-pingTimer:
if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 {
m.respMu.Lock()
defer m.respMu.Unlock() // We always return in this path.
if !m.closed {
respHandler <- Response{Err: ErrDisconnected}
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
m.closeLocked()
}
m.addErrorNonBlockingClose(respHandler, ErrDisconnected)
return
}
// Send new ping.
@ -323,19 +339,21 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo
}
}
func (m *muxClient) handleTwowayResponses(responseCh chan Response, responses chan Response) {
// responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer m.parent.deleteMux(false, m.MuxID)
defer xioutil.SafeClose(responseCh)
for resp := range responses {
for resp := range internalResp {
responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
}
}
func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests chan []byte) {
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
var errState bool
start := time.Now()
if debugPrint {
start := time.Now()
defer func() {
fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond))
}()
@ -343,19 +361,22 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha
// Listen for client messages.
for {
if errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.respMu.Lock()
defer m.respMu.Unlock()
logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
if !m.closed {
responses <- Response{Err: context.Cause(m.ctx)}
m.closeLocked()
}
return
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true
continue
case req, ok := <-requests:
if !ok {
// Done send EOF
@ -371,19 +392,14 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha
msg.setZeroPayloadFlag()
err := m.send(msg)
if err != nil {
m.respMu.Lock()
responses <- Response{Err: err}
m.closeLocked()
m.respMu.Unlock()
m.addErrorNonBlockingClose(internalResp, err)
}
return
}
if errState {
continue
}
// Grab a send token.
select {
case <-m.ctx.Done():
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true
continue
case <-m.outBlock:
@ -398,8 +414,7 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha
err := m.send(msg)
PutByteBuffer(req)
if err != nil {
responses <- Response{Err: err}
m.close()
m.addErrorNonBlockingClose(internalResp, err)
errState = true
continue
}
@ -534,6 +549,7 @@ func (m *muxClient) closeLocked() {
if m.closed {
return
}
// We hold the lock, so nobody can modify m.respWait while we're closing.
if m.respWait != nil {
xioutil.SafeClose(m.respWait)
m.respWait = nil

View File

@ -41,7 +41,8 @@ type Stream struct {
// Requests sent cannot be used any further by the called.
Requests chan<- []byte
ctx context.Context
muxID uint64
ctx context.Context
}
// Send a payload to the remote server.