state/nodestore: in memory representation of nodes

Initial work on a nodestore which stores all of the nodes
and their relations in memory with relationship for peers
precalculated.

It is a copy-on-write structure, replacing the "snapshot"
when a change to the structure occurs. It is optimised for reads,
and while batches are not fast, they are grouped together
to do less of the expensive peer calculation if there are many
changes rapidly.

Writes will block until commited, while reads are never
blocked.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-07-05 23:30:47 +02:00
committed by Kristoffer Dalby
parent 38be30b6d4
commit 9d236571f4
35 changed files with 3960 additions and 1317 deletions

View File

@@ -27,6 +27,60 @@ type batcherTestCase struct {
fn batcherFunc
}
// testBatcherWrapper wraps a real batcher to add online/offline notifications
// that would normally be sent by poll.go in production.
type testBatcherWrapper struct {
Batcher
state *state.State
}
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
// Mark node as online in state before AddNode to match production behavior
// This ensures the NodeStore has correct online status for change processing
if t.state != nil {
// Use Connect to properly mark node online in NodeStore but don't send its changes
_ = t.state.Connect(id)
}
// First add the node to the real batcher
err := t.Batcher.AddNode(id, c, version)
if err != nil {
return err
}
// Send the online notification that poll.go would normally send
// This ensures other nodes get notified about this node coming online
t.AddWork(change.NodeOnline(id))
return nil
}
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
// Mark node as offline in state BEFORE removing from batcher
// This ensures the NodeStore has correct offline status when the change is processed
if t.state != nil {
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
_, _ = t.state.Disconnect(id)
}
// Send the offline notification that poll.go would normally send
// Do this BEFORE removing from batcher so the change can be processed
t.AddWork(change.NodeOffline(id))
// Finally remove from the real batcher
removed := t.Batcher.RemoveNode(id, c)
if !removed {
return false
}
return true
}
// wrapBatcherForTest wraps a batcher with test-specific behavior.
func wrapBatcherForTest(b Batcher, state *state.State) Batcher {
return &testBatcherWrapper{Batcher: b, state: state}
}
// allBatcherFunctions contains all batcher implementations to test.
var allBatcherFunctions = []batcherTestCase{
{"LockFree", NewBatcherAndMapper},
@@ -183,8 +237,8 @@ func setupBatcherWithTestData(
"acls": [
{
"action": "accept",
"users": ["*"],
"ports": ["*:*"]
"src": ["*"],
"dst": ["*:*"]
}
]
}`
@@ -194,8 +248,8 @@ func setupBatcherWithTestData(
t.Fatalf("Failed to set allow-all policy: %v", err)
}
// Create batcher with the state
batcher := bf(cfg, state)
// Create batcher with the state and wrap it for testing
batcher := wrapBatcherForTest(bf(cfg, state), state)
batcher.Start()
testData := &TestData{
@@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start()
// Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
time.Sleep(100 * time.Millisecond) // Let connection settle
// Generate some work
@@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
for i := range allNodes {
node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullSet)
@@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Disconnect all nodes
for i := range allNodes {
node := &allNodes[i]
batcher.RemoveNode(node.n.ID, node.ch, false)
batcher.RemoveNode(node.n.ID, node.ch)
}
// Give time for final updates to process
@@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) {
tn2 := testData.Nodes[1]
// Test AddNode with real node ID
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
batcher.AddNode(tn.n.ID, tn.ch, 100)
if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode")
}
@@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) {
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
// Add the second node and verify update message
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected.
select {
case data := <-tn.ch:
assertOnlineMapResponse(t, data, true)
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
t.Error("Did not receive expected Online response update")
}
@@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) {
}
// Disconnect the second node
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
assert.False(t, batcher.IsConnected(tn2.n.ID))
batcher.RemoveNode(tn2.n.ID, tn2.ch)
// Note: IsConnected may return true during grace period for DNS resolution
// First node should get update that second has disconnected.
select {
case data := <-tn.ch:
assertOnlineMapResponse(t, data, false)
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
t.Error("Did not receive expected Online response update")
}
@@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) {
// }
// Test RemoveNode
batcher.RemoveNode(tn.n.ID, tn.ch, false)
if batcher.IsConnected(tn.n.ID) {
t.Error("Node should be disconnected after RemoveNode")
}
batcher.RemoveNode(tn.n.ID, tn.ch)
// Note: IsConnected may return true during grace period for DNS resolution
// The node is actually removed from active connections but grace period allows DNS lookups
})
}
}
@@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10)
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
// Track update content for validation
var receivedUpdates []*tailcfg.MapResponse
@@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
}()
// Add real work during connection chaos
@@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
go func() {
defer wg.Done()
time.Sleep(1 * time.Microsecond)
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
}()
// Remove second connection
@@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
go func() {
defer wg.Done()
time.Sleep(2 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch2, false)
batcher.RemoveNode(testNode.n.ID, ch2)
}()
wg.Wait()
@@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPSet)
// Consumer goroutine to validate data and detect channel issues
@@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Rapid removal creates race between worker and removal
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch, false)
batcher.RemoveNode(testNode.n.ID, ch)
// Give workers time to process and close channels
time.Sleep(5 * time.Millisecond)
@@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for _, node := range stableNodes {
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
// Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking
go func() {
@@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch, exists := churningChannels[nodeID]
churningChannelsMutex.Unlock()
if exists {
batcher.RemoveNode(nodeID, ch, false)
batcher.RemoveNode(nodeID, ch)
}
}(node.n.ID)
}
@@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
var connectedNodesMutex sync.RWMutex
for i := range testNodes {
node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true
connectedNodesMutex.Unlock()
@@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) {
connectedNodesMutex.RUnlock()
if isConnected {
batcher.RemoveNode(nodeID, channel, false)
batcher.RemoveNode(nodeID, channel)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = false
connectedNodesMutex.Unlock()
@@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) {
// Now disconnect all nodes from batcher to stop new updates
for i := range testNodes {
node := &testNodes[i]
batcher.RemoveNode(node.n.ID, node.ch, false)
batcher.RemoveNode(node.n.ID, node.ch)
}
// Give time for enhanced tracking goroutines to process any remaining data in channels
@@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Connect nodes one at a time to avoid overwhelming the work queue
for i, node := range allNodes {
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Small delay between connections to allow NodeCameOnline processing
time.Sleep(50 * time.Millisecond)
@@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Check how many peers each node should see
for i, node := range allNodes {
peers, err := testData.State.ListPeers(node.n.ID)
if err != nil {
t.Errorf("Error listing peers for node %d: %v", i, err)
} else {
t.Logf("Node %d should see %d peers from state", i, len(peers))
}
peers := testData.State.ListPeers(node.n.ID)
t.Logf("Node %d should see %d peers from state", i, peers.Len())
}
// Send a full update - this should generate full peer lists
@@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
foundFullUpdate := false
// Read all available updates for each node
for i := range len(allNodes) {
for i := range allNodes {
nodeUpdates := 0
t.Logf("Reading updates for node %d:", i)
@@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
t.Logf("=== WORK QUEUE TRACING TEST ===")
// Connect first node
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d", nodes[0].n.ID)
time.Sleep(100 * time.Millisecond) // Let connections settle
// Wait for initial NodeCameOnline to be processed
time.Sleep(200 * time.Millisecond)
@@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
t.Errorf("ERROR: Received unknown update type!")
}
// Check if there should be peers available
peers, err := testData.State.ListPeers(nodes[0].n.ID)
if err != nil {
t.Errorf("Error getting peers from state: %v", err)
} else {
t.Logf("State shows %d peers available for this node", len(peers))
if len(peers) > 0 && len(data.Peers) == 0 {
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers))
batcher := testData.Batcher
node1 := testData.Nodes[0]
node2 := testData.Nodes[1]
t.Logf("=== MULTI-CONNECTION TEST ===")
// Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node1: %v", err)
}
// Connect second node for comparison
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node2: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
debugInfo := debugBatcher.Debug()
if info, exists := debugInfo[node1.n.ID]; exists {
t.Logf("Node1 debug info: %+v", info)
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 3 {
t.Errorf("Node1 should have 3 active connections, got %d", activeConnections)
} else {
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
}
}
if connected, ok := infoMap["connected"].(bool); ok && !connected {
t.Errorf("Node1 should show as connected with 3 active connections")
}
}
}
if info, exists := debugInfo[node2.n.ID]; exists {
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 1 {
t.Errorf("Node2 should have 1 active connection, got %d", activeConnections)
}
}
}
}
}
// Phase 5: Send update and verify ALL connections receive it
t.Logf("Phase 5: Testing update distribution to all connections...")
// Clear any existing updates from all channels
clearChannel := func(ch chan *tailcfg.MapResponse) {
for {
select {
case <-ch:
// drain
default:
return
}
}
}
clearChannel(node1.ch)
clearChannel(secondChannel)
clearChannel(thirdChannel)
clearChannel(node2.ch)
// Send a change notification from node2 (so node1 should receive it on all connections)
testChangeSet := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
batcher.AddWork(testChangeSet)
time.Sleep(100 * time.Millisecond) // Let updates propagate
// Verify all three connections for node1 receive the update
connection1Received := false
connection2Received := false
connection3Received := false
select {
case mapResp := <-node1.ch:
connection1Received = (mapResp != nil)
t.Logf("Node1 connection 1 received update: %t", connection1Received)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 1 did not receive update")
}
select {
case mapResp := <-secondChannel:
connection2Received = (mapResp != nil)
t.Logf("Node1 connection 2 received update: %t", connection2Received)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 2 did not receive update")
}
select {
case mapResp := <-thirdChannel:
connection3Received = (mapResp != nil)
t.Logf("Node1 connection 3 received update: %t", connection3Received)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 3 did not receive update")
}
if connection1Received && connection2Received && connection3Received {
t.Logf("SUCCESS: All three connections for node1 received the update")
} else {
t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t",
connection1Received, connection2Received, connection3Received)
}
// Phase 6: Test connection removal and verify remaining connections still work
t.Logf("Phase 6: Testing connection removal...")
// Remove the second connection
removed := batcher.RemoveNode(node1.n.ID, secondChannel)
if !removed {
t.Errorf("Failed to remove second connection for node1")
}
time.Sleep(50 * time.Millisecond)
// Verify debug status shows 2 connections now
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
debugInfo := debugBatcher.Debug()
if info, exists := debugInfo[node1.n.ID]; exists {
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 2 {
t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections)
} else {
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
}
}
}
} else {