state/nodestore: in memory representation of nodes

Initial work on a nodestore which stores all of the nodes
and their relations in memory with relationship for peers
precalculated.

It is a copy-on-write structure, replacing the "snapshot"
when a change to the structure occurs. It is optimised for reads,
and while batches are not fast, they are grouped together
to do less of the expensive peer calculation if there are many
changes rapidly.

Writes will block until commited, while reads are never
blocked.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-07-05 23:30:47 +02:00
committed by Kristoffer Dalby
parent 38be30b6d4
commit 9d236571f4
35 changed files with 3960 additions and 1317 deletions

View File

@@ -15,7 +15,6 @@ import (
"strings"
"time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"google.golang.org/grpc/codes"
@@ -25,6 +24,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/views"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/state"
@@ -59,9 +59,10 @@ func (api headscaleV1APIServer) CreateUser(
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
}
c := change.UserAdded(types.UserID(user.ID))
if policyChanged {
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
if !policyChanged.Empty() {
c.Change = change.Policy
}
@@ -79,15 +80,13 @@ func (api headscaleV1APIServer) RenameUser(
return nil, err
}
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
_, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
api.h.Change(change.PolicyChange())
}
api.h.Change(c)
newUser, err := api.h.state.GetUserByName(request.GetNewName())
if err != nil {
@@ -288,17 +287,13 @@ func (api headscaleV1APIServer) GetNode(
ctx context.Context,
request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) {
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if !ok {
return nil, status.Errorf(codes.NotFound, "node not found")
}
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
return &v1.GetNodeResponse{Node: resp}, nil
}
@@ -323,7 +318,8 @@ func (api headscaleV1APIServer) SetTags(
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
Caller().
Str("node", node.Hostname()).
Strs("tags", request.GetTags()).
Msg("Changing tags of node")
@@ -334,7 +330,13 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
ctx context.Context,
request *v1.SetApprovedRoutesRequest,
) (*v1.SetApprovedRoutesResponse, error) {
var routes []netip.Prefix
log.Debug().
Caller().
Uint64("node.id", request.GetNodeId()).
Strs("requestedRoutes", request.GetRoutes()).
Msg("gRPC SetApprovedRoutes called")
var newApproved []netip.Prefix
for _, route := range request.GetRoutes() {
prefix, err := netip.ParsePrefix(route)
if err != nil {
@@ -344,31 +346,35 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
// If the prefix is an exit route, add both. The client expect both
// to annotate the node as an exit node.
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
} else {
routes = append(routes, prefix)
newApproved = append(newApproved, prefix)
}
}
tsaddr.SortPrefixes(routes)
routes = slices.Compact(routes)
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
// Always propagate node changes from SetApprovedRoutes
api.h.Change(nodeChange)
// If routes changed, propagate those changes too
if !routeChange.Empty() {
api.h.Change(routeChange)
}
proto := node.Proto()
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
// Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
// routes that are actively served from the node (per architectural requirement in types/node.go)
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
log.Debug().
Caller().
Uint64("node.id", node.ID().Uint64()).
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
Strs("finalSubnetRoutes", proto.SubnetRoutes).
Msg("gRPC SetApprovedRoutes completed")
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
}
@@ -390,9 +396,9 @@ func (api headscaleV1APIServer) DeleteNode(
ctx context.Context,
request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) {
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if !ok {
return nil, status.Errorf(codes.NotFound, "node not found")
}
nodeChange, err := api.h.state.DeleteNode(node)
@@ -420,8 +426,9 @@ func (api headscaleV1APIServer) ExpireNode(
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
Time("expiry", *node.Expiry).
Caller().
Str("node", node.Hostname()).
Time("expiry", *node.AsStruct().Expiry).
Msg("node expired")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
@@ -440,7 +447,8 @@ func (api headscaleV1APIServer) RenameNode(
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
Caller().
Str("node", node.Hostname()).
Str("new_name", request.GetNewName()).
Msg("node renamed")
@@ -455,58 +463,45 @@ func (api headscaleV1APIServer) ListNodes(
// the filtering of nodes by user, vs nodes as a whole can
// probably be done once.
// TODO(kradalby): This should be done in one tx.
IsConnected := api.h.mapBatcher.ConnectedMap()
if request.GetUser() != "" {
user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
if err != nil {
return nil, err
}
nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
response := nodesToProto(api.h.state, IsConnected, nodes)
response := nodesToProto(api.h.state, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
nodes, err := api.h.state.ListNodes()
if err != nil {
return nil, err
}
nodes := api.h.state.ListNodes()
sort.Slice(nodes, func(i, j int) bool {
return nodes[i].ID < nodes[j].ID
})
response := nodesToProto(api.h.state, IsConnected, nodes)
response := nodesToProto(api.h.state, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node {
response := make([]*v1.Node, nodes.Len())
for index, node := range nodes.All() {
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
if val, ok := IsConnected.Load(node.ID); ok && val {
resp.Online = true
}
var tags []string
for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node.View(), tag) {
if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
})
return response
}
@@ -674,17 +669,15 @@ func (api headscaleV1APIServer) SetPolicy(
// a scenario where they might be allowed if the server has no nodes
// yet, but it should help for the general case and for hot reloading
// configurations.
nodes, err := api.h.state.ListNodes()
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
changed, err := api.h.state.SetPolicy([]byte(p))
nodes := api.h.state.ListNodes()
_, err := api.h.state.SetPolicy([]byte(p))
if err != nil {
return nil, fmt.Errorf("setting policy: %w", err)
}
if len(nodes) > 0 {
_, err = api.h.state.SSHPolicy(nodes[0].View())
if nodes.Len() > 0 {
_, err = api.h.state.SSHPolicy(nodes.At(0))
if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err)
}
@@ -695,14 +688,20 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
// Only send update if the packet filter has changed.
if changed {
err = api.h.state.AutoApproveNodes()
if err != nil {
return nil, err
}
// Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
// This ensures that routes are re-evaluated for auto-approval in cases where routes
// were manually disabled but could now be auto-approved with the current policy.
cs, err := api.h.state.ReloadPolicy()
if err != nil {
return nil, fmt.Errorf("reloading policy: %w", err)
}
api.h.Change(change.PolicyChange())
if len(cs) > 0 {
api.h.Change(cs...)
} else {
log.Debug().
Caller().
Msg("No policy changes to distribute because ReloadPolicy returned empty changeset")
}
response := &v1.SetPolicyResponse{