mirror of
https://github.com/juanfont/headscale.git
synced 2025-04-23 20:15:38 -04:00
Fix panic on fast reconnection of node (#2536)
* Fix panic on fast reconnection of node * Use parameter captured in closure as per review request
This commit is contained in:
parent
92e587a82c
commit
56d085bd08
@ -63,9 +63,26 @@ func (n *Notifier) Close() {
|
|||||||
n.closed = true
|
n.closed = true
|
||||||
n.b.close()
|
n.b.close()
|
||||||
|
|
||||||
for _, c := range n.nodes {
|
// Close channels safely using the helper method
|
||||||
close(c)
|
for nodeID, c := range n.nodes {
|
||||||
|
n.safeCloseChannel(nodeID, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear node map after closing channels
|
||||||
|
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeCloseChannel closes a channel and panic recovers if already closed
|
||||||
|
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Error().
|
||||||
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
|
Any("recover", r).
|
||||||
|
Msg("recovered from panic when closing channel in Close()")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
close(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
||||||
@ -90,7 +107,11 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
|||||||
// connection. Close the old channel and replace it.
|
// connection. Close the old channel and replace it.
|
||||||
if curr, ok := n.nodes[nodeID]; ok {
|
if curr, ok := n.nodes[nodeID]; ok {
|
||||||
n.tracef(nodeID, "channel present, closing and replacing")
|
n.tracef(nodeID, "channel present, closing and replacing")
|
||||||
close(curr)
|
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
|
||||||
|
// if/when someone is waiting to send on this channel
|
||||||
|
go func(ch chan<- types.StateUpdate) {
|
||||||
|
n.safeCloseChannel(nodeID, ch)
|
||||||
|
}(curr)
|
||||||
}
|
}
|
||||||
|
|
||||||
n.nodes[nodeID] = c
|
n.nodes[nodeID] = c
|
||||||
@ -161,6 +182,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LikelyConnectedMap returns a thread safe map of connected nodes
|
||||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||||
return n.connected
|
return n.connected
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,11 @@ package notifier
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -263,3 +266,78 @@ func TestBatcher(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
||||||
|
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
||||||
|
// close a channel that was already closed, which can happen when a node changes
|
||||||
|
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
|
||||||
|
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||||
|
// mock config for the notifier
|
||||||
|
cfg := &types.Config{
|
||||||
|
Tuning: types.Tuning{
|
||||||
|
NotifierSendTimeout: 1 * time.Second,
|
||||||
|
BatchChangeDelay: 1 * time.Second,
|
||||||
|
NodeMapSessionBufferedChanSize: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
notifier := NewNotifier(cfg)
|
||||||
|
defer notifier.Close()
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
updateChan := make(chan types.StateUpdate, 10)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Number of goroutines to spawn for concurrent access
|
||||||
|
concurrentAccessors := 100
|
||||||
|
iterations := 100
|
||||||
|
|
||||||
|
// Add node to notifier
|
||||||
|
notifier.AddNode(nodeID, updateChan)
|
||||||
|
|
||||||
|
// Track errors
|
||||||
|
errChan := make(chan string, concurrentAccessors*iterations)
|
||||||
|
|
||||||
|
// Start goroutines to cause a race
|
||||||
|
wg.Add(concurrentAccessors)
|
||||||
|
for i := 0; i < concurrentAccessors; i++ {
|
||||||
|
go func(routineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for j := 0; j < iterations; j++ {
|
||||||
|
// Simulate race by having some goroutines check IsLikelyConnected
|
||||||
|
// while others add/remove the node
|
||||||
|
if routineID%3 == 0 {
|
||||||
|
// This goroutine checks connection status
|
||||||
|
isConnected := notifier.IsLikelyConnected(nodeID)
|
||||||
|
if isConnected != true && isConnected != false {
|
||||||
|
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
||||||
|
}
|
||||||
|
} else if routineID%3 == 1 {
|
||||||
|
// This goroutine removes the node
|
||||||
|
notifier.RemoveNode(nodeID, updateChan)
|
||||||
|
} else {
|
||||||
|
// This goroutine adds the node back
|
||||||
|
notifier.AddNode(nodeID, updateChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small random delay to increase chance of races
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errChan)
|
||||||
|
|
||||||
|
// Collate errors
|
||||||
|
var errors []string
|
||||||
|
for err := range errChan {
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user