mirror of
https://github.com/minio/minio.git
synced 2025-04-27 05:15:01 -04:00
Two way streams for upcoming locking enhancements (#19796)
This commit is contained in:
parent
c5141d65ac
commit
f00187033d
@ -200,6 +200,7 @@ func BenchmarkStream(b *testing.B) {
|
|||||||
}{
|
}{
|
||||||
{name: "request", fn: benchmarkGridStreamReqOnly},
|
{name: "request", fn: benchmarkGridStreamReqOnly},
|
||||||
{name: "responses", fn: benchmarkGridStreamRespOnly},
|
{name: "responses", fn: benchmarkGridStreamRespOnly},
|
||||||
|
{name: "twoway", fn: benchmarkGridStreamTwoway},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
b.Run(test.name, func(b *testing.B) {
|
b.Run(test.name, func(b *testing.B) {
|
||||||
@ -438,3 +439,134 @@ func benchmarkGridStreamReqOnly(b *testing.B, n int) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func benchmarkGridStreamTwoway(b *testing.B, n int) {
|
||||||
|
defer testlogger.T.SetErrorTB(b)()
|
||||||
|
|
||||||
|
errFatal := func(err error) {
|
||||||
|
b.Helper()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
grid, err := SetupTestGrid(n)
|
||||||
|
errFatal(err)
|
||||||
|
b.Cleanup(grid.Cleanup)
|
||||||
|
const messages = 10
|
||||||
|
// Create n managers.
|
||||||
|
const payloadSize = 512
|
||||||
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
_, err = rng.Read(payload)
|
||||||
|
errFatal(err)
|
||||||
|
|
||||||
|
for _, remote := range grid.Managers {
|
||||||
|
// Register a single handler which echos the payload.
|
||||||
|
errFatal(remote.RegisterStreamingHandler(handlerTest, StreamHandler{
|
||||||
|
// Send 10x requests.
|
||||||
|
Handle: func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr {
|
||||||
|
got := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case b, ok := <-in:
|
||||||
|
if !ok {
|
||||||
|
if got != messages {
|
||||||
|
return NewRemoteErrf("wrong number of requests. want %d, got %d", messages, got)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out <- b
|
||||||
|
got++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
Subroute: "some-subroute",
|
||||||
|
OutCapacity: 1,
|
||||||
|
InCapacity: 1, // Only one message buffered.
|
||||||
|
}))
|
||||||
|
errFatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all to connect
|
||||||
|
// Parallel writes per server.
|
||||||
|
for par := 1; par <= 32; par *= 2 {
|
||||||
|
b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) {
|
||||||
|
defer timeout(30 * time.Second)()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(payload) * (2*messages + 1)))
|
||||||
|
b.ResetTimer()
|
||||||
|
t := time.Now()
|
||||||
|
var ops int64
|
||||||
|
var lat int64
|
||||||
|
b.SetParallelism(par)
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
n := 0
|
||||||
|
var latency int64
|
||||||
|
managers := grid.Managers
|
||||||
|
hosts := grid.Hosts
|
||||||
|
for pb.Next() {
|
||||||
|
// Pick a random manager.
|
||||||
|
src, dst := rng.Intn(len(managers)), rng.Intn(len(managers))
|
||||||
|
if src == dst {
|
||||||
|
dst = (dst + 1) % len(managers)
|
||||||
|
}
|
||||||
|
local := managers[src]
|
||||||
|
conn := local.Connection(hosts[dst]).Subroute("some-subroute")
|
||||||
|
if conn == nil {
|
||||||
|
b.Fatal("No connection")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// Send the payload.
|
||||||
|
t := time.Now()
|
||||||
|
st, err := conn.NewStream(ctx, handlerTest, payload)
|
||||||
|
if err != nil {
|
||||||
|
if debugReqs {
|
||||||
|
fmt.Println(err.Error())
|
||||||
|
}
|
||||||
|
b.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
got := 0
|
||||||
|
sent := 0
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < messages; i++ {
|
||||||
|
st.Requests <- append(GetByteBuffer()[:0], payload...)
|
||||||
|
if sent++; sent == messages {
|
||||||
|
close(st.Requests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = st.Results(func(b []byte) error {
|
||||||
|
got++
|
||||||
|
PutByteBuffer(b)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if debugReqs {
|
||||||
|
fmt.Println(err.Error())
|
||||||
|
}
|
||||||
|
b.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
if got != messages {
|
||||||
|
b.Fatalf("wrong number of responses. want %d, got %d", messages, got)
|
||||||
|
}
|
||||||
|
latency += time.Since(t).Nanoseconds()
|
||||||
|
cancel()
|
||||||
|
n += got
|
||||||
|
}
|
||||||
|
atomic.AddInt64(&ops, int64(n*2))
|
||||||
|
atomic.AddInt64(&lat, latency)
|
||||||
|
})
|
||||||
|
spent := time.Since(t)
|
||||||
|
if spent > 0 && n > 0 {
|
||||||
|
// Since we are benchmarking n parallel servers we need to multiply by n.
|
||||||
|
// This will give an estimate of the total ops/s.
|
||||||
|
latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond)
|
||||||
|
b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s")
|
||||||
|
b.ReportMetric(latency/float64(ops), "ms/op")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -130,6 +130,7 @@ type Connection struct {
|
|||||||
// For testing only
|
// For testing only
|
||||||
debugInConn net.Conn
|
debugInConn net.Conn
|
||||||
debugOutConn net.Conn
|
debugOutConn net.Conn
|
||||||
|
blockMessages atomic.Pointer[<-chan struct{}]
|
||||||
addDeadline time.Duration
|
addDeadline time.Duration
|
||||||
connMu sync.Mutex
|
connMu sync.Mutex
|
||||||
}
|
}
|
||||||
@ -883,7 +884,7 @@ func (c *Connection) updateState(s State) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s == StateConnected {
|
if s == StateConnected {
|
||||||
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
atomic.StoreUint32((*uint32)(&c.state), uint32(s))
|
atomic.StoreUint32((*uint32)(&c.state), uint32(s))
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
@ -993,6 +994,11 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
block := c.blockMessages.Load()
|
||||||
|
if block != nil && *block != nil {
|
||||||
|
<-*block
|
||||||
|
}
|
||||||
|
|
||||||
if c.incomingBytes != nil {
|
if c.incomingBytes != nil {
|
||||||
c.incomingBytes(int64(len(msg)))
|
c.incomingBytes(int64(len(msg)))
|
||||||
}
|
}
|
||||||
@ -1094,7 +1100,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
lastPong := atomic.LoadInt64(&c.LastPong)
|
lastPong := atomic.LoadInt64(&c.LastPong)
|
||||||
if lastPong > 0 {
|
if lastPong > 0 {
|
||||||
lastPongTime := time.Unix(lastPong, 0)
|
lastPongTime := time.Unix(0, lastPong)
|
||||||
if d := time.Since(lastPongTime); d > connPingInterval*2 {
|
if d := time.Since(lastPongTime); d > connPingInterval*2 {
|
||||||
gridLogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond)))
|
gridLogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond)))
|
||||||
return
|
return
|
||||||
@ -1105,7 +1111,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
gridLogIf(ctx, err)
|
gridLogIf(ctx, err)
|
||||||
// Fake it...
|
// Fake it...
|
||||||
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case toSend = <-c.outQueue:
|
case toSend = <-c.outQueue:
|
||||||
@ -1406,12 +1412,16 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) handlePong(ctx context.Context, m message) {
|
func (c *Connection) handlePong(ctx context.Context, m message) {
|
||||||
|
if m.MuxID == 0 && m.Payload == nil {
|
||||||
|
atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
|
||||||
|
return
|
||||||
|
}
|
||||||
var pong pongMsg
|
var pong pongMsg
|
||||||
_, err := pong.UnmarshalMsg(m.Payload)
|
_, err := pong.UnmarshalMsg(m.Payload)
|
||||||
PutByteBuffer(m.Payload)
|
PutByteBuffer(m.Payload)
|
||||||
gridLogIf(ctx, err)
|
gridLogIf(ctx, err)
|
||||||
if m.MuxID == 0 {
|
if m.MuxID == 0 {
|
||||||
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v, ok := c.outgoing.Load(m.MuxID); ok {
|
if v, ok := c.outgoing.Load(m.MuxID); ok {
|
||||||
@ -1425,7 +1435,9 @@ func (c *Connection) handlePong(ctx context.Context, m message) {
|
|||||||
|
|
||||||
func (c *Connection) handlePing(ctx context.Context, m message) {
|
func (c *Connection) handlePing(ctx context.Context, m message) {
|
||||||
if m.MuxID == 0 {
|
if m.MuxID == 0 {
|
||||||
gridLogIf(ctx, c.queueMsg(m, &pongMsg{}))
|
m.Flags.Clear(FlagPayloadIsZero)
|
||||||
|
m.Op = OpPong
|
||||||
|
gridLogIf(ctx, c.queueMsg(m, nil))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Single calls do not support pinging.
|
// Single calls do not support pinging.
|
||||||
@ -1560,9 +1572,15 @@ func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) {
|
|||||||
}
|
}
|
||||||
if m.Flags&FlagEOF != 0 {
|
if m.Flags&FlagEOF != 0 {
|
||||||
if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 {
|
if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 {
|
||||||
|
// We must obtain the lock before calling cancelFn
|
||||||
|
// Otherwise others may pick up the error before close is called.
|
||||||
|
v.respMu.Lock()
|
||||||
v.cancelFn(errStreamEOF)
|
v.cancelFn(errStreamEOF)
|
||||||
}
|
v.closeLocked()
|
||||||
|
v.respMu.Unlock()
|
||||||
|
} else {
|
||||||
v.close()
|
v.close()
|
||||||
|
}
|
||||||
if debugReqs {
|
if debugReqs {
|
||||||
fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX")
|
fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX")
|
||||||
}
|
}
|
||||||
@ -1617,7 +1635,7 @@ func (c *Connection) Stats() madmin.RPCMetrics {
|
|||||||
IncomingMessages: c.inMessages.Load(),
|
IncomingMessages: c.inMessages.Load(),
|
||||||
OutgoingMessages: c.outMessages.Load(),
|
OutgoingMessages: c.outMessages.Load(),
|
||||||
OutQueue: len(c.outQueue),
|
OutQueue: len(c.outQueue),
|
||||||
LastPongTime: time.Unix(c.LastPong, 0).UTC(),
|
LastPongTime: time.Unix(0, c.LastPong).UTC(),
|
||||||
}
|
}
|
||||||
m.ByDestination = map[string]madmin.RPCMetrics{
|
m.ByDestination = map[string]madmin.RPCMetrics{
|
||||||
c.Remote: m,
|
c.Remote: m,
|
||||||
@ -1659,7 +1677,12 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
|
|||||||
c.connMu.Lock()
|
c.connMu.Lock()
|
||||||
defer c.connMu.Unlock()
|
defer c.connMu.Unlock()
|
||||||
c.connPingInterval = args[0].(time.Duration)
|
c.connPingInterval = args[0].(time.Duration)
|
||||||
|
if c.connPingInterval < time.Second {
|
||||||
|
panic("CONN ping interval too low")
|
||||||
|
}
|
||||||
case debugSetClientPingDuration:
|
case debugSetClientPingDuration:
|
||||||
|
c.connMu.Lock()
|
||||||
|
defer c.connMu.Unlock()
|
||||||
c.clientPingInterval = args[0].(time.Duration)
|
c.clientPingInterval = args[0].(time.Duration)
|
||||||
case debugAddToDeadline:
|
case debugAddToDeadline:
|
||||||
c.addDeadline = args[0].(time.Duration)
|
c.addDeadline = args[0].(time.Duration)
|
||||||
@ -1675,6 +1698,11 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
|
|||||||
mid.respMu.Lock()
|
mid.respMu.Lock()
|
||||||
resp(mid.closed)
|
resp(mid.closed)
|
||||||
mid.respMu.Unlock()
|
mid.respMu.Unlock()
|
||||||
|
case debugBlockInboundMessages:
|
||||||
|
c.connMu.Lock()
|
||||||
|
block := (<-chan struct{})(args[0].(chan struct{}))
|
||||||
|
c.blockMessages.Store(&block)
|
||||||
|
c.connMu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ const (
|
|||||||
debugSetClientPingDuration
|
debugSetClientPingDuration
|
||||||
debugAddToDeadline
|
debugAddToDeadline
|
||||||
debugIsOutgoingClosed
|
debugIsOutgoingClosed
|
||||||
|
debugBlockInboundMessages
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestGrid contains a grid of servers for testing purposes.
|
// TestGrid contains a grid of servers for testing purposes.
|
||||||
|
@ -16,11 +16,12 @@ func _() {
|
|||||||
_ = x[debugSetClientPingDuration-5]
|
_ = x[debugSetClientPingDuration-5]
|
||||||
_ = x[debugAddToDeadline-6]
|
_ = x[debugAddToDeadline-6]
|
||||||
_ = x[debugIsOutgoingClosed-7]
|
_ = x[debugIsOutgoingClosed-7]
|
||||||
|
_ = x[debugBlockInboundMessages-8]
|
||||||
}
|
}
|
||||||
|
|
||||||
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed"
|
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages"
|
||||||
|
|
||||||
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151}
|
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176}
|
||||||
|
|
||||||
func (i debugMsg) String() string {
|
func (i debugMsg) String() string {
|
||||||
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {
|
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {
|
||||||
|
@ -26,6 +26,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -378,6 +379,54 @@ func TestStreamSuite(t *testing.T) {
|
|||||||
assertNoActive(t, connRemoteLocal)
|
assertNoActive(t, connRemoteLocal)
|
||||||
assertNoActive(t, connLocalToRemote)
|
assertNoActive(t, connLocalToRemote)
|
||||||
})
|
})
|
||||||
|
t.Run("testServerStreamOnewayNoPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamNoPing(t, local, remote, 0)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayNoPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamNoPing(t, local, remote, 1)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, false, false)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPingReq", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, false, true)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPingResp", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, true, false)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 1, true, true)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamOnewayPing", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 0, false, true)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
|
t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) {
|
||||||
|
defer timeout(1 * time.Minute)()
|
||||||
|
testServerStreamPingRunning(t, local, remote, 0, false, false)
|
||||||
|
assertNoActive(t, connRemoteLocal)
|
||||||
|
assertNoActive(t, connLocalToRemote)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
|
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
|
||||||
@ -491,12 +540,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
|
|||||||
register(remote)
|
register(remote)
|
||||||
|
|
||||||
// local to remote
|
// local to remote
|
||||||
testHandler := func(t *testing.T, handler HandlerID) {
|
testHandler := func(t *testing.T, handler HandlerID, sendReq bool) {
|
||||||
remoteConn := local.Connection(remoteHost)
|
remoteConn := local.Connection(remoteHost)
|
||||||
const testPayload = "Hello Grid World!"
|
const testPayload = "Hello Grid World!"
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
|
st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload))
|
||||||
errFatal(err)
|
errFatal(err)
|
||||||
clientCanceled := make(chan time.Time, 1)
|
clientCanceled := make(chan time.Time, 1)
|
||||||
err = nil
|
err = nil
|
||||||
@ -513,6 +562,18 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
|
|||||||
clientCanceled <- time.Now()
|
clientCanceled <- time.Now()
|
||||||
}(t)
|
}(t)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
if st.Requests != nil {
|
||||||
|
defer close(st.Requests)
|
||||||
|
}
|
||||||
|
// Fill up queue.
|
||||||
|
for sendReq {
|
||||||
|
select {
|
||||||
|
case st.Requests <- []byte("Hello"):
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
default:
|
||||||
|
sendReq = false
|
||||||
|
}
|
||||||
|
}
|
||||||
cancel()
|
cancel()
|
||||||
<-serverCanceled
|
<-serverCanceled
|
||||||
t.Log("server cancel time:", time.Since(start))
|
t.Log("server cancel time:", time.Since(start))
|
||||||
@ -524,11 +585,13 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
|
|||||||
}
|
}
|
||||||
// local to remote, unbuffered
|
// local to remote, unbuffered
|
||||||
t.Run("unbuffered", func(t *testing.T) {
|
t.Run("unbuffered", func(t *testing.T) {
|
||||||
testHandler(t, handlerTest)
|
testHandler(t, handlerTest, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("buffered", func(t *testing.T) {
|
t.Run("buffered", func(t *testing.T) {
|
||||||
testHandler(t, handlerTest2)
|
testHandler(t, handlerTest2, false)
|
||||||
|
})
|
||||||
|
t.Run("buffered", func(t *testing.T) {
|
||||||
|
testHandler(t, handlerTest2, true)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1025,6 +1088,170 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testServerStreamNoPing will test if server and client handle no pings.
|
||||||
|
func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) {
|
||||||
|
defer testlogger.T.SetErrorTB(t)()
|
||||||
|
errFatal := func(err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We fake a local and remote server.
|
||||||
|
remoteHost := remote.HostName()
|
||||||
|
|
||||||
|
// 1: Echo
|
||||||
|
reqStarted := make(chan struct{})
|
||||||
|
serverCanceled := make(chan struct{})
|
||||||
|
register := func(manager *Manager) {
|
||||||
|
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
|
||||||
|
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
|
||||||
|
close(reqStarted)
|
||||||
|
// Just wait for it to cancel.
|
||||||
|
<-ctx.Done()
|
||||||
|
close(serverCanceled)
|
||||||
|
return NewRemoteErr(ctx.Err())
|
||||||
|
},
|
||||||
|
OutCapacity: 1,
|
||||||
|
InCapacity: inCap,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
register(local)
|
||||||
|
register(remote)
|
||||||
|
|
||||||
|
remoteConn := local.Connection(remoteHost)
|
||||||
|
const testPayload = "Hello Grid World!"
|
||||||
|
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
|
||||||
|
defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
|
||||||
|
errFatal(err)
|
||||||
|
|
||||||
|
// Wait for the server start the request.
|
||||||
|
<-reqStarted
|
||||||
|
|
||||||
|
// Stop processing requests
|
||||||
|
nowBlocking := make(chan struct{})
|
||||||
|
remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking)
|
||||||
|
|
||||||
|
// Check that local returned.
|
||||||
|
err = st.Results(func(b []byte) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
t.Logf("response: %v", err)
|
||||||
|
|
||||||
|
// Check that remote is canceled.
|
||||||
|
<-serverCanceled
|
||||||
|
close(nowBlocking)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testServerStreamPingRunning will test if server and client handle ping even when blocked.
|
||||||
|
func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) {
|
||||||
|
defer testlogger.T.SetErrorTB(t)()
|
||||||
|
errFatal := func(err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We fake a local and remote server.
|
||||||
|
remoteHost := remote.HostName()
|
||||||
|
|
||||||
|
// 1: Echo
|
||||||
|
reqStarted := make(chan struct{})
|
||||||
|
serverCanceled := make(chan struct{})
|
||||||
|
register := func(manager *Manager) {
|
||||||
|
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
|
||||||
|
Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr {
|
||||||
|
close(reqStarted)
|
||||||
|
// Just wait for it to cancel.
|
||||||
|
for blockResp {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(serverCanceled)
|
||||||
|
return NewRemoteErr(ctx.Err())
|
||||||
|
case resp <- []byte{1}:
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Just wait for it to cancel.
|
||||||
|
<-ctx.Done()
|
||||||
|
close(serverCanceled)
|
||||||
|
return NewRemoteErr(ctx.Err())
|
||||||
|
},
|
||||||
|
OutCapacity: 1,
|
||||||
|
InCapacity: inCap,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
register(local)
|
||||||
|
register(remote)
|
||||||
|
|
||||||
|
remoteConn := local.Connection(remoteHost)
|
||||||
|
const testPayload = "Hello Grid World!"
|
||||||
|
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
|
||||||
|
defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
|
||||||
|
errFatal(err)
|
||||||
|
|
||||||
|
// Wait for the server start the request.
|
||||||
|
<-reqStarted
|
||||||
|
|
||||||
|
// Block until we have exceeded the deadline several times over.
|
||||||
|
nowBlocking := make(chan struct{})
|
||||||
|
var mu sync.Mutex
|
||||||
|
time.AfterFunc(time.Second, func() {
|
||||||
|
mu.Lock()
|
||||||
|
cancel()
|
||||||
|
close(nowBlocking)
|
||||||
|
mu.Unlock()
|
||||||
|
})
|
||||||
|
if inCap > 0 {
|
||||||
|
go func() {
|
||||||
|
defer close(st.Requests)
|
||||||
|
if !blockReq {
|
||||||
|
<-nowBlocking
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-nowBlocking:
|
||||||
|
return
|
||||||
|
case <-st.Done():
|
||||||
|
case st.Requests <- []byte{1}:
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
// Check that local returned.
|
||||||
|
err = st.Results(func(b []byte) error {
|
||||||
|
<-st.Done()
|
||||||
|
return ctx.Err()
|
||||||
|
})
|
||||||
|
mu.Lock()
|
||||||
|
select {
|
||||||
|
case <-nowBlocking:
|
||||||
|
default:
|
||||||
|
t.Fatal("expected to be blocked. got err", err)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
t.Logf("response: %v", err)
|
||||||
|
// Check that remote is canceled.
|
||||||
|
<-serverCanceled
|
||||||
|
}
|
||||||
|
|
||||||
func timeout(after time.Duration) (cancel func()) {
|
func timeout(after time.Duration) (cancel func()) {
|
||||||
c := time.After(after)
|
c := time.After(after)
|
||||||
cc := make(chan struct{})
|
cc := make(chan struct{})
|
||||||
|
@ -50,6 +50,7 @@ type muxClient struct {
|
|||||||
outBlock chan struct{}
|
outBlock chan struct{}
|
||||||
subroute *subHandlerID
|
subroute *subHandlerID
|
||||||
respErr atomic.Pointer[error]
|
respErr atomic.Pointer[error]
|
||||||
|
clientPingInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response is a response from the server.
|
// Response is a response from the server.
|
||||||
@ -65,8 +66,9 @@ func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxCli
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancelFn: cancelFn,
|
cancelFn: cancelFn,
|
||||||
parent: parent,
|
parent: parent,
|
||||||
LastPong: time.Now().Unix(),
|
LastPong: time.Now().UnixNano(),
|
||||||
BaseFlags: parent.baseFlags,
|
BaseFlags: parent.baseFlags,
|
||||||
|
clientPingInterval: parent.clientPingInterval,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var pingTimer <-chan time.Time
|
var pingTimer <-chan time.Time
|
||||||
if m.deadline == 0 || m.deadline > clientPingInterval {
|
if m.deadline == 0 || m.deadline > m.clientPingInterval {
|
||||||
ticker := time.NewTicker(clientPingInterval)
|
ticker := time.NewTicker(m.clientPingInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
pingTimer = ticker.C
|
pingTimer = ticker.C
|
||||||
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
defer m.parent.deleteMux(false, m.MuxID)
|
defer m.parent.deleteMux(false, m.MuxID)
|
||||||
for {
|
for {
|
||||||
@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Only check ping when not closed.
|
// Only check ping when not closed.
|
||||||
if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 {
|
if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 {
|
||||||
m.respMu.Unlock()
|
m.respMu.Unlock()
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
|
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
|
||||||
@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
|
|||||||
// responseCh is the channel to that goes to the requester.
|
// responseCh is the channel to that goes to the requester.
|
||||||
// internalResp is the channel that comes from the server.
|
// internalResp is the channel that comes from the server.
|
||||||
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
|
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
|
||||||
defer m.parent.deleteMux(false, m.MuxID)
|
defer func() {
|
||||||
defer xioutil.SafeClose(responseCh)
|
m.parent.deleteMux(false, m.MuxID)
|
||||||
|
// addErrorNonBlockingClose will close the response channel.
|
||||||
|
xioutil.SafeClose(responseCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Cancelation and errors are handled by handleTwowayRequests below.
|
||||||
for resp := range internalResp {
|
for resp := range internalResp {
|
||||||
responseCh <- resp
|
|
||||||
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
|
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
|
||||||
|
responseCh <- resp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
|
func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) {
|
||||||
var errState bool
|
var errState bool
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pingTimer <-chan time.Time
|
||||||
|
if m.deadline == 0 || m.deadline > m.clientPingInterval {
|
||||||
|
ticker := time.NewTicker(m.clientPingInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
pingTimer = ticker.C
|
||||||
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
// Listen for client messages.
|
// Listen for client messages.
|
||||||
for {
|
reqLoop:
|
||||||
if errState {
|
for !errState {
|
||||||
go func() {
|
|
||||||
// Drain requests.
|
|
||||||
for range requests {
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Println("Client sending disconnect to mux", m.MuxID)
|
fmt.Println("Client sending disconnect to mux", m.MuxID)
|
||||||
}
|
}
|
||||||
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
|
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
|
||||||
errState = true
|
errState = true
|
||||||
continue
|
continue
|
||||||
|
case <-pingTimer:
|
||||||
|
if !m.doPing(errResp) {
|
||||||
|
errState = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
case req, ok := <-requests:
|
case req, ok := <-requests:
|
||||||
if !ok {
|
if !ok {
|
||||||
// Done send EOF
|
// Done send EOF
|
||||||
@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
|
|||||||
msg := message{
|
msg := message{
|
||||||
Op: OpMuxClientMsg,
|
Op: OpMuxClientMsg,
|
||||||
MuxID: m.MuxID,
|
MuxID: m.MuxID,
|
||||||
Seq: 1,
|
|
||||||
Flags: FlagEOF,
|
Flags: FlagEOF,
|
||||||
}
|
}
|
||||||
msg.setZeroPayloadFlag()
|
msg.setZeroPayloadFlag()
|
||||||
err := m.send(msg)
|
err := m.send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.addErrorNonBlockingClose(internalResp, err)
|
m.addErrorNonBlockingClose(errResp, err)
|
||||||
}
|
}
|
||||||
return
|
break reqLoop
|
||||||
}
|
}
|
||||||
// Grab a send token.
|
// Grab a send token.
|
||||||
|
sendReq:
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
|
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
|
||||||
errState = true
|
errState = true
|
||||||
continue
|
continue
|
||||||
|
case <-pingTimer:
|
||||||
|
if !m.doPing(errResp) {
|
||||||
|
errState = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
goto sendReq
|
||||||
case <-m.outBlock:
|
case <-m.outBlock:
|
||||||
}
|
}
|
||||||
msg := message{
|
msg := message{
|
||||||
@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
|
|||||||
err := m.send(msg)
|
err := m.send(msg)
|
||||||
PutByteBuffer(req)
|
PutByteBuffer(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.addErrorNonBlockingClose(internalResp, err)
|
m.addErrorNonBlockingClose(errResp, err)
|
||||||
errState = true
|
errState = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
msg.Seq++
|
msg.Seq++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errState {
|
||||||
|
// Drain requests.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case r, ok := <-requests:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
PutByteBuffer(r)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for !errState {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
if debugPrint {
|
||||||
|
fmt.Println("Client sending disconnect to mux", m.MuxID)
|
||||||
|
}
|
||||||
|
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
|
||||||
|
return
|
||||||
|
case <-pingTimer:
|
||||||
|
errState = !m.doPing(errResp)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkSeq will check if sequence number is correct and increment it by 1.
|
// checkSeq will check if sequence number is correct and increment it by 1.
|
||||||
@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) {
|
|||||||
m.addResponse(r)
|
m.addResponse(r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
ok := m.addResponse(r)
|
ok := m.addResponse(r)
|
||||||
if !ok {
|
if !ok {
|
||||||
PutByteBuffer(r.Msg)
|
PutByteBuffer(r.Msg)
|
||||||
@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) {
|
|||||||
m.addResponse(Response{Err: err})
|
m.addResponse(Response{Err: err})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
|
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
// addResponse will add a response to the response channel.
|
// addResponse will add a response to the response channel.
|
||||||
|
@ -28,8 +28,6 @@ import (
|
|||||||
xioutil "github.com/minio/minio/internal/ioutil"
|
xioutil "github.com/minio/minio/internal/ioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const lastPingThreshold = 4 * clientPingInterval
|
|
||||||
|
|
||||||
type muxServer struct {
|
type muxServer struct {
|
||||||
ID uint64
|
ID uint64
|
||||||
LastPing int64
|
LastPing int64
|
||||||
@ -43,6 +41,7 @@ type muxServer struct {
|
|||||||
sendMu sync.Mutex
|
sendMu sync.Mutex
|
||||||
recvMu sync.Mutex
|
recvMu sync.Mutex
|
||||||
outBlock chan struct{}
|
outBlock chan struct{}
|
||||||
|
clientPingInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
|
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
|
||||||
@ -99,6 +98,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
|
|||||||
outBlock: make(chan struct{}, outboundCap),
|
outBlock: make(chan struct{}, outboundCap),
|
||||||
LastPing: time.Now().Unix(),
|
LastPing: time.Now().Unix(),
|
||||||
BaseFlags: c.baseFlags,
|
BaseFlags: c.baseFlags,
|
||||||
|
clientPingInterval: c.clientPingInterval,
|
||||||
}
|
}
|
||||||
// Acknowledge Mux created.
|
// Acknowledge Mux created.
|
||||||
// Send async.
|
// Send async.
|
||||||
@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
|
|||||||
}(m.outBlock)
|
}(m.outBlock)
|
||||||
|
|
||||||
// Remote aliveness check if needed.
|
// Remote aliveness check if needed.
|
||||||
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
|
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) {
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
m.checkRemoteAlive()
|
m.checkRemoteAlive()
|
||||||
@ -164,11 +164,23 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
|
|||||||
|
|
||||||
// handleInbound sends unblocks when we have delivered the message to the handler.
|
// handleInbound sends unblocks when we have delivered the message to the handler.
|
||||||
func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) {
|
func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) {
|
||||||
for in := range inbound {
|
for {
|
||||||
handlerIn <- in
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
return
|
||||||
|
case in, ok := <-inbound:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
return
|
||||||
|
case handlerIn <- in:
|
||||||
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
|
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sendResponses will send responses to the client.
|
// sendResponses will send responses to the client.
|
||||||
func (m *muxServer) sendResponses(ctx context.Context, toSend <-chan []byte, c *Connection, handlerErr *atomic.Pointer[RemoteErr], outBlock <-chan struct{}) {
|
func (m *muxServer) sendResponses(ctx context.Context, toSend <-chan []byte, c *Connection, handlerErr *atomic.Pointer[RemoteErr], outBlock <-chan struct{}) {
|
||||||
@ -234,7 +246,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<-
|
|||||||
|
|
||||||
// checkRemoteAlive will check if the remote is alive.
|
// checkRemoteAlive will check if the remote is alive.
|
||||||
func (m *muxServer) checkRemoteAlive() {
|
func (m *muxServer) checkRemoteAlive() {
|
||||||
t := time.NewTicker(lastPingThreshold / 4)
|
t := time.NewTicker(m.clientPingInterval)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -242,7 +254,7 @@ func (m *muxServer) checkRemoteAlive() {
|
|||||||
return
|
return
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
|
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
|
||||||
if last > lastPingThreshold {
|
if last > 4*m.clientPingInterval {
|
||||||
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
|
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
|
||||||
m.close()
|
m.close()
|
||||||
return
|
return
|
||||||
@ -257,7 +269,7 @@ func (m *muxServer) checkSeq(seq uint32) (ok bool) {
|
|||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq)
|
fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq)
|
||||||
}
|
}
|
||||||
m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq))
|
m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq), false)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
m.RecvSeq++
|
m.RecvSeq++
|
||||||
@ -268,13 +280,13 @@ func (m *muxServer) message(msg message) {
|
|||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload))
|
fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload))
|
||||||
}
|
}
|
||||||
|
if !m.checkSeq(msg.Seq) {
|
||||||
|
return
|
||||||
|
}
|
||||||
m.recvMu.Lock()
|
m.recvMu.Lock()
|
||||||
defer m.recvMu.Unlock()
|
defer m.recvMu.Unlock()
|
||||||
if cap(m.inbound) == 0 {
|
if cap(m.inbound) == 0 {
|
||||||
m.disconnect("did not expect inbound message")
|
m.disconnect("did not expect inbound message", true)
|
||||||
return
|
|
||||||
}
|
|
||||||
if !m.checkSeq(msg.Seq) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Note, on EOF no value can be sent.
|
// Note, on EOF no value can be sent.
|
||||||
@ -296,7 +308,7 @@ func (m *muxServer) message(msg message) {
|
|||||||
fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq)
|
fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
m.disconnect("handler blocked")
|
m.disconnect("handler blocked", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -332,7 +344,9 @@ func (m *muxServer) ping(seq uint32) pongMsg {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *muxServer) disconnect(msg string) {
|
// disconnect will disconnect the mux.
|
||||||
|
// m.recvMu must be locked when calling this function.
|
||||||
|
func (m *muxServer) disconnect(msg string, locked bool) {
|
||||||
if debugPrint {
|
if debugPrint {
|
||||||
fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg)
|
fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg)
|
||||||
}
|
}
|
||||||
@ -341,6 +355,11 @@ func (m *muxServer) disconnect(msg string) {
|
|||||||
} else {
|
} else {
|
||||||
m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID})
|
m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID})
|
||||||
}
|
}
|
||||||
|
// Unlock, since we are calling deleteMux, which will call close - which will lock recvMu.
|
||||||
|
if locked {
|
||||||
|
m.recvMu.Unlock()
|
||||||
|
defer m.recvMu.Lock()
|
||||||
|
}
|
||||||
m.parent.deleteMux(true, m.ID)
|
m.parent.deleteMux(true, m.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,12 +89,26 @@ func (s *Stream) Results(next func(b []byte) error) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if resp.Err != nil {
|
if resp.Err != nil {
|
||||||
|
s.cancel(resp.Err)
|
||||||
return resp.Err
|
return resp.Err
|
||||||
}
|
}
|
||||||
err = next(resp.Msg)
|
err = next(resp.Msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.cancel(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Done will return a channel that will be closed when the stream is done.
|
||||||
|
// This mirrors context.Done().
|
||||||
|
func (s *Stream) Done() <-chan struct{} {
|
||||||
|
return s.ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err will return the error that caused the stream to end.
|
||||||
|
// This mirrors context.Err().
|
||||||
|
func (s *Stream) Err() error {
|
||||||
|
return s.ctx.Err()
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user