mirror of
https://github.com/minio/minio.git
synced 2025-02-23 11:32:32 -05:00
Fix two-way stream cancelation and pings (#19763)
Do not log errors on oneway streams when sending ping fails. Instead, cancel the stream. This also makes sure pings are sent when blocked on sending responses.
This commit is contained in:
parent
9906b3ade9
commit
4d698841f4
@ -123,10 +123,11 @@ type Connection struct {
|
|||||||
baseFlags Flags
|
baseFlags Flags
|
||||||
|
|
||||||
// For testing only
|
// For testing only
|
||||||
debugInConn net.Conn
|
debugInConn net.Conn
|
||||||
debugOutConn net.Conn
|
debugOutConn net.Conn
|
||||||
addDeadline time.Duration
|
blockMessages atomic.Pointer[<-chan struct{}]
|
||||||
connMu sync.Mutex
|
addDeadline time.Duration
|
||||||
|
connMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID.
|
// Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID.
|
||||||
@ -975,6 +976,11 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
|
|||||||
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
|
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
block := c.blockMessages.Load()
|
||||||
|
if block != nil && *block != nil {
|
||||||
|
<-*block
|
||||||
|
}
|
||||||
|
|
||||||
if c.incomingBytes != nil {
|
if c.incomingBytes != nil {
|
||||||
c.incomingBytes(int64(len(msg)))
|
c.incomingBytes(int64(len(msg)))
|
||||||
}
|
}
|
||||||
@ -1363,6 +1369,10 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) handlePong(ctx context.Context, m message) {
|
func (c *Connection) handlePong(ctx context.Context, m message) {
|
||||||
|
if m.MuxID == 0 && m.Payload == nil {
|
||||||
|
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
|
||||||
|
return
|
||||||
|
}
|
||||||
var pong pongMsg
|
var pong pongMsg
|
||||||
_, err := pong.UnmarshalMsg(m.Payload)
|
_, err := pong.UnmarshalMsg(m.Payload)
|
||||||
PutByteBuffer(m.Payload)
|
PutByteBuffer(m.Payload)
|
||||||
@ -1382,7 +1392,9 @@ func (c *Connection) handlePong(ctx context.Context, m message) {
|
|||||||
|
|
||||||
func (c *Connection) handlePing(ctx context.Context, m message) {
|
func (c *Connection) handlePing(ctx context.Context, m message) {
|
||||||
if m.MuxID == 0 {
|
if m.MuxID == 0 {
|
||||||
gridLogIf(ctx, c.queueMsg(m, &pongMsg{}))
|
m.Flags.Clear(FlagPayloadIsZero)
|
||||||
|
m.Op = OpPong
|
||||||
|
gridLogIf(ctx, c.queueMsg(m, nil))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Single calls do not support pinging.
|
// Single calls do not support pinging.
|
||||||
@ -1599,7 +1611,12 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
|
|||||||
c.connMu.Lock()
|
c.connMu.Lock()
|
||||||
defer c.connMu.Unlock()
|
defer c.connMu.Unlock()
|
||||||
c.connPingInterval = args[0].(time.Duration)
|
c.connPingInterval = args[0].(time.Duration)
|
||||||
|
if c.connPingInterval < time.Second {
|
||||||
|
panic("CONN ping interval too low")
|
||||||
|
}
|
||||||
case debugSetClientPingDuration:
|
case debugSetClientPingDuration:
|
||||||
|
c.connMu.Lock()
|
||||||
|
defer c.connMu.Unlock()
|
||||||
c.clientPingInterval = args[0].(time.Duration)
|
c.clientPingInterval = args[0].(time.Duration)
|
||||||
case debugAddToDeadline:
|
case debugAddToDeadline:
|
||||||
c.addDeadline = args[0].(time.Duration)
|
c.addDeadline = args[0].(time.Duration)
|
||||||
@ -1615,6 +1632,11 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
|
|||||||
mid.respMu.Lock()
|
mid.respMu.Lock()
|
||||||
resp(mid.closed)
|
resp(mid.closed)
|
||||||
mid.respMu.Unlock()
|
mid.respMu.Unlock()
|
||||||
|
case debugBlockInboundMessages:
|
||||||
|
c.connMu.Lock()
|
||||||
|
block := (<-chan struct{})(args[0].(chan struct{}))
|
||||||
|
c.blockMessages.Store(&block)
|
||||||
|
c.connMu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ const (
|
|||||||
debugSetClientPingDuration
|
debugSetClientPingDuration
|
||||||
debugAddToDeadline
|
debugAddToDeadline
|
||||||
debugIsOutgoingClosed
|
debugIsOutgoingClosed
|
||||||
|
debugBlockInboundMessages
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestGrid contains a grid of servers for testing purposes.
|
// TestGrid contains a grid of servers for testing purposes.
|
||||||
|
@ -16,11 +16,12 @@ func _() {
|
|||||||
_ = x[debugSetClientPingDuration-5]
|
_ = x[debugSetClientPingDuration-5]
|
||||||
_ = x[debugAddToDeadline-6]
|
_ = x[debugAddToDeadline-6]
|
||||||
_ = x[debugIsOutgoingClosed-7]
|
_ = x[debugIsOutgoingClosed-7]
|
||||||
|
_ = x[debugBlockInboundMessages-8]
|
||||||
}
|
}
|
||||||
|
|
||||||
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed"
|
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages"
|
||||||
|
|
||||||
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151}
|
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176}
|
||||||
|
|
||||||
func (i debugMsg) String() string {
|
func (i debugMsg) String() string {
|
||||||
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {
|
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {
|
||||||
|
@ -378,6 +378,54 @@ func TestStreamSuite(t *testing.T) {
|
|||||||
assertNoActive(t, connRemoteLocal)
|
assertNoActive(t, connRemoteLocal)
|
||||||
assertNoActive(t, connLocalToRemote)
|
assertNoActive(t, connLocalToRemote)
|
||||||
})
|
})
|
||||||
|
t.Run("testServerStreamOnewayNoPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamNoPing(t, local, remote, 0)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayNoPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamNoPing(t, local, remote, 1)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, false, false)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPingReq", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, false, true)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPingResp", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, true, false)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, true, true)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamOnewayPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 0, false, true)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 0, false, false)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
|
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
|
||||||
@ -491,12 +539,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
|
|||||||
register(remote)
|
register(remote)
|
||||||
|
|
||||||
// local to remote
|
// local to remote
|
||||||
testHandler := func(t *testing.T, handler HandlerID) {
|
testHandler := func(t *testing.T, handler HandlerID, sendReq bool) {
|
||||||
remoteConn := local.Connection(remoteHost)
|
remoteConn := local.Connection(remoteHost)
|
||||||
const testPayload = "Hello Grid World!"
|
const testPayload = "Hello Grid World!"
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
|
st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload))
|
||||||
errFatal(err)
|
errFatal(err)
|
||||||
clientCanceled := make(chan time.Time, 1)
|
clientCanceled := make(chan time.Time, 1)
|
||||||
err = nil
|
err = nil
|
||||||
@ -513,6 +561,18 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
|
|||||||
clientCanceled <- time.Now()
|
clientCanceled <- time.Now()
|
||||||
}(t)
|
}(t)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
if st.Requests != nil {
|
||||||
|
defer close(st.Requests)
|
||||||
|
}
|
||||||
|
// Fill up queue.
|
||||||
|
for sendReq {
|
||||||
|
select {
|
||||||
|
case st.Requests <- []byte("Hello"):
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
default:
|
||||||
|
sendReq = false
|
||||||
|
}
|
||||||
|
}
|
||||||
cancel()
|
cancel()
|
||||||
<-serverCanceled
|
<-serverCanceled
|
||||||
t.Log("server cancel time:", time.Since(start))
|
t.Log("server cancel time:", time.Since(start))
|
||||||
@ -524,11 +584,13 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
|
|||||||
}
|
}
|
||||||
// local to remote, unbuffered
|
// local to remote, unbuffered
|
||||||
t.Run("unbuffered", func(t *testing.T) {
|
t.Run("unbuffered", func(t *testing.T) {
|
||||||
testHandler(t, handlerTest)
|
testHandler(t, handlerTest, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("buffered", func(t *testing.T) {
|
t.Run("buffered", func(t *testing.T) {
|
||||||
testHandler(t, handlerTest2)
|
testHandler(t, handlerTest2, false)
|
||||||
|
})
|
||||||
|
t.Run("buffered", func(t *testing.T) {
|
||||||
|
testHandler(t, handlerTest2, true)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1025,6 +1087,167 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testServerStreamNoPing will test if server and client handle no pings.
|
||||||
|
func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) {
|
||||||
|
defer testlogger.T.SetErrorTB(t)()
|
||||||
|
errFatal := func(err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We fake a local and remote server.
|
||||||
|
remoteHost := remote.HostName()
|
||||||
|
|
||||||
|
// 1: Echo
|
||||||
|
reqStarted := make(chan struct{})
|
||||||
|
serverCanceled := make(chan struct{})
|
||||||
|
register := func(manager *Manager) {
|
||||||
|
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
|
||||||
|
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
|
||||||
|
close(reqStarted)
|
||||||
|
// Just wait for it to cancel.
|
||||||
|
<-ctx.Done()
|
||||||
|
close(serverCanceled)
|
||||||
|
return NewRemoteErr(ctx.Err())
|
||||||
|
},
|
||||||
|
OutCapacity: 1,
|
||||||
|
InCapacity: inCap,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
register(local)
|
||||||
|
register(remote)
|
||||||
|
|
||||||
|
remoteConn := local.Connection(remoteHost)
|
||||||
|
const testPayload = "Hello Grid World!"
|
||||||
|
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
|
||||||
|
errFatal(err)
|
||||||
|
|
||||||
|
// Wait for the server start the request.
|
||||||
|
<-reqStarted
|
||||||
|
|
||||||
|
// Stop processing requests
|
||||||
|
nowBlocking := make(chan struct{})
|
||||||
|
remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking)
|
||||||
|
|
||||||
|
// Check that local returned.
|
||||||
|
err = st.Results(func(b []byte) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
t.Logf("error: %v", err)
|
||||||
|
// Check that remote is canceled.
|
||||||
|
<-serverCanceled
|
||||||
|
close(nowBlocking)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testServerStreamNoPing will test if server and client handle ping even when blocked.
|
||||||
|
func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) {
|
||||||
|
defer testlogger.T.SetErrorTB(t)()
|
||||||
|
errFatal := func(err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We fake a local and remote server.
|
||||||
|
remoteHost := remote.HostName()
|
||||||
|
|
||||||
|
// 1: Echo
|
||||||
|
reqStarted := make(chan struct{})
|
||||||
|
serverCanceled := make(chan struct{})
|
||||||
|
register := func(manager *Manager) {
|
||||||
|
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
|
||||||
|
Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr {
|
||||||
|
close(reqStarted)
|
||||||
|
// Just wait for it to cancel.
|
||||||
|
for blockResp {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(serverCanceled)
|
||||||
|
return NewRemoteErr(ctx.Err())
|
||||||
|
case resp <- []byte{1}:
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Just wait for it to cancel.
|
||||||
|
<-ctx.Done()
|
||||||
|
close(serverCanceled)
|
||||||
|
return NewRemoteErr(ctx.Err())
|
||||||
|
},
|
||||||
|
OutCapacity: 1,
|
||||||
|
InCapacity: inCap,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
register(local)
|
||||||
|
register(remote)
|
||||||
|
|
||||||
|
remoteConn := local.Connection(remoteHost)
|
||||||
|
const testPayload = "Hello Grid World!"
|
||||||
|
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
|
||||||
|
errFatal(err)
|
||||||
|
|
||||||
|
// Wait for the server start the request.
|
||||||
|
<-reqStarted
|
||||||
|
|
||||||
|
// Block until we have exceeded the deadline several times over.
|
||||||
|
nowBlocking := make(chan struct{})
|
||||||
|
time.AfterFunc(time.Second, func() {
|
||||||
|
cancel()
|
||||||
|
close(nowBlocking)
|
||||||
|
})
|
||||||
|
if inCap > 0 {
|
||||||
|
go func() {
|
||||||
|
defer close(st.Requests)
|
||||||
|
if !blockReq {
|
||||||
|
<-nowBlocking
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-nowBlocking:
|
||||||
|
return
|
||||||
|
case <-st.Done():
|
||||||
|
case st.Requests <- []byte{1}:
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
// Check that local returned.
|
||||||
|
err = st.Results(func(b []byte) error {
|
||||||
|
select {
|
||||||
|
case <-nowBlocking:
|
||||||
|
case <-st.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
select {
|
||||||
|
case <-nowBlocking:
|
||||||
|
default:
|
||||||
|
t.Fatal("expected to be blocked. got err", err)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
t.Logf("error: %v", err)
|
||||||
|
// Check that remote is canceled.
|
||||||
|
<-serverCanceled
|
||||||
|
}
|
||||||
|
|
||||||
func timeout(after time.Duration) (cancel func()) {
|
func timeout(after time.Duration) (cancel func()) {
|
||||||
c := time.After(after)
|
c := time.After(after)
|
||||||
cc := make(chan struct{})
|
cc := make(chan struct{})
|
||||||
|
@ -32,24 +32,25 @@ import (
|
|||||||
|
|
||||||
// muxClient is a stateful connection to a remote.
|
// muxClient is a stateful connection to a remote.
|
||||||
type muxClient struct {
|
type muxClient struct {
|
||||||
MuxID uint64
|
MuxID uint64
|
||||||
SendSeq, RecvSeq uint32
|
SendSeq, RecvSeq uint32
|
||||||
LastPong int64
|
LastPong int64
|
||||||
BaseFlags Flags
|
BaseFlags Flags
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancelFn context.CancelCauseFunc
|
cancelFn context.CancelCauseFunc
|
||||||
parent *Connection
|
parent *Connection
|
||||||
respWait chan<- Response
|
respWait chan<- Response
|
||||||
respMu sync.Mutex
|
respMu sync.Mutex
|
||||||
singleResp bool
|
singleResp bool
|
||||||
closed bool
|
closed bool
|
||||||
stateless bool
|
stateless bool
|
||||||
acked bool
|
acked bool
|
||||||
init bool
|
init bool
|
||||||
deadline time.Duration
|
deadline time.Duration
|
||||||
outBlock chan struct{}
|
outBlock chan struct{}
|
||||||
subroute *subHandlerID
|
subroute *subHandlerID
|
||||||
respErr atomic.Pointer[error]
|
respErr atomic.Pointer[error]
|
||||||
|
clientPingInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response is a response from the server.
|
// Response is a response from the server.
|
||||||
@ -61,12 +62,13 @@ type Response struct {
|
|||||||
func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
|
func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
|
||||||
ctx, cancelFn := context.WithCancelCause(ctx)
|
ctx, cancelFn := context.WithCancelCause(ctx)
|
||||||
return &muxClient{
|
return &muxClient{
|
||||||
MuxID: muxID,
|
MuxID: muxID,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancelFn: cancelFn,
|
cancelFn: cancelFn,
|
||||||
parent: parent,
|
parent: parent,
|
||||||
LastPong: time.Now().Unix(),
|
LastPong: time.Now().UnixNano(),
|
||||||
BaseFlags: parent.baseFlags,
|
BaseFlags: parent.baseFlags,
|
||||||
|
clientPingInterval: parent.clientPingInterval,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var pingTimer <-chan time.Time
|
var pingTimer <-chan time.Time
|
||||||
if m.deadline == 0 || m.deadline > clientPingInterval {
|
if m.deadline == 0 || m.deadline > m.clientPingInterval {
|
||||||
ticker := time.NewTicker(clientPingInterval)
|
ticker := time.NewTicker(m.clientPingInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
pingTimer = ticker.C
|
pingTimer = ticker.C
|
||||||
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
defer m.parent.deleteMux(false, m.MuxID)
|
defer m.parent.deleteMux(false, m.MuxID)
|
||||||
for {
|
for {
|
||||||
@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Only check ping when not closed.
|
// Only check ping when not closed.
|
||||||
if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 {
|
if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 {
|
||||||
m.respMu.Unlock()
|
m.respMu.Unlock()
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
|
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
|
||||||
@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
|
|||||||
// responseCh is the channel to that goes to the requester.
|
// responseCh is the channel to that goes to the requester.
|
||||||
// internalResp is the channel that comes from the server.
|
// internalResp is the channel that comes from the server.
|
||||||
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
|
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
|
||||||
defer m.parent.deleteMux(false, m.MuxID)
|
defer func() {
|
||||||
defer xioutil.SafeClose(responseCh)
|
m.parent.deleteMux(false, m.MuxID)
|
||||||
|
// addErrorNonBlockingClose will close the response channel.
|
||||||
|
xioutil.SafeClose(responseCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Cancelation and errors are handled by handleTwowayRequests below.
|
||||||
for resp := range internalResp {
|
for resp := range internalResp {
|
||||||
responseCh <- resp
|
|
||||||
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
|
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
|
||||||
|
responseCh <- resp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
|
func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) {
|
||||||
var errState bool
|
var errState bool
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pingTimer <-chan time.Time
|
||||||
|
if m.deadline == 0 || m.deadline > m.clientPingInterval {
|
||||||
|
ticker := time.NewTicker(m.clientPingInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
pingTimer = ticker.C
|
||||||
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
// Listen for client messages.
|
// Listen for client messages.
|
||||||
for {
|
reqLoop:
|
||||||
if errState {
|
for !errState {
|
||||||
go func() {
|
|
||||||
// Drain requests.
|
|
||||||
for range requests {
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Println("Client sending disconnect to mux", m.MuxID)
|
fmt.Println("Client sending disconnect to mux", m.MuxID)
|
||||||
}
|
}
|
||||||
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
|
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
|
||||||
errState = true
|
errState = true
|
||||||
continue
|
continue
|
||||||
|
case <-pingTimer:
|
||||||
|
if !m.doPing(errResp) {
|
||||||
|
errState = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
case req, ok := <-requests:
|
case req, ok := <-requests:
|
||||||
if !ok {
|
if !ok {
|
||||||
// Done send EOF
|
// Done send EOF
|
||||||
@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
|
|||||||
msg := message{
|
msg := message{
|
||||||
Op: OpMuxClientMsg,
|
Op: OpMuxClientMsg,
|
||||||
MuxID: m.MuxID,
|
MuxID: m.MuxID,
|
||||||
Seq: 1,
|
|
||||||
Flags: FlagEOF,
|
Flags: FlagEOF,
|
||||||
}
|
}
|
||||||
msg.setZeroPayloadFlag()
|
msg.setZeroPayloadFlag()
|
||||||
err := m.send(msg)
|
err := m.send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.addErrorNonBlockingClose(internalResp, err)
|
m.addErrorNonBlockingClose(errResp, err)
|
||||||
}
|
}
|
||||||
return
|
break reqLoop
|
||||||
}
|
}
|
||||||
// Grab a send token.
|
// Grab a send token.
|
||||||
|
sendReq:
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
|
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
|
||||||
errState = true
|
errState = true
|
||||||
continue
|
continue
|
||||||
|
case <-pingTimer:
|
||||||
|
if !m.doPing(errResp) {
|
||||||
|
errState = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
goto sendReq
|
||||||
case <-m.outBlock:
|
case <-m.outBlock:
|
||||||
}
|
}
|
||||||
msg := message{
|
msg := message{
|
||||||
@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
|
|||||||
err := m.send(msg)
|
err := m.send(msg)
|
||||||
PutByteBuffer(req)
|
PutByteBuffer(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.addErrorNonBlockingClose(internalResp, err)
|
m.addErrorNonBlockingClose(errResp, err)
|
||||||
errState = true
|
errState = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
msg.Seq++
|
msg.Seq++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errState {
|
||||||
|
// Drain requests.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case r, ok := <-requests:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
PutByteBuffer(r)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for !errState {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
if debugPrint {
|
||||||
|
fmt.Println("Client sending disconnect to mux", m.MuxID)
|
||||||
|
}
|
||||||
|
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
|
||||||
|
return
|
||||||
|
case <-pingTimer:
|
||||||
|
errState = !m.doPing(errResp)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkSeq will check if sequence number is correct and increment it by 1.
|
// checkSeq will check if sequence number is correct and increment it by 1.
|
||||||
@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) {
|
|||||||
m.addResponse(r)
|
m.addResponse(r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
ok := m.addResponse(r)
|
ok := m.addResponse(r)
|
||||||
if !ok {
|
if !ok {
|
||||||
PutByteBuffer(r.Msg)
|
PutByteBuffer(r.Msg)
|
||||||
@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) {
|
|||||||
m.addResponse(Response{Err: err})
|
m.addResponse(Response{Err: err})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
// addResponse will add a response to the response channel.
|
// addResponse will add a response to the response channel.
|
||||||
|
@ -28,21 +28,20 @@ import (
|
|||||||
xioutil "github.com/minio/minio/internal/ioutil"
|
xioutil "github.com/minio/minio/internal/ioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const lastPingThreshold = 4 * clientPingInterval
|
|
||||||
|
|
||||||
type muxServer struct {
|
type muxServer struct {
|
||||||
ID uint64
|
ID uint64
|
||||||
LastPing int64
|
LastPing int64
|
||||||
SendSeq, RecvSeq uint32
|
SendSeq, RecvSeq uint32
|
||||||
Resp chan []byte
|
Resp chan []byte
|
||||||
BaseFlags Flags
|
BaseFlags Flags
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
inbound chan []byte
|
inbound chan []byte
|
||||||
parent *Connection
|
parent *Connection
|
||||||
sendMu sync.Mutex
|
sendMu sync.Mutex
|
||||||
recvMu sync.Mutex
|
recvMu sync.Mutex
|
||||||
outBlock chan struct{}
|
outBlock chan struct{}
|
||||||
|
clientPingInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
|
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
|
||||||
@ -89,16 +88,17 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
|
|||||||
}
|
}
|
||||||
|
|
||||||
m := muxServer{
|
m := muxServer{
|
||||||
ID: msg.MuxID,
|
ID: msg.MuxID,
|
||||||
RecvSeq: msg.Seq + 1,
|
RecvSeq: msg.Seq + 1,
|
||||||
SendSeq: msg.Seq,
|
SendSeq: msg.Seq,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
parent: c,
|
parent: c,
|
||||||
inbound: nil,
|
inbound: nil,
|
||||||
outBlock: make(chan struct{}, outboundCap),
|
outBlock: make(chan struct{}, outboundCap),
|
||||||
LastPing: time.Now().Unix(),
|
LastPing: time.Now().Unix(),
|
||||||
BaseFlags: c.baseFlags,
|
BaseFlags: c.baseFlags,
|
||||||
|
clientPingInterval: c.clientPingInterval,
|
||||||
}
|
}
|
||||||
// Acknowledge Mux created.
|
// Acknowledge Mux created.
|
||||||
// Send async.
|
// Send async.
|
||||||
@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
|
|||||||
}(m.outBlock)
|
}(m.outBlock)
|
||||||
|
|
||||||
// Remote aliveness check if needed.
|
// Remote aliveness check if needed.
|
||||||
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
|
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) {
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
m.checkRemoteAlive()
|
m.checkRemoteAlive()
|
||||||
@ -234,7 +234,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<-
|
|||||||
|
|
||||||
// checkRemoteAlive will check if the remote is alive.
|
// checkRemoteAlive will check if the remote is alive.
|
||||||
func (m *muxServer) checkRemoteAlive() {
|
func (m *muxServer) checkRemoteAlive() {
|
||||||
t := time.NewTicker(lastPingThreshold / 4)
|
t := time.NewTicker(m.clientPingInterval)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -242,7 +242,7 @@ func (m *muxServer) checkRemoteAlive() {
|
|||||||
return
|
return
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
|
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
|
||||||
if last > lastPingThreshold {
|
if last > 4*m.clientPingInterval {
|
||||||
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
|
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
|
||||||
m.close()
|
m.close()
|
||||||
return
|
return
|
||||||
|
@ -89,12 +89,26 @@ func (s *Stream) Results(next func(b []byte) error) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if resp.Err != nil {
|
if resp.Err != nil {
|
||||||
|
s.cancel(resp.Err)
|
||||||
return resp.Err
|
return resp.Err
|
||||||
}
|
}
|
||||||
err = next(resp.Msg)
|
err = next(resp.Msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.cancel(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Done will return a channel that will be closed when the stream is done.
|
||||||
|
// This mirrors context.Done().
|
||||||
|
func (s *Stream) Done() <-chan struct{} {
|
||||||
|
return s.ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err will return the error that caused the stream to end.
|
||||||
|
// This mirrors context.Err().
|
||||||
|
func (s *Stream) Err() error {
|
||||||
|
return s.ctx.Err()
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user