From 73023c2ec398d5dea8bbf0a74532c07647c77c6d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 5 Jul 2025 23:31:13 +0200 Subject: [PATCH] all: use immutable node view in read path This commit changes most of our (*)types.Node to types.NodeView, which is a readonly version of the underlying node ensuring that there is no mutations happening in the read path. Based on the migration, there didnt seem to be any, but the idea here is to prevent it in the future and simplify other new implementations. Signed-off-by: Kristoffer Dalby --- cmd/headscale/cli/policy.go | 4 +- hscontrol/db/node_test.go | 2 +- hscontrol/debug.go | 2 +- hscontrol/grpcv1.go | 4 +- hscontrol/mapper/mapper.go | 75 ++++--- hscontrol/mapper/mapper_test.go | 6 +- hscontrol/mapper/tail.go | 65 +++--- hscontrol/mapper/tail_test.go | 8 +- hscontrol/policy/pm.go | 19 +- hscontrol/policy/policy.go | 61 ++++-- hscontrol/policy/policy_test.go | 21 +- hscontrol/policy/route_approval_test.go | 4 +- hscontrol/policy/v2/filter.go | 7 +- hscontrol/policy/v2/filter_test.go | 2 +- hscontrol/policy/v2/policy.go | 25 ++- hscontrol/policy/v2/policy_test.go | 2 +- hscontrol/policy/v2/types.go | 68 +++--- hscontrol/policy/v2/types_test.go | 14 +- hscontrol/poll.go | 54 +++-- hscontrol/state/state.go | 20 +- hscontrol/types/common.go | 2 + hscontrol/types/node.go | 192 +++++++++++++++++ hscontrol/types/types_clone.go | 135 ++++++++++++ hscontrol/types/types_view.go | 270 ++++++++++++++++++++++++ 24 files changed, 866 insertions(+), 196 deletions(-) create mode 100644 hscontrol/types/types_clone.go create mode 100644 hscontrol/types/types_view.go diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 63f4a6bf..caf9d436 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -7,8 +7,10 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "github.com/spf13/cobra" + "tailscale.com/types/views" ) func init() { @@ -111,7 +113,7 @@ var checkPolicy = &cobra.Command{ ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) } - _, err = policy.NewPolicyManager(policyBytes, nil, nil) + _, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{}) if err != nil { ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 9e302541..9f10fc1c 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -485,7 +485,7 @@ func TestAutoApproveRoutes(t *testing.T) { nodes, err := adb.ListNodes() assert.NoError(t, err) - pm, err := pmf(users, nodes) + pm, err := pmf(users, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, pm) diff --git a/hscontrol/debug.go b/hscontrol/debug.go index e711f3a2..038582c8 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -78,7 +78,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { sshPol := make(map[string]*tailcfg.SSHPolicy) for _, node := range nodes { - pol, err := h.state.SSHPolicy(node) + pol, err := h.state.SSHPolicy(node.View()) if err != nil { httpError(w, err) return diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 277e729d..e098b766 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -537,7 +537,7 @@ func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeI var tags []string for _, tag := range node.RequestTags() { - if state.NodeCanHaveTag(node, tag) { + if state.NodeCanHaveTag(node.View(), tag) { tags = append(tags, tag) } } @@ -733,7 +733,7 @@ func (api headscaleV1APIServer) SetPolicy( } if len(nodes) > 0 { - _, err = api.h.state.SSHPolicy(nodes[0]) + _, err = api.h.state.SSHPolicy(nodes[0].View()) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index cce1b870..49a99351 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -27,6 +27,7 @@ import ( "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/views" ) const ( @@ -88,16 +89,18 @@ func (m *Mapper) String() string { } func generateUserProfiles( - node *types.Node, - peers types.Nodes, + node types.NodeView, + peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { userMap := make(map[uint]*types.User) - ids := make([]uint, 0, len(userMap)) - userMap[node.User.ID] = &node.User - ids = append(ids, node.User.ID) - for _, peer := range peers { - userMap[peer.User.ID] = &peer.User - ids = append(ids, peer.User.ID) + ids := make([]uint, 0, peers.Len()+1) + user := node.User() + userMap[user.ID] = &user + ids = append(ids, user.ID) + for _, peer := range peers.All() { + peerUser := peer.User() + userMap[peerUser.ID] = &peerUser + ids = append(ids, peerUser.ID) } slices.Sort(ids) @@ -114,7 +117,7 @@ func generateUserProfiles( func generateDNSConfig( cfg *types.Config, - node *types.Node, + node types.NodeView, ) *tailcfg.DNSConfig { if cfg.TailcfgDNSConfig == nil { return nil @@ -134,16 +137,17 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ - "device_name": []string{node.Hostname}, - "device_model": []string{node.Hostinfo.OS}, + "device_name": []string{node.Hostname()}, + "device_model": []string{node.Hostinfo().OS()}, } - if len(node.IPs()) > 0 { - attrs.Add("device_ip", node.IPs()[0].String()) + nodeIPs := node.IPs() + if len(nodeIPs) > 0 { + attrs.Add("device_ip", nodeIPs[0].String()) } resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) @@ -154,8 +158,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { // fullMapResponse creates a complete MapResponse for a node. // It is a separate function to make testing easier. func (m *Mapper) fullMapResponse( - node *types.Node, - peers types.Nodes, + node types.NodeView, + peers views.Slice[types.NodeView], capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp, err := m.baseWithConfigMapResponse(node, capVer) @@ -182,15 +186,15 @@ func (m *Mapper) fullMapResponse( // FullMapResponse returns a MapResponse for the given node. func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, - node *types.Node, + node types.NodeView, messages ...string, ) ([]byte, error) { - peers, err := m.ListPeers(node.ID) + peers, err := m.ListPeers(node.ID()) if err != nil { return nil, err } - resp, err := m.fullMapResponse(node, peers, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version) if err != nil { return nil, err } @@ -203,7 +207,7 @@ func (m *Mapper) FullMapResponse( // to be used to answer MapRequests with OmitPeers set to true. func (m *Mapper) ReadOnlyMapResponse( mapRequest tailcfg.MapRequest, - node *types.Node, + node types.NodeView, messages ...string, ) ([]byte, error) { resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) @@ -216,7 +220,7 @@ func (m *Mapper) ReadOnlyMapResponse( func (m *Mapper) KeepAliveResponse( mapRequest tailcfg.MapRequest, - node *types.Node, + node types.NodeView, ) ([]byte, error) { resp := m.baseMapResponse() resp.KeepAlive = true @@ -226,7 +230,7 @@ func (m *Mapper) KeepAliveResponse( func (m *Mapper) DERPMapResponse( mapRequest tailcfg.MapRequest, - node *types.Node, + node types.NodeView, derpMap *tailcfg.DERPMap, ) ([]byte, error) { resp := m.baseMapResponse() @@ -237,7 +241,7 @@ func (m *Mapper) DERPMapResponse( func (m *Mapper) PeerChangedResponse( mapRequest tailcfg.MapRequest, - node *types.Node, + node types.NodeView, changed map[types.NodeID]bool, patches []*tailcfg.PeerChange, messages ...string, @@ -249,7 +253,7 @@ func (m *Mapper) PeerChangedResponse( var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { if nodeChanged { - if nodeID != node.ID { + if nodeID != node.ID() { changedIDs = append(changedIDs, nodeID) } } else { @@ -270,7 +274,7 @@ func (m *Mapper) PeerChangedResponse( m.state, node, mapRequest.Version, - changedNodes, + changedNodes.ViewSlice(), m.cfg, ) if err != nil { @@ -315,7 +319,7 @@ func (m *Mapper) PeerChangedResponse( // incoming update from a state change. func (m *Mapper) PeerChangedPatchResponse( mapRequest tailcfg.MapRequest, - node *types.Node, + node types.NodeView, changed []*tailcfg.PeerChange, ) ([]byte, error) { resp := m.baseMapResponse() @@ -327,7 +331,7 @@ func (m *Mapper) PeerChangedPatchResponse( func (m *Mapper) marshalMapResponse( mapRequest tailcfg.MapRequest, resp *tailcfg.MapResponse, - node *types.Node, + node types.NodeView, compression string, messages ...string, ) ([]byte, error) { @@ -366,7 +370,7 @@ func (m *Mapper) marshalMapResponse( } perms := fs.FileMode(debugMapResponsePerm) - mPath := path.Join(debugDumpMapResponsePath, node.Hostname) + mPath := path.Join(debugDumpMapResponsePath, node.Hostname()) err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -444,7 +448,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { // It is used in for bigger updates, such as full and lite, not // incremental. func (m *Mapper) baseWithConfigMapResponse( - node *types.Node, + node types.NodeView, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() @@ -523,9 +527,9 @@ func appendPeerChanges( fullChange bool, state *state.State, - node *types.Node, + node types.NodeView, capVer tailcfg.CapabilityVersion, - changed types.Nodes, + changed views.Slice[types.NodeView], cfg *types.Config, ) error { filter, matchers := state.Filter() @@ -537,16 +541,19 @@ func appendPeerChanges( // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. + var reducedChanged views.Slice[types.NodeView] if len(filter) > 0 { - changed = policy.ReduceNodes(node, changed, matchers) + reducedChanged = policy.ReduceNodes(node, changed, matchers) + } else { + reducedChanged = changed } - profiles := generateUserProfiles(node, changed) + profiles := generateUserProfiles(node, reducedChanged) dnsConfig := generateDNSConfig(cfg, node) tailPeers, err := tailNodes( - changed, capVer, state, + reducedChanged, capVer, state, func(id types.NodeID) []netip.Prefix { return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers) }, diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 73bb5060..71b9e4b9 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -70,7 +70,7 @@ func TestDNSConfigMapResponse(t *testing.T) { &types.Config{ TailcfgDNSConfig: &dnsConfigOrig, }, - nodeInShared1, + nodeInShared1.View(), ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { @@ -100,14 +100,14 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { return m.polMan.Filter() } -func (m *mockState) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { +func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { if m.polMan == nil { return nil, nil } return m.polMan.SSHPolicy(node) } -func (m *mockState) NodeCanHaveTag(node *types.Node, tag string) bool { +func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { if m.polMan == nil { return false } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index ac3d5b16..9b58ad34 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -8,24 +8,25 @@ import ( "github.com/samber/lo" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/types/views" ) // NodeCanHaveTagChecker is an interface for checking if a node can have a tag type NodeCanHaveTagChecker interface { - NodeCanHaveTag(node *types.Node, tag string) bool + NodeCanHaveTag(node types.NodeView, tag string) bool } func tailNodes( - nodes types.Nodes, + nodes views.Slice[types.NodeView], capVer tailcfg.CapabilityVersion, checker NodeCanHaveTagChecker, primaryRouteFunc routeFilterFunc, cfg *types.Config, ) ([]*tailcfg.Node, error) { - tNodes := make([]*tailcfg.Node, len(nodes)) + tNodes := make([]*tailcfg.Node, 0, nodes.Len()) - for index, node := range nodes { - node, err := tailNode( + for _, node := range nodes.All() { + tNode, err := tailNode( node, capVer, checker, @@ -36,7 +37,7 @@ func tailNodes( return nil, err } - tNodes[index] = node + tNodes = append(tNodes, tNode) } return tNodes, nil @@ -44,7 +45,7 @@ func tailNodes( // tailNode converts a Node into a Tailscale Node. func tailNode( - node *types.Node, + node types.NodeView, capVer tailcfg.CapabilityVersion, checker NodeCanHaveTagChecker, primaryRouteFunc routeFilterFunc, @@ -57,61 +58,64 @@ func tailNode( // TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077 // and should be removed after 111 is the minimum capver. var legacyDERP string - if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil { - legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) - derp = node.Hostinfo.NetInfo.PreferredDERP + if node.Hostinfo().Valid() && node.Hostinfo().NetInfo().Valid() { + legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo().NetInfo().PreferredDERP()) + derp = node.Hostinfo().NetInfo().PreferredDERP() } else { legacyDERP = "127.3.3.40:0" // Zero means disconnected or unknown. } var keyExpiry time.Time - if node.Expiry != nil { - keyExpiry = *node.Expiry + if node.Expiry().Valid() { + keyExpiry = node.Expiry().Get() } else { keyExpiry = time.Time{} } hostname, err := node.GetFQDN(cfg.BaseDomain) if err != nil { - return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) + return nil, err } var tags []string - for _, tag := range node.RequestTags() { + for _, tag := range node.RequestTagsSlice().All() { if checker.NodeCanHaveTag(node, tag) { tags = append(tags, tag) } } - tags = lo.Uniq(append(tags, node.ForcedTags...)) + for _, tag := range node.ForcedTags().All() { + tags = append(tags, tag) + } + tags = lo.Uniq(tags) - routes := primaryRouteFunc(node.ID) - allowed := append(node.Prefixes(), routes...) + routes := primaryRouteFunc(node.ID()) + allowed := append(addrs, routes...) allowed = append(allowed, node.ExitRoutes()...) tsaddr.SortPrefixes(allowed) tNode := tailcfg.Node{ - ID: tailcfg.NodeID(node.ID), // this is the actual ID - StableID: node.ID.StableID(), + ID: tailcfg.NodeID(node.ID()), // this is the actual ID + StableID: node.ID().StableID(), Name: hostname, Cap: capVer, - User: tailcfg.UserID(node.UserID), + User: tailcfg.UserID(node.UserID()), - Key: node.NodeKey, + Key: node.NodeKey(), KeyExpiry: keyExpiry.UTC(), - Machine: node.MachineKey, - DiscoKey: node.DiscoKey, + Machine: node.MachineKey(), + DiscoKey: node.DiscoKey(), Addresses: addrs, PrimaryRoutes: routes, AllowedIPs: allowed, - Endpoints: node.Endpoints, + Endpoints: node.Endpoints().AsSlice(), HomeDERP: derp, LegacyDERPString: legacyDERP, - Hostinfo: node.Hostinfo.View(), - Created: node.CreatedAt.UTC(), + Hostinfo: node.Hostinfo(), + Created: node.CreatedAt().UTC(), - Online: node.IsOnline, + Online: node.IsOnline().Clone(), Tags: tags, @@ -129,10 +133,13 @@ func tailNode( tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} } - if node.IsOnline == nil || !*node.IsOnline { + if !node.IsOnline().Valid() || !node.IsOnline().Get() { // LastSeen is only set when node is // not connected to the control server. - tNode.LastSeen = node.LastSeen + if node.LastSeen().Valid() { + lastSeen := node.LastSeen().Get() + tNode.LastSeen = &lastSeen + } } return &tNode, nil diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index cacc4930..c699943f 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -202,7 +202,7 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node}) + polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node}.ViewSlice()) require.NoError(t, err) primary := routes.New() cfg := &types.Config{ @@ -216,7 +216,7 @@ func TestTailNode(t *testing.T) { // This should be baked into the test case proper if it is extended in the future. _ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24")) got, err := tailNode( - tt.node, + tt.node.View(), 0, polMan, func(id types.NodeID) []netip.Prefix { @@ -272,11 +272,11 @@ func TestNodeExpiry(t *testing.T) { GivenName: "test", Expiry: tt.exp, } - polMan, err := policy.NewPolicyManager(nil, nil, nil) + polMan, err := policy.NewPolicyManager(nil, nil, types.Nodes{}.ViewSlice()) require.NoError(t, err) tn, err := tailNode( - node, + node.View(), 0, polMan, func(id types.NodeID) []netip.Prefix { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index c4758929..cfeb65a1 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -8,27 +8,28 @@ import ( policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" + "tailscale.com/types/views" ) type PolicyManager interface { // Filter returns the current filter rules for the entire tailnet and the associated matchers. Filter() ([]tailcfg.FilterRule, []matcher.Match) - SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) SetPolicy([]byte) (bool, error) SetUsers(users []types.User) (bool, error) - SetNodes(nodes types.Nodes) (bool, error) + SetNodes(nodes views.Slice[types.NodeView]) (bool, error) // NodeCanHaveTag reports whether the given node can have the given tag. - NodeCanHaveTag(*types.Node, string) bool + NodeCanHaveTag(types.NodeView, string) bool // NodeCanApproveRoute reports whether the given node can approve the given route. - NodeCanApproveRoute(*types.Node, netip.Prefix) bool + NodeCanApproveRoute(types.NodeView, netip.Prefix) bool Version() int DebugString() string } // NewPolicyManager returns a new policy manager. -func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { +func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { var polMan PolicyManager var err error polMan, err = policyv2.NewPolicyManager(pol, users, nodes) @@ -42,7 +43,7 @@ func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (Policy // PolicyManagersForTest returns all available PostureManagers to be used // in tests to validate them in tests that try to determine that they // behave the same. -func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) { +func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) ([]PolicyManager, error) { var polMans []PolicyManager for _, pmf := range PolicyManagerFuncsForTest(pol) { @@ -56,10 +57,10 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([ return polMans, nil } -func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) { - var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error) +func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { + var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) - polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) { + polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { return policyv2.NewPolicyManager(pol, u, n) }) diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 5859a198..4efd1e01 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -11,32 +11,33 @@ import ( "github.com/samber/lo" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/types/views" ) // ReduceNodes returns the list of peers authorized to be accessed from a given node. func ReduceNodes( - node *types.Node, - nodes types.Nodes, + node types.NodeView, + nodes views.Slice[types.NodeView], matchers []matcher.Match, -) types.Nodes { - var result types.Nodes +) views.Slice[types.NodeView] { + var result []types.NodeView - for index, peer := range nodes { - if peer.ID == node.ID { + for _, peer := range nodes.All() { + if peer.ID() == node.ID() { continue } - if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) { + if node.CanAccess(matchers, peer) || peer.CanAccess(matchers, node) { result = append(result, peer) } } - return result + return views.SliceOf(result) } // ReduceRoutes returns a reduced list of routes for a given node that it can access. func ReduceRoutes( - node *types.Node, + node types.NodeView, routes []netip.Prefix, matchers []matcher.Match, ) []netip.Prefix { @@ -51,9 +52,36 @@ func ReduceRoutes( return result } +// BuildPeerMap builds a map of all peers that can be accessed by each node. +func BuildPeerMap( + nodes views.Slice[types.NodeView], + matchers []matcher.Match, +) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, nodes.Len()) + + // Build the map of all peers according to the matchers. + // Compared to ReduceNodes, which builds the list per node, we end up with doing + // the full work for every node (On^2), while this will reduce the list as we see + // relationships while building the map, making it O(n^2/2) in the end, but with less work per node. + for i := range nodes.Len() { + for j := i + 1; j < nodes.Len(); j++ { + if nodes.At(i).ID() == nodes.At(j).ID() { + continue + } + + if nodes.At(i).CanAccess(matchers, nodes.At(j)) || nodes.At(j).CanAccess(matchers, nodes.At(i)) { + ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j)) + ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i)) + } + } + } + + return ret +} + // ReduceFilterRules takes a node and a set of rules and removes all rules and destinations // that are not relevant to that particular node. -func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule { +func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule { ret := []tailcfg.FilterRule{} for _, rule := range rules { @@ -75,9 +103,10 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F // If the node exposes routes, ensure they are note removed // when the filters are reduced. - if node.Hostinfo != nil { - if len(node.Hostinfo.RoutableIPs) > 0 { - for _, routableIP := range node.Hostinfo.RoutableIPs { + if node.Hostinfo().Valid() { + routableIPs := node.Hostinfo().RoutableIPs() + if routableIPs.Len() > 0 { + for _, routableIP := range routableIPs.All() { if expanded.OverlapsPrefix(routableIP) { dests = append(dests, dest) continue DEST_LOOP @@ -102,13 +131,15 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F // AutoApproveRoutes approves any route that can be autoapproved from // the nodes perspective according to the given policy. // It reports true if any routes were approved. +// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { if pm == nil { return false } + nodeView := node.View() var newApproved []netip.Prefix - for _, route := range node.AnnouncedRoutes() { - if pm.NodeCanApproveRoute(node, route) { + for _, route := range nodeView.AnnouncedRoutes() { + if pm.NodeCanApproveRoute(nodeView, route) { newApproved = append(newApproved, route) } } diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 83d69eb8..9f2f7573 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -815,11 +815,11 @@ func TestReduceFilterRules(t *testing.T) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { var pm PolicyManager var err error - pm, err = pmf(users, append(tt.peers, tt.node)) + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) require.NoError(t, err) got, _ := pm.Filter() t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) - got = ReduceFilterRules(tt.node, got) + got = ReduceFilterRules(tt.node.View(), got) if diff := cmp.Diff(tt.want, got); diff != "" { log.Trace().Interface("got", got).Msg("result") @@ -1576,11 +1576,16 @@ func TestReduceNodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) - got := ReduceNodes( - tt.args.node, - tt.args.nodes, + gotViews := ReduceNodes( + tt.args.node.View(), + tt.args.nodes.ViewSlice(), matchers, ) + // Convert views back to nodes for comparison in tests + var got types.Nodes + for _, v := range gotViews.All() { + got = append(got, v.AsStruct()) + } if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) } @@ -1949,7 +1954,7 @@ func TestSSHPolicyRules(t *testing.T) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { var pm PolicyManager var err error - pm, err = pmf(users, append(tt.peers, &tt.targetNode)) + pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) if tt.expectErr { require.Error(t, err) @@ -1959,7 +1964,7 @@ func TestSSHPolicyRules(t *testing.T) { require.NoError(t, err) - got, err := pm.SSHPolicy(&tt.targetNode) + got, err := pm.SSHPolicy(tt.targetNode.View()) require.NoError(t, err) if diff := cmp.Diff(tt.wantSSH, got); diff != "" { @@ -2426,7 +2431,7 @@ func TestReduceRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) got := ReduceRoutes( - tt.args.node, + tt.args.node.View(), tt.args.routes, matchers, ) diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 19d61d82..5e332fd3 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -776,7 +776,7 @@ func TestNodeCanApproveRoute(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Initialize all policy manager implementations - policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node}) + policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node}.ViewSlice()) if tt.name == "empty policy" { // We expect this one to have a valid but empty policy require.NoError(t, err) @@ -789,7 +789,7 @@ func TestNodeCanApproveRoute(t *testing.T) { for i, pm := range policyManagers { t.Run(fmt.Sprintf("policy-index%d", i), func(t *testing.T) { - result := pm.NodeCanApproveRoute(&tt.node, tt.route) + result := pm.NodeCanApproveRoute(tt.node.View(), tt.route) if diff := cmp.Diff(tt.canApprove, result); diff != "" { t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 6bbc8030..1825926f 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -10,6 +10,7 @@ import ( "github.com/rs/zerolog/log" "go4.org/netipx" "tailscale.com/tailcfg" + "tailscale.com/types/views" ) var ( @@ -20,7 +21,7 @@ var ( // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *Policy) compileFilterRules( users types.Users, - nodes types.Nodes, + nodes views.Slice[types.NodeView], ) ([]tailcfg.FilterRule, error) { if pol == nil { return tailcfg.FilterAllowAll, nil @@ -97,8 +98,8 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { func (pol *Policy) compileSSHPolicy( users types.Users, - node *types.Node, - nodes types.Nodes, + node types.NodeView, + nodes views.Slice[types.NodeView], ) (*tailcfg.SSHPolicy, error) { if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { return nil, nil diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index b5f08164..12c60fbb 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -362,7 +362,7 @@ func TestParsing(t *testing.T) { User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, - }) + }.ViewSlice()) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 80235354..cbc34215 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -16,13 +16,14 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/util/deephash" + "tailscale.com/types/views" ) type PolicyManager struct { mu sync.Mutex pol *Policy users []types.User - nodes types.Nodes + nodes views.Slice[types.NodeView] filterHash deephash.Sum filter []tailcfg.FilterRule @@ -43,7 +44,7 @@ type PolicyManager struct { // NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes. // It returns an error if the policy file is invalid. // The policy manager will update the filter rules based on the users and nodes. -func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { +func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.NodeView]) (*PolicyManager, error) { policy, err := unmarshalPolicy(b) if err != nil { return nil, fmt.Errorf("parsing policy: %w", err) @@ -53,7 +54,7 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM pol: policy, users: users, nodes: nodes, - sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)), + sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()), } _, err = pm.updateLocked() @@ -122,11 +123,11 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return true, nil } -func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { +func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { pm.mu.Lock() defer pm.mu.Unlock() - if sshPol, ok := pm.sshPolicyMap[node.ID]; ok { + if sshPol, ok := pm.sshPolicyMap[node.ID()]; ok { return sshPol, nil } @@ -134,7 +135,7 @@ func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } - pm.sshPolicyMap[node.ID] = sshPol + pm.sshPolicyMap[node.ID()] = sshPol return sshPol, nil } @@ -181,7 +182,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { } // SetNodes updates the nodes in the policy manager and updates the filter rules. -func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { +func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) { if pm == nil { return false, nil } @@ -192,7 +193,7 @@ func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { return pm.updateLocked() } -func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool { +func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool { if pm == nil { return false } @@ -209,7 +210,7 @@ func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool { return false } -func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool { +func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool { if pm == nil { return false } @@ -322,7 +323,11 @@ func (pm *PolicyManager) DebugString() string { } sb.WriteString("\n\n") - sb.WriteString(pm.nodes.DebugString()) + sb.WriteString("Nodes:\n") + for _, node := range pm.nodes.All() { + sb.WriteString(node.String()) + sb.WriteString("\n") + } return sb.String() } diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index b61c5758..b3540e63 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -47,7 +47,7 @@ func TestPolicyManager(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) + pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes.ViewSlice()) require.NoError(t, err) filter, matchers := pm.Filter() diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 941a645b..550287c2 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -18,6 +18,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/multierr" ) @@ -91,7 +92,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error { return nil } -func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder // TODO(kradalby): @@ -179,7 +180,7 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { return potentialUsers[0], nil } -func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -188,12 +189,13 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*net errs = append(errs, err) } - for _, node := range nodes { + for _, node := range nodes.All() { + // Skip tagged nodes if node.IsTagged() { continue } - if node.User.ID == user.ID { + if node.User().ID == user.ID { node.AppendToIPSet(&ips) } } @@ -246,7 +248,7 @@ func (g Group) MarshalJSON() ([]byte, error) { return json.Marshal(string(g)) } -func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -280,7 +282,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error { return nil } -func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder // TODO(kradalby): This is currently resolved twice, and should be resolved once. @@ -295,17 +297,19 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.I return nil, err } - for _, node := range nodes { - if node.HasTag(string(t)) { + for _, node := range nodes.All() { + // Check if node has this tag in all tags (ForcedTags + AuthKey.Tags) + if slices.Contains(node.Tags(), string(t)) { node.AppendToIPSet(&ips) } // TODO(kradalby): remove as part of #2417, see comment above if tagMap != nil { - if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo != nil { - for _, tag := range node.Hostinfo.RequestTags { + if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo().Valid() { + for _, tag := range node.RequestTagsSlice().All() { if tag == string(t) { node.AppendToIPSet(&ips) + break } } } @@ -346,7 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error { return nil } -func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -371,7 +375,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSe if err != nil { errs = append(errs, err) } - for _, node := range nodes { + for _, node := range nodes.All() { if node.InIPSet(ipsTemp) { node.AppendToIPSet(&ips) } @@ -432,7 +436,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // of the Prefix and the Policy, Users, and Nodes. // // See [Policy], [types.Users], and [types.Nodes] for more details. -func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -446,12 +450,12 @@ func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IP // appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the // IP address in the prefix. -func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref netip.Prefix) { +func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuilder, pref netip.Prefix) { if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) { return } - for _, node := range nodes { + for _, node := range nodes.All() { if node.HasIP(pref.Addr()) { node.AppendToIPSet(ips) } @@ -499,7 +503,7 @@ func (ag AutoGroup) MarshalJSON() ([]byte, error) { return json.Marshal(string(ag)) } -func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder switch ag { @@ -513,17 +517,17 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n return nil, err } - for _, node := range nodes { - // Skip if node has forced tags - if len(node.ForcedTags) != 0 { + for _, node := range nodes.All() { + // Skip if node is tagged + if node.IsTagged() { continue } // Skip if node has any allowed requested tags hasAllowedTag := false - if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 { - for _, tag := range node.Hostinfo.RequestTags { - if tagips, ok := tagMap[Tag(tag)]; ok && node.InIPSet(tagips) { + if node.RequestTagsSlice().Len() != 0 { + for _, tag := range node.RequestTagsSlice().All() { + if _, ok := tagMap[Tag(tag)]; ok { hasAllowedTag = true break } @@ -546,16 +550,16 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n return nil, err } - for _, node := range nodes { - // Include if node has forced tags - if len(node.ForcedTags) != 0 { + for _, node := range nodes.All() { + // Include if node is tagged + if node.IsTagged() { node.AppendToIPSet(&build) continue } // Include if node has any allowed requested tags - if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 { - for _, tag := range node.Hostinfo.RequestTags { + if node.RequestTagsSlice().Len() != 0 { + for _, tag := range node.RequestTagsSlice().All() { if _, ok := tagMap[Tag(tag)]; ok { node.AppendToIPSet(&build) break @@ -588,7 +592,7 @@ type Alias interface { // of the Alias and the Policy, Users and Nodes. // This is an interface definition and the implementation is independent of // the Alias type. - Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error) + Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -759,7 +763,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -1094,7 +1098,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error { // resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet. // The resulting map can be used to quickly look up the IPSet for a given Tag. // It is intended for internal use in a PolicyManager. -func resolveTagOwners(p *Policy, users types.Users, nodes types.Nodes) (map[Tag]*netipx.IPSet, error) { +func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) { if p == nil { return nil, nil } @@ -1158,7 +1162,7 @@ func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) { // resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet. // The resulting map can be used to quickly look up if a node can self-approve a route. // It is intended for internal use in a PolicyManager. -func resolveAutoApprovers(p *Policy, users types.Users, nodes types.Nodes) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) { +func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) { if p == nil { return nil, nil, nil } @@ -1671,7 +1675,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { +func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index ac2fc3b1..8cddfeba 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1377,7 +1377,7 @@ func TestResolvePolicy(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ips, err := tt.toResolve.Resolve(tt.pol, xmaps.Values(users), - tt.nodes) + tt.nodes.ViewSlice()) if tt.wantErr == "" { if err != nil { t.Fatalf("got %v; want no error", err) @@ -1557,7 +1557,7 @@ func TestResolveAutoApprovers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes) + got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes.ViewSlice()) if (err != nil) != tt.wantErr { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return @@ -1716,10 +1716,10 @@ func TestNodeCanApproveRoute(t *testing.T) { b, err := json.Marshal(tt.policy) require.NoError(t, err) - pm, err := NewPolicyManager(b, users, nodes) + pm, err := NewPolicyManager(b, users, nodes.ViewSlice()) require.NoErrorf(t, err, "NewPolicyManager() error = %v", err) - got := pm.NodeCanApproveRoute(tt.node, tt.route) + got := pm.NodeCanApproveRoute(tt.node.View(), tt.route) if got != tt.want { t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want) } @@ -1800,7 +1800,7 @@ func TestResolveTagOwners(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := resolveTagOwners(tt.policy, users, nodes) + got, err := resolveTagOwners(tt.policy, users, nodes.ViewSlice()) if (err != nil) != tt.wantErr { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return @@ -1911,14 +1911,14 @@ func TestNodeCanHaveTag(t *testing.T) { b, err := json.Marshal(tt.policy) require.NoError(t, err) - pm, err := NewPolicyManager(b, users, nodes) + pm, err := NewPolicyManager(b, users, nodes.ViewSlice()) if tt.wantErr != "" { require.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) - got := pm.NodeCanHaveTag(tt.node, tt.tag) + got := pm.NodeCanHaveTag(tt.node.View(), tt.tag) if got != tt.want { t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want) } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 56175fdb..13504071 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -168,6 +168,10 @@ func (m *mapSession) serve() { func (m *mapSession) serveLongPoll() { m.beforeServeLongPoll() + // For now, mapSession uses a normal node, but since serveLongPoll is a read operation, + // convert the node to a view at the beginning. + nv := m.node.View() + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -179,16 +183,16 @@ func (m *mapSession) serveLongPoll() { // in principal, it will be removed, but the client rapidly // reconnects, the channel might be of another connection. // In that case, it is not closed and the node is still online. - if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) { + if m.h.nodeNotifier.RemoveNode(nv.ID(), m.ch) { // TODO(kradalby): This can likely be made more effective, but likely most // nodes has access to the same routes, so it might not be a big deal. - change, err := m.h.state.Disconnect(m.node) + change, err := m.h.state.Disconnect(nv) if err != nil { - m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + m.errf(err, "Failed to disconnect node %s", nv.Hostname()) } if change { - ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) + ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } } @@ -201,8 +205,8 @@ func (m *mapSession) serveLongPoll() { m.h.pollNetMapStreamWG.Add(1) defer m.h.pollNetMapStreamWG.Done() - if m.h.state.Connect(m.node) { - ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) + if m.h.state.Connect(nv) { + ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } @@ -213,17 +217,17 @@ func (m *mapSession) serveLongPoll() { // so it needs to be disabled. rc.SetWriteDeadline(time.Time{}) - ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) + ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, nv.Hostname())) defer cancel() m.keepAliveTicker = time.NewTicker(m.keepAlive) - m.h.nodeNotifier.AddNode(m.node.ID, m.ch) + m.h.nodeNotifier.AddNode(nv.ID(), m.ch) go func() { - changed := m.h.state.Connect(m.node) + changed := m.h.state.Connect(nv) if changed { - ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) + ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } }() @@ -253,7 +257,7 @@ func (m *mapSession) serveLongPoll() { } // If the node has been removed from headscale, close the stream - if slices.Contains(update.Removed, m.node.ID) { + if slices.Contains(update.Removed, nv.ID()) { m.tracef("node removed, closing stream") return } @@ -268,18 +272,22 @@ func (m *mapSession) serveLongPoll() { // Ensure the node object is updated, for example, there // might have been a hostinfo update in a sidechannel // which contains data needed to generate a map response. - m.node, err = m.h.state.GetNodeByID(m.node.ID) + m.node, err = m.h.state.GetNodeByID(nv.ID()) if err != nil { m.errf(err, "Could not get machine from db") return } + // Update the node view to reflect the latest node state + // TODO(kradalby): This should become a full read only path, with no update for the node view + // in the new mapper model. + nv = m.node.View() updateType := "full" switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + data, err = m.mapper.FullMapResponse(m.req, nv, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) case types.StatePeerChanged: changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) @@ -289,12 +297,12 @@ func (m *mapSession) serveLongPoll() { lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage) updateType = "change" case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) + data, err = m.mapper.PeerChangedPatchResponse(m.req, nv, update.ChangePatches) updateType = "patch" case types.StatePeerRemoved: changed := make(map[types.NodeID]bool, len(update.Removed)) @@ -303,17 +311,17 @@ func (m *mapSession) serveLongPoll() { changed[nodeID] = false } m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage) updateType = "remove" case types.StateSelfUpdate: lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, nv, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) updateType = "remove" case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") - data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap()) + data, err = m.mapper.DERPMapResponse(m.req, nv, m.h.state.DERPMap()) updateType = "derp" } @@ -340,10 +348,10 @@ func (m *mapSession) serveLongPoll() { return } - log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") + log.Trace().Str("node", nv.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", nv.MachineKey().String()).Msg("finished writing mapresp to node") if debugHighCardinalityMetrics { - mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID.String()).Set(float64(time.Now().Unix())) + mapResponseLastSentSeconds.WithLabelValues(updateType, nv.ID().String()).Set(float64(time.Now().Unix())) } mapResponseSent.WithLabelValues("ok", updateType).Inc() m.tracef("update sent") @@ -351,7 +359,7 @@ func (m *mapSession) serveLongPoll() { } case <-m.keepAliveTicker.C: - data, err := m.mapper.KeepAliveResponse(m.req, m.node) + data, err := m.mapper.KeepAliveResponse(m.req, nv) if err != nil { m.errf(err, "Error generating the keep alive msg") mapResponseSent.WithLabelValues("error", "keepalive").Inc() @@ -371,7 +379,7 @@ func (m *mapSession) serveLongPoll() { } if debugHighCardinalityMetrics { - mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) + mapResponseLastSentSeconds.WithLabelValues("keepalive", nv.ID().String()).Set(float64(time.Now().Unix())) } mapResponseSent.WithLabelValues("ok", "keepalive").Inc() } @@ -490,7 +498,7 @@ func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleReadOnlyRequest() { m.tracef("Client asked for a lite update, responding without peers") - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node.View()) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index c8927810..2a08ef29 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -104,7 +104,7 @@ func NewState(cfg *types.Config) (*State, error) { return nil, fmt.Errorf("loading policy: %w", err) } - polMan, err := policy.NewPolicyManager(pol, users, nodes) + polMan, err := policy.NewPolicyManager(pol, users, nodes.ViewSlice()) if err != nil { return nil, fmt.Errorf("init policy manager: %w", err) } @@ -400,22 +400,22 @@ func (s *State) DeleteNode(node *types.Node) (bool, error) { return policyChanged, nil } -func (s *State) Connect(node *types.Node) bool { - _ = s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) +func (s *State) Connect(node types.NodeView) bool { + changed := s.primaryRoutes.SetRoutes(node.ID(), node.SubnetRoutes()...) // TODO(kradalby): this should be more granular, allowing us to // only send a online update change. - return true + return changed } -func (s *State) Disconnect(node *types.Node) (bool, error) { +func (s *State) Disconnect(node types.NodeView) (bool, error) { // TODO(kradalby): This node should update the in memory state - _, polChanged, err := s.SetLastSeen(node.ID, time.Now()) + _, polChanged, err := s.SetLastSeen(node.ID(), time.Now()) if err != nil { return false, fmt.Errorf("disconnecting node: %w", err) } - changed := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) + changed := s.primaryRoutes.SetRoutes(node.ID()) // TODO(kradalby): the returned change should be more nuanced allowing us to // send more directed updates. @@ -512,7 +512,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateU } // SSHPolicy returns the SSH access policy for a node. -func (s *State) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { +func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { return s.polMan.SSHPolicy(node) } @@ -522,7 +522,7 @@ func (s *State) Filter() ([]tailcfg.FilterRule, []matcher.Match) { } // NodeCanHaveTag checks if a node is allowed to have a specific tag. -func (s *State) NodeCanHaveTag(node *types.Node, tag string) bool { +func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool { return s.polMan.NodeCanHaveTag(node, tag) } @@ -761,7 +761,7 @@ func (s *State) updatePolicyManagerNodes() (bool, error) { return false, fmt.Errorf("listing nodes for policy update: %w", err) } - changed, err := s.polMan.SetNodes(nodes) + changed, err := s.polMan.SetNodes(nodes.ViewSlice()) if err != nil { return false, fmt.Errorf("updating policy manager nodes: %w", err) } diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index c4cc8a2e..69c298b9 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -1,3 +1,5 @@ +//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey + package types import ( diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index da185563..11383950 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -18,6 +18,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/views" ) var ( @@ -115,6 +116,15 @@ type Node struct { type Nodes []*Node +func (ns Nodes) ViewSlice() views.Slice[NodeView] { + vs := make([]NodeView, len(ns)) + for i, n := range ns { + vs[i] = n.View() + } + + return views.SliceOf(vs) +} + // GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node. func (node *Node) GivenNameHasBeenChanged() bool { return node.GivenName == util.ConvertWithFQDNRules(node.Hostname) @@ -582,3 +592,185 @@ func (node Node) DebugString() string { sb.WriteString("\n") return sb.String() } + +func (v NodeView) IPs() []netip.Addr { + if !v.Valid() { + return nil + } + return v.ж.IPs() +} + +func (v NodeView) InIPSet(set *netipx.IPSet) bool { + if !v.Valid() { + return false + } + return v.ж.InIPSet(set) +} + +func (v NodeView) CanAccess(matchers []matcher.Match, node2 NodeView) bool { + if !v.Valid() || !node2.Valid() { + return false + } + src := v.IPs() + allowedIPs := node2.IPs() + + for _, matcher := range matchers { + if !matcher.SrcsContainsIPs(src...) { + continue + } + + if matcher.DestsContainsIP(allowedIPs...) { + return true + } + + if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) { + return true + } + } + + return false +} + +func (v NodeView) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool { + if !v.Valid() { + return false + } + src := v.IPs() + + for _, matcher := range matchers { + if !matcher.SrcsContainsIPs(src...) { + continue + } + + if matcher.DestsOverlapsPrefixes(route) { + return true + } + } + + return false +} + +func (v NodeView) AnnouncedRoutes() []netip.Prefix { + if !v.Valid() { + return nil + } + return v.ж.AnnouncedRoutes() +} + +func (v NodeView) SubnetRoutes() []netip.Prefix { + if !v.Valid() { + return nil + } + return v.ж.SubnetRoutes() +} + +func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) { + if !v.Valid() { + return + } + v.ж.AppendToIPSet(build) +} + +func (v NodeView) RequestTagsSlice() views.Slice[string] { + if !v.Valid() || !v.Hostinfo().Valid() { + return views.Slice[string]{} + } + return v.Hostinfo().RequestTags() +} + +func (v NodeView) Tags() []string { + if !v.Valid() { + return nil + } + return v.ж.Tags() +} + +// IsTagged reports if a device is tagged +// and therefore should not be treated as a +// user owned device. +// Currently, this function only handles tags set +// via CLI ("forced tags" and preauthkeys) +func (v NodeView) IsTagged() bool { + if !v.Valid() { + return false + } + return v.ж.IsTagged() +} + +// IsExpired returns whether the node registration has expired. +func (v NodeView) IsExpired() bool { + if !v.Valid() { + return true + } + return v.ж.IsExpired() +} + +// IsEphemeral returns if the node is registered as an Ephemeral node. +// https://tailscale.com/kb/1111/ephemeral-nodes/ +func (v NodeView) IsEphemeral() bool { + if !v.Valid() { + return false + } + return v.ж.IsEphemeral() +} + +// PeerChangeFromMapRequest takes a MapRequest and compares it to the node +// to produce a PeerChange struct that can be used to updated the node and +// inform peers about smaller changes to the node. +func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { + if !v.Valid() { + return tailcfg.PeerChange{} + } + return v.ж.PeerChangeFromMapRequest(req) +} + +// GetFQDN returns the fully qualified domain name for the node. +func (v NodeView) GetFQDN(baseDomain string) (string, error) { + if !v.Valid() { + return "", fmt.Errorf("failed to create valid FQDN: node view is invalid") + } + return v.ж.GetFQDN(baseDomain) +} + +// ExitRoutes returns a list of both exit routes if the +// node has any exit routes enabled. +// If none are enabled, it will return nil. +func (v NodeView) ExitRoutes() []netip.Prefix { + if !v.Valid() { + return nil + } + return v.ж.ExitRoutes() +} + +// HasIP reports if a node has a given IP address. +func (v NodeView) HasIP(i netip.Addr) bool { + if !v.Valid() { + return false + } + return v.ж.HasIP(i) +} + +// HasTag reports if a node has a given tag. +func (v NodeView) HasTag(tag string) bool { + if !v.Valid() { + return false + } + return v.ж.HasTag(tag) +} + +// Prefixes returns the node IPs as netip.Prefix. +func (v NodeView) Prefixes() []netip.Prefix { + if !v.Valid() { + return nil + } + return v.ж.Prefixes() +} + +// IPsAsString returns the node IPs as strings. +func (v NodeView) IPsAsString() []string { + if !v.Valid() { + return nil + } + return v.ж.IPsAsString() +} + diff --git a/hscontrol/types/types_clone.go b/hscontrol/types/types_clone.go new file mode 100644 index 00000000..3f530dc9 --- /dev/null +++ b/hscontrol/types/types_clone.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by tailscale.com/cmd/cloner; DO NOT EDIT. + +package types + +import ( + "database/sql" + "net/netip" + "time" + + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +// Clone makes a deep copy of User. +// The result aliases no memory with the original. +func (src *User) Clone() *User { + if src == nil { + return nil + } + dst := new(User) + *dst = *src + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _UserCloneNeedsRegeneration = User(struct { + gorm.Model + Name string + DisplayName string + Email string + ProviderIdentifier sql.NullString + Provider string + ProfilePicURL string +}{}) + +// Clone makes a deep copy of Node. +// The result aliases no memory with the original. +func (src *Node) Clone() *Node { + if src == nil { + return nil + } + dst := new(Node) + *dst = *src + dst.Endpoints = append(src.Endpoints[:0:0], src.Endpoints...) + dst.Hostinfo = src.Hostinfo.Clone() + if dst.IPv4 != nil { + dst.IPv4 = ptr.To(*src.IPv4) + } + if dst.IPv6 != nil { + dst.IPv6 = ptr.To(*src.IPv6) + } + dst.ForcedTags = append(src.ForcedTags[:0:0], src.ForcedTags...) + if dst.AuthKeyID != nil { + dst.AuthKeyID = ptr.To(*src.AuthKeyID) + } + dst.AuthKey = src.AuthKey.Clone() + if dst.Expiry != nil { + dst.Expiry = ptr.To(*src.Expiry) + } + if dst.LastSeen != nil { + dst.LastSeen = ptr.To(*src.LastSeen) + } + dst.ApprovedRoutes = append(src.ApprovedRoutes[:0:0], src.ApprovedRoutes...) + if dst.DeletedAt != nil { + dst.DeletedAt = ptr.To(*src.DeletedAt) + } + if dst.IsOnline != nil { + dst.IsOnline = ptr.To(*src.IsOnline) + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _NodeCloneNeedsRegeneration = Node(struct { + ID NodeID + MachineKey key.MachinePublic + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + Endpoints []netip.AddrPort + Hostinfo *tailcfg.Hostinfo + IPv4 *netip.Addr + IPv6 *netip.Addr + Hostname string + GivenName string + UserID uint + User User + RegisterMethod string + ForcedTags []string + AuthKeyID *uint64 + AuthKey *PreAuthKey + Expiry *time.Time + LastSeen *time.Time + ApprovedRoutes []netip.Prefix + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time + IsOnline *bool +}{}) + +// Clone makes a deep copy of PreAuthKey. +// The result aliases no memory with the original. +func (src *PreAuthKey) Clone() *PreAuthKey { + if src == nil { + return nil + } + dst := new(PreAuthKey) + *dst = *src + dst.Tags = append(src.Tags[:0:0], src.Tags...) + if dst.CreatedAt != nil { + dst.CreatedAt = ptr.To(*src.CreatedAt) + } + if dst.Expiration != nil { + dst.Expiration = ptr.To(*src.Expiration) + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct { + ID uint64 + Key string + UserID uint + User User + Reusable bool + Ephemeral bool + Used bool + Tags []string + CreatedAt *time.Time + Expiration *time.Time +}{}) diff --git a/hscontrol/types/types_view.go b/hscontrol/types/types_view.go new file mode 100644 index 00000000..5c31eac8 --- /dev/null +++ b/hscontrol/types/types_view.go @@ -0,0 +1,270 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by tailscale/cmd/viewer; DO NOT EDIT. + +package types + +import ( + "database/sql" + "encoding/json" + "errors" + "net/netip" + "time" + + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=User,Node,PreAuthKey + +// View returns a read-only view of User. +func (p *User) View() UserView { + return UserView{ж: p} +} + +// UserView provides a read-only view over User. +// +// Its methods should only be called if `Valid()` returns true. +type UserView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *User +} + +// Valid reports whether v's underlying value is non-nil. +func (v UserView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v UserView) AsStruct() *User { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v UserView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *UserView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x User + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v UserView) Model() gorm.Model { return v.ж.Model } +func (v UserView) Name() string { return v.ж.Name } +func (v UserView) DisplayName() string { return v.ж.DisplayName } +func (v UserView) Email() string { return v.ж.Email } +func (v UserView) ProviderIdentifier() sql.NullString { return v.ж.ProviderIdentifier } +func (v UserView) Provider() string { return v.ж.Provider } +func (v UserView) ProfilePicURL() string { return v.ж.ProfilePicURL } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _UserViewNeedsRegeneration = User(struct { + gorm.Model + Name string + DisplayName string + Email string + ProviderIdentifier sql.NullString + Provider string + ProfilePicURL string +}{}) + +// View returns a read-only view of Node. +func (p *Node) View() NodeView { + return NodeView{ж: p} +} + +// NodeView provides a read-only view over Node. +// +// Its methods should only be called if `Valid()` returns true. +type NodeView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *Node +} + +// Valid reports whether v's underlying value is non-nil. +func (v NodeView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v NodeView) AsStruct() *Node { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v NodeView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *NodeView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x Node + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v NodeView) ID() NodeID { return v.ж.ID } +func (v NodeView) MachineKey() key.MachinePublic { return v.ж.MachineKey } +func (v NodeView) NodeKey() key.NodePublic { return v.ж.NodeKey } +func (v NodeView) DiscoKey() key.DiscoPublic { return v.ж.DiscoKey } +func (v NodeView) Endpoints() views.Slice[netip.AddrPort] { return views.SliceOf(v.ж.Endpoints) } +func (v NodeView) Hostinfo() tailcfg.HostinfoView { return v.ж.Hostinfo.View() } +func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv4) } + +func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) } + +func (v NodeView) Hostname() string { return v.ж.Hostname } +func (v NodeView) GivenName() string { return v.ж.GivenName } +func (v NodeView) UserID() uint { return v.ж.UserID } +func (v NodeView) User() User { return v.ж.User } +func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod } +func (v NodeView) ForcedTags() views.Slice[string] { return views.SliceOf(v.ж.ForcedTags) } +func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) } + +func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() } +func (v NodeView) Expiry() views.ValuePointer[time.Time] { return views.ValuePointerOf(v.ж.Expiry) } + +func (v NodeView) LastSeen() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.LastSeen) +} + +func (v NodeView) ApprovedRoutes() views.Slice[netip.Prefix] { + return views.SliceOf(v.ж.ApprovedRoutes) +} +func (v NodeView) CreatedAt() time.Time { return v.ж.CreatedAt } +func (v NodeView) UpdatedAt() time.Time { return v.ж.UpdatedAt } +func (v NodeView) DeletedAt() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.DeletedAt) +} + +func (v NodeView) IsOnline() views.ValuePointer[bool] { return views.ValuePointerOf(v.ж.IsOnline) } + +func (v NodeView) String() string { return v.ж.String() } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _NodeViewNeedsRegeneration = Node(struct { + ID NodeID + MachineKey key.MachinePublic + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + Endpoints []netip.AddrPort + Hostinfo *tailcfg.Hostinfo + IPv4 *netip.Addr + IPv6 *netip.Addr + Hostname string + GivenName string + UserID uint + User User + RegisterMethod string + ForcedTags []string + AuthKeyID *uint64 + AuthKey *PreAuthKey + Expiry *time.Time + LastSeen *time.Time + ApprovedRoutes []netip.Prefix + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time + IsOnline *bool +}{}) + +// View returns a read-only view of PreAuthKey. +func (p *PreAuthKey) View() PreAuthKeyView { + return PreAuthKeyView{ж: p} +} + +// PreAuthKeyView provides a read-only view over PreAuthKey. +// +// Its methods should only be called if `Valid()` returns true. +type PreAuthKeyView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *PreAuthKey +} + +// Valid reports whether v's underlying value is non-nil. +func (v PreAuthKeyView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v PreAuthKeyView) AsStruct() *PreAuthKey { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v PreAuthKeyView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x PreAuthKey + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } +func (v PreAuthKeyView) Key() string { return v.ж.Key } +func (v PreAuthKeyView) UserID() uint { return v.ж.UserID } +func (v PreAuthKeyView) User() User { return v.ж.User } +func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } +func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } +func (v PreAuthKeyView) Used() bool { return v.ж.Used } +func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } +func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.CreatedAt) +} + +func (v PreAuthKeyView) Expiration() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.Expiration) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct { + ID uint64 + Key string + UserID uint + User User + Reusable bool + Ephemeral bool + Used bool + Tags []string + CreatedAt *time.Time + Expiration *time.Time +}{})