Fix two-way stream cancelation and pings (#19763)

Do not log errors on oneway streams when sending ping fails. Instead, cancel the stream.

This also makes sure pings are sent when blocked on sending responses.
This commit is contained in:
Klaus Post
2024-05-22 01:25:25 -07:00
committed by GitHub
parent 9906b3ade9
commit 4d698841f4
7 changed files with 396 additions and 88 deletions

View File

@@ -32,24 +32,25 @@ import (
// muxClient is a stateful connection to a remote.
type muxClient struct {
MuxID uint64
SendSeq, RecvSeq uint32
LastPong int64
BaseFlags Flags
ctx context.Context
cancelFn context.CancelCauseFunc
parent *Connection
respWait chan<- Response
respMu sync.Mutex
singleResp bool
closed bool
stateless bool
acked bool
init bool
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
MuxID uint64
SendSeq, RecvSeq uint32
LastPong int64
BaseFlags Flags
ctx context.Context
cancelFn context.CancelCauseFunc
parent *Connection
respWait chan<- Response
respMu sync.Mutex
singleResp bool
closed bool
stateless bool
acked bool
init bool
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
clientPingInterval time.Duration
}
// Response is a response from the server.
@@ -61,12 +62,13 @@ type Response struct {
func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
ctx, cancelFn := context.WithCancelCause(ctx)
return &muxClient{
MuxID: muxID,
ctx: ctx,
cancelFn: cancelFn,
parent: parent,
LastPong: time.Now().Unix(),
BaseFlags: parent.baseFlags,
MuxID: muxID,
ctx: ctx,
cancelFn: cancelFn,
parent: parent,
LastPong: time.Now().UnixNano(),
BaseFlags: parent.baseFlags,
clientPingInterval: parent.clientPingInterval,
}
}
@@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
}
}()
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > clientPingInterval {
ticker := time.NewTicker(clientPingInterval)
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
defer m.parent.deleteMux(false, m.MuxID)
for {
@@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
}
// Only check ping when not closed.
if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 {
if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 {
m.respMu.Unlock()
if debugPrint {
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
@@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
// responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer m.parent.deleteMux(false, m.MuxID)
defer xioutil.SafeClose(responseCh)
defer func() {
m.parent.deleteMux(false, m.MuxID)
// addErrorNonBlockingClose will close the response channel.
xioutil.SafeClose(responseCh)
}()
// Cancelation and errors are handled by handleTwowayRequests below.
for resp := range internalResp {
responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
responseCh <- resp
}
}
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) {
var errState bool
if debugPrint {
start := time.Now()
@@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
}()
}
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// Listen for client messages.
for {
if errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
reqLoop:
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
errState = true
continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
case req, ok := <-requests:
if !ok {
// Done send EOF
@@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
msg := message{
Op: OpMuxClientMsg,
MuxID: m.MuxID,
Seq: 1,
Flags: FlagEOF,
}
msg.setZeroPayloadFlag()
err := m.send(msg)
if err != nil {
m.addErrorNonBlockingClose(internalResp, err)
m.addErrorNonBlockingClose(errResp, err)
}
return
break reqLoop
}
// Grab a send token.
sendReq:
select {
case <-m.ctx.Done():
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
errState = true
continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
goto sendReq
case <-m.outBlock:
}
msg := message{
@@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
err := m.send(msg)
PutByteBuffer(req)
if err != nil {
m.addErrorNonBlockingClose(internalResp, err)
m.addErrorNonBlockingClose(errResp, err)
errState = true
continue
}
msg.Seq++
}
}
if errState {
// Drain requests.
for {
select {
case r, ok := <-requests:
if !ok {
return
}
PutByteBuffer(r)
default:
return
}
}
}
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
return
case <-pingTimer:
errState = !m.doPing(errResp)
}
}
}
// checkSeq will check if sequence number is correct and increment it by 1.
@@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) {
m.addResponse(r)
return
}
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
ok := m.addResponse(r)
if !ok {
PutByteBuffer(r.Msg)
@@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) {
m.addResponse(Response{Err: err})
return
}
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// addResponse will add a response to the response channel.