mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-09 13:39:39 -05:00
state/nodestore: in memory representation of nodes
Initial work on a nodestore which stores all of the nodes and their relations in memory with relationship for peers precalculated. It is a copy-on-write structure, replacing the "snapshot" when a change to the structure occurs. It is optimised for reads, and while batches are not fast, they are grouped together to do less of the expensive peer calculation if there are many changes rapidly. Writes will block until commited, while reads are never blocked. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
committed by
Kristoffer Dalby
parent
38be30b6d4
commit
9d236571f4
403
hscontrol/state/node_store.go
Normal file
403
hscontrol/state/node_store.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
batchSize = 10
|
||||
batchTimeout = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
const (
|
||||
put = 1
|
||||
del = 2
|
||||
update = 3
|
||||
)
|
||||
|
||||
const prometheusNamespace = "headscale"
|
||||
|
||||
var (
|
||||
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_operations_total",
|
||||
Help: "Total number of NodeStore operations",
|
||||
}, []string{"operation"})
|
||||
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_operation_duration_seconds",
|
||||
Help: "Duration of NodeStore operations",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, []string{"operation"})
|
||||
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_batch_size",
|
||||
Help: "Size of NodeStore write batches",
|
||||
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
|
||||
})
|
||||
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_batch_duration_seconds",
|
||||
Help: "Duration of NodeStore batch processing",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_snapshot_build_duration_seconds",
|
||||
Help: "Duration of NodeStore snapshot building from nodes",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_nodes_total",
|
||||
Help: "Total number of nodes in the NodeStore",
|
||||
})
|
||||
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_peers_calculation_duration_seconds",
|
||||
Help: "Duration of peers calculation in NodeStore",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_queue_depth",
|
||||
Help: "Current depth of NodeStore write queue",
|
||||
})
|
||||
)
|
||||
|
||||
// NodeStore is a thread-safe store for nodes.
|
||||
// It is a copy-on-write structure, replacing the "snapshot"
|
||||
// when a change to the structure occurs. It is optimised for reads,
|
||||
// and while batches are not fast, they are grouped together
|
||||
// to do less of the expensive peer calculation if there are many
|
||||
// changes rapidly.
|
||||
//
|
||||
// Writes will block until committed, while reads are never
|
||||
// blocked. This means that the caller of a write operation
|
||||
// is responsible for ensuring an update depending on a write
|
||||
// is not issued before the write is complete.
|
||||
type NodeStore struct {
|
||||
data atomic.Pointer[Snapshot]
|
||||
|
||||
peersFunc PeersFunc
|
||||
writeQueue chan work
|
||||
}
|
||||
|
||||
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||
nodes := make(map[types.NodeID]types.Node, len(allNodes))
|
||||
for _, n := range allNodes {
|
||||
nodes[n.ID] = *n
|
||||
}
|
||||
snap := snapshotFromNodes(nodes, peersFunc)
|
||||
|
||||
store := &NodeStore{
|
||||
peersFunc: peersFunc,
|
||||
}
|
||||
store.data.Store(&snap)
|
||||
|
||||
// Initialize node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Snapshot is the representation of the current state of the NodeStore.
|
||||
// It contains all nodes and their relationships.
|
||||
// It is a copy-on-write structure, meaning that when a write occurs,
|
||||
// a new Snapshot is created with the updated state,
|
||||
// and replaces the old one atomically.
|
||||
type Snapshot struct {
|
||||
// nodesByID is the main source of truth for nodes.
|
||||
nodesByID map[types.NodeID]types.Node
|
||||
|
||||
// calculated from nodesByID
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
}
|
||||
|
||||
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||
// with the relationships between nodes and their peers.
|
||||
// This will typically be used to calculate which nodes can see each other
|
||||
// based on the current policy.
|
||||
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
|
||||
// work represents a single operation to be performed on the NodeStore.
|
||||
type work struct {
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
// If the node already exists, it will be replaced.
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
type UpdateNodeFunc func(n *types.Node)
|
||||
|
||||
// UpdateNode applies a function to modify a specific node in the store.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
// This is analogous to a database "transaction", or, the caller should
|
||||
// rather collect all data they want to change, and then call this function.
|
||||
// Fewer calls are better.
|
||||
//
|
||||
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||
// This is because the main nodesByID map contains the struct, and every other map is using a
|
||||
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
|
||||
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||
// on all read operations.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: del,
|
||||
nodeID: id,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("delete").Inc()
|
||||
}
|
||||
|
||||
// Start initializes the NodeStore and starts processing the write queue.
|
||||
func (s *NodeStore) Start() {
|
||||
s.writeQueue = make(chan work)
|
||||
go s.processWrite()
|
||||
}
|
||||
|
||||
// Stop stops the NodeStore.
|
||||
func (s *NodeStore) Stop() {
|
||||
close(s.writeQueue)
|
||||
}
|
||||
|
||||
// processWrite processes the write queue in batches.
|
||||
func (s *NodeStore) processWrite() {
|
||||
c := time.NewTicker(batchTimeout)
|
||||
defer c.Stop()
|
||||
batch := make([]work, 0, batchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-s.writeQueue:
|
||||
if !ok {
|
||||
// Channel closed, apply any remaining batch and exit
|
||||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
}
|
||||
return
|
||||
}
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= batchSize {
|
||||
s.applyBatch(batch)
|
||||
batch = batch[:0]
|
||||
c.Reset(batchTimeout)
|
||||
}
|
||||
case <-c.C:
|
||||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
batch = batch[:0]
|
||||
}
|
||||
c.Reset(batchTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyBatch applies a batch of work to the node store.
|
||||
// This means that it takes a copy of the current nodes,
|
||||
// then applies the batch of operations to that copy,
|
||||
// runs any precomputation needed (like calculating peers),
|
||||
// and finally replaces the snapshot in the store with the new one.
|
||||
// The replacement of the snapshot is atomic, ensuring that reads
|
||||
// are never blocked by writes.
|
||||
// Each write item is blocked until the batch is applied to ensure
|
||||
// the caller knows the operation is complete and do not send any
|
||||
// updates that are dependent on a read that is yet to be written.
|
||||
func (s *NodeStore) applyBatch(batch []work) {
|
||||
timer := prometheus.NewTimer(nodeStoreBatchDuration)
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreBatchSize.Observe(float64(len(batch)))
|
||||
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
for _, w := range batch {
|
||||
switch w.op {
|
||||
case put:
|
||||
nodes[w.nodeID] = w.node
|
||||
case update:
|
||||
// Update the specific node identified by nodeID
|
||||
if n, exists := nodes[w.nodeID]; exists {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
case del:
|
||||
delete(nodes, w.nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
||||
s.data.Store(&newSnap)
|
||||
|
||||
// Update node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotFromNodes creates a new Snapshot from the provided nodes.
|
||||
// It builds a lot of "indexes" to make lookups fast for datasets we
|
||||
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
|
||||
// This is not a fast operation, it is the "slow" part of our copy-on-write
|
||||
// structure, but it allows us to have fast reads and efficient lookups.
|
||||
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
||||
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
allNodes := make([]types.NodeView, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
allNodes = append(allNodes, n.View())
|
||||
}
|
||||
|
||||
newSnap := Snapshot{
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
|
||||
// peersByNode is most likely the most expensive operation,
|
||||
// it will use the list of all nodes, combined with the
|
||||
// current policy to precalculate which nodes are peers and
|
||||
// can see each other.
|
||||
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||
defer peersTimer.ObserveDuration()
|
||||
return peersFunc(allNodes)
|
||||
}(),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
}
|
||||
|
||||
// Build nodesByUser and nodesByNodeKey maps
|
||||
for _, n := range nodes {
|
||||
nodeView := n.View()
|
||||
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
|
||||
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
|
||||
}
|
||||
|
||||
return newSnap
|
||||
}
|
||||
|
||||
// GetNode retrieves a node by its ID.
|
||||
// The bool indicates if the node exists or is available (like "err not found").
|
||||
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get").Inc()
|
||||
|
||||
n, exists := s.data.Load().nodesByID[id]
|
||||
if !exists {
|
||||
return types.NodeView{}, false
|
||||
}
|
||||
|
||||
return n.View(), true
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
||||
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
|
||||
return s.data.Load().nodesByNodeKey[nodeKey]
|
||||
}
|
||||
|
||||
// ListNodes returns a slice of all nodes in the store.
|
||||
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().allNodes)
|
||||
}
|
||||
|
||||
// ListPeers returns a slice of all peers for a given node ID.
|
||||
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list_peers").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().peersByNode[id])
|
||||
}
|
||||
|
||||
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
||||
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
||||
}
|
||||
501
hscontrol/state/node_store_test.go
Normal file
501
hscontrol/state/node_store_test.go
Normal file
@@ -0,0 +1,501 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestSnapshotFromNodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
|
||||
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
|
||||
}{
|
||||
{
|
||||
name: "empty nodes",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
return make(map[types.NodeID][]types.NodeView)
|
||||
}
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single node",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
}
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
|
||||
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes same user",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 1, "user1", "node2"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
// Each node sees the other as peer (but not itself)
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes different users",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 1, "user1", "node3"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// Each node should have 2 peers (all others, but not itself)
|
||||
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||
|
||||
// User groupings
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "odd-even peers filtering",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 3, "user3", "node3"),
|
||||
4: createTestNode(4, 4, "user4", "node4"),
|
||||
}
|
||||
peersFunc := oddEvenPeersFunc
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
assert.Len(t, snapshot.allNodes, 4)
|
||||
assert.Len(t, snapshot.peersByNode, 4)
|
||||
assert.Len(t, snapshot.nodesByUser, 4)
|
||||
|
||||
// Odd nodes should only see other odd nodes as peers
|
||||
require.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
// Even nodes should only see other even nodes as peers
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nodes, peersFunc := tt.setupFunc()
|
||||
snapshot := snapshotFromNodes(nodes, peersFunc)
|
||||
tt.validate(t, nodes, snapshot)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
|
||||
now := time.Now()
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
ipv4 := netip.MustParseAddr("100.64.0.1")
|
||||
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
||||
|
||||
return types.Node{
|
||||
ID: nodeID,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
DiscoKey: discoKey.Public(),
|
||||
Hostname: hostname,
|
||||
GivenName: hostname,
|
||||
UserID: userID,
|
||||
User: types.User{
|
||||
Name: username,
|
||||
DisplayName: username,
|
||||
},
|
||||
RegisterMethod: "test",
|
||||
IPv4: &ipv4,
|
||||
IPv6: &ipv6,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// Peer functions
|
||||
|
||||
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
for _, n := range nodes {
|
||||
if n.ID() != node.ID() {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
nodeIsOdd := node.ID()%2 == 1
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() == node.ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIsOdd := n.ID()%2 == 1
|
||||
|
||||
// Only add peer if both are odd or both are even
|
||||
if nodeIsOdd == peerIsOdd {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestNodeStoreOperations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(t *testing.T) *NodeStore
|
||||
steps []testStep
|
||||
}{
|
||||
{
|
||||
name: "create empty store and add single node",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
return NewNodeStore(nil, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify empty store",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add first node",
|
||||
action: func(store *NodeStore) {
|
||||
node := createTestNode(1, 1, "user1", "node1")
|
||||
store.PutNode(node)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
|
||||
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create store with initial node and add more",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
initialNodes := types.Nodes{&node1}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial state",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
assert.Empty(t, snapshot.peersByNode[1])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add second node same user",
|
||||
action: func(store *NodeStore) {
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
store.PutNode(node2)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
// Now both nodes should see each other as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add third node different user",
|
||||
action: func(store *NodeStore) {
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
store.PutNode(node3)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// All nodes should see the other 2 as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||
|
||||
// User groupings
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test node deletion",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
initialNodes := types.Nodes{&node1, &node2, &node3}
|
||||
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial 3 nodes",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete middle node",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(2)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// Node 2 should be gone
|
||||
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
|
||||
|
||||
// Remaining nodes should see each other as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
// User groupings updated
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete all remaining nodes",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(1)
|
||||
store.DeleteNode(3)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test node updates",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial hostnames",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update node hostname",
|
||||
action: func(store *NodeStore) {
|
||||
store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "updated-node1"
|
||||
n.GivenName = "updated-node1"
|
||||
})
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
|
||||
|
||||
// Peers should still work correctly
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test with odd-even peers filtering",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
return NewNodeStore(nil, oddEvenPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "add nodes with odd-even filtering",
|
||||
action: func(store *NodeStore) {
|
||||
// Add nodes in sequence
|
||||
store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
|
||||
// Verify odd-even peer relationships
|
||||
require.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete odd node and verify even nodes unaffected",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(1)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
|
||||
// Node 3 (odd) should now have no peers
|
||||
assert.Empty(t, snapshot.peersByNode[3])
|
||||
|
||||
// Even nodes should still see each other
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := tt.setupFunc(t)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
for _, step := range tt.steps {
|
||||
t.Run(step.name, func(t *testing.T) {
|
||||
step.action(store)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testStep struct {
|
||||
name string
|
||||
action func(store *NodeStore)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user