all: use immutable node view in read path

This commit changes most of our (*)types.Node to
types.NodeView, which is a readonly version of the
underlying node ensuring that there is no mutations
happening in the read path.

Based on the migration, there didnt seem to be any, but the
idea here is to prevent it in the future and simplify other
new implementations.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-07-05 23:31:13 +02:00 committed by Kristoffer Dalby
parent 5ba7120418
commit 73023c2ec3
24 changed files with 866 additions and 196 deletions

View File

@ -7,8 +7,10 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"tailscale.com/types/views"
) )
func init() { func init() {
@ -111,7 +113,7 @@ var checkPolicy = &cobra.Command{
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
} }
_, err = policy.NewPolicyManager(policyBytes, nil, nil) _, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{})
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output)
} }

View File

@ -485,7 +485,7 @@ func TestAutoApproveRoutes(t *testing.T) {
nodes, err := adb.ListNodes() nodes, err := adb.ListNodes()
assert.NoError(t, err) assert.NoError(t, err)
pm, err := pmf(users, nodes) pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, pm) require.NotNil(t, pm)

View File

@ -78,7 +78,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
sshPol := make(map[string]*tailcfg.SSHPolicy) sshPol := make(map[string]*tailcfg.SSHPolicy)
for _, node := range nodes { for _, node := range nodes {
pol, err := h.state.SSHPolicy(node) pol, err := h.state.SSHPolicy(node.View())
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return

View File

@ -537,7 +537,7 @@ func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeI
var tags []string var tags []string
for _, tag := range node.RequestTags() { for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node, tag) { if state.NodeCanHaveTag(node.View(), tag) {
tags = append(tags, tag) tags = append(tags, tag)
} }
} }
@ -733,7 +733,7 @@ func (api headscaleV1APIServer) SetPolicy(
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = api.h.state.SSHPolicy(nodes[0]) _, err = api.h.state.SSHPolicy(nodes[0].View())
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err) return nil, fmt.Errorf("verifying SSH rules: %w", err)
} }

View File

@ -27,6 +27,7 @@ import (
"tailscale.com/smallzstd" "tailscale.com/smallzstd"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
"tailscale.com/types/views"
) )
const ( const (
@ -88,16 +89,18 @@ func (m *Mapper) String() string {
} }
func generateUserProfiles( func generateUserProfiles(
node *types.Node, node types.NodeView,
peers types.Nodes, peers views.Slice[types.NodeView],
) []tailcfg.UserProfile { ) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User) userMap := make(map[uint]*types.User)
ids := make([]uint, 0, len(userMap)) ids := make([]uint, 0, peers.Len()+1)
userMap[node.User.ID] = &node.User user := node.User()
ids = append(ids, node.User.ID) userMap[user.ID] = &user
for _, peer := range peers { ids = append(ids, user.ID)
userMap[peer.User.ID] = &peer.User for _, peer := range peers.All() {
ids = append(ids, peer.User.ID) peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
} }
slices.Sort(ids) slices.Sort(ids)
@ -114,7 +117,7 @@ func generateUserProfiles(
func generateDNSConfig( func generateDNSConfig(
cfg *types.Config, cfg *types.Config,
node *types.Node, node types.NodeView,
) *tailcfg.DNSConfig { ) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil { if cfg.TailcfgDNSConfig == nil {
return nil return nil
@ -134,16 +137,17 @@ func generateDNSConfig(
// //
// This will produce a resolver like: // This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
for _, resolver := range resolvers { for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{ attrs := url.Values{
"device_name": []string{node.Hostname}, "device_name": []string{node.Hostname()},
"device_model": []string{node.Hostinfo.OS}, "device_model": []string{node.Hostinfo().OS()},
} }
if len(node.IPs()) > 0 { nodeIPs := node.IPs()
attrs.Add("device_ip", node.IPs()[0].String()) if len(nodeIPs) > 0 {
attrs.Add("device_ip", nodeIPs[0].String())
} }
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
@ -154,8 +158,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
// fullMapResponse creates a complete MapResponse for a node. // fullMapResponse creates a complete MapResponse for a node.
// It is a separate function to make testing easier. // It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse( func (m *Mapper) fullMapResponse(
node *types.Node, node types.NodeView,
peers types.Nodes, peers views.Slice[types.NodeView],
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer) resp, err := m.baseWithConfigMapResponse(node, capVer)
@ -182,15 +186,15 @@ func (m *Mapper) fullMapResponse(
// FullMapResponse returns a MapResponse for the given node. // FullMapResponse returns a MapResponse for the given node.
func (m *Mapper) FullMapResponse( func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node types.NodeView,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
peers, err := m.ListPeers(node.ID) peers, err := m.ListPeers(node.ID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := m.fullMapResponse(node, peers, mapRequest.Version) resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -203,7 +207,7 @@ func (m *Mapper) FullMapResponse(
// to be used to answer MapRequests with OmitPeers set to true. // to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) ReadOnlyMapResponse( func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node types.NodeView,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
@ -216,7 +220,7 @@ func (m *Mapper) ReadOnlyMapResponse(
func (m *Mapper) KeepAliveResponse( func (m *Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node types.NodeView,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.KeepAlive = true resp.KeepAlive = true
@ -226,7 +230,7 @@ func (m *Mapper) KeepAliveResponse(
func (m *Mapper) DERPMapResponse( func (m *Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node types.NodeView,
derpMap *tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
@ -237,7 +241,7 @@ func (m *Mapper) DERPMapResponse(
func (m *Mapper) PeerChangedResponse( func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node types.NodeView,
changed map[types.NodeID]bool, changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange, patches []*tailcfg.PeerChange,
messages ...string, messages ...string,
@ -249,7 +253,7 @@ func (m *Mapper) PeerChangedResponse(
var changedIDs []types.NodeID var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed { for nodeID, nodeChanged := range changed {
if nodeChanged { if nodeChanged {
if nodeID != node.ID { if nodeID != node.ID() {
changedIDs = append(changedIDs, nodeID) changedIDs = append(changedIDs, nodeID)
} }
} else { } else {
@ -270,7 +274,7 @@ func (m *Mapper) PeerChangedResponse(
m.state, m.state,
node, node,
mapRequest.Version, mapRequest.Version,
changedNodes, changedNodes.ViewSlice(),
m.cfg, m.cfg,
) )
if err != nil { if err != nil {
@ -315,7 +319,7 @@ func (m *Mapper) PeerChangedResponse(
// incoming update from a state change. // incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse( func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node types.NodeView,
changed []*tailcfg.PeerChange, changed []*tailcfg.PeerChange,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
@ -327,7 +331,7 @@ func (m *Mapper) PeerChangedPatchResponse(
func (m *Mapper) marshalMapResponse( func (m *Mapper) marshalMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
node *types.Node, node types.NodeView,
compression string, compression string,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
@ -366,7 +370,7 @@ func (m *Mapper) marshalMapResponse(
} }
perms := fs.FileMode(debugMapResponsePerm) perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, node.Hostname) mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
err = os.MkdirAll(mPath, perms) err = os.MkdirAll(mPath, perms)
if err != nil { if err != nil {
panic(err) panic(err)
@ -444,7 +448,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
// It is used in for bigger updates, such as full and lite, not // It is used in for bigger updates, such as full and lite, not
// incremental. // incremental.
func (m *Mapper) baseWithConfigMapResponse( func (m *Mapper) baseWithConfigMapResponse(
node *types.Node, node types.NodeView,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
@ -523,9 +527,9 @@ func appendPeerChanges(
fullChange bool, fullChange bool,
state *state.State, state *state.State,
node *types.Node, node types.NodeView,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
changed types.Nodes, changed views.Slice[types.NodeView],
cfg *types.Config, cfg *types.Config,
) error { ) error {
filter, matchers := state.Filter() filter, matchers := state.Filter()
@ -537,16 +541,19 @@ func appendPeerChanges(
// If there are filter rules present, see if there are any nodes that cannot // If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers. // access each-other at all and remove them from the peers.
var reducedChanged views.Slice[types.NodeView]
if len(filter) > 0 { if len(filter) > 0 {
changed = policy.ReduceNodes(node, changed, matchers) reducedChanged = policy.ReduceNodes(node, changed, matchers)
} else {
reducedChanged = changed
} }
profiles := generateUserProfiles(node, changed) profiles := generateUserProfiles(node, reducedChanged)
dnsConfig := generateDNSConfig(cfg, node) dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes( tailPeers, err := tailNodes(
changed, capVer, state, reducedChanged, capVer, state,
func(id types.NodeID) []netip.Prefix { func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers) return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
}, },

View File

@ -70,7 +70,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
&types.Config{ &types.Config{
TailcfgDNSConfig: &dnsConfigOrig, TailcfgDNSConfig: &dnsConfigOrig,
}, },
nodeInShared1, nodeInShared1.View(),
) )
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
@ -100,14 +100,14 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
return m.polMan.Filter() return m.polMan.Filter()
} }
func (m *mockState) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
if m.polMan == nil { if m.polMan == nil {
return nil, nil return nil, nil
} }
return m.polMan.SSHPolicy(node) return m.polMan.SSHPolicy(node)
} }
func (m *mockState) NodeCanHaveTag(node *types.Node, tag string) bool { func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
if m.polMan == nil { if m.polMan == nil {
return false return false
} }

View File

@ -8,24 +8,25 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/views"
) )
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag // NodeCanHaveTagChecker is an interface for checking if a node can have a tag
type NodeCanHaveTagChecker interface { type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node *types.Node, tag string) bool NodeCanHaveTag(node types.NodeView, tag string) bool
} }
func tailNodes( func tailNodes(
nodes types.Nodes, nodes views.Slice[types.NodeView],
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
checker NodeCanHaveTagChecker, checker NodeCanHaveTagChecker,
primaryRouteFunc routeFilterFunc, primaryRouteFunc routeFilterFunc,
cfg *types.Config, cfg *types.Config,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes)) tNodes := make([]*tailcfg.Node, 0, nodes.Len())
for index, node := range nodes { for _, node := range nodes.All() {
node, err := tailNode( tNode, err := tailNode(
node, node,
capVer, capVer,
checker, checker,
@ -36,7 +37,7 @@ func tailNodes(
return nil, err return nil, err
} }
tNodes[index] = node tNodes = append(tNodes, tNode)
} }
return tNodes, nil return tNodes, nil
@ -44,7 +45,7 @@ func tailNodes(
// tailNode converts a Node into a Tailscale Node. // tailNode converts a Node into a Tailscale Node.
func tailNode( func tailNode(
node *types.Node, node types.NodeView,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
checker NodeCanHaveTagChecker, checker NodeCanHaveTagChecker,
primaryRouteFunc routeFilterFunc, primaryRouteFunc routeFilterFunc,
@ -57,61 +58,64 @@ func tailNode(
// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077 // TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077
// and should be removed after 111 is the minimum capver. // and should be removed after 111 is the minimum capver.
var legacyDERP string var legacyDERP string
if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil { if node.Hostinfo().Valid() && node.Hostinfo().NetInfo().Valid() {
legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo().NetInfo().PreferredDERP())
derp = node.Hostinfo.NetInfo.PreferredDERP derp = node.Hostinfo().NetInfo().PreferredDERP()
} else { } else {
legacyDERP = "127.3.3.40:0" // Zero means disconnected or unknown. legacyDERP = "127.3.3.40:0" // Zero means disconnected or unknown.
} }
var keyExpiry time.Time var keyExpiry time.Time
if node.Expiry != nil { if node.Expiry().Valid() {
keyExpiry = *node.Expiry keyExpiry = node.Expiry().Get()
} else { } else {
keyExpiry = time.Time{} keyExpiry = time.Time{}
} }
hostname, err := node.GetFQDN(cfg.BaseDomain) hostname, err := node.GetFQDN(cfg.BaseDomain)
if err != nil { if err != nil {
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) return nil, err
} }
var tags []string var tags []string
for _, tag := range node.RequestTags() { for _, tag := range node.RequestTagsSlice().All() {
if checker.NodeCanHaveTag(node, tag) { if checker.NodeCanHaveTag(node, tag) {
tags = append(tags, tag) tags = append(tags, tag)
} }
} }
tags = lo.Uniq(append(tags, node.ForcedTags...)) for _, tag := range node.ForcedTags().All() {
tags = append(tags, tag)
}
tags = lo.Uniq(tags)
routes := primaryRouteFunc(node.ID) routes := primaryRouteFunc(node.ID())
allowed := append(node.Prefixes(), routes...) allowed := append(addrs, routes...)
allowed = append(allowed, node.ExitRoutes()...) allowed = append(allowed, node.ExitRoutes()...)
tsaddr.SortPrefixes(allowed) tsaddr.SortPrefixes(allowed)
tNode := tailcfg.Node{ tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID ID: tailcfg.NodeID(node.ID()), // this is the actual ID
StableID: node.ID.StableID(), StableID: node.ID().StableID(),
Name: hostname, Name: hostname,
Cap: capVer, Cap: capVer,
User: tailcfg.UserID(node.UserID), User: tailcfg.UserID(node.UserID()),
Key: node.NodeKey, Key: node.NodeKey(),
KeyExpiry: keyExpiry.UTC(), KeyExpiry: keyExpiry.UTC(),
Machine: node.MachineKey, Machine: node.MachineKey(),
DiscoKey: node.DiscoKey, DiscoKey: node.DiscoKey(),
Addresses: addrs, Addresses: addrs,
PrimaryRoutes: routes, PrimaryRoutes: routes,
AllowedIPs: allowed, AllowedIPs: allowed,
Endpoints: node.Endpoints, Endpoints: node.Endpoints().AsSlice(),
HomeDERP: derp, HomeDERP: derp,
LegacyDERPString: legacyDERP, LegacyDERPString: legacyDERP,
Hostinfo: node.Hostinfo.View(), Hostinfo: node.Hostinfo(),
Created: node.CreatedAt.UTC(), Created: node.CreatedAt().UTC(),
Online: node.IsOnline, Online: node.IsOnline().Clone(),
Tags: tags, Tags: tags,
@ -129,10 +133,13 @@ func tailNode(
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
} }
if node.IsOnline == nil || !*node.IsOnline { if !node.IsOnline().Valid() || !node.IsOnline().Get() {
// LastSeen is only set when node is // LastSeen is only set when node is
// not connected to the control server. // not connected to the control server.
tNode.LastSeen = node.LastSeen if node.LastSeen().Valid() {
lastSeen := node.LastSeen().Get()
tNode.LastSeen = &lastSeen
}
} }
return &tNode, nil return &tNode, nil

View File

@ -202,7 +202,7 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node}) polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node}.ViewSlice())
require.NoError(t, err) require.NoError(t, err)
primary := routes.New() primary := routes.New()
cfg := &types.Config{ cfg := &types.Config{
@ -216,7 +216,7 @@ func TestTailNode(t *testing.T) {
// This should be baked into the test case proper if it is extended in the future. // This should be baked into the test case proper if it is extended in the future.
_ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24")) _ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24"))
got, err := tailNode( got, err := tailNode(
tt.node, tt.node.View(),
0, 0,
polMan, polMan,
func(id types.NodeID) []netip.Prefix { func(id types.NodeID) []netip.Prefix {
@ -272,11 +272,11 @@ func TestNodeExpiry(t *testing.T) {
GivenName: "test", GivenName: "test",
Expiry: tt.exp, Expiry: tt.exp,
} }
polMan, err := policy.NewPolicyManager(nil, nil, nil) polMan, err := policy.NewPolicyManager(nil, nil, types.Nodes{}.ViewSlice())
require.NoError(t, err) require.NoError(t, err)
tn, err := tailNode( tn, err := tailNode(
node, node.View(),
0, 0,
polMan, polMan,
func(id types.NodeID) []netip.Prefix { func(id types.NodeID) []netip.Prefix {

View File

@ -8,27 +8,28 @@ import (
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/views"
) )
type PolicyManager interface { type PolicyManager interface {
// Filter returns the current filter rules for the entire tailnet and the associated matchers. // Filter returns the current filter rules for the entire tailnet and the associated matchers.
Filter() ([]tailcfg.FilterRule, []matcher.Match) Filter() ([]tailcfg.FilterRule, []matcher.Match)
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error) SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error) SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
// NodeCanHaveTag reports whether the given node can have the given tag. // NodeCanHaveTag reports whether the given node can have the given tag.
NodeCanHaveTag(*types.Node, string) bool NodeCanHaveTag(types.NodeView, string) bool
// NodeCanApproveRoute reports whether the given node can approve the given route. // NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(*types.Node, netip.Prefix) bool NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
Version() int Version() int
DebugString() string DebugString() string
} }
// NewPolicyManager returns a new policy manager. // NewPolicyManager returns a new policy manager.
func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
var polMan PolicyManager var polMan PolicyManager
var err error var err error
polMan, err = policyv2.NewPolicyManager(pol, users, nodes) polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
@ -42,7 +43,7 @@ func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (Policy
// PolicyManagersForTest returns all available PostureManagers to be used // PolicyManagersForTest returns all available PostureManagers to be used
// in tests to validate them in tests that try to determine that they // in tests to validate them in tests that try to determine that they
// behave the same. // behave the same.
func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) { func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) ([]PolicyManager, error) {
var polMans []PolicyManager var polMans []PolicyManager
for _, pmf := range PolicyManagerFuncsForTest(pol) { for _, pmf := range PolicyManagerFuncsForTest(pol) {
@ -56,10 +57,10 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([
return polMans, nil return polMans, nil
} }
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) { func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error) var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) { polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
return policyv2.NewPolicyManager(pol, u, n) return policyv2.NewPolicyManager(pol, u, n)
}) })

View File

@ -11,32 +11,33 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/views"
) )
// ReduceNodes returns the list of peers authorized to be accessed from a given node. // ReduceNodes returns the list of peers authorized to be accessed from a given node.
func ReduceNodes( func ReduceNodes(
node *types.Node, node types.NodeView,
nodes types.Nodes, nodes views.Slice[types.NodeView],
matchers []matcher.Match, matchers []matcher.Match,
) types.Nodes { ) views.Slice[types.NodeView] {
var result types.Nodes var result []types.NodeView
for index, peer := range nodes { for _, peer := range nodes.All() {
if peer.ID == node.ID { if peer.ID() == node.ID() {
continue continue
} }
if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) { if node.CanAccess(matchers, peer) || peer.CanAccess(matchers, node) {
result = append(result, peer) result = append(result, peer)
} }
} }
return result return views.SliceOf(result)
} }
// ReduceRoutes returns a reduced list of routes for a given node that it can access. // ReduceRoutes returns a reduced list of routes for a given node that it can access.
func ReduceRoutes( func ReduceRoutes(
node *types.Node, node types.NodeView,
routes []netip.Prefix, routes []netip.Prefix,
matchers []matcher.Match, matchers []matcher.Match,
) []netip.Prefix { ) []netip.Prefix {
@ -51,9 +52,36 @@ func ReduceRoutes(
return result return result
} }
// BuildPeerMap builds a map of all peers that can be accessed by each node.
func BuildPeerMap(
nodes views.Slice[types.NodeView],
matchers []matcher.Match,
) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
// Build the map of all peers according to the matchers.
// Compared to ReduceNodes, which builds the list per node, we end up with doing
// the full work for every node (On^2), while this will reduce the list as we see
// relationships while building the map, making it O(n^2/2) in the end, but with less work per node.
for i := range nodes.Len() {
for j := i + 1; j < nodes.Len(); j++ {
if nodes.At(i).ID() == nodes.At(j).ID() {
continue
}
if nodes.At(i).CanAccess(matchers, nodes.At(j)) || nodes.At(j).CanAccess(matchers, nodes.At(i)) {
ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j))
ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i))
}
}
}
return ret
}
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations // ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
// that are not relevant to that particular node. // that are not relevant to that particular node.
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule { func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
ret := []tailcfg.FilterRule{} ret := []tailcfg.FilterRule{}
for _, rule := range rules { for _, rule := range rules {
@ -75,9 +103,10 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
// If the node exposes routes, ensure they are note removed // If the node exposes routes, ensure they are note removed
// when the filters are reduced. // when the filters are reduced.
if node.Hostinfo != nil { if node.Hostinfo().Valid() {
if len(node.Hostinfo.RoutableIPs) > 0 { routableIPs := node.Hostinfo().RoutableIPs()
for _, routableIP := range node.Hostinfo.RoutableIPs { if routableIPs.Len() > 0 {
for _, routableIP := range routableIPs.All() {
if expanded.OverlapsPrefix(routableIP) { if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest) dests = append(dests, dest)
continue DEST_LOOP continue DEST_LOOP
@ -102,13 +131,15 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
// AutoApproveRoutes approves any route that can be autoapproved from // AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy. // the nodes perspective according to the given policy.
// It reports true if any routes were approved. // It reports true if any routes were approved.
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
if pm == nil { if pm == nil {
return false return false
} }
nodeView := node.View()
var newApproved []netip.Prefix var newApproved []netip.Prefix
for _, route := range node.AnnouncedRoutes() { for _, route := range nodeView.AnnouncedRoutes() {
if pm.NodeCanApproveRoute(node, route) { if pm.NodeCanApproveRoute(nodeView, route) {
newApproved = append(newApproved, route) newApproved = append(newApproved, route)
} }
} }

View File

@ -815,11 +815,11 @@ func TestReduceFilterRules(t *testing.T) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager var pm PolicyManager
var err error var err error
pm, err = pmf(users, append(tt.peers, tt.node)) pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err) require.NoError(t, err)
got, _ := pm.Filter() got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = ReduceFilterRules(tt.node, got) got = ReduceFilterRules(tt.node.View(), got)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
log.Trace().Interface("got", got).Msg("result") log.Trace().Interface("got", got).Msg("result")
@ -1576,11 +1576,16 @@ func TestReduceNodes(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
matchers := matcher.MatchesFromFilterRules(tt.args.rules) matchers := matcher.MatchesFromFilterRules(tt.args.rules)
got := ReduceNodes( gotViews := ReduceNodes(
tt.args.node, tt.args.node.View(),
tt.args.nodes, tt.args.nodes.ViewSlice(),
matchers, matchers,
) )
// Convert views back to nodes for comparison in tests
var got types.Nodes
for _, v := range gotViews.All() {
got = append(got, v.AsStruct())
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
} }
@ -1949,7 +1954,7 @@ func TestSSHPolicyRules(t *testing.T) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager var pm PolicyManager
var err error var err error
pm, err = pmf(users, append(tt.peers, &tt.targetNode)) pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
if tt.expectErr { if tt.expectErr {
require.Error(t, err) require.Error(t, err)
@ -1959,7 +1964,7 @@ func TestSSHPolicyRules(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
got, err := pm.SSHPolicy(&tt.targetNode) got, err := pm.SSHPolicy(tt.targetNode.View())
require.NoError(t, err) require.NoError(t, err)
if diff := cmp.Diff(tt.wantSSH, got); diff != "" { if diff := cmp.Diff(tt.wantSSH, got); diff != "" {
@ -2426,7 +2431,7 @@ func TestReduceRoutes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
matchers := matcher.MatchesFromFilterRules(tt.args.rules) matchers := matcher.MatchesFromFilterRules(tt.args.rules)
got := ReduceRoutes( got := ReduceRoutes(
tt.args.node, tt.args.node.View(),
tt.args.routes, tt.args.routes,
matchers, matchers,
) )

View File

@ -776,7 +776,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Initialize all policy manager implementations // Initialize all policy manager implementations
policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node}) policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node}.ViewSlice())
if tt.name == "empty policy" { if tt.name == "empty policy" {
// We expect this one to have a valid but empty policy // We expect this one to have a valid but empty policy
require.NoError(t, err) require.NoError(t, err)
@ -789,7 +789,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
for i, pm := range policyManagers { for i, pm := range policyManagers {
t.Run(fmt.Sprintf("policy-index%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("policy-index%d", i), func(t *testing.T) {
result := pm.NodeCanApproveRoute(&tt.node, tt.route) result := pm.NodeCanApproveRoute(tt.node.View(), tt.route)
if diff := cmp.Diff(tt.canApprove, result); diff != "" { if diff := cmp.Diff(tt.canApprove, result); diff != "" {
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)

View File

@ -10,6 +10,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/views"
) )
var ( var (
@ -20,7 +21,7 @@ var (
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) compileFilterRules( func (pol *Policy) compileFilterRules(
users types.Users, users types.Users,
nodes types.Nodes, nodes views.Slice[types.NodeView],
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
if pol == nil { if pol == nil {
return tailcfg.FilterAllowAll, nil return tailcfg.FilterAllowAll, nil
@ -97,8 +98,8 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
func (pol *Policy) compileSSHPolicy( func (pol *Policy) compileSSHPolicy(
users types.Users, users types.Users,
node *types.Node, node types.NodeView,
nodes types.Nodes, nodes views.Slice[types.NodeView],
) (*tailcfg.SSHPolicy, error) { ) (*tailcfg.SSHPolicy, error) {
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
return nil, nil return nil, nil

View File

@ -362,7 +362,7 @@ func TestParsing(t *testing.T) {
User: users[0], User: users[0],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}) }.ViewSlice())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)

View File

@ -16,13 +16,14 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/util/deephash" "tailscale.com/util/deephash"
"tailscale.com/types/views"
) )
type PolicyManager struct { type PolicyManager struct {
mu sync.Mutex mu sync.Mutex
pol *Policy pol *Policy
users []types.User users []types.User
nodes types.Nodes nodes views.Slice[types.NodeView]
filterHash deephash.Sum filterHash deephash.Sum
filter []tailcfg.FilterRule filter []tailcfg.FilterRule
@ -43,7 +44,7 @@ type PolicyManager struct {
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes. // NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid. // It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes. // The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.NodeView]) (*PolicyManager, error) {
policy, err := unmarshalPolicy(b) policy, err := unmarshalPolicy(b)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err) return nil, fmt.Errorf("parsing policy: %w", err)
@ -53,7 +54,7 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM
pol: policy, pol: policy,
users: users, users: users,
nodes: nodes, nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)), sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
} }
_, err = pm.updateLocked() _, err = pm.updateLocked()
@ -122,11 +123,11 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil return true, nil
} }
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok { if sshPol, ok := pm.sshPolicyMap[node.ID()]; ok {
return sshPol, nil return sshPol, nil
} }
@ -134,7 +135,7 @@ func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error)
if err != nil { if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err) return nil, fmt.Errorf("compiling SSH policy: %w", err)
} }
pm.sshPolicyMap[node.ID] = sshPol pm.sshPolicyMap[node.ID()] = sshPol
return sshPol, nil return sshPol, nil
} }
@ -181,7 +182,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
} }
// SetNodes updates the nodes in the policy manager and updates the filter rules. // SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) {
if pm == nil { if pm == nil {
return false, nil return false, nil
} }
@ -192,7 +193,7 @@ func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
return pm.updateLocked() return pm.updateLocked()
} }
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool { func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool {
if pm == nil { if pm == nil {
return false return false
} }
@ -209,7 +210,7 @@ func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
return false return false
} }
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool { func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool {
if pm == nil { if pm == nil {
return false return false
} }
@ -322,7 +323,11 @@ func (pm *PolicyManager) DebugString() string {
} }
sb.WriteString("\n\n") sb.WriteString("\n\n")
sb.WriteString(pm.nodes.DebugString()) sb.WriteString("Nodes:\n")
for _, node := range pm.nodes.All() {
sb.WriteString(node.String())
sb.WriteString("\n")
}
return sb.String() return sb.String()
} }

View File

@ -47,7 +47,7 @@ func TestPolicyManager(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes.ViewSlice())
require.NoError(t, err) require.NoError(t, err)
filter, matchers := pm.Filter() filter, matchers := pm.Filter()

View File

@ -18,6 +18,7 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
) )
@ -91,7 +92,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
// TODO(kradalby): // TODO(kradalby):
@ -179,7 +180,7 @@ func (u Username) resolveUser(users types.Users) (types.User, error) {
return potentialUsers[0], nil return potentialUsers[0], nil
} }
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error
@ -188,12 +189,13 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*net
errs = append(errs, err) errs = append(errs, err)
} }
for _, node := range nodes { for _, node := range nodes.All() {
// Skip tagged nodes
if node.IsTagged() { if node.IsTagged() {
continue continue
} }
if node.User.ID == user.ID { if node.User().ID == user.ID {
node.AppendToIPSet(&ips) node.AppendToIPSet(&ips)
} }
} }
@ -246,7 +248,7 @@ func (g Group) MarshalJSON() ([]byte, error) {
return json.Marshal(string(g)) return json.Marshal(string(g))
} }
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error
@ -280,7 +282,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
// TODO(kradalby): This is currently resolved twice, and should be resolved once. // TODO(kradalby): This is currently resolved twice, and should be resolved once.
@ -295,17 +297,19 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.I
return nil, err return nil, err
} }
for _, node := range nodes { for _, node := range nodes.All() {
if node.HasTag(string(t)) { // Check if node has this tag in all tags (ForcedTags + AuthKey.Tags)
if slices.Contains(node.Tags(), string(t)) {
node.AppendToIPSet(&ips) node.AppendToIPSet(&ips)
} }
// TODO(kradalby): remove as part of #2417, see comment above // TODO(kradalby): remove as part of #2417, see comment above
if tagMap != nil { if tagMap != nil {
if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo != nil { if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo().Valid() {
for _, tag := range node.Hostinfo.RequestTags { for _, tag := range node.RequestTagsSlice().All() {
if tag == string(t) { if tag == string(t) {
node.AppendToIPSet(&ips) node.AppendToIPSet(&ips)
break
} }
} }
} }
@ -346,7 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error
@ -371,7 +375,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSe
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
for _, node := range nodes { for _, node := range nodes.All() {
if node.InIPSet(ipsTemp) { if node.InIPSet(ipsTemp) {
node.AppendToIPSet(&ips) node.AppendToIPSet(&ips)
} }
@ -432,7 +436,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
// of the Prefix and the Policy, Users, and Nodes. // of the Prefix and the Policy, Users, and Nodes.
// //
// See [Policy], [types.Users], and [types.Nodes] for more details. // See [Policy], [types.Users], and [types.Nodes] for more details.
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error
@ -446,12 +450,12 @@ func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IP
// appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the // appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the
// IP address in the prefix. // IP address in the prefix.
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref netip.Prefix) { func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuilder, pref netip.Prefix) {
if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) { if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) {
return return
} }
for _, node := range nodes { for _, node := range nodes.All() {
if node.HasIP(pref.Addr()) { if node.HasIP(pref.Addr()) {
node.AppendToIPSet(ips) node.AppendToIPSet(ips)
} }
@ -499,7 +503,7 @@ func (ag AutoGroup) MarshalJSON() ([]byte, error) {
return json.Marshal(string(ag)) return json.Marshal(string(ag))
} }
func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
switch ag { switch ag {
@ -513,17 +517,17 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n
return nil, err return nil, err
} }
for _, node := range nodes { for _, node := range nodes.All() {
// Skip if node has forced tags // Skip if node is tagged
if len(node.ForcedTags) != 0 { if node.IsTagged() {
continue continue
} }
// Skip if node has any allowed requested tags // Skip if node has any allowed requested tags
hasAllowedTag := false hasAllowedTag := false
if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 { if node.RequestTagsSlice().Len() != 0 {
for _, tag := range node.Hostinfo.RequestTags { for _, tag := range node.RequestTagsSlice().All() {
if tagips, ok := tagMap[Tag(tag)]; ok && node.InIPSet(tagips) { if _, ok := tagMap[Tag(tag)]; ok {
hasAllowedTag = true hasAllowedTag = true
break break
} }
@ -546,16 +550,16 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n
return nil, err return nil, err
} }
for _, node := range nodes { for _, node := range nodes.All() {
// Include if node has forced tags // Include if node is tagged
if len(node.ForcedTags) != 0 { if node.IsTagged() {
node.AppendToIPSet(&build) node.AppendToIPSet(&build)
continue continue
} }
// Include if node has any allowed requested tags // Include if node has any allowed requested tags
if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 { if node.RequestTagsSlice().Len() != 0 {
for _, tag := range node.Hostinfo.RequestTags { for _, tag := range node.RequestTagsSlice().All() {
if _, ok := tagMap[Tag(tag)]; ok { if _, ok := tagMap[Tag(tag)]; ok {
node.AppendToIPSet(&build) node.AppendToIPSet(&build)
break break
@ -588,7 +592,7 @@ type Alias interface {
// of the Alias and the Policy, Users and Nodes. // of the Alias and the Policy, Users and Nodes.
// This is an interface definition and the implementation is independent of // This is an interface definition and the implementation is independent of
// the Alias type. // the Alias type.
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error) Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error)
} }
type AliasWithPorts struct { type AliasWithPorts struct {
@ -759,7 +763,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) {
return json.Marshal(aliases) return json.Marshal(aliases)
} }
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error
@ -1094,7 +1098,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error {
// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet. // resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet.
// The resulting map can be used to quickly look up the IPSet for a given Tag. // The resulting map can be used to quickly look up the IPSet for a given Tag.
// It is intended for internal use in a PolicyManager. // It is intended for internal use in a PolicyManager.
func resolveTagOwners(p *Policy, users types.Users, nodes types.Nodes) (map[Tag]*netipx.IPSet, error) { func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) {
if p == nil { if p == nil {
return nil, nil return nil, nil
} }
@ -1158,7 +1162,7 @@ func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) {
// resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet. // resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet.
// The resulting map can be used to quickly look up if a node can self-approve a route. // The resulting map can be used to quickly look up if a node can self-approve a route.
// It is intended for internal use in a PolicyManager. // It is intended for internal use in a PolicyManager.
func resolveAutoApprovers(p *Policy, users types.Users, nodes types.Nodes) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) { func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) {
if p == nil { if p == nil {
return nil, nil, nil return nil, nil, nil
} }
@ -1671,7 +1675,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
return json.Marshal(aliases) return json.Marshal(aliases)
} }
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error

View File

@ -1377,7 +1377,7 @@ func TestResolvePolicy(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol, ips, err := tt.toResolve.Resolve(tt.pol,
xmaps.Values(users), xmaps.Values(users),
tt.nodes) tt.nodes.ViewSlice())
if tt.wantErr == "" { if tt.wantErr == "" {
if err != nil { if err != nil {
t.Fatalf("got %v; want no error", err) t.Fatalf("got %v; want no error", err)
@ -1557,7 +1557,7 @@ func TestResolveAutoApprovers(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes) got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes.ViewSlice())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -1716,10 +1716,10 @@ func TestNodeCanApproveRoute(t *testing.T) {
b, err := json.Marshal(tt.policy) b, err := json.Marshal(tt.policy)
require.NoError(t, err) require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes) pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
require.NoErrorf(t, err, "NewPolicyManager() error = %v", err) require.NoErrorf(t, err, "NewPolicyManager() error = %v", err)
got := pm.NodeCanApproveRoute(tt.node, tt.route) got := pm.NodeCanApproveRoute(tt.node.View(), tt.route)
if got != tt.want { if got != tt.want {
t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want) t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want)
} }
@ -1800,7 +1800,7 @@ func TestResolveTagOwners(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := resolveTagOwners(tt.policy, users, nodes) got, err := resolveTagOwners(tt.policy, users, nodes.ViewSlice())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -1911,14 +1911,14 @@ func TestNodeCanHaveTag(t *testing.T) {
b, err := json.Marshal(tt.policy) b, err := json.Marshal(tt.policy)
require.NoError(t, err) require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes) pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
if tt.wantErr != "" { if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr) require.ErrorContains(t, err, tt.wantErr)
return return
} }
require.NoError(t, err) require.NoError(t, err)
got := pm.NodeCanHaveTag(tt.node, tt.tag) got := pm.NodeCanHaveTag(tt.node.View(), tt.tag)
if got != tt.want { if got != tt.want {
t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want) t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want)
} }

View File

@ -168,6 +168,10 @@ func (m *mapSession) serve() {
func (m *mapSession) serveLongPoll() { func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll() m.beforeServeLongPoll()
// For now, mapSession uses a normal node, but since serveLongPoll is a read operation,
// convert the node to a view at the beginning.
nv := m.node.View()
// Clean up the session when the client disconnects // Clean up the session when the client disconnects
defer func() { defer func() {
m.cancelChMu.Lock() m.cancelChMu.Lock()
@ -179,16 +183,16 @@ func (m *mapSession) serveLongPoll() {
// in principal, it will be removed, but the client rapidly // in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection. // reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online. // In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) { if m.h.nodeNotifier.RemoveNode(nv.ID(), m.ch) {
// TODO(kradalby): This can likely be made more effective, but likely most // TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal. // nodes has access to the same routes, so it might not be a big deal.
change, err := m.h.state.Disconnect(m.node) change, err := m.h.state.Disconnect(nv)
if err != nil { if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname) m.errf(err, "Failed to disconnect node %s", nv.Hostname())
} }
if change { if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
} }
@ -201,8 +205,8 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1) m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done() defer m.h.pollNetMapStreamWG.Done()
if m.h.state.Connect(m.node) { if m.h.state.Connect(nv) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
@ -213,17 +217,17 @@ func (m *mapSession) serveLongPoll() {
// so it needs to be disabled. // so it needs to be disabled.
rc.SetWriteDeadline(time.Time{}) rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, nv.Hostname()))
defer cancel() defer cancel()
m.keepAliveTicker = time.NewTicker(m.keepAlive) m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(m.node.ID, m.ch) m.h.nodeNotifier.AddNode(nv.ID(), m.ch)
go func() { go func() {
changed := m.h.state.Connect(m.node) changed := m.h.state.Connect(nv)
if changed { if changed {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
}() }()
@ -253,7 +257,7 @@ func (m *mapSession) serveLongPoll() {
} }
// If the node has been removed from headscale, close the stream // If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, m.node.ID) { if slices.Contains(update.Removed, nv.ID()) {
m.tracef("node removed, closing stream") m.tracef("node removed, closing stream")
return return
} }
@ -268,18 +272,22 @@ func (m *mapSession) serveLongPoll() {
// Ensure the node object is updated, for example, there // Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel // might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response. // which contains data needed to generate a map response.
m.node, err = m.h.state.GetNodeByID(m.node.ID) m.node, err = m.h.state.GetNodeByID(nv.ID())
if err != nil { if err != nil {
m.errf(err, "Could not get machine from db") m.errf(err, "Could not get machine from db")
return return
} }
// Update the node view to reflect the latest node state
// TODO(kradalby): This should become a full read only path, with no update for the node view
// in the new mapper model.
nv = m.node.View()
updateType := "full" updateType := "full"
switch update.Type { switch update.Type {
case types.StateFullUpdate: case types.StateFullUpdate:
m.tracef("Sending Full MapResponse") m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) data, err = m.mapper.FullMapResponse(m.req, nv, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged: case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
@ -289,12 +297,12 @@ func (m *mapSession) serveLongPoll() {
lastMessage = update.Message lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage)
updateType = "change" updateType = "change"
case types.StatePeerChangedPatch: case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) data, err = m.mapper.PeerChangedPatchResponse(m.req, nv, update.ChangePatches)
updateType = "patch" updateType = "patch"
case types.StatePeerRemoved: case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed)) changed := make(map[types.NodeID]bool, len(update.Removed))
@ -303,17 +311,17 @@ func (m *mapSession) serveLongPoll() {
changed[nodeID] = false changed[nodeID] = false
} }
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage)
updateType = "remove" updateType = "remove"
case types.StateSelfUpdate: case types.StateSelfUpdate:
lastMessage = update.Message lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent // create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, nv, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove" updateType = "remove"
case types.StateDERPUpdated: case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse") m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap()) data, err = m.mapper.DERPMapResponse(m.req, nv, m.h.state.DERPMap())
updateType = "derp" updateType = "derp"
} }
@ -340,10 +348,10 @@ func (m *mapSession) serveLongPoll() {
return return
} }
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") log.Trace().Str("node", nv.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", nv.MachineKey().String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics { if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID.String()).Set(float64(time.Now().Unix())) mapResponseLastSentSeconds.WithLabelValues(updateType, nv.ID().String()).Set(float64(time.Now().Unix()))
} }
mapResponseSent.WithLabelValues("ok", updateType).Inc() mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent") m.tracef("update sent")
@ -351,7 +359,7 @@ func (m *mapSession) serveLongPoll() {
} }
case <-m.keepAliveTicker.C: case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node) data, err := m.mapper.KeepAliveResponse(m.req, nv)
if err != nil { if err != nil {
m.errf(err, "Error generating the keep alive msg") m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc() mapResponseSent.WithLabelValues("error", "keepalive").Inc()
@ -371,7 +379,7 @@ func (m *mapSession) serveLongPoll() {
} }
if debugHighCardinalityMetrics { if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) mapResponseLastSentSeconds.WithLabelValues("keepalive", nv.ID().String()).Set(float64(time.Now().Unix()))
} }
mapResponseSent.WithLabelValues("ok", "keepalive").Inc() mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
} }
@ -490,7 +498,7 @@ func (m *mapSession) handleEndpointUpdate() {
func (m *mapSession) handleReadOnlyRequest() { func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers") m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node.View())
if err != nil { if err != nil {
m.errf(err, "Failed to create MapResponse") m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)

View File

@ -104,7 +104,7 @@ func NewState(cfg *types.Config) (*State, error) {
return nil, fmt.Errorf("loading policy: %w", err) return nil, fmt.Errorf("loading policy: %w", err)
} }
polMan, err := policy.NewPolicyManager(pol, users, nodes) polMan, err := policy.NewPolicyManager(pol, users, nodes.ViewSlice())
if err != nil { if err != nil {
return nil, fmt.Errorf("init policy manager: %w", err) return nil, fmt.Errorf("init policy manager: %w", err)
} }
@ -400,22 +400,22 @@ func (s *State) DeleteNode(node *types.Node) (bool, error) {
return policyChanged, nil return policyChanged, nil
} }
func (s *State) Connect(node *types.Node) bool { func (s *State) Connect(node types.NodeView) bool {
_ = s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) changed := s.primaryRoutes.SetRoutes(node.ID(), node.SubnetRoutes()...)
// TODO(kradalby): this should be more granular, allowing us to // TODO(kradalby): this should be more granular, allowing us to
// only send a online update change. // only send a online update change.
return true return changed
} }
func (s *State) Disconnect(node *types.Node) (bool, error) { func (s *State) Disconnect(node types.NodeView) (bool, error) {
// TODO(kradalby): This node should update the in memory state // TODO(kradalby): This node should update the in memory state
_, polChanged, err := s.SetLastSeen(node.ID, time.Now()) _, polChanged, err := s.SetLastSeen(node.ID(), time.Now())
if err != nil { if err != nil {
return false, fmt.Errorf("disconnecting node: %w", err) return false, fmt.Errorf("disconnecting node: %w", err)
} }
changed := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) changed := s.primaryRoutes.SetRoutes(node.ID())
// TODO(kradalby): the returned change should be more nuanced allowing us to // TODO(kradalby): the returned change should be more nuanced allowing us to
// send more directed updates. // send more directed updates.
@ -512,7 +512,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateU
} }
// SSHPolicy returns the SSH access policy for a node. // SSHPolicy returns the SSH access policy for a node.
func (s *State) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
return s.polMan.SSHPolicy(node) return s.polMan.SSHPolicy(node)
} }
@ -522,7 +522,7 @@ func (s *State) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
} }
// NodeCanHaveTag checks if a node is allowed to have a specific tag. // NodeCanHaveTag checks if a node is allowed to have a specific tag.
func (s *State) NodeCanHaveTag(node *types.Node, tag string) bool { func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool {
return s.polMan.NodeCanHaveTag(node, tag) return s.polMan.NodeCanHaveTag(node, tag)
} }
@ -761,7 +761,7 @@ func (s *State) updatePolicyManagerNodes() (bool, error) {
return false, fmt.Errorf("listing nodes for policy update: %w", err) return false, fmt.Errorf("listing nodes for policy update: %w", err)
} }
changed, err := s.polMan.SetNodes(nodes) changed, err := s.polMan.SetNodes(nodes.ViewSlice())
if err != nil { if err != nil {
return false, fmt.Errorf("updating policy manager nodes: %w", err) return false, fmt.Errorf("updating policy manager nodes: %w", err)
} }

View File

@ -1,3 +1,5 @@
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
package types package types
import ( import (

View File

@ -18,6 +18,7 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/views"
) )
var ( var (
@ -115,6 +116,15 @@ type Node struct {
type Nodes []*Node type Nodes []*Node
func (ns Nodes) ViewSlice() views.Slice[NodeView] {
vs := make([]NodeView, len(ns))
for i, n := range ns {
vs[i] = n.View()
}
return views.SliceOf(vs)
}
// GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node. // GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node.
func (node *Node) GivenNameHasBeenChanged() bool { func (node *Node) GivenNameHasBeenChanged() bool {
return node.GivenName == util.ConvertWithFQDNRules(node.Hostname) return node.GivenName == util.ConvertWithFQDNRules(node.Hostname)
@ -582,3 +592,185 @@ func (node Node) DebugString() string {
sb.WriteString("\n") sb.WriteString("\n")
return sb.String() return sb.String()
} }
func (v NodeView) IPs() []netip.Addr {
if !v.Valid() {
return nil
}
return v.ж.IPs()
}
func (v NodeView) InIPSet(set *netipx.IPSet) bool {
if !v.Valid() {
return false
}
return v.ж.InIPSet(set)
}
func (v NodeView) CanAccess(matchers []matcher.Match, node2 NodeView) bool {
if !v.Valid() || !node2.Valid() {
return false
}
src := v.IPs()
allowedIPs := node2.IPs()
for _, matcher := range matchers {
if !matcher.SrcsContainsIPs(src...) {
continue
}
if matcher.DestsContainsIP(allowedIPs...) {
return true
}
if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) {
return true
}
}
return false
}
func (v NodeView) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool {
if !v.Valid() {
return false
}
src := v.IPs()
for _, matcher := range matchers {
if !matcher.SrcsContainsIPs(src...) {
continue
}
if matcher.DestsOverlapsPrefixes(route) {
return true
}
}
return false
}
func (v NodeView) AnnouncedRoutes() []netip.Prefix {
if !v.Valid() {
return nil
}
return v.ж.AnnouncedRoutes()
}
func (v NodeView) SubnetRoutes() []netip.Prefix {
if !v.Valid() {
return nil
}
return v.ж.SubnetRoutes()
}
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
if !v.Valid() {
return
}
v.ж.AppendToIPSet(build)
}
func (v NodeView) RequestTagsSlice() views.Slice[string] {
if !v.Valid() || !v.Hostinfo().Valid() {
return views.Slice[string]{}
}
return v.Hostinfo().RequestTags()
}
func (v NodeView) Tags() []string {
if !v.Valid() {
return nil
}
return v.ж.Tags()
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
func (v NodeView) IsTagged() bool {
if !v.Valid() {
return false
}
return v.ж.IsTagged()
}
// IsExpired returns whether the node registration has expired.
func (v NodeView) IsExpired() bool {
if !v.Valid() {
return true
}
return v.ж.IsExpired()
}
// IsEphemeral returns if the node is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/
func (v NodeView) IsEphemeral() bool {
if !v.Valid() {
return false
}
return v.ж.IsEphemeral()
}
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
// to produce a PeerChange struct that can be used to updated the node and
// inform peers about smaller changes to the node.
func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
if !v.Valid() {
return tailcfg.PeerChange{}
}
return v.ж.PeerChangeFromMapRequest(req)
}
// GetFQDN returns the fully qualified domain name for the node.
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
if !v.Valid() {
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid")
}
return v.ж.GetFQDN(baseDomain)
}
// ExitRoutes returns a list of both exit routes if the
// node has any exit routes enabled.
// If none are enabled, it will return nil.
func (v NodeView) ExitRoutes() []netip.Prefix {
if !v.Valid() {
return nil
}
return v.ж.ExitRoutes()
}
// HasIP reports if a node has a given IP address.
func (v NodeView) HasIP(i netip.Addr) bool {
if !v.Valid() {
return false
}
return v.ж.HasIP(i)
}
// HasTag reports if a node has a given tag.
func (v NodeView) HasTag(tag string) bool {
if !v.Valid() {
return false
}
return v.ж.HasTag(tag)
}
// Prefixes returns the node IPs as netip.Prefix.
func (v NodeView) Prefixes() []netip.Prefix {
if !v.Valid() {
return nil
}
return v.ж.Prefixes()
}
// IPsAsString returns the node IPs as strings.
func (v NodeView) IPsAsString() []string {
if !v.Valid() {
return nil
}
return v.ж.IPsAsString()
}

View File

@ -0,0 +1,135 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Code generated by tailscale.com/cmd/cloner; DO NOT EDIT.
package types
import (
"database/sql"
"net/netip"
"time"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
// Clone makes a deep copy of User.
// The result aliases no memory with the original.
func (src *User) Clone() *User {
if src == nil {
return nil
}
dst := new(User)
*dst = *src
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _UserCloneNeedsRegeneration = User(struct {
gorm.Model
Name string
DisplayName string
Email string
ProviderIdentifier sql.NullString
Provider string
ProfilePicURL string
}{})
// Clone makes a deep copy of Node.
// The result aliases no memory with the original.
func (src *Node) Clone() *Node {
if src == nil {
return nil
}
dst := new(Node)
*dst = *src
dst.Endpoints = append(src.Endpoints[:0:0], src.Endpoints...)
dst.Hostinfo = src.Hostinfo.Clone()
if dst.IPv4 != nil {
dst.IPv4 = ptr.To(*src.IPv4)
}
if dst.IPv6 != nil {
dst.IPv6 = ptr.To(*src.IPv6)
}
dst.ForcedTags = append(src.ForcedTags[:0:0], src.ForcedTags...)
if dst.AuthKeyID != nil {
dst.AuthKeyID = ptr.To(*src.AuthKeyID)
}
dst.AuthKey = src.AuthKey.Clone()
if dst.Expiry != nil {
dst.Expiry = ptr.To(*src.Expiry)
}
if dst.LastSeen != nil {
dst.LastSeen = ptr.To(*src.LastSeen)
}
dst.ApprovedRoutes = append(src.ApprovedRoutes[:0:0], src.ApprovedRoutes...)
if dst.DeletedAt != nil {
dst.DeletedAt = ptr.To(*src.DeletedAt)
}
if dst.IsOnline != nil {
dst.IsOnline = ptr.To(*src.IsOnline)
}
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _NodeCloneNeedsRegeneration = Node(struct {
ID NodeID
MachineKey key.MachinePublic
NodeKey key.NodePublic
DiscoKey key.DiscoPublic
Endpoints []netip.AddrPort
Hostinfo *tailcfg.Hostinfo
IPv4 *netip.Addr
IPv6 *netip.Addr
Hostname string
GivenName string
UserID uint
User User
RegisterMethod string
ForcedTags []string
AuthKeyID *uint64
AuthKey *PreAuthKey
Expiry *time.Time
LastSeen *time.Time
ApprovedRoutes []netip.Prefix
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
IsOnline *bool
}{})
// Clone makes a deep copy of PreAuthKey.
// The result aliases no memory with the original.
func (src *PreAuthKey) Clone() *PreAuthKey {
if src == nil {
return nil
}
dst := new(PreAuthKey)
*dst = *src
dst.Tags = append(src.Tags[:0:0], src.Tags...)
if dst.CreatedAt != nil {
dst.CreatedAt = ptr.To(*src.CreatedAt)
}
if dst.Expiration != nil {
dst.Expiration = ptr.To(*src.Expiration)
}
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct {
ID uint64
Key string
UserID uint
User User
Reusable bool
Ephemeral bool
Used bool
Tags []string
CreatedAt *time.Time
Expiration *time.Time
}{})

View File

@ -0,0 +1,270 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Code generated by tailscale/cmd/viewer; DO NOT EDIT.
package types
import (
"database/sql"
"encoding/json"
"errors"
"net/netip"
"time"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/views"
)
//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=User,Node,PreAuthKey
// View returns a read-only view of User.
func (p *User) View() UserView {
return UserView{ж: p}
}
// UserView provides a read-only view over User.
//
// Its methods should only be called if `Valid()` returns true.
type UserView struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *User
}
// Valid reports whether v's underlying value is non-nil.
func (v UserView) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v UserView) AsStruct() *User {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v UserView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *UserView) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x User
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v UserView) Model() gorm.Model { return v.ж.Model }
func (v UserView) Name() string { return v.ж.Name }
func (v UserView) DisplayName() string { return v.ж.DisplayName }
func (v UserView) Email() string { return v.ж.Email }
func (v UserView) ProviderIdentifier() sql.NullString { return v.ж.ProviderIdentifier }
func (v UserView) Provider() string { return v.ж.Provider }
func (v UserView) ProfilePicURL() string { return v.ж.ProfilePicURL }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _UserViewNeedsRegeneration = User(struct {
gorm.Model
Name string
DisplayName string
Email string
ProviderIdentifier sql.NullString
Provider string
ProfilePicURL string
}{})
// View returns a read-only view of Node.
func (p *Node) View() NodeView {
return NodeView{ж: p}
}
// NodeView provides a read-only view over Node.
//
// Its methods should only be called if `Valid()` returns true.
type NodeView struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *Node
}
// Valid reports whether v's underlying value is non-nil.
func (v NodeView) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v NodeView) AsStruct() *Node {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v NodeView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *NodeView) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x Node
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v NodeView) ID() NodeID { return v.ж.ID }
func (v NodeView) MachineKey() key.MachinePublic { return v.ж.MachineKey }
func (v NodeView) NodeKey() key.NodePublic { return v.ж.NodeKey }
func (v NodeView) DiscoKey() key.DiscoPublic { return v.ж.DiscoKey }
func (v NodeView) Endpoints() views.Slice[netip.AddrPort] { return views.SliceOf(v.ж.Endpoints) }
func (v NodeView) Hostinfo() tailcfg.HostinfoView { return v.ж.Hostinfo.View() }
func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv4) }
func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) }
func (v NodeView) Hostname() string { return v.ж.Hostname }
func (v NodeView) GivenName() string { return v.ж.GivenName }
func (v NodeView) UserID() uint { return v.ж.UserID }
func (v NodeView) User() User { return v.ж.User }
func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod }
func (v NodeView) ForcedTags() views.Slice[string] { return views.SliceOf(v.ж.ForcedTags) }
func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) }
func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() }
func (v NodeView) Expiry() views.ValuePointer[time.Time] { return views.ValuePointerOf(v.ж.Expiry) }
func (v NodeView) LastSeen() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.LastSeen)
}
func (v NodeView) ApprovedRoutes() views.Slice[netip.Prefix] {
return views.SliceOf(v.ж.ApprovedRoutes)
}
func (v NodeView) CreatedAt() time.Time { return v.ж.CreatedAt }
func (v NodeView) UpdatedAt() time.Time { return v.ж.UpdatedAt }
func (v NodeView) DeletedAt() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.DeletedAt)
}
func (v NodeView) IsOnline() views.ValuePointer[bool] { return views.ValuePointerOf(v.ж.IsOnline) }
func (v NodeView) String() string { return v.ж.String() }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _NodeViewNeedsRegeneration = Node(struct {
ID NodeID
MachineKey key.MachinePublic
NodeKey key.NodePublic
DiscoKey key.DiscoPublic
Endpoints []netip.AddrPort
Hostinfo *tailcfg.Hostinfo
IPv4 *netip.Addr
IPv6 *netip.Addr
Hostname string
GivenName string
UserID uint
User User
RegisterMethod string
ForcedTags []string
AuthKeyID *uint64
AuthKey *PreAuthKey
Expiry *time.Time
LastSeen *time.Time
ApprovedRoutes []netip.Prefix
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
IsOnline *bool
}{})
// View returns a read-only view of PreAuthKey.
func (p *PreAuthKey) View() PreAuthKeyView {
return PreAuthKeyView{ж: p}
}
// PreAuthKeyView provides a read-only view over PreAuthKey.
//
// Its methods should only be called if `Valid()` returns true.
type PreAuthKeyView struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *PreAuthKey
}
// Valid reports whether v's underlying value is non-nil.
func (v PreAuthKeyView) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v PreAuthKeyView) AsStruct() *PreAuthKey {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v PreAuthKeyView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x PreAuthKey
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v PreAuthKeyView) ID() uint64 { return v.ж.ID }
func (v PreAuthKeyView) Key() string { return v.ж.Key }
func (v PreAuthKeyView) UserID() uint { return v.ж.UserID }
func (v PreAuthKeyView) User() User { return v.ж.User }
func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable }
func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral }
func (v PreAuthKeyView) Used() bool { return v.ж.Used }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.CreatedAt)
}
func (v PreAuthKeyView) Expiration() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.Expiration)
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct {
ID uint64
Key string
UserID uint
User User
Reusable bool
Ephemeral bool
Used bool
Tags []string
CreatedAt *time.Time
Expiration *time.Time
}{})