mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-09 13:39:39 -05:00
state/nodestore: in memory representation of nodes
Initial work on a nodestore which stores all of the nodes and their relations in memory with relationship for peers precalculated. It is a copy-on-write structure, replacing the "snapshot" when a change to the structure occurs. It is optimised for reads, and while batches are not fast, they are grouped together to do less of the expensive peer calculation if there are many changes rapidly. Writes will block until commited, while reads are never blocked. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
committed by
Kristoffer Dalby
parent
38be30b6d4
commit
9d236571f4
@@ -1,6 +1,7 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
@@ -120,7 +121,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
return errors.New("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
||||
@@ -21,8 +21,7 @@ type LockFreeBatcher struct {
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
@@ -32,7 +31,6 @@ type LockFreeBatcher struct {
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
@@ -45,65 +43,63 @@ type LockFreeBatcher struct {
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
addNodeStart := time.Now()
|
||||
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
// Generate connection ID
|
||||
connID := generateConnectionID()
|
||||
|
||||
// Create new connection entry
|
||||
now := time.Now()
|
||||
newEntry := &connectionEntry{
|
||||
id: connID,
|
||||
c: c,
|
||||
version: version,
|
||||
created: now,
|
||||
}
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
|
||||
var conn *nodeConn
|
||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||
// Update existing connection
|
||||
existing.updateConnection(c, version)
|
||||
conn = existing
|
||||
} else {
|
||||
if !loaded {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
if err := conn.send(initialMap); err != nil {
|
||||
// Clean up the connection state on send failure
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
// Use a blocking send with timeout for initial map since the channel should be ready
|
||||
// and we want to avoid the race condition where the receiver isn't ready yet
|
||||
select {
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
||||
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
}
|
||||
// It validates the connection channel matches one of the current connections, closes that specific connection,
|
||||
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
||||
// Reports if the node still has active connections after removal.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
nodeConn, exists := b.nodes.Load(id)
|
||||
if !exists {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher")
|
||||
return false
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
@@ -111,15 +107,20 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
// Check if node has any remaining active connections
|
||||
if nodeConn.hasActiveConnections() {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection removed but keeping online because other connections remain")
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
@@ -205,15 +206,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -221,16 +213,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Check if this connection is still active before processing
|
||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply change to node - this will handle offline nodes gracefully
|
||||
// and queue work for when they reconnect
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
@@ -240,52 +224,18 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow asynchronous work processing")
|
||||
}
|
||||
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
// For critical changes that need immediate processing, send directly
|
||||
if b.shouldProcessImmediately(c) {
|
||||
if c.SelfUpdateOnly {
|
||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||
return
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For non-critical changes, add to batch
|
||||
b.addToBatch(c)
|
||||
func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
|
||||
b.addToBatch(c...)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
// queueWork safely queues work.
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
@@ -298,26 +248,21 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
// addToBatch adds a change to the pending batch.
|
||||
func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||
// Short circuit if any of the changes is a full update, which
|
||||
// means we can skip sending individual changes.
|
||||
if change.HasFull(c) {
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
|
||||
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(c.NodeID, changes)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -329,15 +274,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
// processBatchedChanges processes all pending batched changes.
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
@@ -355,16 +298,31 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
// First check if we have active connections for this node
|
||||
if nodeConn, exists := b.nodes.Load(id); exists {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check disconnected timestamp with grace period
|
||||
val, ok := b.connected.Load(id)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// nil means connected
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -372,9 +330,26 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
// First, add all nodes with active connections
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Then add all entries from the connected map
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// nil means connected
|
||||
ret.Store(id, val == nil)
|
||||
// Only add if not already added as connected above
|
||||
if _, exists := ret.Load(id); !exists {
|
||||
if val == nil {
|
||||
// nil means connected
|
||||
ret.Store(id, true)
|
||||
} else {
|
||||
// timestamp means disconnected
|
||||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -482,12 +457,21 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// Only add if not already processed above
|
||||
if _, exists := result[id]; !exists {
|
||||
// Use immediate connection status for debug (no grace period)
|
||||
connected := (val == nil) // nil means connected, timestamp means disconnected
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
|
||||
@@ -27,6 +27,60 @@ type batcherTestCase struct {
|
||||
fn batcherFunc
|
||||
}
|
||||
|
||||
// testBatcherWrapper wraps a real batcher to add online/offline notifications
|
||||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
state *state.State
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
// Mark node as online in state before AddNode to match production behavior
|
||||
// This ensures the NodeStore has correct online status for change processing
|
||||
if t.state != nil {
|
||||
// Use Connect to properly mark node online in NodeStore but don't send its changes
|
||||
_ = t.state.Connect(id)
|
||||
}
|
||||
|
||||
// First add the node to the real batcher
|
||||
err := t.Batcher.AddNode(id, c, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send the online notification that poll.go would normally send
|
||||
// This ensures other nodes get notified about this node coming online
|
||||
t.AddWork(change.NodeOnline(id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
// Mark node as offline in state BEFORE removing from batcher
|
||||
// This ensures the NodeStore has correct offline status when the change is processed
|
||||
if t.state != nil {
|
||||
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
|
||||
_, _ = t.state.Disconnect(id)
|
||||
}
|
||||
|
||||
// Send the offline notification that poll.go would normally send
|
||||
// Do this BEFORE removing from batcher so the change can be processed
|
||||
t.AddWork(change.NodeOffline(id))
|
||||
|
||||
// Finally remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
func wrapBatcherForTest(b Batcher, state *state.State) Batcher {
|
||||
return &testBatcherWrapper{Batcher: b, state: state}
|
||||
}
|
||||
|
||||
// allBatcherFunctions contains all batcher implementations to test.
|
||||
var allBatcherFunctions = []batcherTestCase{
|
||||
{"LockFree", NewBatcherAndMapper},
|
||||
@@ -183,8 +237,8 @@ func setupBatcherWithTestData(
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"users": ["*"],
|
||||
"ports": ["*:*"]
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
@@ -194,8 +248,8 @@ func setupBatcherWithTestData(
|
||||
t.Fatalf("Failed to set allow-all policy: %v", err)
|
||||
}
|
||||
|
||||
// Create batcher with the state
|
||||
batcher := bf(cfg, state)
|
||||
// Create batcher with the state and wrap it for testing
|
||||
batcher := wrapBatcherForTest(bf(cfg, state), state)
|
||||
batcher.Start()
|
||||
|
||||
testData := &TestData{
|
||||
@@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
time.Sleep(100 * time.Millisecond) // Let connection settle
|
||||
|
||||
// Generate some work
|
||||
@@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullSet)
|
||||
@@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Disconnect all nodes
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
// Give time for final updates to process
|
||||
@@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
tn2 := testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
}
|
||||
@@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
select {
|
||||
case data := <-tn.ch:
|
||||
assertOnlineMapResponse(t, data, true)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Did not receive expected Online response update")
|
||||
}
|
||||
|
||||
@@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Disconnect the second node
|
||||
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
|
||||
assert.False(t, batcher.IsConnected(tn2.n.ID))
|
||||
batcher.RemoveNode(tn2.n.ID, tn2.ch)
|
||||
// Note: IsConnected may return true during grace period for DNS resolution
|
||||
|
||||
// First node should get update that second has disconnected.
|
||||
select {
|
||||
case data := <-tn.ch:
|
||||
assertOnlineMapResponse(t, data, false)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Did not receive expected Online response update")
|
||||
}
|
||||
|
||||
@@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
// }
|
||||
|
||||
// Test RemoveNode
|
||||
batcher.RemoveNode(tn.n.ID, tn.ch, false)
|
||||
if batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be disconnected after RemoveNode")
|
||||
}
|
||||
batcher.RemoveNode(tn.n.ID, tn.ch)
|
||||
// Note: IsConnected may return true during grace period for DNS resolution
|
||||
// The node is actually removed from active connections but grace period allows DNS lookups
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
@@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
|
||||
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Add real work during connection chaos
|
||||
@@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Remove second connection
|
||||
@@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2, false)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
@@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPSet)
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
@@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
// Rapid removal creates race between worker and removal
|
||||
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch, false)
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Give workers time to process and close channels
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
@@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
for _, node := range stableNodes {
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
@@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
churningChannelsMutex.Lock()
|
||||
churningChannels[nodeID] = ch
|
||||
churningChannelsMutex.Unlock()
|
||||
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
@@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
ch, exists := churningChannels[nodeID]
|
||||
churningChannelsMutex.Unlock()
|
||||
if exists {
|
||||
batcher.RemoveNode(nodeID, ch, false)
|
||||
batcher.RemoveNode(nodeID, ch)
|
||||
}
|
||||
}(node.n.ID)
|
||||
}
|
||||
@@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
var connectedNodesMutex sync.RWMutex
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
connectedNodesMutex.Lock()
|
||||
connectedNodes[node.n.ID] = true
|
||||
connectedNodesMutex.Unlock()
|
||||
@@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
connectedNodesMutex.RUnlock()
|
||||
|
||||
if isConnected {
|
||||
batcher.RemoveNode(nodeID, channel, false)
|
||||
batcher.RemoveNode(nodeID, channel)
|
||||
connectedNodesMutex.Lock()
|
||||
connectedNodes[nodeID] = false
|
||||
connectedNodesMutex.Unlock()
|
||||
@@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Now disconnect all nodes from batcher to stop new updates
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
||||
@@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Connect nodes one at a time to avoid overwhelming the work queue
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
// Small delay between connections to allow NodeCameOnline processing
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
@@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Check how many peers each node should see
|
||||
for i, node := range allNodes {
|
||||
peers, err := testData.State.ListPeers(node.n.ID)
|
||||
if err != nil {
|
||||
t.Errorf("Error listing peers for node %d: %v", i, err)
|
||||
} else {
|
||||
t.Logf("Node %d should see %d peers from state", i, len(peers))
|
||||
}
|
||||
peers := testData.State.ListPeers(node.n.ID)
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
|
||||
// Send a full update - this should generate full peer lists
|
||||
@@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
foundFullUpdate := false
|
||||
|
||||
// Read all available updates for each node
|
||||
for i := range len(allNodes) {
|
||||
for i := range allNodes {
|
||||
nodeUpdates := 0
|
||||
t.Logf("Reading updates for node %d:", i)
|
||||
|
||||
@@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
|
||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||
|
||||
// Connect first node
|
||||
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d", nodes[0].n.ID)
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Wait for initial NodeCameOnline to be processed
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
@@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
t.Errorf("ERROR: Received unknown update type!")
|
||||
}
|
||||
|
||||
// Check if there should be peers available
|
||||
peers, err := testData.State.ListPeers(nodes[0].n.ID)
|
||||
if err != nil {
|
||||
t.Errorf("Error getting peers from state: %v", err)
|
||||
} else {
|
||||
t.Logf("State shows %d peers available for this node", len(peers))
|
||||
if len(peers) > 0 && len(data.Peers) == 0 {
|
||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers))
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
|
||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
}
|
||||
|
||||
// Connect second node for comparison
|
||||
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node2: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
t.Errorf("Node1 should have 3 active connections, got %d", activeConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info, exists := debugInfo[node2.n.ID]; exists {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 1 {
|
||||
t.Errorf("Node2 should have 1 active connection, got %d", activeConnections)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Send update and verify ALL connections receive it
|
||||
t.Logf("Phase 5: Testing update distribution to all connections...")
|
||||
|
||||
// Clear any existing updates from all channels
|
||||
clearChannel := func(ch chan *tailcfg.MapResponse) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
// drain
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clearChannel(node1.ch)
|
||||
clearChannel(secondChannel)
|
||||
clearChannel(thirdChannel)
|
||||
clearChannel(node2.ch)
|
||||
|
||||
// Send a change notification from node2 (so node1 should receive it on all connections)
|
||||
testChangeSet := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
|
||||
batcher.AddWork(testChangeSet)
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let updates propagate
|
||||
|
||||
// Verify all three connections for node1 receive the update
|
||||
connection1Received := false
|
||||
connection2Received := false
|
||||
connection3Received := false
|
||||
|
||||
select {
|
||||
case mapResp := <-node1.ch:
|
||||
connection1Received = (mapResp != nil)
|
||||
t.Logf("Node1 connection 1 received update: %t", connection1Received)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 1 did not receive update")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-secondChannel:
|
||||
connection2Received = (mapResp != nil)
|
||||
t.Logf("Node1 connection 2 received update: %t", connection2Received)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 2 did not receive update")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-thirdChannel:
|
||||
connection3Received = (mapResp != nil)
|
||||
t.Logf("Node1 connection 3 received update: %t", connection3Received)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 3 did not receive update")
|
||||
}
|
||||
|
||||
if connection1Received && connection2Received && connection3Received {
|
||||
t.Logf("SUCCESS: All three connections for node1 received the update")
|
||||
} else {
|
||||
t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t",
|
||||
connection1Received, connection2Received, connection3Received)
|
||||
}
|
||||
|
||||
// Phase 6: Test connection removal and verify remaining connections still work
|
||||
t.Logf("Phase 6: Testing connection removal...")
|
||||
|
||||
// Remove the second connection
|
||||
removed := batcher.RemoveNode(node1.n.ID, secondChannel)
|
||||
if !removed {
|
||||
t.Errorf("Failed to remove second connection for node1")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify debug status shows 2 connections now
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 2 {
|
||||
t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
@@ -21,7 +22,17 @@ type MapResponseBuilder struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
type debugType string
|
||||
|
||||
const (
|
||||
fullResponseDebug debugType = "full"
|
||||
patchResponseDebug debugType = "patch"
|
||||
removeResponseDebug debugType = "remove"
|
||||
changeResponseDebug debugType = "change"
|
||||
derpResponseDebug debugType = "derp"
|
||||
)
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
@@ -35,32 +46,39 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
// addError adds an error to the builder's error list.
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
// hasErrors returns true if the builder has accumulated any errors.
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
// WithCapabilityVersion sets the capability version for the response.
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
// WithSelfNode adds the requesting node to the response.
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
// Always use batcher's view of online status for self node
|
||||
// The batcher respects grace periods for logout scenarios
|
||||
node := nodeView.AsStruct()
|
||||
// if b.mapper.batcher != nil {
|
||||
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||
// }
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
@@ -74,29 +92,38 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder {
|
||||
if debugDumpMapResponsePath != "" {
|
||||
b.debugType = t
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response.
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
// WithDomain adds the domain configuration.
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
// WithCollectServicesDisabled sets the collect services flag to false.
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
// It disables log tailing if the mapper's LogTail is not enabled.
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
@@ -104,53 +131,56 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node.
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
// WithDNSConfig adds DNS configuration for the requesting node.
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers.
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
// WithPacketFilters adds packet filter rules based on policy.
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -161,15 +191,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response).
|
||||
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@@ -177,12 +206,12 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates).
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@@ -190,14 +219,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
@@ -206,15 +236,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
changedViews = policy.ReduceNodes(node, peers, matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
changedViews = peers
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
@@ -229,19 +259,20 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
// WithPeerChangedPatch adds peer change patches.
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -251,11 +282,7 @@ func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeDebugMapResponse(b.resp, node)
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -69,16 +70,18 @@ func newMapper(
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@@ -95,7 +98,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@@ -115,12 +118,12 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
}
|
||||
|
||||
if len(node.IPs()) > 0 {
|
||||
@@ -138,10 +141,7 @@ func (m *mapper) fullMapResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
@@ -183,10 +183,7 @@ func (m *mapper) peerChangeResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
@@ -208,7 +205,8 @@ func (m *mapper) peerRemovedResponse(
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
node *types.Node,
|
||||
t debugType,
|
||||
nodeID types.NodeID,
|
||||
) {
|
||||
body, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
@@ -236,25 +234,6 @@ func writeDebugMapResponse(
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1,
|
||||
nodeInShared1.View(),
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
|
||||
@@ -133,13 +133,12 @@ func tailNode(
|
||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||
}
|
||||
|
||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
|
||||
// LastSeen is only set when node is
|
||||
// not connected to the control server.
|
||||
if node.LastSeen().Valid() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
tNode.LastSeen = &lastSeen
|
||||
}
|
||||
// Set LastSeen only for offline nodes to avoid confusing Tailscale clients
|
||||
// during rapid reconnection cycles. Online nodes should not have LastSeen set
|
||||
// as this can make clients interpret them as "not online" despite Online=true.
|
||||
if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
tNode.LastSeen = &lastSeen
|
||||
}
|
||||
|
||||
return &tNode, nil
|
||||
|
||||
Reference in New Issue
Block a user