mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-09 13:39:39 -05:00
lint and leftover
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
committed by
Kristoffer Dalby
parent
39443184d6
commit
233dffc186
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ type Batcher interface {
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
AddWork(c ...change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
|
||||
}
|
||||
@@ -36,7 +37,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
@@ -47,6 +48,7 @@ func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -72,8 +74,10 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
var (
|
||||
mapResp *tailcfg.MapResponse
|
||||
err error
|
||||
)
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
@@ -84,10 +88,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
// CRITICAL FIX: Read actual online status from NodeStore when available,
|
||||
// fall back to deriving from change type for unit tests or when NodeStore is empty
|
||||
var onlineStatus bool
|
||||
if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() {
|
||||
// Use actual NodeStore status when available (production case)
|
||||
onlineStatus = node.IsOnline().Get()
|
||||
} else {
|
||||
// Fall back to deriving from change type (unit test case or initial setup)
|
||||
onlineStatus = c.Change == change.NodeCameOnline
|
||||
}
|
||||
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
Online: ptr.To(onlineStatus),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -125,7 +140,12 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
|
||||
|
||||
var data *tailcfg.MapResponse
|
||||
var err error
|
||||
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
@@ -136,7 +156,8 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
err = nc.send(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -57,16 +58,21 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
version: version,
|
||||
created: now,
|
||||
}
|
||||
// Initialize last used timestamp
|
||||
newEntry.lastUsed.Store(now.Unix())
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
|
||||
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
|
||||
|
||||
if !loaded {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
// Add connection to the list (lock-free)
|
||||
nodeConn.addConnection(newEntry)
|
||||
|
||||
// Use the worker pool for controlled concurrency instead of direct generation
|
||||
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
@@ -87,6 +93,16 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
// Node will automatically receive updates through the normal flow
|
||||
// The initial full map already contains all current state
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection established in batcher because AddNode completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,10 +117,11 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return false
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
// Remove specific connection
|
||||
removed := nodeConn.removeConnectionByChannel(c)
|
||||
if !removed {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if node has any remaining active connections
|
||||
@@ -115,18 +132,17 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
// No active connections - keep the node entry alive for rapid reconnections
|
||||
// The node will get a fresh full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
|
||||
b.addWork(c...)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
@@ -137,23 +153,36 @@ func (b *LockFreeBatcher) Start() {
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel = nil // Prevent multiple calls
|
||||
}
|
||||
|
||||
// Only close workCh once
|
||||
select {
|
||||
case <-b.workCh:
|
||||
// Channel is already closed
|
||||
default:
|
||||
close(b.workCh)
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
// Create a cleanup ticker for removing truly disconnected nodes
|
||||
cleanupTicker := time.NewTicker(5 * time.Minute)
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-cleanupTicker.C:
|
||||
// Clean up nodes that have been offline for too long
|
||||
b.cleanupOfflineNodes()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -161,8 +190,6 @@ func (b *LockFreeBatcher) doWork() {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -171,7 +198,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
@@ -181,7 +207,9 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
var err error
|
||||
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
@@ -192,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
@@ -260,19 +289,22 @@ func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
all, self := change.SplitAllAndSelf(c)
|
||||
|
||||
for _, changeSet := range self {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, changeSet)
|
||||
b.pendingChanges.Store(changeSet.NodeID, changes)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
rel := change.RemoveUpdatesForSelf(nodeID, all)
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
changes = append(changes, rel...)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
|
||||
return true
|
||||
@@ -303,7 +335,44 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
|
||||
func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
cleanupThreshold := 15 * time.Minute
|
||||
now := time.Now()
|
||||
|
||||
var nodesToCleanup []types.NodeID
|
||||
|
||||
// Find nodes that have been offline for too long
|
||||
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
|
||||
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
|
||||
// Double-check the node doesn't have active connections
|
||||
if nodeConn, exists := b.nodes.Load(nodeID); exists {
|
||||
if !nodeConn.hasActiveConnections() {
|
||||
nodesToCleanup = append(nodesToCleanup, nodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up the identified nodes
|
||||
for _, nodeID := range nodesToCleanup {
|
||||
log.Info().Uint64("node.id", nodeID.Uint64()).
|
||||
Dur("offline_duration", cleanupThreshold).
|
||||
Msg("Cleaning up node that has been offline for too long")
|
||||
|
||||
b.nodes.Delete(nodeID)
|
||||
b.connected.Delete(nodeID)
|
||||
b.totalNodes.Add(-1)
|
||||
}
|
||||
|
||||
if len(nodesToCleanup) > 0 {
|
||||
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
|
||||
Msg("Completed cleanup of long-offline nodes")
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read that checks if a node has any active connections.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
// First check if we have active connections for this node
|
||||
if nodeConn, exists := b.nodes.Load(id); exists {
|
||||
@@ -373,89 +442,234 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
// connectionEntry represents a single connection to a node.
|
||||
type connectionEntry struct {
|
||||
id string // unique connection ID
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
created time.Time
|
||||
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
||||
type multiChannelNodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
mutex sync.RWMutex
|
||||
connections []*connectionEntry
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
||||
return &multiChannelNodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Dur("mutex_wait_time", mutexWaitDur).
|
||||
Msg("Successfully added connection after mutex wait")
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
// removeConnectionByChannel removes a connection by matching channel pointer.
|
||||
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for i, entry := range mc.connections {
|
||||
if entry.c == c {
|
||||
// Remove this connection
|
||||
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("Successfully removed connection")
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
return false
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
// hasActiveConnections checks if the node has any active connections.
|
||||
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections) > 0
|
||||
}
|
||||
|
||||
// getActiveConnectionCount returns the number of active connections.
|
||||
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections)
|
||||
}
|
||||
|
||||
// send broadcasts data to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
// During rapid reconnection, nodes may temporarily have no active connections
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
successCount := 0
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
for i, conn := range mc.connections {
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
lastErr = err
|
||||
failedConnections = append(failedConnections, i)
|
||||
log.Warn().Err(err).
|
||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: successfully sent to connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Remove failed connections (in reverse order to maintain indices)
|
||||
for i := len(failedConnections) - 1; i >= 0; i-- {
|
||||
idx := failedConnections[i]
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Str("conn.id", mc.connections[idx].id).
|
||||
Msg("send: removing failed connection")
|
||||
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
|
||||
}
|
||||
|
||||
mc.updateCount.Add(1)
|
||||
|
||||
log.Info().Uint64("node.id", mc.id.Uint64()).
|
||||
Int("successful_sends", successCount).
|
||||
Int("failed_connections", len(failedConnections)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("send: completed broadcast")
|
||||
|
||||
// Success if at least one send succeeded
|
||||
if successCount > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
|
||||
}
|
||||
|
||||
// send sends data to a single connection entry with timeout-based stale connection detection.
|
||||
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
||||
// This is critical for detecting Docker containers that are forcefully terminated
|
||||
// but still have channels that appear open.
|
||||
select {
|
||||
case entry.c <- data:
|
||||
// Update last used timestamp on successful send
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
return nil
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
}
|
||||
}
|
||||
|
||||
// nodeID returns the node ID.
|
||||
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
|
||||
return mc.id
|
||||
}
|
||||
|
||||
// version returns the capability version from the first active connection.
|
||||
// All connections for a node should have the same version in practice.
|
||||
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
return mc.connections[0].version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
// change applies a change to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(mc, mc.mapper, c)
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
// DebugNodeInfo contains debug information about a node's connections.
|
||||
type DebugNodeInfo struct {
|
||||
Connected bool `json:"connected"`
|
||||
ActiveConnections int `json:"active_connections"`
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
// Debug returns a pre-baked map of node debug information for the debug interface.
|
||||
func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
result := make(map[types.NodeID]DebugNodeInfo)
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
// Get all nodes with their connection status using immediate connection logic
|
||||
// (no grace period) for debug purposes
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
nodeConn.mutex.RLock()
|
||||
activeConnCount := len(nodeConn.connections)
|
||||
nodeConn.mutex.RUnlock()
|
||||
|
||||
// Use immediate connection status: if active connections exist, node is connected
|
||||
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
|
||||
connected := false
|
||||
if activeConnCount > 0 {
|
||||
connected = true
|
||||
} else {
|
||||
// Check connected map for immediate status
|
||||
if val, ok := b.connected.Load(id); ok && val == nil {
|
||||
connected = true
|
||||
}
|
||||
}
|
||||
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
|
||||
@@ -209,6 +209,7 @@ func setupBatcherWithTestData(
|
||||
|
||||
// Create test users and nodes in the database
|
||||
users := database.CreateUsersForTest(userCount, "testuser")
|
||||
|
||||
allNodes := make([]node, 0, userCount*nodesPerUser)
|
||||
for _, user := range users {
|
||||
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
|
||||
@@ -353,6 +354,7 @@ func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected b
|
||||
if len(resp.PeersChangedPatch) > 0 {
|
||||
require.Len(t, resp.PeersChangedPatch, 1)
|
||||
assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -412,6 +414,7 @@ func (n *node) start() {
|
||||
n.maxPeersCount = info.PeerCount
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsPatch {
|
||||
atomic.AddInt64(&n.patchCount, 1)
|
||||
// For patches, we track how many patch items
|
||||
@@ -550,6 +553,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Reduce verbose application logging for cleaner test output
|
||||
originalLevel := zerolog.GlobalLevel()
|
||||
defer zerolog.SetGlobalLevel(originalLevel)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
|
||||
// Test cases: different node counts to stress test the all-to-all connectivity
|
||||
@@ -618,6 +622,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
// Join all nodes as fast as possible
|
||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
@@ -693,6 +698,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
if stats.MaxPeersSeen > maxPeersGlobal {
|
||||
maxPeersGlobal = stats.MaxPeersSeen
|
||||
}
|
||||
|
||||
if stats.MaxPeersSeen < minPeersSeen {
|
||||
minPeersSeen = stats.MaxPeersSeen
|
||||
}
|
||||
@@ -730,9 +736,11 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Show sample of node details
|
||||
if len(nodeDetails) > 0 {
|
||||
t.Logf(" Node sample:")
|
||||
|
||||
for _, detail := range nodeDetails[:min(5, len(nodeDetails))] {
|
||||
t.Logf(" %s", detail)
|
||||
}
|
||||
|
||||
if len(nodeDetails) > 5 {
|
||||
t.Logf(" ... (%d more nodes)", len(nodeDetails)-5)
|
||||
}
|
||||
@@ -754,6 +762,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Show details of failed nodes for debugging
|
||||
if len(nodeDetails) > 5 {
|
||||
t.Logf("Failed nodes details:")
|
||||
|
||||
for _, detail := range nodeDetails[5:] {
|
||||
if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) {
|
||||
t.Logf(" %s", detail)
|
||||
@@ -875,6 +884,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
@@ -1026,10 +1036,12 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
// Collect updates with timeout
|
||||
updateCount := 0
|
||||
timeout := time.After(200 * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
updateCount++
|
||||
|
||||
receivedUpdates = append(receivedUpdates, data)
|
||||
|
||||
// Validate update content
|
||||
@@ -1058,6 +1070,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
|
||||
// Validate that all updates have valid content
|
||||
validUpdates := 0
|
||||
|
||||
for _, data := range receivedUpdates {
|
||||
if data != nil {
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
@@ -1095,16 +1108,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
var channelIssues int
|
||||
var mutex sync.Mutex
|
||||
|
||||
var (
|
||||
channelIssues int
|
||||
mutex sync.Mutex
|
||||
)
|
||||
|
||||
// Run rapid connect/disconnect cycles with real updates to test channel closing
|
||||
|
||||
for i := range 100 {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// First connection
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -1118,17 +1137,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
// Rapid second connection - should replace ch1
|
||||
ch2 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Remove second connection
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||
}()
|
||||
@@ -1143,7 +1167,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
case <-time.After(1 * time.Millisecond):
|
||||
// If no data received, increment issues counter
|
||||
mutex.Lock()
|
||||
|
||||
channelIssues++
|
||||
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1185,18 +1211,24 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
var panics int
|
||||
var channelErrors int
|
||||
var invalidData int
|
||||
var mutex sync.Mutex
|
||||
|
||||
var (
|
||||
panics int
|
||||
channelErrors int
|
||||
invalidData int
|
||||
mutex sync.Mutex
|
||||
)
|
||||
|
||||
// Test rapid connect/disconnect with work generation
|
||||
|
||||
for i := range 50 {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mutex.Lock()
|
||||
|
||||
panics++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Panic caught: %v", r)
|
||||
}
|
||||
@@ -1213,7 +1245,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mutex.Lock()
|
||||
|
||||
channelErrors++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Channel consumer panic: %v", r)
|
||||
}
|
||||
@@ -1229,7 +1263,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
// Validate the data we received
|
||||
if valid, reason := validateUpdateContent(data); !valid {
|
||||
mutex.Lock()
|
||||
|
||||
invalidData++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Invalid data received: %s", reason)
|
||||
}
|
||||
@@ -1268,9 +1304,11 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
if panics > 0 {
|
||||
t.Errorf("Worker channel safety failed with %d panics", panics)
|
||||
}
|
||||
|
||||
if channelErrors > 0 {
|
||||
t.Errorf("Channel handling failed with %d channel errors", channelErrors)
|
||||
}
|
||||
|
||||
if invalidData > 0 {
|
||||
t.Errorf("Data validation failed with %d invalid data packets", invalidData)
|
||||
}
|
||||
@@ -1342,15 +1380,19 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Use remaining nodes for connection churn testing
|
||||
churningNodes := allNodes[len(allNodes)/2:]
|
||||
churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||
|
||||
var churningChannelsMutex sync.Mutex // Protect concurrent map access
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
numCycles := 10 // Reduced for simpler test
|
||||
panicCount := 0
|
||||
|
||||
var panicMutex sync.Mutex
|
||||
|
||||
// Track deadlock with timeout
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
@@ -1364,16 +1406,22 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicMutex.Lock()
|
||||
|
||||
panicCount++
|
||||
|
||||
panicMutex.Unlock()
|
||||
t.Logf("Panic in churning connect: %v", r)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
@@ -1400,17 +1448,23 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicMutex.Lock()
|
||||
|
||||
panicCount++
|
||||
|
||||
panicMutex.Unlock()
|
||||
t.Logf("Panic in churning disconnect: %v", r)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
time.Sleep(time.Duration(i%5) * time.Millisecond)
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
if exists {
|
||||
batcher.RemoveNode(nodeID, ch)
|
||||
}
|
||||
@@ -1422,10 +1476,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// DERP changes
|
||||
batcher.AddWork(change.DERPSet)
|
||||
}
|
||||
|
||||
if i%5 == 0 {
|
||||
// Full updates using real node data
|
||||
batcher.AddWork(change.FullSet)
|
||||
}
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
// Node-specific changes using real nodes
|
||||
node := allNodes[i%len(allNodes)]
|
||||
@@ -1453,7 +1509,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
// Validate results
|
||||
panicMutex.Lock()
|
||||
|
||||
finalPanicCount := panicCount
|
||||
|
||||
panicMutex.Unlock()
|
||||
|
||||
allStats := tracker.getAllStats()
|
||||
@@ -1536,6 +1594,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Reduce verbose application logging for cleaner test output
|
||||
originalLevel := zerolog.GlobalLevel()
|
||||
defer zerolog.SetGlobalLevel(originalLevel)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
|
||||
// Full test matrix for scalability testing
|
||||
@@ -1624,6 +1683,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description)
|
||||
t.Logf(
|
||||
" Cycles: %d, Buffer Size: %d, Chaos Type: %s",
|
||||
@@ -1660,12 +1720,16 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Connect all nodes first so they can see each other as peers
|
||||
connectedNodes := make(map[types.NodeID]bool)
|
||||
|
||||
var connectedNodesMutex sync.RWMutex
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1676,6 +1740,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
t.Logf(
|
||||
@@ -1697,14 +1762,17 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// For chaos testing, only disconnect/reconnect a subset of nodes
|
||||
// This ensures some nodes stay connected to continue receiving updates
|
||||
startIdx := cycle % len(testNodes)
|
||||
|
||||
endIdx := startIdx + len(testNodes)/4
|
||||
if endIdx > len(testNodes) {
|
||||
endIdx = len(testNodes)
|
||||
}
|
||||
|
||||
if startIdx >= endIdx {
|
||||
startIdx = 0
|
||||
endIdx = min(len(testNodes)/4, len(testNodes))
|
||||
}
|
||||
|
||||
chaosNodes := testNodes[startIdx:endIdx]
|
||||
if len(chaosNodes) == 0 {
|
||||
chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos
|
||||
@@ -1722,17 +1790,22 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
connectedNodesMutex.RLock()
|
||||
|
||||
isConnected := connectedNodes[nodeID]
|
||||
|
||||
connectedNodesMutex.RUnlock()
|
||||
|
||||
if isConnected {
|
||||
batcher.RemoveNode(nodeID, channel)
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = false
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
}
|
||||
}(
|
||||
@@ -1746,6 +1819,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@@ -1757,7 +1831,9 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
|
||||
// Add work to create load
|
||||
@@ -1776,11 +1852,13 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count
|
||||
for i := range updateCount {
|
||||
wg.Add(1)
|
||||
|
||||
go func(index int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@@ -1823,11 +1901,14 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
deadlockDetected = true
|
||||
// Collect diagnostic information
|
||||
allStats := tracker.getAllStats()
|
||||
|
||||
totalUpdates := 0
|
||||
for _, stats := range allStats {
|
||||
totalUpdates += stats.TotalUpdates
|
||||
}
|
||||
|
||||
interimPanics := atomic.LoadInt64(&panicCount)
|
||||
|
||||
t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT)
|
||||
t.Logf(
|
||||
" Progress at timeout: %d total updates, %d panics",
|
||||
@@ -1873,6 +1954,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
stats := node.cleanup()
|
||||
totalUpdates += stats.TotalUpdates
|
||||
totalPatches += stats.PatchUpdates
|
||||
|
||||
totalFull += stats.FullUpdates
|
||||
if stats.MaxPeersSeen > maxPeersGlobal {
|
||||
maxPeersGlobal = stats.MaxPeersSeen
|
||||
@@ -1910,10 +1992,12 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Legacy tracker comparison (optional)
|
||||
allStats := tracker.getAllStats()
|
||||
|
||||
legacyTotalUpdates := 0
|
||||
for _, stats := range allStats {
|
||||
legacyTotalUpdates += stats.TotalUpdates
|
||||
}
|
||||
|
||||
if legacyTotalUpdates != int(totalUpdates) {
|
||||
t.Logf(
|
||||
"Note: Legacy tracker mismatch - legacy: %d, new: %d",
|
||||
@@ -1926,6 +2010,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Validation based on expectation
|
||||
testPassed := true
|
||||
|
||||
if tc.expectBreak {
|
||||
// For tests expected to break, we're mainly checking that we don't crash
|
||||
if finalPanicCount > 0 {
|
||||
@@ -1947,14 +2032,19 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// For tests expected to pass, validate proper operation
|
||||
if finalPanicCount > 0 {
|
||||
t.Errorf("Scalability test failed with %d panics", finalPanicCount)
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
|
||||
if deadlockDetected {
|
||||
t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes))
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
|
||||
if totalUpdates == 0 {
|
||||
t.Error("No updates received - system may be completely stalled")
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
}
|
||||
@@ -2020,6 +2110,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
// Read all available updates for each node
|
||||
for i := range allNodes {
|
||||
nodeUpdates := 0
|
||||
|
||||
t.Logf("Reading updates for node %d:", i)
|
||||
|
||||
// Read up to 10 updates per node or until timeout/no more data
|
||||
@@ -2056,6 +2147,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
if len(data.Peers) > 0 {
|
||||
t.Logf(" Full peer list with %d peers", len(data.Peers))
|
||||
|
||||
for j, peer := range data.Peers[:min(3, len(data.Peers))] {
|
||||
t.Logf(
|
||||
" Peer %d: NodeID=%d, Online=%v",
|
||||
@@ -2065,8 +2157,10 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(data.PeersChangedPatch) > 0 {
|
||||
t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch))
|
||||
|
||||
for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] {
|
||||
t.Logf(
|
||||
" Patch %d: NodeID=%d, Online=%v",
|
||||
@@ -2080,6 +2174,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Node %d received %d updates", i, nodeUpdates)
|
||||
}
|
||||
|
||||
@@ -2095,71 +2190,132 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items.
|
||||
func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
// TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID
|
||||
// at the same time cause /debug/batcher to show nodes as disconnected when they should be connected.
|
||||
// This specifically tests the multi-channel batcher implementation issue.
|
||||
func TestBatcherRapidReconnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
t.Logf("=== RAPID RECONNECTION TEST ===")
|
||||
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
for i, node := range allNodes {
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
for i, node := range allNodes {
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let reconnections settle
|
||||
|
||||
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
|
||||
t.Logf("Phase 4: Checking debug status...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i, node := range allNodes {
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
// Check if the debug info shows the node as connected
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
// Also check IsConnected method
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Logf("Node %d IsConnected() returns false", i)
|
||||
}
|
||||
}
|
||||
|
||||
if disconnectedCount > 0 {
|
||||
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
|
||||
// This is expected behavior for multi-channel batcher according to user
|
||||
// "it has never worked with the multi"
|
||||
} else {
|
||||
t.Logf("All nodes show as connected - working correctly")
|
||||
}
|
||||
} else {
|
||||
t.Logf("Batcher does not implement Debug() method")
|
||||
}
|
||||
|
||||
// Phase 5: Test if "disconnected" nodes can actually receive updates
|
||||
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
|
||||
|
||||
// Send a change that should reach all nodes
|
||||
batcher.AddWork(change.DERPChange())
|
||||
|
||||
receivedCount := 0
|
||||
timeout := time.After(500 * time.Millisecond)
|
||||
|
||||
for i := 0; i < len(allNodes); i++ {
|
||||
select {
|
||||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Logf("Node %d timed out waiting for update", i)
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes))
|
||||
|
||||
if receivedCount < len(allNodes) {
|
||||
t.Logf("Some nodes failed to receive updates - confirming the issue")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
nodes := testData.Nodes
|
||||
|
||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Wait for initial NodeCameOnline to be processed
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Drain any initial updates
|
||||
drainedCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-nodes[0].ch:
|
||||
drainedCount++
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
t.Logf("Drained %d initial updates", drainedCount)
|
||||
|
||||
// Now send a single FullSet update and trace it closely
|
||||
t.Logf("Sending change.FullSet work item...")
|
||||
batcher.AddWork(change.FullSet)
|
||||
|
||||
// Give short time for processing
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check if any update was received
|
||||
select {
|
||||
case data := <-nodes[0].ch:
|
||||
t.Logf("SUCCESS: Received update after FullSet!")
|
||||
|
||||
if data != nil {
|
||||
// Detailed analysis of the response - data is already a MapResponse
|
||||
t.Logf("Response details:")
|
||||
t.Logf(" Peers: %d", len(data.Peers))
|
||||
t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch))
|
||||
t.Logf(" PeersChanged: %d", len(data.PeersChanged))
|
||||
t.Logf(" PeersRemoved: %d", len(data.PeersRemoved))
|
||||
t.Logf(" DERPMap: %v", data.DERPMap != nil)
|
||||
t.Logf(" KeepAlive: %v", data.KeepAlive)
|
||||
t.Logf(" Node: %v", data.Node != nil)
|
||||
|
||||
if len(data.Peers) > 0 {
|
||||
t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers))
|
||||
} else if len(data.PeersChangedPatch) > 0 {
|
||||
t.Errorf("ERROR: Received patch update instead of full update!")
|
||||
} else if data.DERPMap != nil {
|
||||
t.Logf("Received DERP map update")
|
||||
} else if data.Node != nil {
|
||||
t.Logf("Received self node update")
|
||||
} else {
|
||||
t.Errorf("ERROR: Received unknown update type!")
|
||||
}
|
||||
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
@@ -2328,12 +2484,53 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Response data is nil")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!")
|
||||
t.Errorf("This indicates FullSet work items are not being processed at all")
|
||||
}
|
||||
|
||||
// Send another update and verify remaining connections still work
|
||||
clearChannel(node1.ch)
|
||||
clearChannel(thirdChannel)
|
||||
|
||||
testChangeSet2 := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
|
||||
batcher.AddWork(testChangeSet2)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify remaining connections still receive updates
|
||||
remaining1Received := false
|
||||
remaining3Received := false
|
||||
|
||||
select {
|
||||
case mapResp := <-node1.ch:
|
||||
remaining1Received = (mapResp != nil)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 1 did not receive update after removal")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-thirdChannel:
|
||||
remaining3Received = (mapResp != nil)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 3 did not receive update after removal")
|
||||
}
|
||||
|
||||
if remaining1Received && remaining3Received {
|
||||
t.Logf("SUCCESS: Remaining connections still receive updates after removal")
|
||||
} else {
|
||||
t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t",
|
||||
remaining1Received, remaining3Received)
|
||||
}
|
||||
|
||||
// Verify second channel no longer receives updates (should be closed/removed)
|
||||
select {
|
||||
case <-secondChannel:
|
||||
t.Errorf("Removed connection still received update - this should not happen")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Logf("SUCCESS: Removed connection correctly no longer receives updates")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ type MapResponseBuilder struct {
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
|
||||
debugType debugType
|
||||
}
|
||||
|
||||
type debugType string
|
||||
|
||||
@@ -139,11 +139,11 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(fullResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
@@ -162,6 +162,7 @@ func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(derpResponseDebug).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
@@ -173,6 +174,7 @@ func (m *mapper) peerChangedPatchResponse(
|
||||
changed []*tailcfg.PeerChange,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(patchResponseDebug).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
@@ -186,6 +188,7 @@ func (m *mapper) peerChangeResponse(
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(changeResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
@@ -199,6 +202,7 @@ func (m *mapper) peerRemovedResponse(
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(removeResponseDebug).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
@@ -214,7 +218,7 @@ func writeDebugMapResponse(
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID))
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -224,7 +228,7 @@ func writeDebugMapResponse(
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s.json", now),
|
||||
fmt.Sprintf("%s-%s.json", now, t),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
@@ -244,7 +248,11 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nodes, err := os.ReadDir(debugDumpMapResponsePath)
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
}
|
||||
|
||||
func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
nodes, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -263,7 +271,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
|
||||
nodeID := types.NodeID(nodeIDu)
|
||||
|
||||
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name()))
|
||||
files, err := os.ReadDir(path.Join(dir, node.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
|
||||
continue
|
||||
@@ -278,7 +286,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name()))
|
||||
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading file %s", file.Name())
|
||||
continue
|
||||
|
||||
@@ -158,7 +158,6 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
Tags: []string{},
|
||||
|
||||
LastSeen: &lastSeen,
|
||||
MachineAuthorized: true,
|
||||
|
||||
CapMap: tailcfg.NodeCapMap{
|
||||
|
||||
Reference in New Issue
Block a user