state/nodestore: in memory representation of nodes

Initial work on a nodestore which stores all of the nodes
and their relations in memory with relationship for peers
precalculated.

It is a copy-on-write structure, replacing the "snapshot"
when a change to the structure occurs. It is optimised for reads,
and while batches are not fast, they are grouped together
to do less of the expensive peer calculation if there are many
changes rapidly.

Writes will block until commited, while reads are never
blocked.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-07-05 23:30:47 +02:00
committed by Kristoffer Dalby
parent 38be30b6d4
commit 9d236571f4
35 changed files with 3960 additions and 1317 deletions

View File

@@ -0,0 +1,403 @@
package state
import (
"fmt"
"maps"
"strings"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"tailscale.com/types/key"
"tailscale.com/types/views"
)
const (
batchSize = 10
batchTimeout = 500 * time.Millisecond
)
const (
put = 1
del = 2
update = 3
)
const prometheusNamespace = "headscale"
var (
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "nodestore_operations_total",
Help: "Total number of NodeStore operations",
}, []string{"operation"})
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_operation_duration_seconds",
Help: "Duration of NodeStore operations",
Buckets: prometheus.DefBuckets,
}, []string{"operation"})
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_batch_size",
Help: "Size of NodeStore write batches",
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
})
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_batch_duration_seconds",
Help: "Duration of NodeStore batch processing",
Buckets: prometheus.DefBuckets,
})
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_snapshot_build_duration_seconds",
Help: "Duration of NodeStore snapshot building from nodes",
Buckets: prometheus.DefBuckets,
})
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_nodes_total",
Help: "Total number of nodes in the NodeStore",
})
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_peers_calculation_duration_seconds",
Help: "Duration of peers calculation in NodeStore",
Buckets: prometheus.DefBuckets,
})
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_queue_depth",
Help: "Current depth of NodeStore write queue",
})
)
// NodeStore is a thread-safe store for nodes.
// It is a copy-on-write structure, replacing the "snapshot"
// when a change to the structure occurs. It is optimised for reads,
// and while batches are not fast, they are grouped together
// to do less of the expensive peer calculation if there are many
// changes rapidly.
//
// Writes will block until committed, while reads are never
// blocked. This means that the caller of a write operation
// is responsible for ensuring an update depending on a write
// is not issued before the write is complete.
type NodeStore struct {
data atomic.Pointer[Snapshot]
peersFunc PeersFunc
writeQueue chan work
}
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
nodes := make(map[types.NodeID]types.Node, len(allNodes))
for _, n := range allNodes {
nodes[n.ID] = *n
}
snap := snapshotFromNodes(nodes, peersFunc)
store := &NodeStore{
peersFunc: peersFunc,
}
store.data.Store(&snap)
// Initialize node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
return store
}
// Snapshot is the representation of the current state of the NodeStore.
// It contains all nodes and their relationships.
// It is a copy-on-write structure, meaning that when a write occurs,
// a new Snapshot is created with the updated state,
// and replaces the old one atomically.
type Snapshot struct {
// nodesByID is the main source of truth for nodes.
nodesByID map[types.NodeID]types.Node
// calculated from nodesByID
nodesByNodeKey map[key.NodePublic]types.NodeView
peersByNode map[types.NodeID][]types.NodeView
nodesByUser map[types.UserID][]types.NodeView
allNodes []types.NodeView
}
// PeersFunc is a function that takes a list of nodes and returns a map
// with the relationships between nodes and their peers.
// This will typically be used to calculate which nodes can see each other
// based on the current policy.
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
// work represents a single operation to be performed on the NodeStore.
type work struct {
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
}
// PutNode adds or updates a node in the store.
// If the node already exists, it will be replaced.
// If the node does not exist, it will be added.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) PutNode(n types.Node) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
defer timer.ObserveDuration()
work := work{
op: put,
nodeID: n.ID,
node: n,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("put").Inc()
}
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
type UpdateNodeFunc func(n *types.Node)
// UpdateNode applies a function to modify a specific node in the store.
// This is a blocking operation that waits for the write to complete.
// This is analogous to a database "transaction", or, the caller should
// rather collect all data they want to change, and then call this function.
// Fewer calls are better.
//
// TODO(kradalby): Technically we could have a version of this that modifies the node
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
// This is because the main nodesByID map contains the struct, and every other map is using a
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
// a lock around the nodesByID map to ensure that no other writes are happening
// while we are modifying the node. Which mean we would need to implement read-write locks
// on all read operations.
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
defer timer.ObserveDuration()
work := work{
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("update").Inc()
}
// DeleteNode removes a node from the store by its ID.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) DeleteNode(id types.NodeID) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
defer timer.ObserveDuration()
work := work{
op: del,
nodeID: id,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("delete").Inc()
}
// Start initializes the NodeStore and starts processing the write queue.
func (s *NodeStore) Start() {
s.writeQueue = make(chan work)
go s.processWrite()
}
// Stop stops the NodeStore.
func (s *NodeStore) Stop() {
close(s.writeQueue)
}
// processWrite processes the write queue in batches.
func (s *NodeStore) processWrite() {
c := time.NewTicker(batchTimeout)
defer c.Stop()
batch := make([]work, 0, batchSize)
for {
select {
case w, ok := <-s.writeQueue:
if !ok {
// Channel closed, apply any remaining batch and exit
if len(batch) != 0 {
s.applyBatch(batch)
}
return
}
batch = append(batch, w)
if len(batch) >= batchSize {
s.applyBatch(batch)
batch = batch[:0]
c.Reset(batchTimeout)
}
case <-c.C:
if len(batch) != 0 {
s.applyBatch(batch)
batch = batch[:0]
}
c.Reset(batchTimeout)
}
}
}
// applyBatch applies a batch of work to the node store.
// This means that it takes a copy of the current nodes,
// then applies the batch of operations to that copy,
// runs any precomputation needed (like calculating peers),
// and finally replaces the snapshot in the store with the new one.
// The replacement of the snapshot is atomic, ensuring that reads
// are never blocked by writes.
// Each write item is blocked until the batch is applied to ensure
// the caller knows the operation is complete and do not send any
// updates that are dependent on a read that is yet to be written.
func (s *NodeStore) applyBatch(batch []work) {
timer := prometheus.NewTimer(nodeStoreBatchDuration)
defer timer.ObserveDuration()
nodeStoreBatchSize.Observe(float64(len(batch)))
nodes := make(map[types.NodeID]types.Node)
maps.Copy(nodes, s.data.Load().nodesByID)
for _, w := range batch {
switch w.op {
case put:
nodes[w.nodeID] = w.node
case update:
// Update the specific node identified by nodeID
if n, exists := nodes[w.nodeID]; exists {
w.updateFn(&n)
nodes[w.nodeID] = n
}
case del:
delete(nodes, w.nodeID)
}
}
newSnap := snapshotFromNodes(nodes, s.peersFunc)
s.data.Store(&newSnap)
// Update node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
for _, w := range batch {
close(w.result)
}
}
// snapshotFromNodes creates a new Snapshot from the provided nodes.
// It builds a lot of "indexes" to make lookups fast for datasets we
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
// This is not a fast operation, it is the "slow" part of our copy-on-write
// structure, but it allows us to have fast reads and efficient lookups.
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
defer timer.ObserveDuration()
allNodes := make([]types.NodeView, 0, len(nodes))
for _, n := range nodes {
allNodes = append(allNodes, n.View())
}
newSnap := Snapshot{
nodesByID: nodes,
allNodes: allNodes,
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
// peersByNode is most likely the most expensive operation,
// it will use the list of all nodes, combined with the
// current policy to precalculate which nodes are peers and
// can see each other.
peersByNode: func() map[types.NodeID][]types.NodeView {
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
defer peersTimer.ObserveDuration()
return peersFunc(allNodes)
}(),
nodesByUser: make(map[types.UserID][]types.NodeView),
}
// Build nodesByUser and nodesByNodeKey maps
for _, n := range nodes {
nodeView := n.View()
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
}
return newSnap
}
// GetNode retrieves a node by its ID.
// The bool indicates if the node exists or is available (like "err not found").
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
// it isn't an invalid node (this is more of a node error or node is broken).
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("get").Inc()
n, exists := s.data.Load().nodesByID[id]
if !exists {
return types.NodeView{}, false
}
return n.View(), true
}
// GetNodeByNodeKey retrieves a node by its NodeKey.
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
return s.data.Load().nodesByNodeKey[nodeKey]
}
// ListNodes returns a slice of all nodes in the store.
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list").Inc()
return views.SliceOf(s.data.Load().allNodes)
}
// ListPeers returns a slice of all peers for a given node ID.
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list_peers").Inc()
return views.SliceOf(s.data.Load().peersByNode[id])
}
// ListNodesByUser returns a slice of all nodes for a given user ID.
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
return views.SliceOf(s.data.Load().nodesByUser[uid])
}

View File

@@ -0,0 +1,501 @@
package state
import (
"net/netip"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
)
func TestSnapshotFromNodes(t *testing.T) {
tests := []struct {
name string
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
}{
{
name: "empty nodes",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := make(map[types.NodeID]types.Node)
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
return make(map[types.NodeID][]types.NodeView)
}
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "single node",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
assert.Len(t, snapshot.nodesByUser[1], 1)
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
},
},
{
name: "multiple nodes same user",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 1, "user1", "node2"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Each node sees the other as peer (but not itself)
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "multiple nodes different users",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 1, "user1", "node3"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// Each node should have 2 peers (all others, but not itself)
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
},
},
{
name: "odd-even peers filtering",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 3, "user3", "node3"),
4: createTestNode(4, 4, "user4", "node4"),
}
peersFunc := oddEvenPeersFunc
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 4)
assert.Len(t, snapshot.allNodes, 4)
assert.Len(t, snapshot.peersByNode, 4)
assert.Len(t, snapshot.nodesByUser, 4)
// Odd nodes should only see other odd nodes as peers
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// Even nodes should only see other even nodes as peers
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nodes, peersFunc := tt.setupFunc()
snapshot := snapshotFromNodes(nodes, peersFunc)
tt.validate(t, nodes, snapshot)
})
}
}
// Helper functions
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
now := time.Now()
machineKey := key.NewMachine()
nodeKey := key.NewNode()
discoKey := key.NewDisco()
ipv4 := netip.MustParseAddr("100.64.0.1")
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
return types.Node{
ID: nodeID,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: userID,
User: types.User{
Name: username,
DisplayName: username,
},
RegisterMethod: "test",
IPv4: &ipv4,
IPv6: &ipv6,
CreatedAt: now,
UpdatedAt: now,
}
}
// Peer functions
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
nodeIsOdd := node.ID()%2 == 1
for _, n := range nodes {
if n.ID() == node.ID() {
continue
}
peerIsOdd := n.ID()%2 == 1
// Only add peer if both are odd or both are even
if nodeIsOdd == peerIsOdd {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func TestNodeStoreOperations(t *testing.T) {
tests := []struct {
name string
setupFunc func(t *testing.T) *NodeStore
steps []testStep
}{
{
name: "create empty store and add single node",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify empty store",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "add first node",
action: func(store *NodeStore) {
node := createTestNode(1, 1, "user1", "node1")
store.PutNode(node)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
assert.Len(t, snapshot.nodesByUser[1], 1)
},
},
},
},
{
name: "create store with initial node and add more",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
initialNodes := types.Nodes{&node1}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial state",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
assert.Empty(t, snapshot.peersByNode[1])
},
},
{
name: "add second node same user",
action: func(store *NodeStore) {
node2 := createTestNode(2, 1, "user1", "node2")
store.PutNode(node2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Now both nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "add third node different user",
action: func(store *NodeStore) {
node3 := createTestNode(3, 2, "user2", "node3")
store.PutNode(node3)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// All nodes should see the other 2 as peers
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
},
},
},
},
{
name: "test node deletion",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
node3 := createTestNode(3, 2, "user2", "node3")
initialNodes := types.Nodes{&node1, &node2, &node3}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial 3 nodes",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
},
},
{
name: "delete middle node",
action: func(store *NodeStore) {
store.DeleteNode(2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 2)
// Node 2 should be gone
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
// Remaining nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// User groupings updated
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
},
},
{
name: "delete all remaining nodes",
action: func(store *NodeStore) {
store.DeleteNode(1)
store.DeleteNode(3)
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
},
},
{
name: "test node updates",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial hostnames",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
},
},
{
name: "update node hostname",
action: func(store *NodeStore) {
store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "updated-node1"
n.GivenName = "updated-node1"
})
snapshot := store.data.Load()
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
// Peers should still work correctly
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Len(t, snapshot.peersByNode[2], 1)
},
},
},
},
{
name: "test with odd-even peers filtering",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, oddEvenPeersFunc)
},
steps: []testStep{
{
name: "add nodes with odd-even filtering",
action: func(store *NodeStore) {
// Add nodes in sequence
store.PutNode(createTestNode(1, 1, "user1", "node1"))
store.PutNode(createTestNode(2, 2, "user2", "node2"))
store.PutNode(createTestNode(3, 3, "user3", "node3"))
store.PutNode(createTestNode(4, 4, "user4", "node4"))
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 4)
// Verify odd-even peer relationships
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
{
name: "delete odd node and verify even nodes unaffected",
action: func(store *NodeStore) {
store.DeleteNode(1)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
// Node 3 (odd) should now have no peers
assert.Empty(t, snapshot.peersByNode[3])
// Even nodes should still see each other
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := tt.setupFunc(t)
store.Start()
defer store.Stop()
for _, step := range tt.steps {
t.Run(step.name, func(t *testing.T) {
step.action(store)
})
}
})
}
}
type testStep struct {
name string
action func(store *NodeStore)
}

File diff suppressed because it is too large Load Diff