Split handleMessages (cosmetic) (#20095)

Split the read and write sides of handleMessages into two separate functions

Cosmetic. The only non-copy-and-paste change is that `cancel(ErrDisconnected)` is moved 
into the defer on `readStream`.
This commit is contained in:
Klaus Post 2024-07-15 12:02:30 -07:00 committed by GitHub
parent e8c54c3d6c
commit ded373e600
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -925,137 +925,141 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
c.handleMsgWg.Add(2) c.handleMsgWg.Add(2)
c.reconnectMu.Unlock() c.reconnectMu.Unlock()
// Read goroutine // Start reader and writer
go func() { go c.readStream(ctx, conn, cancel)
defer func() { c.writeStream(ctx, conn, cancel)
if rec := recover(); rec != nil { }
gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
debug.PrintStack()
}
c.connChange.L.Lock()
if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
c.connChange.Broadcast()
}
c.connChange.L.Unlock()
conn.Close()
c.handleMsgWg.Done()
}()
controlHandler := wsutil.ControlFrameHandler(conn, c.side) // readStream handles the read side of the connection.
wsReader := wsutil.Reader{ // It will read messages and send them to c.handleMsg.
Source: conn, // If an error occurs the cancel function will be called and conn be closed.
State: c.side, // The function will block until the connection is closed or an error occurs.
CheckUTF8: true, func (c *Connection) readStream(ctx context.Context, conn net.Conn, cancel context.CancelCauseFunc) {
SkipHeaderCheck: false, defer func() {
OnIntermediate: controlHandler, if rec := recover(); rec != nil {
gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
debug.PrintStack()
} }
readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) { cancel(ErrDisconnected)
dst = dst[:0] c.connChange.L.Lock()
for { if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
hdr, err := wsReader.NextFrame() c.connChange.Broadcast()
if err != nil {
return nil, err
}
if hdr.OpCode.IsControl() {
if err := controlHandler(hdr, &wsReader); err != nil {
return nil, err
}
continue
}
if hdr.OpCode&want == 0 {
if err := wsReader.Discard(); err != nil {
return nil, err
}
continue
}
if int64(cap(dst)) < hdr.Length+1 {
dst = make([]byte, 0, hdr.Length+hdr.Length>>3)
}
return readAllInto(dst[:0], &wsReader)
}
}
// Keep reusing the same buffer.
var msg []byte
for {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
cancel(ErrDisconnected)
return
}
if cap(msg) > readBufferSize*4 {
// Don't keep too much memory around.
msg = nil
}
var err error
msg, err = readDataInto(msg, conn, c.side, ws.OpBinary)
if err != nil {
cancel(ErrDisconnected)
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
}
return
}
block := c.blockMessages.Load()
if block != nil && *block != nil {
<-*block
}
if c.incomingBytes != nil {
c.incomingBytes(int64(len(msg)))
}
// Parse the received message
var m message
subID, remain, err := m.parse(msg)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws parse package: %w", err))
}
cancel(ErrDisconnected)
return
}
if debugPrint {
fmt.Printf("%s Got msg: %v\n", c.Local, m)
}
if m.Op != OpMerged {
c.inMessages.Add(1)
c.handleMsg(ctx, m, subID)
continue
}
// Handle merged messages.
messages := int(m.Seq)
c.inMessages.Add(int64(messages))
for i := 0; i < messages; i++ {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
cancel(ErrDisconnected)
return
}
var next []byte
next, remain, err = msgp.ReadBytesZC(remain)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws read merged: %w", err))
}
cancel(ErrDisconnected)
return
}
m.Payload = nil
subID, _, err = m.parse(next)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws parse merged: %w", err))
}
cancel(ErrDisconnected)
return
}
c.handleMsg(ctx, m, subID)
}
} }
c.connChange.L.Unlock()
conn.Close()
c.handleMsgWg.Done()
}() }()
// Write function. controlHandler := wsutil.ControlFrameHandler(conn, c.side)
wsReader := wsutil.Reader{
Source: conn,
State: c.side,
CheckUTF8: true,
SkipHeaderCheck: false,
OnIntermediate: controlHandler,
}
readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) {
dst = dst[:0]
for {
hdr, err := wsReader.NextFrame()
if err != nil {
return nil, err
}
if hdr.OpCode.IsControl() {
if err := controlHandler(hdr, &wsReader); err != nil {
return nil, err
}
continue
}
if hdr.OpCode&want == 0 {
if err := wsReader.Discard(); err != nil {
return nil, err
}
continue
}
if int64(cap(dst)) < hdr.Length+1 {
dst = make([]byte, 0, hdr.Length+hdr.Length>>3)
}
return readAllInto(dst[:0], &wsReader)
}
}
// Keep reusing the same buffer.
var msg []byte
for atomic.LoadUint32((*uint32)(&c.state)) == StateConnected {
if cap(msg) > readBufferSize*4 {
// Don't keep too much memory around.
msg = nil
}
var err error
msg, err = readDataInto(msg, conn, c.side, ws.OpBinary)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
}
return
}
block := c.blockMessages.Load()
if block != nil && *block != nil {
<-*block
}
if c.incomingBytes != nil {
c.incomingBytes(int64(len(msg)))
}
// Parse the received message
var m message
subID, remain, err := m.parse(msg)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws parse package: %w", err))
}
return
}
if debugPrint {
fmt.Printf("%s Got msg: %v\n", c.Local, m)
}
if m.Op != OpMerged {
c.inMessages.Add(1)
c.handleMsg(ctx, m, subID)
continue
}
// Handle merged messages.
messages := int(m.Seq)
c.inMessages.Add(int64(messages))
for i := 0; i < messages; i++ {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
return
}
var next []byte
next, remain, err = msgp.ReadBytesZC(remain)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws read merged: %w", err))
}
return
}
m.Payload = nil
subID, _, err = m.parse(next)
if err != nil {
if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws parse merged: %w", err))
}
return
}
c.handleMsg(ctx, m, subID)
}
}
}
// writeStream handles the read side of the connection.
// It will grab messages from c.outQueue and write them to the connection.
// If an error occurs the cancel function will be called and conn be closed.
// The function will block until the connection is closed or an error occurs.
func (c *Connection) writeStream(ctx context.Context, conn net.Conn, cancel context.CancelCauseFunc) {
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))