Two way streams for upcoming locking enhancements (#19796)

This commit is contained in:
Klaus Post 2024-06-07 08:51:52 -07:00 committed by GitHub
parent c5141d65ac
commit f00187033d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 573 additions and 104 deletions

View File

@ -200,6 +200,7 @@ func BenchmarkStream(b *testing.B) {
}{ }{
{name: "request", fn: benchmarkGridStreamReqOnly}, {name: "request", fn: benchmarkGridStreamReqOnly},
{name: "responses", fn: benchmarkGridStreamRespOnly}, {name: "responses", fn: benchmarkGridStreamRespOnly},
{name: "twoway", fn: benchmarkGridStreamTwoway},
} }
for _, test := range tests { for _, test := range tests {
b.Run(test.name, func(b *testing.B) { b.Run(test.name, func(b *testing.B) {
@ -438,3 +439,134 @@ func benchmarkGridStreamReqOnly(b *testing.B, n int) {
}) })
} }
} }
func benchmarkGridStreamTwoway(b *testing.B, n int) {
defer testlogger.T.SetErrorTB(b)()
errFatal := func(err error) {
b.Helper()
if err != nil {
b.Fatal(err)
}
}
grid, err := SetupTestGrid(n)
errFatal(err)
b.Cleanup(grid.Cleanup)
const messages = 10
// Create n managers.
const payloadSize = 512
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
payload := make([]byte, payloadSize)
_, err = rng.Read(payload)
errFatal(err)
for _, remote := range grid.Managers {
// Register a single handler which echos the payload.
errFatal(remote.RegisterStreamingHandler(handlerTest, StreamHandler{
// Send 10x requests.
Handle: func(ctx context.Context, payload []byte, in <-chan []byte, out chan<- []byte) *RemoteErr {
got := 0
for {
select {
case b, ok := <-in:
if !ok {
if got != messages {
return NewRemoteErrf("wrong number of requests. want %d, got %d", messages, got)
}
return nil
}
out <- b
got++
}
}
},
Subroute: "some-subroute",
OutCapacity: 1,
InCapacity: 1, // Only one message buffered.
}))
errFatal(err)
}
// Wait for all to connect
// Parallel writes per server.
for par := 1; par <= 32; par *= 2 {
b.Run("par="+strconv.Itoa(par*runtime.GOMAXPROCS(0)), func(b *testing.B) {
defer timeout(30 * time.Second)()
b.ReportAllocs()
b.SetBytes(int64(len(payload) * (2*messages + 1)))
b.ResetTimer()
t := time.Now()
var ops int64
var lat int64
b.SetParallelism(par)
b.RunParallel(func(pb *testing.PB) {
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
n := 0
var latency int64
managers := grid.Managers
hosts := grid.Hosts
for pb.Next() {
// Pick a random manager.
src, dst := rng.Intn(len(managers)), rng.Intn(len(managers))
if src == dst {
dst = (dst + 1) % len(managers)
}
local := managers[src]
conn := local.Connection(hosts[dst]).Subroute("some-subroute")
if conn == nil {
b.Fatal("No connection")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// Send the payload.
t := time.Now()
st, err := conn.NewStream(ctx, handlerTest, payload)
if err != nil {
if debugReqs {
fmt.Println(err.Error())
}
b.Fatal(err.Error())
}
got := 0
sent := 0
go func() {
for i := 0; i < messages; i++ {
st.Requests <- append(GetByteBuffer()[:0], payload...)
if sent++; sent == messages {
close(st.Requests)
return
}
}
}()
err = st.Results(func(b []byte) error {
got++
PutByteBuffer(b)
return nil
})
if err != nil {
if debugReqs {
fmt.Println(err.Error())
}
b.Fatal(err.Error())
}
if got != messages {
b.Fatalf("wrong number of responses. want %d, got %d", messages, got)
}
latency += time.Since(t).Nanoseconds()
cancel()
n += got
}
atomic.AddInt64(&ops, int64(n*2))
atomic.AddInt64(&lat, latency)
})
spent := time.Since(t)
if spent > 0 && n > 0 {
// Since we are benchmarking n parallel servers we need to multiply by n.
// This will give an estimate of the total ops/s.
latency := float64(atomic.LoadInt64(&lat)) / float64(time.Millisecond)
b.ReportMetric(float64(n)*float64(ops)/spent.Seconds(), "vops/s")
b.ReportMetric(latency/float64(ops), "ms/op")
}
})
}
}

View File

@ -128,10 +128,11 @@ type Connection struct {
outMessages atomic.Int64 outMessages atomic.Int64
// For testing only // For testing only
debugInConn net.Conn debugInConn net.Conn
debugOutConn net.Conn debugOutConn net.Conn
addDeadline time.Duration blockMessages atomic.Pointer[<-chan struct{}]
connMu sync.Mutex addDeadline time.Duration
connMu sync.Mutex
} }
// Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID. // Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID.
@ -883,7 +884,7 @@ func (c *Connection) updateState(s State) {
return return
} }
if s == StateConnected { if s == StateConnected {
atomic.StoreInt64(&c.LastPong, time.Now().Unix()) atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
} }
atomic.StoreUint32((*uint32)(&c.state), uint32(s)) atomic.StoreUint32((*uint32)(&c.state), uint32(s))
if debugPrint { if debugPrint {
@ -993,6 +994,11 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
} }
return return
} }
block := c.blockMessages.Load()
if block != nil && *block != nil {
<-*block
}
if c.incomingBytes != nil { if c.incomingBytes != nil {
c.incomingBytes(int64(len(msg))) c.incomingBytes(int64(len(msg)))
} }
@ -1094,7 +1100,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
} }
lastPong := atomic.LoadInt64(&c.LastPong) lastPong := atomic.LoadInt64(&c.LastPong)
if lastPong > 0 { if lastPong > 0 {
lastPongTime := time.Unix(lastPong, 0) lastPongTime := time.Unix(0, lastPong)
if d := time.Since(lastPongTime); d > connPingInterval*2 { if d := time.Since(lastPongTime); d > connPingInterval*2 {
gridLogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond))) gridLogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond)))
return return
@ -1105,7 +1111,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
if err != nil { if err != nil {
gridLogIf(ctx, err) gridLogIf(ctx, err)
// Fake it... // Fake it...
atomic.StoreInt64(&c.LastPong, time.Now().Unix()) atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
continue continue
} }
case toSend = <-c.outQueue: case toSend = <-c.outQueue:
@ -1406,12 +1412,16 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan
} }
func (c *Connection) handlePong(ctx context.Context, m message) { func (c *Connection) handlePong(ctx context.Context, m message) {
if m.MuxID == 0 && m.Payload == nil {
atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
return
}
var pong pongMsg var pong pongMsg
_, err := pong.UnmarshalMsg(m.Payload) _, err := pong.UnmarshalMsg(m.Payload)
PutByteBuffer(m.Payload) PutByteBuffer(m.Payload)
gridLogIf(ctx, err) gridLogIf(ctx, err)
if m.MuxID == 0 { if m.MuxID == 0 {
atomic.StoreInt64(&c.LastPong, time.Now().Unix()) atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
return return
} }
if v, ok := c.outgoing.Load(m.MuxID); ok { if v, ok := c.outgoing.Load(m.MuxID); ok {
@ -1425,7 +1435,9 @@ func (c *Connection) handlePong(ctx context.Context, m message) {
func (c *Connection) handlePing(ctx context.Context, m message) { func (c *Connection) handlePing(ctx context.Context, m message) {
if m.MuxID == 0 { if m.MuxID == 0 {
gridLogIf(ctx, c.queueMsg(m, &pongMsg{})) m.Flags.Clear(FlagPayloadIsZero)
m.Op = OpPong
gridLogIf(ctx, c.queueMsg(m, nil))
return return
} }
// Single calls do not support pinging. // Single calls do not support pinging.
@ -1560,9 +1572,15 @@ func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) {
} }
if m.Flags&FlagEOF != 0 { if m.Flags&FlagEOF != 0 {
if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 { if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 {
// We must obtain the lock before calling cancelFn
// Otherwise others may pick up the error before close is called.
v.respMu.Lock()
v.cancelFn(errStreamEOF) v.cancelFn(errStreamEOF)
v.closeLocked()
v.respMu.Unlock()
} else {
v.close()
} }
v.close()
if debugReqs { if debugReqs {
fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX") fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX")
} }
@ -1617,7 +1635,7 @@ func (c *Connection) Stats() madmin.RPCMetrics {
IncomingMessages: c.inMessages.Load(), IncomingMessages: c.inMessages.Load(),
OutgoingMessages: c.outMessages.Load(), OutgoingMessages: c.outMessages.Load(),
OutQueue: len(c.outQueue), OutQueue: len(c.outQueue),
LastPongTime: time.Unix(c.LastPong, 0).UTC(), LastPongTime: time.Unix(0, c.LastPong).UTC(),
} }
m.ByDestination = map[string]madmin.RPCMetrics{ m.ByDestination = map[string]madmin.RPCMetrics{
c.Remote: m, c.Remote: m,
@ -1659,7 +1677,12 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.connMu.Lock() c.connMu.Lock()
defer c.connMu.Unlock() defer c.connMu.Unlock()
c.connPingInterval = args[0].(time.Duration) c.connPingInterval = args[0].(time.Duration)
if c.connPingInterval < time.Second {
panic("CONN ping interval too low")
}
case debugSetClientPingDuration: case debugSetClientPingDuration:
c.connMu.Lock()
defer c.connMu.Unlock()
c.clientPingInterval = args[0].(time.Duration) c.clientPingInterval = args[0].(time.Duration)
case debugAddToDeadline: case debugAddToDeadline:
c.addDeadline = args[0].(time.Duration) c.addDeadline = args[0].(time.Duration)
@ -1675,6 +1698,11 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
mid.respMu.Lock() mid.respMu.Lock()
resp(mid.closed) resp(mid.closed)
mid.respMu.Unlock() mid.respMu.Unlock()
case debugBlockInboundMessages:
c.connMu.Lock()
block := (<-chan struct{})(args[0].(chan struct{}))
c.blockMessages.Store(&block)
c.connMu.Unlock()
} }
} }

View File

@ -50,6 +50,7 @@ const (
debugSetClientPingDuration debugSetClientPingDuration
debugAddToDeadline debugAddToDeadline
debugIsOutgoingClosed debugIsOutgoingClosed
debugBlockInboundMessages
) )
// TestGrid contains a grid of servers for testing purposes. // TestGrid contains a grid of servers for testing purposes.

View File

@ -16,11 +16,12 @@ func _() {
_ = x[debugSetClientPingDuration-5] _ = x[debugSetClientPingDuration-5]
_ = x[debugAddToDeadline-6] _ = x[debugAddToDeadline-6]
_ = x[debugIsOutgoingClosed-7] _ = x[debugIsOutgoingClosed-7]
_ = x[debugBlockInboundMessages-8]
} }
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed" const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages"
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151} var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176}
func (i debugMsg) String() string { func (i debugMsg) String() string {
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) { if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {

View File

@ -26,6 +26,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -378,6 +379,54 @@ func TestStreamSuite(t *testing.T) {
assertNoActive(t, connRemoteLocal) assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote) assertNoActive(t, connLocalToRemote)
}) })
t.Run("testServerStreamOnewayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 0)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 1)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReq", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
} }
func testStreamRoundtrip(t *testing.T, local, remote *Manager) { func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
@ -491,12 +540,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
register(remote) register(remote)
// local to remote // local to remote
testHandler := func(t *testing.T, handler HandlerID) { testHandler := func(t *testing.T, handler HandlerID, sendReq bool) {
remoteConn := local.Connection(remoteHost) remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!" const testPayload = "Hello Grid World!"
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload))
errFatal(err) errFatal(err)
clientCanceled := make(chan time.Time, 1) clientCanceled := make(chan time.Time, 1)
err = nil err = nil
@ -513,6 +562,18 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
clientCanceled <- time.Now() clientCanceled <- time.Now()
}(t) }(t)
start := time.Now() start := time.Now()
if st.Requests != nil {
defer close(st.Requests)
}
// Fill up queue.
for sendReq {
select {
case st.Requests <- []byte("Hello"):
time.Sleep(10 * time.Millisecond)
default:
sendReq = false
}
}
cancel() cancel()
<-serverCanceled <-serverCanceled
t.Log("server cancel time:", time.Since(start)) t.Log("server cancel time:", time.Since(start))
@ -524,11 +585,13 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
} }
// local to remote, unbuffered // local to remote, unbuffered
t.Run("unbuffered", func(t *testing.T) { t.Run("unbuffered", func(t *testing.T) {
testHandler(t, handlerTest) testHandler(t, handlerTest, false)
}) })
t.Run("buffered", func(t *testing.T) { t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2) testHandler(t, handlerTest2, false)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2, true)
}) })
} }
@ -1025,6 +1088,170 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
} }
} }
// testServerStreamNoPing will test if server and client handle no pings.
func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Stop processing requests
nowBlocking := make(chan struct{})
remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking)
// Check that local returned.
err = st.Results(func(b []byte) error {
return nil
})
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("response: %v", err)
// Check that remote is canceled.
<-serverCanceled
close(nowBlocking)
}
// testServerStreamPingRunning will test if server and client handle ping even when blocked.
func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
for blockResp {
select {
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
case resp <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Block until we have exceeded the deadline several times over.
nowBlocking := make(chan struct{})
var mu sync.Mutex
time.AfterFunc(time.Second, func() {
mu.Lock()
cancel()
close(nowBlocking)
mu.Unlock()
})
if inCap > 0 {
go func() {
defer close(st.Requests)
if !blockReq {
<-nowBlocking
return
}
for {
select {
case <-nowBlocking:
return
case <-st.Done():
case st.Requests <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
}()
}
// Check that local returned.
err = st.Results(func(b []byte) error {
<-st.Done()
return ctx.Err()
})
mu.Lock()
select {
case <-nowBlocking:
default:
t.Fatal("expected to be blocked. got err", err)
}
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("response: %v", err)
// Check that remote is canceled.
<-serverCanceled
}
func timeout(after time.Duration) (cancel func()) { func timeout(after time.Duration) (cancel func()) {
c := time.After(after) c := time.After(after)
cc := make(chan struct{}) cc := make(chan struct{})

View File

@ -32,24 +32,25 @@ import (
// muxClient is a stateful connection to a remote. // muxClient is a stateful connection to a remote.
type muxClient struct { type muxClient struct {
MuxID uint64 MuxID uint64
SendSeq, RecvSeq uint32 SendSeq, RecvSeq uint32
LastPong int64 LastPong int64
BaseFlags Flags BaseFlags Flags
ctx context.Context ctx context.Context
cancelFn context.CancelCauseFunc cancelFn context.CancelCauseFunc
parent *Connection parent *Connection
respWait chan<- Response respWait chan<- Response
respMu sync.Mutex respMu sync.Mutex
singleResp bool singleResp bool
closed bool closed bool
stateless bool stateless bool
acked bool acked bool
init bool init bool
deadline time.Duration deadline time.Duration
outBlock chan struct{} outBlock chan struct{}
subroute *subHandlerID subroute *subHandlerID
respErr atomic.Pointer[error] respErr atomic.Pointer[error]
clientPingInterval time.Duration
} }
// Response is a response from the server. // Response is a response from the server.
@ -61,12 +62,13 @@ type Response struct {
func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient { func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
ctx, cancelFn := context.WithCancelCause(ctx) ctx, cancelFn := context.WithCancelCause(ctx)
return &muxClient{ return &muxClient{
MuxID: muxID, MuxID: muxID,
ctx: ctx, ctx: ctx,
cancelFn: cancelFn, cancelFn: cancelFn,
parent: parent, parent: parent,
LastPong: time.Now().Unix(), LastPong: time.Now().UnixNano(),
BaseFlags: parent.baseFlags, BaseFlags: parent.baseFlags,
clientPingInterval: parent.clientPingInterval,
} }
} }
@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
} }
}() }()
var pingTimer <-chan time.Time var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > clientPingInterval { if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(clientPingInterval) ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop() defer ticker.Stop()
pingTimer = ticker.C pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().Unix()) atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
} }
defer m.parent.deleteMux(false, m.MuxID) defer m.parent.deleteMux(false, m.MuxID)
for { for {
@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
} }
// Only check ping when not closed. // Only check ping when not closed.
if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 { if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 {
m.respMu.Unlock() m.respMu.Unlock()
if debugPrint { if debugPrint {
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got) fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
// responseCh is the channel to that goes to the requester. // responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server. // internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer m.parent.deleteMux(false, m.MuxID) defer func() {
defer xioutil.SafeClose(responseCh) m.parent.deleteMux(false, m.MuxID)
// addErrorNonBlockingClose will close the response channel.
xioutil.SafeClose(responseCh)
}()
// Cancelation and errors are handled by handleTwowayRequests below.
for resp := range internalResp { for resp := range internalResp {
responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
responseCh <- resp
} }
} }
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) { func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) {
var errState bool var errState bool
if debugPrint { if debugPrint {
start := time.Now() start := time.Now()
@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
}() }()
} }
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// Listen for client messages. // Listen for client messages.
for { reqLoop:
if errState { for !errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
if debugPrint { if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID) fmt.Println("Client sending disconnect to mux", m.MuxID)
} }
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
errState = true errState = true
continue continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
case req, ok := <-requests: case req, ok := <-requests:
if !ok { if !ok {
// Done send EOF // Done send EOF
@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
msg := message{ msg := message{
Op: OpMuxClientMsg, Op: OpMuxClientMsg,
MuxID: m.MuxID, MuxID: m.MuxID,
Seq: 1,
Flags: FlagEOF, Flags: FlagEOF,
} }
msg.setZeroPayloadFlag() msg.setZeroPayloadFlag()
err := m.send(msg) err := m.send(msg)
if err != nil { if err != nil {
m.addErrorNonBlockingClose(internalResp, err) m.addErrorNonBlockingClose(errResp, err)
} }
return break reqLoop
} }
// Grab a send token. // Grab a send token.
sendReq:
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
errState = true errState = true
continue continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
goto sendReq
case <-m.outBlock: case <-m.outBlock:
} }
msg := message{ msg := message{
@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
err := m.send(msg) err := m.send(msg)
PutByteBuffer(req) PutByteBuffer(req)
if err != nil { if err != nil {
m.addErrorNonBlockingClose(internalResp, err) m.addErrorNonBlockingClose(errResp, err)
errState = true errState = true
continue continue
} }
msg.Seq++ msg.Seq++
} }
} }
if errState {
// Drain requests.
for {
select {
case r, ok := <-requests:
if !ok {
return
}
PutByteBuffer(r)
default:
return
}
}
}
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
return
case <-pingTimer:
errState = !m.doPing(errResp)
}
}
} }
// checkSeq will check if sequence number is correct and increment it by 1. // checkSeq will check if sequence number is correct and increment it by 1.
@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) {
m.addResponse(r) m.addResponse(r)
return return
} }
atomic.StoreInt64(&m.LastPong, time.Now().Unix()) atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
ok := m.addResponse(r) ok := m.addResponse(r)
if !ok { if !ok {
PutByteBuffer(r.Msg) PutByteBuffer(r.Msg)
@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) {
m.addResponse(Response{Err: err}) m.addResponse(Response{Err: err})
return return
} }
atomic.StoreInt64(&m.LastPong, time.Now().Unix()) atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
} }
// addResponse will add a response to the response channel. // addResponse will add a response to the response channel.

View File

@ -28,21 +28,20 @@ import (
xioutil "github.com/minio/minio/internal/ioutil" xioutil "github.com/minio/minio/internal/ioutil"
) )
const lastPingThreshold = 4 * clientPingInterval
type muxServer struct { type muxServer struct {
ID uint64 ID uint64
LastPing int64 LastPing int64
SendSeq, RecvSeq uint32 SendSeq, RecvSeq uint32
Resp chan []byte Resp chan []byte
BaseFlags Flags BaseFlags Flags
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
inbound chan []byte inbound chan []byte
parent *Connection parent *Connection
sendMu sync.Mutex sendMu sync.Mutex
recvMu sync.Mutex recvMu sync.Mutex
outBlock chan struct{} outBlock chan struct{}
clientPingInterval time.Duration
} }
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
@ -89,16 +88,17 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
} }
m := muxServer{ m := muxServer{
ID: msg.MuxID, ID: msg.MuxID,
RecvSeq: msg.Seq + 1, RecvSeq: msg.Seq + 1,
SendSeq: msg.Seq, SendSeq: msg.Seq,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
parent: c, parent: c,
inbound: nil, inbound: nil,
outBlock: make(chan struct{}, outboundCap), outBlock: make(chan struct{}, outboundCap),
LastPing: time.Now().Unix(), LastPing: time.Now().Unix(),
BaseFlags: c.baseFlags, BaseFlags: c.baseFlags,
clientPingInterval: c.clientPingInterval,
} }
// Acknowledge Mux created. // Acknowledge Mux created.
// Send async. // Send async.
@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
}(m.outBlock) }(m.outBlock)
// Remote aliveness check if needed. // Remote aliveness check if needed.
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) {
go func() { go func() {
wg.Wait() wg.Wait()
m.checkRemoteAlive() m.checkRemoteAlive()
@ -164,9 +164,21 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
// handleInbound sends unblocks when we have delivered the message to the handler. // handleInbound sends unblocks when we have delivered the message to the handler.
func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) { func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) {
for in := range inbound { for {
handlerIn <- in select {
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) case <-m.ctx.Done():
return
case in, ok := <-inbound:
if !ok {
return
}
select {
case <-m.ctx.Done():
return
case handlerIn <- in:
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
}
}
} }
} }
@ -234,7 +246,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<-
// checkRemoteAlive will check if the remote is alive. // checkRemoteAlive will check if the remote is alive.
func (m *muxServer) checkRemoteAlive() { func (m *muxServer) checkRemoteAlive() {
t := time.NewTicker(lastPingThreshold / 4) t := time.NewTicker(m.clientPingInterval)
defer t.Stop() defer t.Stop()
for { for {
select { select {
@ -242,7 +254,7 @@ func (m *muxServer) checkRemoteAlive() {
return return
case <-t.C: case <-t.C:
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
if last > lastPingThreshold { if last > 4*m.clientPingInterval {
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
m.close() m.close()
return return
@ -257,7 +269,7 @@ func (m *muxServer) checkSeq(seq uint32) (ok bool) {
if debugPrint { if debugPrint {
fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq) fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq)
} }
m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq)) m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq), false)
return false return false
} }
m.RecvSeq++ m.RecvSeq++
@ -268,13 +280,13 @@ func (m *muxServer) message(msg message) {
if debugPrint { if debugPrint {
fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload)) fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload))
} }
if !m.checkSeq(msg.Seq) {
return
}
m.recvMu.Lock() m.recvMu.Lock()
defer m.recvMu.Unlock() defer m.recvMu.Unlock()
if cap(m.inbound) == 0 { if cap(m.inbound) == 0 {
m.disconnect("did not expect inbound message") m.disconnect("did not expect inbound message", true)
return
}
if !m.checkSeq(msg.Seq) {
return return
} }
// Note, on EOF no value can be sent. // Note, on EOF no value can be sent.
@ -296,7 +308,7 @@ func (m *muxServer) message(msg message) {
fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq) fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq)
} }
default: default:
m.disconnect("handler blocked") m.disconnect("handler blocked", true)
} }
} }
@ -332,7 +344,9 @@ func (m *muxServer) ping(seq uint32) pongMsg {
} }
} }
func (m *muxServer) disconnect(msg string) { // disconnect will disconnect the mux.
// m.recvMu must be locked when calling this function.
func (m *muxServer) disconnect(msg string, locked bool) {
if debugPrint { if debugPrint {
fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg) fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg)
} }
@ -341,6 +355,11 @@ func (m *muxServer) disconnect(msg string) {
} else { } else {
m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID}) m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID})
} }
// Unlock, since we are calling deleteMux, which will call close - which will lock recvMu.
if locked {
m.recvMu.Unlock()
defer m.recvMu.Lock()
}
m.parent.deleteMux(true, m.ID) m.parent.deleteMux(true, m.ID)
} }

View File

@ -89,12 +89,26 @@ func (s *Stream) Results(next func(b []byte) error) (err error) {
return nil return nil
} }
if resp.Err != nil { if resp.Err != nil {
s.cancel(resp.Err)
return resp.Err return resp.Err
} }
err = next(resp.Msg) err = next(resp.Msg)
if err != nil { if err != nil {
s.cancel(err)
return err return err
} }
} }
} }
} }
// Done will return a channel that will be closed when the stream is done.
// This mirrors context.Done().
func (s *Stream) Done() <-chan struct{} {
return s.ctx.Done()
}
// Err will return the error that caused the stream to end.
// This mirrors context.Err().
func (s *Stream) Err() error {
return s.ctx.Err()
}