Add separate grid reconnection mutex (#18862)

Add separate reconnection mutex

Give more safety around reconnects and make sure a state change isn't missed.

Tested with several runs of `λ go test -race -v -count=500`

Adds separate mutex and doesn't mix in the testing mutex.
This commit is contained in:
Klaus Post 2024-01-24 11:49:39 -08:00 committed by GitHub
parent 4a6c97463f
commit 6968f7237a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 18 deletions

View File

@ -68,7 +68,8 @@ type Connection struct {
id uuid.UUID
// Remote uuid, if we have been connected.
remoteID *uuid.UUID
remoteID *uuid.UUID
reconnectMu sync.Mutex
// Context for the server.
ctx context.Context
@ -697,6 +698,7 @@ func (c *Connection) connect() {
retry(fmt.Errorf("connection rejected: %s", r.RejectedReason))
continue
}
c.reconnectMu.Lock()
remoteUUID := uuid.UUID(r.ID)
if c.remoteID != nil {
c.reconnected()
@ -792,11 +794,6 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn
Op: OpConnectResponse,
}
if c.remoteID != nil {
c.reconnected()
}
rid := uuid.UUID(req.ID)
c.remoteID = &rid
resp := connectResp{
ID: c.id,
Accepted: true,
@ -805,13 +802,26 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn
if debugPrint {
fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side)
}
if err == nil {
c.updateState(StateConnected)
c.handleMessages(ctx, conn)
if err != nil {
return err
}
return err
// Signal that we are reconnected, update state and handle messages.
// Prevent other connections from connecting while we process.
c.reconnectMu.Lock()
if c.remoteID != nil {
c.reconnected()
}
rid := uuid.UUID(req.ID)
c.remoteID = &rid
c.updateState(StateConnected)
c.handleMessages(ctx, conn)
return nil
}
// reconnected signals the connection has been reconnected.
// It will close all active requests and streams.
// caller *must* hold reconnectMu.
func (c *Connection) reconnected() {
c.updateState(StateConnectionError)
// Close all active requests.
@ -831,9 +841,7 @@ func (c *Connection) reconnected() {
c.outgoing.Clear()
// Wait for existing to exit
c.connMu.Lock()
c.handleMsgWg.Wait()
c.connMu.Unlock()
}
func (c *Connection) updateState(s State) {
@ -855,12 +863,13 @@ func (c *Connection) updateState(s State) {
c.connChange.Broadcast()
}
// handleMessages will handle incoming messages on conn.
// caller *must* hold reconnectMu.
func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
// Read goroutine
c.connMu.Lock()
c.handleMsgWg.Add(2)
c.connMu.Unlock()
c.reconnectMu.Unlock()
ctx, cancel := context.WithCancelCause(ctx)
// Read goroutine
go func() {
defer func() {
if rec := recover(); rec != nil {
@ -1538,9 +1547,9 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.debugInConn.Close()
}
case debugWaitForExit:
c.connMu.Lock()
c.reconnectMu.Lock()
c.handleMsgWg.Wait()
c.connMu.Unlock()
c.reconnectMu.Unlock()
case debugSetConnPingDuration:
c.connMu.Lock()
defer c.connMu.Unlock()

View File

@ -535,7 +535,6 @@ func testStreamDeadline(t *testing.T, local, remote *Manager) {
err = resp.Err
}
clientCanceled <- time.Since(started)
t.Log("Client Context canceled")
}()
serverEnd := <-serverCanceled
clientEnd := <-clientCanceled