Split handleMessages (cosmetic) (#20095)

Split the read and write sides of handleMessages into two separate functions

Cosmetic. The only non-copy-and-paste change is that `cancel(ErrDisconnected)` is moved 
into the defer on `readStream`.
This commit is contained in:
Klaus Post 2024-07-15 12:02:30 -07:00 committed by GitHub
parent e8c54c3d6c
commit ded373e600
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -925,13 +925,22 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
c.handleMsgWg.Add(2) c.handleMsgWg.Add(2)
c.reconnectMu.Unlock() c.reconnectMu.Unlock()
// Read goroutine // Start reader and writer
go func() { go c.readStream(ctx, conn, cancel)
c.writeStream(ctx, conn, cancel)
}
// readStream handles the read side of the connection.
// It will read messages and send them to c.handleMsg.
// If an error occurs the cancel function will be called and conn be closed.
// The function will block until the connection is closed or an error occurs.
func (c *Connection) readStream(ctx context.Context, conn net.Conn, cancel context.CancelCauseFunc) {
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
debug.PrintStack() debug.PrintStack()
} }
cancel(ErrDisconnected)
c.connChange.L.Lock() c.connChange.L.Lock()
if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
c.connChange.Broadcast() c.connChange.Broadcast()
@ -977,11 +986,7 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
// Keep reusing the same buffer. // Keep reusing the same buffer.
var msg []byte var msg []byte
for { for atomic.LoadUint32((*uint32)(&c.state)) == StateConnected {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
cancel(ErrDisconnected)
return
}
if cap(msg) > readBufferSize*4 { if cap(msg) > readBufferSize*4 {
// Don't keep too much memory around. // Don't keep too much memory around.
msg = nil msg = nil
@ -990,7 +995,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
var err error var err error
msg, err = readDataInto(msg, conn, c.side, ws.OpBinary) msg, err = readDataInto(msg, conn, c.side, ws.OpBinary)
if err != nil { if err != nil {
cancel(ErrDisconnected)
if !xnet.IsNetworkOrHostDown(err, true) { if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
} }
@ -1012,7 +1016,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
if !xnet.IsNetworkOrHostDown(err, true) { if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws parse package: %w", err)) gridLogIf(ctx, fmt.Errorf("ws parse package: %w", err))
} }
cancel(ErrDisconnected)
return return
} }
if debugPrint { if debugPrint {
@ -1028,7 +1031,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
c.inMessages.Add(int64(messages)) c.inMessages.Add(int64(messages))
for i := 0; i < messages; i++ { for i := 0; i < messages; i++ {
if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
cancel(ErrDisconnected)
return return
} }
var next []byte var next []byte
@ -1037,7 +1039,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
if !xnet.IsNetworkOrHostDown(err, true) { if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws read merged: %w", err)) gridLogIf(ctx, fmt.Errorf("ws read merged: %w", err))
} }
cancel(ErrDisconnected)
return return
} }
@ -1047,15 +1048,18 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
if !xnet.IsNetworkOrHostDown(err, true) { if !xnet.IsNetworkOrHostDown(err, true) {
gridLogIf(ctx, fmt.Errorf("ws parse merged: %w", err)) gridLogIf(ctx, fmt.Errorf("ws parse merged: %w", err))
} }
cancel(ErrDisconnected)
return return
} }
c.handleMsg(ctx, m, subID) c.handleMsg(ctx, m, subID)
} }
} }
}() }
// Write function. // writeStream handles the read side of the connection.
// It will grab messages from c.outQueue and write them to the connection.
// If an error occurs the cancel function will be called and conn be closed.
// The function will block until the connection is closed or an error occurs.
func (c *Connection) writeStream(ctx context.Context, conn net.Conn, cancel context.CancelCauseFunc) {
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))