cosmetic: Move request goroutines to methods (#19241)

Cosmetic change, but breaks up a big code block and will make a goroutine 
dumps of streams are more readable, so it is clearer what each goroutine is doing.
This commit is contained in:
Klaus Post 2024-03-13 19:43:58 +01:00 committed by GitHub
parent 24b4f9d748
commit 5c32058ff3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,7 +21,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"runtime/debug"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -123,109 +122,136 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
if inboundCap > 0 { if inboundCap > 0 {
m.inbound = make(chan []byte, inboundCap) m.inbound = make(chan []byte, inboundCap)
handlerIn = make(chan []byte, 1) handlerIn = make(chan []byte, 1)
go func(inbound <-chan []byte) { go func(inbound chan []byte) {
wg.Wait() wg.Wait()
defer xioutil.SafeClose(handlerIn) defer xioutil.SafeClose(handlerIn)
// Send unblocks when we have delivered the message to the handler. m.handleInbound(c, inbound, handlerIn)
for in := range inbound {
handlerIn <- in
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
}
}(m.inbound) }(m.inbound)
} }
// Fill outbound block.
// Each token represents a message that can be sent to the client without blocking.
// The client will refill the tokens as they confirm delivery of the messages.
for i := 0; i < outboundCap; i++ { for i := 0; i < outboundCap; i++ {
m.outBlock <- struct{}{} m.outBlock <- struct{}{}
} }
// Handler goroutine. // Handler goroutine.
var handlerErr *RemoteErr var handlerErr atomic.Pointer[RemoteErr]
go func() { go func() {
wg.Wait() wg.Wait()
start := time.Now() defer xioutil.SafeClose(send)
defer func() { err := m.handleRequests(ctx, msg, send, handler, handlerIn)
if debugPrint { if err != nil {
fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond)) handlerErr.Store(err)
} }
if r := recover(); r != nil {
logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r))
debug.PrintStack()
err := RemoteErr(fmt.Sprintf("remote call panic: %v", r))
handlerErr = &err
}
if debugPrint {
fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr)
}
xioutil.SafeClose(send)
}()
// handlerErr is guarded by 'send' channel.
handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send)
}() }()
// Response sender gorutine...
// Response sender goroutine...
go func(outBlock <-chan struct{}) { go func(outBlock <-chan struct{}) {
wg.Wait() wg.Wait()
defer m.parent.deleteMux(true, m.ID) defer m.parent.deleteMux(true, m.ID)
for { m.sendResponses(ctx, send, c, &handlerErr, outBlock)
// Process outgoing message.
var payload []byte
var ok bool
select {
case payload, ok = <-send:
case <-ctx.Done():
return
}
select {
case <-ctx.Done():
return
case <-outBlock:
}
msg := message{
MuxID: m.ID,
Op: OpMuxServerMsg,
Flags: c.baseFlags,
}
if !ok {
if debugPrint {
fmt.Println("muxServer: Mux", m.ID, "send EOF", handlerErr)
}
msg.Flags |= FlagEOF
if handlerErr != nil {
msg.Flags |= FlagPayloadIsErr
msg.Payload = []byte(*handlerErr)
}
msg.setZeroPayloadFlag()
m.send(msg)
return
}
msg.Payload = payload
msg.setZeroPayloadFlag()
m.send(msg)
}
}(m.outBlock) }(m.outBlock)
// Remote aliveness check. // Remote aliveness check if needed.
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
go func() { go func() {
wg.Wait() wg.Wait()
t := time.NewTicker(lastPingThreshold / 4) m.checkRemoteAlive()
defer t.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-t.C:
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
if last > lastPingThreshold {
logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
m.close()
return
}
}
}
}() }()
} }
return &m return &m
} }
// handleInbound sends unblocks when we have delivered the message to the handler.
func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) {
for in := range inbound {
handlerIn <- in
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
}
}
// sendResponses will send responses to the client.
func (m *muxServer) sendResponses(ctx context.Context, toSend <-chan []byte, c *Connection, handlerErr *atomic.Pointer[RemoteErr], outBlock <-chan struct{}) {
for {
// Process outgoing message.
var payload []byte
var ok bool
select {
case payload, ok = <-toSend:
case <-ctx.Done():
return
}
select {
case <-ctx.Done():
return
case <-outBlock:
}
msg := message{
MuxID: m.ID,
Op: OpMuxServerMsg,
Flags: c.baseFlags,
}
if !ok {
hErr := handlerErr.Load()
if debugPrint {
fmt.Println("muxServer: Mux", m.ID, "send EOF", hErr)
}
msg.Flags |= FlagEOF
if hErr != nil {
msg.Flags |= FlagPayloadIsErr
msg.Payload = []byte(*hErr)
}
msg.setZeroPayloadFlag()
m.send(msg)
return
}
msg.Payload = payload
msg.setZeroPayloadFlag()
m.send(msg)
}
}
// handleRequests will handle the requests from the client and call the handler function.
func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- []byte, handler StreamHandler, handlerIn <-chan []byte) (handlerErr *RemoteErr) {
start := time.Now()
defer func() {
if debugPrint {
fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond))
}
if r := recover(); r != nil {
logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r))
err := RemoteErr(fmt.Sprintf("handler panic: %v", r))
handlerErr = &err
}
if debugPrint {
fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr)
}
}()
// handlerErr is guarded by 'send' channel.
handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send)
return handlerErr
}
// checkRemoteAlive will check if the remote is alive.
func (m *muxServer) checkRemoteAlive() {
t := time.NewTicker(lastPingThreshold / 4)
defer t.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-t.C:
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
if last > lastPingThreshold {
logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
m.close()
return
}
}
}
}
// checkSeq will check if sequence number is correct and increment it by 1. // checkSeq will check if sequence number is correct and increment it by 1.
func (m *muxServer) checkSeq(seq uint32) (ok bool) { func (m *muxServer) checkSeq(seq uint32) (ok bool) {
if seq != m.RecvSeq { if seq != m.RecvSeq {