Revert "Fix two-way stream cancelation and pings (#19763)"

This reverts commit 4d698841f4.
This commit is contained in:
Harshavardhana
2024-05-22 03:00:00 -07:00
parent 4d698841f4
commit ae14681c3e
7 changed files with 88 additions and 396 deletions

View File

@@ -32,25 +32,24 @@ import (
// muxClient is a stateful connection to a remote.
type muxClient struct {
MuxID uint64
SendSeq, RecvSeq uint32
LastPong int64
BaseFlags Flags
ctx context.Context
cancelFn context.CancelCauseFunc
parent *Connection
respWait chan<- Response
respMu sync.Mutex
singleResp bool
closed bool
stateless bool
acked bool
init bool
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
clientPingInterval time.Duration
MuxID uint64
SendSeq, RecvSeq uint32
LastPong int64
BaseFlags Flags
ctx context.Context
cancelFn context.CancelCauseFunc
parent *Connection
respWait chan<- Response
respMu sync.Mutex
singleResp bool
closed bool
stateless bool
acked bool
init bool
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
}
// Response is a response from the server.
@@ -62,13 +61,12 @@ type Response struct {
func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
ctx, cancelFn := context.WithCancelCause(ctx)
return &muxClient{
MuxID: muxID,
ctx: ctx,
cancelFn: cancelFn,
parent: parent,
LastPong: time.Now().UnixNano(),
BaseFlags: parent.baseFlags,
clientPingInterval: parent.clientPingInterval,
MuxID: muxID,
ctx: ctx,
cancelFn: cancelFn,
parent: parent,
LastPong: time.Now().Unix(),
BaseFlags: parent.baseFlags,
}
}
@@ -311,11 +309,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
}
}()
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
if m.deadline == 0 || m.deadline > clientPingInterval {
ticker := time.NewTicker(clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
}
defer m.parent.deleteMux(false, m.MuxID)
for {
@@ -369,7 +367,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
}
// Only check ping when not closed.
if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 {
if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 {
m.respMu.Unlock()
if debugPrint {
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
@@ -390,20 +388,15 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
// responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer func() {
m.parent.deleteMux(false, m.MuxID)
// addErrorNonBlockingClose will close the response channel.
xioutil.SafeClose(responseCh)
}()
// Cancelation and errors are handled by handleTwowayRequests below.
defer m.parent.deleteMux(false, m.MuxID)
defer xioutil.SafeClose(responseCh)
for resp := range internalResp {
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
}
}
func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) {
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
var errState bool
if debugPrint {
start := time.Now()
@@ -412,30 +405,24 @@ func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-cha
}()
}
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// Listen for client messages.
reqLoop:
for !errState {
for {
if errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true
continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
case req, ok := <-requests:
if !ok {
// Done send EOF
@@ -445,28 +432,22 @@ reqLoop:
msg := message{
Op: OpMuxClientMsg,
MuxID: m.MuxID,
Seq: 1,
Flags: FlagEOF,
}
msg.setZeroPayloadFlag()
err := m.send(msg)
if err != nil {
m.addErrorNonBlockingClose(errResp, err)
m.addErrorNonBlockingClose(internalResp, err)
}
break reqLoop
return
}
// Grab a send token.
sendReq:
select {
case <-m.ctx.Done():
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
errState = true
continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
goto sendReq
case <-m.outBlock:
}
msg := message{
@@ -479,41 +460,13 @@ reqLoop:
err := m.send(msg)
PutByteBuffer(req)
if err != nil {
m.addErrorNonBlockingClose(errResp, err)
m.addErrorNonBlockingClose(internalResp, err)
errState = true
continue
}
msg.Seq++
}
}
if errState {
// Drain requests.
for {
select {
case r, ok := <-requests:
if !ok {
return
}
PutByteBuffer(r)
default:
return
}
}
}
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
return
case <-pingTimer:
errState = !m.doPing(errResp)
}
}
}
// checkSeq will check if sequence number is correct and increment it by 1.
@@ -549,7 +502,7 @@ func (m *muxClient) response(seq uint32, r Response) {
m.addResponse(r)
return
}
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
ok := m.addResponse(r)
if !ok {
PutByteBuffer(r.Msg)
@@ -600,7 +553,7 @@ func (m *muxClient) pong(msg pongMsg) {
m.addResponse(Response{Err: err})
return
}
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
}
// addResponse will add a response to the response channel.