mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-10 14:09:39 -05:00
stability and race conditions in auth and node store (#2781)
This PR addresses some consistency issues that was introduced or discovered with the nodestore. nodestore: Now returns the node that is being put or updated when it is finished. This closes a race condition where when we read it back, we do not necessarily get the node with the given change and it ensures we get all the other updates from that batch write. auth: Authentication paths have been unified and simplified. It removes a lot of bad branches and ensures we only do the minimal work. A comprehensive auth test set has been created so we do not have to run integration tests to validate auth and it has allowed us to generate test cases for all the branches we currently know of. integration: added a lot more tooling and checks to validate that nodes reach the expected state when they come up and down. Standardised between the different auth models. A lot of this is to support or detect issues in the changes to nodestore (races) and auth (inconsistencies after login and reaching correct state) This PR was assisted, particularly tests, by claude code.
This commit is contained in:
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
batchSize = 10
|
||||
batchSize = 100
|
||||
batchTimeout = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
@@ -121,10 +121,11 @@ type Snapshot struct {
|
||||
nodesByID map[types.NodeID]types.Node
|
||||
|
||||
// calculated from nodesByID
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
nodesByMachineKey map[key.MachinePublic]map[types.UserID]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
}
|
||||
|
||||
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||
@@ -135,26 +136,29 @@ type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
|
||||
// work represents a single operation to be performed on the NodeStore.
|
||||
type work struct {
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
nodeResult chan types.NodeView // Channel to return the resulting node after batch application
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
// If the node already exists, it will be replaced.
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
// Returns the resulting node after all modifications in the batch have been applied.
|
||||
func (s *NodeStore) PutNode(n types.Node) types.NodeView {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
nodeResult: make(chan types.NodeView, 1),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
@@ -162,7 +166,10 @@ func (s *NodeStore) PutNode(n types.Node) {
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
|
||||
return resultNode
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
@@ -173,6 +180,7 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
// This is analogous to a database "transaction", or, the caller should
|
||||
// rather collect all data they want to change, and then call this function.
|
||||
// Fewer calls are better.
|
||||
// Returns the resulting node after all modifications in the batch have been applied.
|
||||
//
|
||||
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||
@@ -181,15 +189,16 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||
// on all read operations.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
nodeResult: make(chan types.NodeView, 1),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
@@ -197,7 +206,11 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
|
||||
// Return the node and whether it exists (is valid)
|
||||
return resultNode, resultNode.Valid()
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
@@ -282,18 +295,32 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
for _, w := range batch {
|
||||
// Track which work items need node results
|
||||
nodeResultRequests := make(map[types.NodeID][]*work)
|
||||
|
||||
for i := range batch {
|
||||
w := &batch[i]
|
||||
switch w.op {
|
||||
case put:
|
||||
nodes[w.nodeID] = w.node
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
case update:
|
||||
// Update the specific node identified by nodeID
|
||||
if n, exists := nodes[w.nodeID]; exists {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
case del:
|
||||
delete(nodes, w.nodeID)
|
||||
// For delete operations, send an invalid NodeView if requested
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,6 +330,24 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
// Update node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
// Send the resulting nodes to all work items that requested them
|
||||
for nodeID, workItems := range nodeResultRequests {
|
||||
if node, exists := nodes[nodeID]; exists {
|
||||
nodeView := node.View()
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- nodeView
|
||||
close(w.nodeResult)
|
||||
}
|
||||
} else {
|
||||
// Node was deleted or doesn't exist
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- types.NodeView{} // Send invalid view
|
||||
close(w.nodeResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Signal completion for all work items
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
@@ -323,9 +368,10 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
}
|
||||
|
||||
newSnap := Snapshot{
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
nodesByMachineKey: make(map[key.MachinePublic]map[types.UserID]types.NodeView),
|
||||
|
||||
// peersByNode is most likely the most expensive operation,
|
||||
// it will use the list of all nodes, combined with the
|
||||
@@ -339,11 +385,19 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
}
|
||||
|
||||
// Build nodesByUser and nodesByNodeKey maps
|
||||
// Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps
|
||||
for _, n := range nodes {
|
||||
nodeView := n.View()
|
||||
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
|
||||
userID := types.UserID(n.UserID)
|
||||
|
||||
newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView)
|
||||
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
|
||||
|
||||
// Build machine key index
|
||||
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
|
||||
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
|
||||
}
|
||||
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
|
||||
}
|
||||
|
||||
return newSnap
|
||||
@@ -382,19 +436,40 @@ func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bo
|
||||
return nodeView, exists
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists.
|
||||
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||
// GetNodeByMachineKey returns a node by its machine key and user ID. The bool indicates if the node exists.
|
||||
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc()
|
||||
|
||||
snapshot := s.data.Load()
|
||||
// We don't have a byMachineKey map, so we need to iterate
|
||||
// This could be optimized by adding a byMachineKey map if this becomes a hot path
|
||||
for _, node := range snapshot.nodesByID {
|
||||
if node.MachineKey == machineKey {
|
||||
return node.View(), true
|
||||
if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists {
|
||||
if node, exists := userMap[userID]; exists {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
return types.NodeView{}, false
|
||||
}
|
||||
|
||||
// GetNodeByMachineKeyAnyUser returns the first node with the given machine key,
|
||||
// regardless of which user it belongs to. This is useful for scenarios like
|
||||
// transferring a node to a different user when re-authenticating with a
|
||||
// different user's auth key.
|
||||
// If multiple nodes exist with the same machine key (different users), the
|
||||
// first one found is returned (order is not guaranteed).
|
||||
func (s *NodeStore) GetNodeByMachineKeyAnyUser(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key_any_user"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get_by_machine_key_any_user").Inc()
|
||||
|
||||
snapshot := s.data.Load()
|
||||
if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists {
|
||||
// Return the first node found (order not guaranteed due to map iteration)
|
||||
for _, node := range userMap {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user