Add separate grid reconnection mutex (#18862)

Add separate reconnection mutex

Give more safety around reconnects and make sure a state change isn't missed.

Tested with several runs of `λ go test -race -v -count=500`

Adds separate mutex and doesn't mix in the testing mutex.
This commit is contained in:
Klaus Post 2024-01-24 11:49:39 -08:00 committed by GitHub
parent 4a6c97463f
commit 6968f7237a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 18 deletions

View File

@ -68,7 +68,8 @@ type Connection struct {
id uuid.UUID id uuid.UUID
// Remote uuid, if we have been connected. // Remote uuid, if we have been connected.
remoteID *uuid.UUID remoteID *uuid.UUID
reconnectMu sync.Mutex
// Context for the server. // Context for the server.
ctx context.Context ctx context.Context
@ -697,6 +698,7 @@ func (c *Connection) connect() {
retry(fmt.Errorf("connection rejected: %s", r.RejectedReason)) retry(fmt.Errorf("connection rejected: %s", r.RejectedReason))
continue continue
} }
c.reconnectMu.Lock()
remoteUUID := uuid.UUID(r.ID) remoteUUID := uuid.UUID(r.ID)
if c.remoteID != nil { if c.remoteID != nil {
c.reconnected() c.reconnected()
@ -792,11 +794,6 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn
Op: OpConnectResponse, Op: OpConnectResponse,
} }
if c.remoteID != nil {
c.reconnected()
}
rid := uuid.UUID(req.ID)
c.remoteID = &rid
resp := connectResp{ resp := connectResp{
ID: c.id, ID: c.id,
Accepted: true, Accepted: true,
@ -805,13 +802,26 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn
if debugPrint { if debugPrint {
fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side) fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side)
} }
if err == nil { if err != nil {
c.updateState(StateConnected) return err
c.handleMessages(ctx, conn)
} }
return err // Signal that we are reconnected, update state and handle messages.
// Prevent other connections from connecting while we process.
c.reconnectMu.Lock()
if c.remoteID != nil {
c.reconnected()
}
rid := uuid.UUID(req.ID)
c.remoteID = &rid
c.updateState(StateConnected)
c.handleMessages(ctx, conn)
return nil
} }
// reconnected signals the connection has been reconnected.
// It will close all active requests and streams.
// caller *must* hold reconnectMu.
func (c *Connection) reconnected() { func (c *Connection) reconnected() {
c.updateState(StateConnectionError) c.updateState(StateConnectionError)
// Close all active requests. // Close all active requests.
@ -831,9 +841,7 @@ func (c *Connection) reconnected() {
c.outgoing.Clear() c.outgoing.Clear()
// Wait for existing to exit // Wait for existing to exit
c.connMu.Lock()
c.handleMsgWg.Wait() c.handleMsgWg.Wait()
c.connMu.Unlock()
} }
func (c *Connection) updateState(s State) { func (c *Connection) updateState(s State) {
@ -855,12 +863,13 @@ func (c *Connection) updateState(s State) {
c.connChange.Broadcast() c.connChange.Broadcast()
} }
// handleMessages will handle incoming messages on conn.
// caller *must* hold reconnectMu.
func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
// Read goroutine
c.connMu.Lock()
c.handleMsgWg.Add(2) c.handleMsgWg.Add(2)
c.connMu.Unlock() c.reconnectMu.Unlock()
ctx, cancel := context.WithCancelCause(ctx) ctx, cancel := context.WithCancelCause(ctx)
// Read goroutine
go func() { go func() {
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
@ -1538,9 +1547,9 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.debugInConn.Close() c.debugInConn.Close()
} }
case debugWaitForExit: case debugWaitForExit:
c.connMu.Lock() c.reconnectMu.Lock()
c.handleMsgWg.Wait() c.handleMsgWg.Wait()
c.connMu.Unlock() c.reconnectMu.Unlock()
case debugSetConnPingDuration: case debugSetConnPingDuration:
c.connMu.Lock() c.connMu.Lock()
defer c.connMu.Unlock() defer c.connMu.Unlock()

View File

@ -535,7 +535,6 @@ func testStreamDeadline(t *testing.T, local, remote *Manager) {
err = resp.Err err = resp.Err
} }
clientCanceled <- time.Since(started) clientCanceled <- time.Since(started)
t.Log("Client Context canceled")
}() }()
serverEnd := <-serverCanceled serverEnd := <-serverCanceled
clientEnd := <-clientCanceled clientEnd := <-clientCanceled