Fix panic on fast reconnection of node (#2536)

* Fix panic on fast reconnection of node

* Use parameter captured in closure as per review request
This commit is contained in:
Relihan Myburgh 2025-04-23 21:52:24 +12:00 committed by GitHub
parent 92e587a82c
commit 56d085bd08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 3 deletions

View File

@ -63,9 +63,26 @@ func (n *Notifier) Close() {
n.closed = true
n.b.close()
for _, c := range n.nodes {
close(c)
// Close channels safely using the helper method
for nodeID, c := range n.nodes {
n.safeCloseChannel(nodeID, c)
}
// Clear node map after closing channels
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
log.Error().
Uint64("node.id", nodeID.Uint64()).
Any("recover", r).
Msg("recovered from panic when closing channel in Close()")
}
}()
close(c)
}
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
@ -90,7 +107,11 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
// connection. Close the old channel and replace it.
if curr, ok := n.nodes[nodeID]; ok {
n.tracef(nodeID, "channel present, closing and replacing")
close(curr)
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
// if/when someone is waiting to send on this channel
go func(ch chan<- types.StateUpdate) {
n.safeCloseChannel(nodeID, ch)
}(curr)
}
n.nodes[nodeID] = c
@ -161,6 +182,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return false
}
// LikelyConnectedMap returns a thread safe map of connected nodes
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}

View File

@ -2,8 +2,11 @@ package notifier
import (
"context"
"fmt"
"math/rand"
"net/netip"
"sort"
"sync"
"testing"
"time"
@ -263,3 +266,78 @@ func TestBatcher(t *testing.T) {
})
}
}
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
Tuning: types.Tuning{
NotifierSendTimeout: 1 * time.Second,
BatchChangeDelay: 1 * time.Second,
NodeMapSessionBufferedChanSize: 30,
},
}
notifier := NewNotifier(cfg)
defer notifier.Close()
nodeID := types.NodeID(1)
updateChan := make(chan types.StateUpdate, 10)
var wg sync.WaitGroup
// Number of goroutines to spawn for concurrent access
concurrentAccessors := 100
iterations := 100
// Add node to notifier
notifier.AddNode(nodeID, updateChan)
// Track errors
errChan := make(chan string, concurrentAccessors*iterations)
// Start goroutines to cause a race
wg.Add(concurrentAccessors)
for i := 0; i < concurrentAccessors; i++ {
go func(routineID int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
if routineID%3 == 0 {
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
} else if routineID%3 == 1 {
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
} else {
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}
// Small random delay to increase chance of races
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
}
}(i)
}
wg.Wait()
close(errChan)
// Collate errors
var errors []string
for err := range errChan {
errors = append(errors, err)
}
if len(errors) > 0 {
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
}
}