diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 4d2e277b..8d66f182 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -63,9 +63,26 @@ func (n *Notifier) Close() { n.closed = true n.b.close() - for _, c := range n.nodes { - close(c) + // Close channels safely using the helper method + for nodeID, c := range n.nodes { + n.safeCloseChannel(nodeID, c) } + + // Clear node map after closing channels + n.nodes = make(map[types.NodeID]chan<- types.StateUpdate) +} + +// safeCloseChannel closes a channel and panic recovers if already closed +func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) { + defer func() { + if r := recover(); r != nil { + log.Error(). + Uint64("node.id", nodeID.Uint64()). + Any("recover", r). + Msg("recovered from panic when closing channel in Close()") + } + }() + close(c) } func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) { @@ -90,7 +107,11 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { // connection. Close the old channel and replace it. if curr, ok := n.nodes[nodeID]; ok { n.tracef(nodeID, "channel present, closing and replacing") - close(curr) + // Use the safeCloseChannel helper in a goroutine to avoid deadlocks + // if/when someone is waiting to send on this channel + go func(ch chan<- types.StateUpdate) { + n.safeCloseChannel(nodeID, ch) + }(curr) } n.nodes[nodeID] = c @@ -161,6 +182,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { return false } +// LikelyConnectedMap returns a thread safe map of connected nodes func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { return n.connected } diff --git a/hscontrol/notifier/notifier_test.go b/hscontrol/notifier/notifier_test.go index d11bc26c..a7369740 100644 --- a/hscontrol/notifier/notifier_test.go +++ b/hscontrol/notifier/notifier_test.go @@ -2,8 +2,11 @@ package notifier import ( "context" + "fmt" + "math/rand" "net/netip" "sort" + "sync" "testing" "time" @@ -263,3 +266,78 @@ func TestBatcher(t *testing.T) { }) } } + +// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected +// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to +// close a channel that was already closed, which can happen when a node changes +// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting +func TestIsLikelyConnectedRaceCondition(t *testing.T) { + // mock config for the notifier + cfg := &types.Config{ + Tuning: types.Tuning{ + NotifierSendTimeout: 1 * time.Second, + BatchChangeDelay: 1 * time.Second, + NodeMapSessionBufferedChanSize: 30, + }, + } + + notifier := NewNotifier(cfg) + defer notifier.Close() + + nodeID := types.NodeID(1) + updateChan := make(chan types.StateUpdate, 10) + + var wg sync.WaitGroup + + // Number of goroutines to spawn for concurrent access + concurrentAccessors := 100 + iterations := 100 + + // Add node to notifier + notifier.AddNode(nodeID, updateChan) + + // Track errors + errChan := make(chan string, concurrentAccessors*iterations) + + // Start goroutines to cause a race + wg.Add(concurrentAccessors) + for i := 0; i < concurrentAccessors; i++ { + go func(routineID int) { + defer wg.Done() + + for j := 0; j < iterations; j++ { + // Simulate race by having some goroutines check IsLikelyConnected + // while others add/remove the node + if routineID%3 == 0 { + // This goroutine checks connection status + isConnected := notifier.IsLikelyConnected(nodeID) + if isConnected != true && isConnected != false { + errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected) + } + } else if routineID%3 == 1 { + // This goroutine removes the node + notifier.RemoveNode(nodeID, updateChan) + } else { + // This goroutine adds the node back + notifier.AddNode(nodeID, updateChan) + } + + // Small random delay to increase chance of races + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + } + }(i) + } + + wg.Wait() + close(errChan) + + // Collate errors + var errors []string + for err := range errChan { + errors = append(errors, err) + } + + if len(errors) > 0 { + t.Errorf("Detected %d race condition errors: %v", len(errors), errors) + } +}