Revert "Fix two-way stream cancelation and pings (#19763)"

This reverts commit 4d698841f4ea563cec3b2824db22bf928a1d6273.
This commit is contained in:
Harshavardhana 2024-05-22 03:00:00 -07:00
parent 4d698841f4
commit ae14681c3e
7 changed files with 88 additions and 396 deletions

View File

@ -125,7 +125,6 @@ type Connection struct {
// For testing only // For testing only
debugInConn net.Conn debugInConn net.Conn
debugOutConn net.Conn debugOutConn net.Conn
blockMessages atomic.Pointer[<-chan struct{}]
addDeadline time.Duration addDeadline time.Duration
connMu sync.Mutex connMu sync.Mutex
} }
@ -976,11 +975,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
return return
} }
block := c.blockMessages.Load()
if block != nil && *block != nil {
<-*block
}
if c.incomingBytes != nil { if c.incomingBytes != nil {
c.incomingBytes(int64(len(msg))) c.incomingBytes(int64(len(msg)))
} }
@ -1369,10 +1363,6 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan
} }
func (c *Connection) handlePong(ctx context.Context, m message) { func (c *Connection) handlePong(ctx context.Context, m message) {
if m.MuxID == 0 && m.Payload == nil {
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
return
}
var pong pongMsg var pong pongMsg
_, err := pong.UnmarshalMsg(m.Payload) _, err := pong.UnmarshalMsg(m.Payload)
PutByteBuffer(m.Payload) PutByteBuffer(m.Payload)
@ -1392,9 +1382,7 @@ func (c *Connection) handlePong(ctx context.Context, m message) {
func (c *Connection) handlePing(ctx context.Context, m message) { func (c *Connection) handlePing(ctx context.Context, m message) {
if m.MuxID == 0 { if m.MuxID == 0 {
m.Flags.Clear(FlagPayloadIsZero) gridLogIf(ctx, c.queueMsg(m, &pongMsg{}))
m.Op = OpPong
gridLogIf(ctx, c.queueMsg(m, nil))
return return
} }
// Single calls do not support pinging. // Single calls do not support pinging.
@ -1611,12 +1599,7 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.connMu.Lock() c.connMu.Lock()
defer c.connMu.Unlock() defer c.connMu.Unlock()
c.connPingInterval = args[0].(time.Duration) c.connPingInterval = args[0].(time.Duration)
if c.connPingInterval < time.Second {
panic("CONN ping interval too low")
}
case debugSetClientPingDuration: case debugSetClientPingDuration:
c.connMu.Lock()
defer c.connMu.Unlock()
c.clientPingInterval = args[0].(time.Duration) c.clientPingInterval = args[0].(time.Duration)
case debugAddToDeadline: case debugAddToDeadline:
c.addDeadline = args[0].(time.Duration) c.addDeadline = args[0].(time.Duration)
@ -1632,11 +1615,6 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
mid.respMu.Lock() mid.respMu.Lock()
resp(mid.closed) resp(mid.closed)
mid.respMu.Unlock() mid.respMu.Unlock()
case debugBlockInboundMessages:
c.connMu.Lock()
block := (<-chan struct{})(args[0].(chan struct{}))
c.blockMessages.Store(&block)
c.connMu.Unlock()
} }
} }

View File

@ -50,7 +50,6 @@ const (
debugSetClientPingDuration debugSetClientPingDuration
debugAddToDeadline debugAddToDeadline
debugIsOutgoingClosed debugIsOutgoingClosed
debugBlockInboundMessages
) )
// TestGrid contains a grid of servers for testing purposes. // TestGrid contains a grid of servers for testing purposes.

View File

@ -16,12 +16,11 @@ func _() {
_ = x[debugSetClientPingDuration-5] _ = x[debugSetClientPingDuration-5]
_ = x[debugAddToDeadline-6] _ = x[debugAddToDeadline-6]
_ = x[debugIsOutgoingClosed-7] _ = x[debugIsOutgoingClosed-7]
_ = x[debugBlockInboundMessages-8]
} }
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages" const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed"
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176} var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151}
func (i debugMsg) String() string { func (i debugMsg) String() string {
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) { if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {

View File

@ -378,54 +378,6 @@ func TestStreamSuite(t *testing.T) {
assertNoActive(t, connRemoteLocal) assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote) assertNoActive(t, connLocalToRemote)
}) })
t.Run("testServerStreamOnewayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 0)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 1)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReq", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
} }
func testStreamRoundtrip(t *testing.T, local, remote *Manager) { func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
@ -539,12 +491,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
register(remote) register(remote)
// local to remote // local to remote
testHandler := func(t *testing.T, handler HandlerID, sendReq bool) { testHandler := func(t *testing.T, handler HandlerID) {
remoteConn := local.Connection(remoteHost) remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!" const testPayload = "Hello Grid World!"
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload)) st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err) errFatal(err)
clientCanceled := make(chan time.Time, 1) clientCanceled := make(chan time.Time, 1)
err = nil err = nil
@ -561,18 +513,6 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
clientCanceled <- time.Now() clientCanceled <- time.Now()
}(t) }(t)
start := time.Now() start := time.Now()
if st.Requests != nil {
defer close(st.Requests)
}
// Fill up queue.
for sendReq {
select {
case st.Requests <- []byte("Hello"):
time.Sleep(10 * time.Millisecond)
default:
sendReq = false
}
}
cancel() cancel()
<-serverCanceled <-serverCanceled
t.Log("server cancel time:", time.Since(start)) t.Log("server cancel time:", time.Since(start))
@ -584,13 +524,11 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
} }
// local to remote, unbuffered // local to remote, unbuffered
t.Run("unbuffered", func(t *testing.T) { t.Run("unbuffered", func(t *testing.T) {
testHandler(t, handlerTest, false) testHandler(t, handlerTest)
}) })
t.Run("buffered", func(t *testing.T) { t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2, false) testHandler(t, handlerTest2)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2, true)
}) })
} }
@ -1087,167 +1025,6 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
} }
} }
// testServerStreamNoPing will test if server and client handle no pings.
func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Stop processing requests
nowBlocking := make(chan struct{})
remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking)
// Check that local returned.
err = st.Results(func(b []byte) error {
return nil
})
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("error: %v", err)
// Check that remote is canceled.
<-serverCanceled
close(nowBlocking)
}
// testServerStreamNoPing will test if server and client handle ping even when blocked.
func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
for blockResp {
select {
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
case resp <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Block until we have exceeded the deadline several times over.
nowBlocking := make(chan struct{})
time.AfterFunc(time.Second, func() {
cancel()
close(nowBlocking)
})
if inCap > 0 {
go func() {
defer close(st.Requests)
if !blockReq {
<-nowBlocking
return
}
for {
select {
case <-nowBlocking:
return
case <-st.Done():
case st.Requests <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
}()
}
// Check that local returned.
err = st.Results(func(b []byte) error {
select {
case <-nowBlocking:
case <-st.Done():
return ctx.Err()
}
return nil
})
select {
case <-nowBlocking:
default:
t.Fatal("expected to be blocked. got err", err)
}
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("error: %v", err)
// Check that remote is canceled.
<-serverCanceled
}
func timeout(after time.Duration) (cancel func()) { func timeout(after time.Duration) (cancel func()) {
c := time.After(after) c := time.After(after)
cc := make(chan struct{}) cc := make(chan struct{})

View File

@ -50,7 +50,6 @@ type muxClient struct {
outBlock chan struct{} outBlock chan struct{}
subroute *subHandlerID subroute *subHandlerID
respErr atomic.Pointer[error] respErr atomic.Pointer[error]
clientPingInterval time.Duration
} }
// Response is a response from the server. // Response is a response from the server.
@ -66,9 +65,8 @@ func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxCli
ctx: ctx, ctx: ctx,
cancelFn: cancelFn, cancelFn: cancelFn,
parent: parent, parent: parent,
LastPong: time.Now().UnixNano(), LastPong: time.Now().Unix(),
BaseFlags: parent.baseFlags, BaseFlags: parent.baseFlags,
clientPingInterval: parent.clientPingInterval,
} }
} }
@ -311,11 +309,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
} }
}() }()
var pingTimer <-chan time.Time var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval { if m.deadline == 0 || m.deadline > clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval) ticker := time.NewTicker(clientPingInterval)
defer ticker.Stop() defer ticker.Stop()
pingTimer = ticker.C pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) atomic.StoreInt64(&m.LastPong, time.Now().Unix())
} }
defer m.parent.deleteMux(false, m.MuxID) defer m.parent.deleteMux(false, m.MuxID)
for { for {
@ -369,7 +367,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
} }
// Only check ping when not closed. // Only check ping when not closed.
if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 { if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 {
m.respMu.Unlock() m.respMu.Unlock()
if debugPrint { if debugPrint {
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got) fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
@ -390,20 +388,15 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
// responseCh is the channel to that goes to the requester. // responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server. // internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer func() { defer m.parent.deleteMux(false, m.MuxID)
m.parent.deleteMux(false, m.MuxID) defer xioutil.SafeClose(responseCh)
// addErrorNonBlockingClose will close the response channel.
xioutil.SafeClose(responseCh)
}()
// Cancelation and errors are handled by handleTwowayRequests below.
for resp := range internalResp { for resp := range internalResp {
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
responseCh <- resp responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
} }
} }
func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) { func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
var errState bool var errState bool
if debugPrint { if debugPrint {
start := time.Now() start := time.Now()
@ -412,30 +405,24 @@ func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-cha
}() }()
} }
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// Listen for client messages. // Listen for client messages.
reqLoop: for {
for !errState { if errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
if debugPrint { if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID) fmt.Println("Client sending disconnect to mux", m.MuxID)
} }
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true errState = true
continue continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
case req, ok := <-requests: case req, ok := <-requests:
if !ok { if !ok {
// Done send EOF // Done send EOF
@ -445,28 +432,22 @@ reqLoop:
msg := message{ msg := message{
Op: OpMuxClientMsg, Op: OpMuxClientMsg,
MuxID: m.MuxID, MuxID: m.MuxID,
Seq: 1,
Flags: FlagEOF, Flags: FlagEOF,
} }
msg.setZeroPayloadFlag() msg.setZeroPayloadFlag()
err := m.send(msg) err := m.send(msg)
if err != nil { if err != nil {
m.addErrorNonBlockingClose(errResp, err) m.addErrorNonBlockingClose(internalResp, err)
} }
break reqLoop return
} }
// Grab a send token. // Grab a send token.
sendReq:
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx)) m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true errState = true
continue continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
goto sendReq
case <-m.outBlock: case <-m.outBlock:
} }
msg := message{ msg := message{
@ -479,41 +460,13 @@ reqLoop:
err := m.send(msg) err := m.send(msg)
PutByteBuffer(req) PutByteBuffer(req)
if err != nil { if err != nil {
m.addErrorNonBlockingClose(errResp, err) m.addErrorNonBlockingClose(internalResp, err)
errState = true errState = true
continue continue
} }
msg.Seq++ msg.Seq++
} }
} }
if errState {
// Drain requests.
for {
select {
case r, ok := <-requests:
if !ok {
return
}
PutByteBuffer(r)
default:
return
}
}
}
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
return
case <-pingTimer:
errState = !m.doPing(errResp)
}
}
} }
// checkSeq will check if sequence number is correct and increment it by 1. // checkSeq will check if sequence number is correct and increment it by 1.
@ -549,7 +502,7 @@ func (m *muxClient) response(seq uint32, r Response) {
m.addResponse(r) m.addResponse(r)
return return
} }
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) atomic.StoreInt64(&m.LastPong, time.Now().Unix())
ok := m.addResponse(r) ok := m.addResponse(r)
if !ok { if !ok {
PutByteBuffer(r.Msg) PutByteBuffer(r.Msg)
@ -600,7 +553,7 @@ func (m *muxClient) pong(msg pongMsg) {
m.addResponse(Response{Err: err}) m.addResponse(Response{Err: err})
return return
} }
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano()) atomic.StoreInt64(&m.LastPong, time.Now().Unix())
} }
// addResponse will add a response to the response channel. // addResponse will add a response to the response channel.

View File

@ -28,6 +28,8 @@ import (
xioutil "github.com/minio/minio/internal/ioutil" xioutil "github.com/minio/minio/internal/ioutil"
) )
const lastPingThreshold = 4 * clientPingInterval
type muxServer struct { type muxServer struct {
ID uint64 ID uint64
LastPing int64 LastPing int64
@ -41,7 +43,6 @@ type muxServer struct {
sendMu sync.Mutex sendMu sync.Mutex
recvMu sync.Mutex recvMu sync.Mutex
outBlock chan struct{} outBlock chan struct{}
clientPingInterval time.Duration
} }
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
@ -98,7 +99,6 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
outBlock: make(chan struct{}, outboundCap), outBlock: make(chan struct{}, outboundCap),
LastPing: time.Now().Unix(), LastPing: time.Now().Unix(),
BaseFlags: c.baseFlags, BaseFlags: c.baseFlags,
clientPingInterval: c.clientPingInterval,
} }
// Acknowledge Mux created. // Acknowledge Mux created.
// Send async. // Send async.
@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
}(m.outBlock) }(m.outBlock)
// Remote aliveness check if needed. // Remote aliveness check if needed.
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) { if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
go func() { go func() {
wg.Wait() wg.Wait()
m.checkRemoteAlive() m.checkRemoteAlive()
@ -234,7 +234,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<-
// checkRemoteAlive will check if the remote is alive. // checkRemoteAlive will check if the remote is alive.
func (m *muxServer) checkRemoteAlive() { func (m *muxServer) checkRemoteAlive() {
t := time.NewTicker(m.clientPingInterval) t := time.NewTicker(lastPingThreshold / 4)
defer t.Stop() defer t.Stop()
for { for {
select { select {
@ -242,7 +242,7 @@ func (m *muxServer) checkRemoteAlive() {
return return
case <-t.C: case <-t.C:
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
if last > 4*m.clientPingInterval { if last > lastPingThreshold {
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
m.close() m.close()
return return

View File

@ -89,26 +89,12 @@ func (s *Stream) Results(next func(b []byte) error) (err error) {
return nil return nil
} }
if resp.Err != nil { if resp.Err != nil {
s.cancel(resp.Err)
return resp.Err return resp.Err
} }
err = next(resp.Msg) err = next(resp.Msg)
if err != nil { if err != nil {
s.cancel(err)
return err return err
} }
} }
} }
} }
// Done will return a channel that will be closed when the stream is done.
// This mirrors context.Done().
func (s *Stream) Done() <-chan struct{} {
return s.ctx.Done()
}
// Err will return the error that caused the stream to end.
// This mirrors context.Err().
func (s *Stream) Err() error {
return s.ctx.Err()
}