diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 58c5705a..3c8141c7 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -70,6 +70,7 @@ jobs: - TestSubnetRouterMultiNetwork - TestSubnetRouterMultiNetworkExitNode - TestAutoApproveMultiNetwork + - TestSubnetRouteACLFiltering - TestHeadscale - TestTailscaleNodesJoiningHeadcale - TestSSHOneUserToAll diff --git a/CHANGELOG.md b/CHANGELOG.md index 48d11080..80e08c6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ will be approved. [#2422](https://github.com/juanfont/headscale/pull/2422) - Routes are now managed via the Node API [#2422](https://github.com/juanfont/headscale/pull/2422) +- Only routes accessible to the node will be sent to the node + [#2561](https://github.com/juanfont/headscale/pull/2561) #### Policy v2 diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 662e491c..d7deb0a5 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/fs" + "net/netip" "net/url" "os" "path" @@ -308,9 +309,15 @@ func (m *Mapper) PeerChangedResponse( resp.PeersChangedPatch = patches } + _, matchers := m.polMan.Filter() // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.primary, m.cfg) + tailnode, err := tailNode( + node, mapRequest.Version, m.polMan, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers) + }, + m.cfg) if err != nil { return nil, err } @@ -347,7 +354,7 @@ func (m *Mapper) marshalMapResponse( } if debugDumpMapResponsePath != "" { - data := map[string]interface{}{ + data := map[string]any{ "Messages": messages, "MapRequest": mapRequest, "MapResponse": resp, @@ -457,7 +464,13 @@ func (m *Mapper) baseWithConfigMapResponse( ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, m.polMan, m.primary, m.cfg) + _, matchers := m.polMan.Filter() + tailnode, err := tailNode( + node, capVer, m.polMan, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers) + }, + m.cfg) if err != nil { return nil, err } @@ -513,15 +526,10 @@ func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { return nodes, nil } -func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { - ret := make(types.Nodes, 0) - - for _, node := range nodes { - ret = append(ret, node) - } - - return ret -} +// routeFilterFunc is a function that takes a node ID and returns a list of +// netip.Prefixes that are allowed for that node. It is used to filter routes +// from the primary route manager to the node. +type routeFilterFunc func(id types.NodeID) []netip.Prefix // appendPeerChanges mutates a tailcfg.MapResponse with all the // necessary changes when peers have changed. @@ -546,14 +554,19 @@ func appendPeerChanges( // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. if len(filter) > 0 { - changed = policy.FilterNodesByACL(node, changed, matchers) + changed = policy.ReduceNodes(node, changed, matchers) } profiles := generateUserProfiles(node, changed) dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, polMan, primary, cfg) + tailPeers, err := tailNodes( + changed, capVer, polMan, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, primary.PrimaryRoutes(id), matchers) + }, + cfg) if err != nil { return err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 5d718b54..dfce60bb 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -348,6 +348,11 @@ func Test_fullMapResponse(t *testing.T) { "src": ["100.64.0.2"], "dst": ["user1@:*"], }, + { + "action": "accept", + "src": ["100.64.0.1"], + "dst": ["192.168.0.0/24:*"], + }, ], } `), @@ -380,6 +385,10 @@ func Test_fullMapResponse(t *testing.T) { {IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}, }, }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "192.168.0.0/24", Ports: tailcfg.PortRangeAny}}, + }, }, }, SSHPolicy: nil, diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 32905345..eae70e96 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -5,7 +5,6 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "github.com/samber/lo" "tailscale.com/net/tsaddr" @@ -16,7 +15,7 @@ func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, polMan policy.PolicyManager, - primary *routes.PrimaryRoutes, + primaryRouteFunc routeFilterFunc, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -26,7 +25,7 @@ func tailNodes( node, capVer, polMan, - primary, + primaryRouteFunc, cfg, ) if err != nil { @@ -44,7 +43,7 @@ func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, polMan policy.PolicyManager, - primary *routes.PrimaryRoutes, + primaryRouteFunc routeFilterFunc, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,7 +80,8 @@ func tailNode( } tags = lo.Uniq(append(tags, node.ForcedTags...)) - allowed := append(node.Prefixes(), primary.PrimaryRoutes(node.ID)...) + routes := primaryRouteFunc(node.ID) + allowed := append(node.Prefixes(), routes...) allowed = append(allowed, node.ExitRoutes()...) tsaddr.SortPrefixes(allowed) @@ -99,7 +99,7 @@ func tailNode( Machine: node.MachineKey, DiscoKey: node.DiscoKey, Addresses: addrs, - PrimaryRoutes: primary.PrimaryRoutes(node.ID), + PrimaryRoutes: routes, AllowedIPs: allowed, Endpoints: node.Endpoints, HomeDERP: derp, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 1c3c018f..cacc4930 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -219,7 +219,9 @@ func TestTailNode(t *testing.T) { tt.node, 0, polMan, - primary, + func(id types.NodeID) []netip.Prefix { + return primary.PrimaryRoutes(id) + }, cfg, ) @@ -266,14 +268,20 @@ func TestNodeExpiry(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { node := &types.Node{ + ID: 0, GivenName: "test", Expiry: tt.exp, } + polMan, err := policy.NewPolicyManager(nil, nil, nil) + require.NoError(t, err) + tn, err := tailNode( node, 0, - nil, // TODO(kradalby): removed in merge but error? - nil, + polMan, + func(id types.NodeID) []netip.Prefix { + return []netip.Prefix{} + }, &types.Config{}, ) if err != nil { diff --git a/hscontrol/notifier/notifier_test.go b/hscontrol/notifier/notifier_test.go index a7369740..9654cfc8 100644 --- a/hscontrol/notifier/notifier_test.go +++ b/hscontrol/notifier/notifier_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "slices" + "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -252,9 +254,7 @@ func TestBatcher(t *testing.T) { // Make the inner order stable for comparison. for _, u := range got { - sort.Slice(u.ChangeNodes, func(i, j int) bool { - return u.ChangeNodes[i] < u.ChangeNodes[j] - }) + slices.Sort(u.ChangeNodes) sort.Slice(u.ChangePatches, func(i, j int) bool { return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID }) @@ -301,11 +301,11 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) { // Start goroutines to cause a race wg.Add(concurrentAccessors) - for i := 0; i < concurrentAccessors; i++ { + for i := range concurrentAccessors { go func(routineID int) { defer wg.Done() - for j := 0; j < iterations; j++ { + for range iterations { // Simulate race by having some goroutines check IsLikelyConnected // while others add/remove the node if routineID%3 == 0 { diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index ec07d19c..d246d5e2 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -2,6 +2,7 @@ package matcher import ( "net/netip" + "strings" "slices" @@ -15,6 +16,21 @@ type Match struct { dests *netipx.IPSet } +func (m Match) DebugString() string { + var sb strings.Builder + + sb.WriteString("Match:\n") + sb.WriteString(" Sources:\n") + for _, prefix := range m.srcs.Prefixes() { + sb.WriteString(" " + prefix.String() + "\n") + } + sb.WriteString(" Destinations:\n") + for _, prefix := range m.dests.Prefixes() { + sb.WriteString(" " + prefix.String() + "\n") + } + return sb.String() +} + func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { matches := make([]Match, 0, len(rules)) for _, rule := range rules { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 0df1bcc4..b90d2efc 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -1,9 +1,10 @@ package policy import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" @@ -33,7 +34,7 @@ type PolicyManager interface { } // NewPolicyManager returns a new policy manager, the version is determined by -// the environment flag "HEADSCALE_EXPERIMENTAL_POLICY_V2". +// the environment flag "HEADSCALE_POLICY_V1". func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { var polMan PolicyManager var err error diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index d86de29b..5859a198 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -1,10 +1,11 @@ package policy import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "slices" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/samber/lo" @@ -12,8 +13,8 @@ import ( "tailscale.com/tailcfg" ) -// FilterNodesByACL returns the list of peers authorized to be accessed from a given node. -func FilterNodesByACL( +// ReduceNodes returns the list of peers authorized to be accessed from a given node. +func ReduceNodes( node *types.Node, nodes types.Nodes, matchers []matcher.Match, @@ -33,6 +34,23 @@ func FilterNodesByACL( return result } +// ReduceRoutes returns a reduced list of routes for a given node that it can access. +func ReduceRoutes( + node *types.Node, + routes []netip.Prefix, + matchers []matcher.Match, +) []netip.Prefix { + var result []netip.Prefix + + for _, route := range routes { + if node.CanAccessRoute(matchers, route) { + result = append(result, route) + } + } + + return result +} + // ReduceFilterRules takes a node and a set of rules and removes all rules and destinations // that are not relevant to that particular node. func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule { diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 5b3814a2..c1000334 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1,6 +1,7 @@ package policy import ( + "encoding/json" "fmt" "net/netip" "testing" @@ -16,6 +17,7 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/util/must" ) var ap = func(ipStr string) *netip.Addr { @@ -23,6 +25,11 @@ var ap = func(ipStr string) *netip.Addr { return &ip } +var p = func(prefStr string) netip.Prefix { + ip := netip.MustParsePrefix(prefStr) + return ip +} + // hsExitNodeDestForTest is the list of destination IP ranges that are allowed when // we use headscale "autogroup:internet". var hsExitNodeDestForTest = []tailcfg.NetPortRange{ @@ -762,6 +769,54 @@ func TestReduceFilterRules(t *testing.T) { }, }, }, + { + name: "2365-only-route-policy", + pol: ` +{ + "hosts": { + "router": "100.64.0.1/32", + "node": "100.64.0.2/32" + }, + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "router:8000" + ] + }, + { + "action": "accept", + "src": [ + "node" + ], + "dst": [ + "172.26.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[3], + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, + }, + ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, + }, + }, + want: []tailcfg.FilterRule{}, + }, } for _, tt := range tests { @@ -773,6 +828,7 @@ func TestReduceFilterRules(t *testing.T) { pm, err = pmf(users, append(tt.peers, tt.node)) require.NoError(t, err) got, _ := pm.Filter() + t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) got = ReduceFilterRules(tt.node, got) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -784,7 +840,7 @@ func TestReduceFilterRules(t *testing.T) { } } -func TestFilterNodesByACL(t *testing.T) { +func TestReduceNodes(t *testing.T) { type args struct { nodes types.Nodes rules []tailcfg.FilterRule @@ -1530,7 +1586,7 @@ func TestFilterNodesByACL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) - got := FilterNodesByACL( + got := ReduceNodes( tt.args.node, tt.args.nodes, matchers, @@ -1946,3 +2002,470 @@ func TestSSHPolicyRules(t *testing.T) { } } } +func TestReduceRoutes(t *testing.T) { + type args struct { + node *types.Node + routes []netip.Prefix + rules []tailcfg.FilterRule + } + tests := []struct { + name string + args args + want []netip.Prefix + }{ + { + name: "node-can-access-all-routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + }, + { + name: "node-can-access-specific-route", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + }, + { + name: "node-can-access-multiple-specific-routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, + {IP: "192.168.1.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "node-can-access-overlapping-routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/16"), // Overlaps with the first one + netip.MustParsePrefix("192.168.1.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/16"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/16"), + }, + }, + { + name: "node-with-no-matching-rules", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, // Different source IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + }, + want: nil, + }, + { + name: "node-with-both-ipv4-and-ipv6", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/64"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"fd7a:115c:a1e0::1"}, // IPv6 source + DstPorts: []tailcfg.NetPortRange{ + {IP: "2001:db8::/64"}, // IPv6 destination + }, + }, + { + SrcIPs: []string{"100.64.0.1"}, // IPv4 source + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, // IPv4 destination + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/64"), + }, + }, + { + name: "router-with-multiple-routes-and-node-with-specific-access", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // Node IP + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, // Any source + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1"}, // Router node + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // Node IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, // Only one subnet allowed + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + }, + }, + { + name: "node-with-access-to-one-subnet-and-partial-overlap", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.10.0/16"), // Overlaps with the first one + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, // Only specific subnet + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.10.0/16"), // With current implementation, this is included because it overlaps with the allowed subnet + }, + }, + { + name: "node-with-access-to-wildcard-subnet", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.0.0/16"}, // Broader subnet that includes all three + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + }, + { + name: "multiple-nodes-with-different-subnet-permissions", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, // Different node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.11.0/24"}, + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // Our node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, + }, + }, + { + SrcIPs: []string{"100.64.0.3"}, // Different node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.12.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + }, + }, + { + name: "exactly-matching-users-acl-example", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // node with IP 100.64.0.2 + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + // This represents the rule: action: accept, src: ["*"], dst: ["router:0"] + SrcIPs: []string{"*"}, // Any source + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1"}, // Router IP + }, + }, + { + // This represents the rule: action: accept, src: ["node"], dst: ["10.10.10.0/24:*"] + SrcIPs: []string{"100.64.0.2"}, // Node IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24", Ports: tailcfg.PortRangeAny}, // All ports on this subnet + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + }, + }, + { + name: "acl-all-source-nodes-can-access-router-only-node-can-access-10.10.10.0-24", + args: args{ + // When testing from router node's perspective + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), // router with IP 100.64.0.1 + User: types.User{Name: "router"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1"}, // Router can be accessed by all + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // Only node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, // Can access this subnet + }, + }, + // Add a rule for router to access its own routes + { + SrcIPs: []string{"100.64.0.1"}, // Router node + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, // Can access everything + }, + }, + }, + }, + // Router needs explicit rules to access routes + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + }, + { + name: "acl-specific-port-ranges-for-subnets", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // node + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, // node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24", Ports: tailcfg.PortRange{First: 22, Last: 22}}, // Only SSH + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.11.0/24", Ports: tailcfg.PortRange{First: 80, Last: 80}}, // Only HTTP + }, + }, + }, + }, + // Should get both subnets with specific port ranges + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + }, + }, + { + name: "acl-order-of-rules-and-rule-specificity", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // node + User: types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + // First rule allows all traffic + { + SrcIPs: []string{"*"}, // Any source + DstPorts: []tailcfg.NetPortRange{ + {IP: "*", Ports: tailcfg.PortRangeAny}, // Any destination and any port + }, + }, + // Second rule is more specific but should be overridden by the first rule + { + SrcIPs: []string{"100.64.0.2"}, // node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, + }, + }, + }, + }, + // Due to the first rule allowing all traffic, node should have access to all routes + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( + tt.args.node, + tt.args.routes, + matchers, + ) + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("ReduceRoutes() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index ec4b7737..4dec2bd4 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -152,6 +152,10 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { // Filter returns the current filter rules for the entire tailnet and the associated matchers. func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { + if pm == nil { + return nil, nil + } + pm.mu.Lock() defer pm.mu.Unlock() return pm.filter, pm.matchers @@ -159,6 +163,10 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { // SetUsers updates the users in the policy manager and updates the filter rules. func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { + if pm == nil { + return false, nil + } + pm.mu.Lock() defer pm.mu.Unlock() pm.users = users @@ -167,6 +175,10 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { // SetNodes updates the nodes in the policy manager and updates the filter rules. func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { + if pm == nil { + return false, nil + } + pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes @@ -238,6 +250,10 @@ func (pm *PolicyManager) Version() int { } func (pm *PolicyManager) DebugString() string { + if pm == nil { + return "PolicyManager is not setup" + } + var sb strings.Builder fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version()) @@ -281,6 +297,14 @@ func (pm *PolicyManager) DebugString() string { } } + sb.WriteString("\n\n") + sb.WriteString("Matchers:\n") + sb.WriteString("an internal structure used to filter nodes and routes\n") + for _, match := range pm.matchers { + sb.WriteString(match.DebugString()) + sb.WriteString("\n") + } + sb.WriteString("\n\n") sb.WriteString(pm.nodes.DebugString()) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 826867eb..2749237e 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -239,10 +239,8 @@ func (node *Node) Prefixes() []netip.Prefix { // node has any exit routes enabled. // If none are enabled, it will return nil. func (node *Node) ExitRoutes() []netip.Prefix { - for _, route := range node.SubnetRoutes() { - if tsaddr.IsExitRoute(route) { - return tsaddr.ExitRoutes() - } + if slices.ContainsFunc(node.SubnetRoutes(), tsaddr.IsExitRoute) { + return tsaddr.ExitRoutes() } return nil @@ -291,6 +289,22 @@ func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { return false } +func (node *Node) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool { + src := node.IPs() + + for _, matcher := range matchers { + if !matcher.SrcsContainsIPs(src...) { + continue + } + + if matcher.DestsOverlapsPrefixes(route) { + return true + } + } + + return false +} + func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { var found Nodes @@ -567,6 +581,7 @@ func (node Node) DebugString() string { fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags()) fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs()) fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes) + fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes()) sb.WriteString("\n") return sb.String() diff --git a/integration/control.go b/integration/control.go index 9dfe150c..22e7552b 100644 --- a/integration/control.go +++ b/integration/control.go @@ -21,6 +21,8 @@ type ControlServer interface { CreateUser(user string) (*v1.User, error) CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) ListNodes(users ...string) ([]*v1.Node, error) + NodesByUser() (map[string][]*v1.Node, error) + NodesByName() (map[string]*v1.Node, error) ListUsers() ([]*v1.User, error) MapUsers() (map[string]*v1.User, error) ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 27e18697..e6762cf0 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -819,6 +819,38 @@ func (t *HeadscaleInContainer) ListNodes( return ret, nil } +func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) { + nodes, err := t.ListNodes() + if err != nil { + return nil, err + } + + var userMap map[string][]*v1.Node + for _, node := range nodes { + if _, ok := userMap[node.User.Name]; !ok { + mak.Set(&userMap, node.User.Name, []*v1.Node{node}) + } else { + userMap[node.User.Name] = append(userMap[node.User.Name], node) + } + } + + return userMap, nil +} + +func (t *HeadscaleInContainer) NodesByName() (map[string]*v1.Node, error) { + nodes, err := t.ListNodes() + if err != nil { + return nil, err + } + + var nameMap map[string]*v1.Node + for _, node := range nodes { + mak.Set(&nameMap, node.GetName(), node) + } + + return nameMap, nil +} + // ListUsers returns a list of users from Headscale. func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { command := []string{"headscale", "users", "list", "--output", "json"} @@ -973,7 +1005,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( "headscale", "nodes", "approve-routes", "--output", "json", "--identifier", strconv.FormatUint(id, 10), - fmt.Sprintf("--routes=%q", strings.Join(util.PrefixesToString(routes), ",")), + fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")), } result, _, err := dockertestutil.ExecuteCommand( diff --git a/integration/route_test.go b/integration/route_test.go index e4b6239b..5a85f436 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1,6 +1,7 @@ package integration import ( + "encoding/json" "fmt" "net/netip" "sort" @@ -9,7 +10,7 @@ import ( "slices" - "github.com/google/go-cmp/cmp" + cmpdiff "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" @@ -23,6 +24,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/types/ipproto" "tailscale.com/types/views" + "tailscale.com/util/must" "tailscale.com/util/slicesx" "tailscale.com/wgengine/filter" ) @@ -940,7 +942,7 @@ func TestSubnetRouteACL(t *testing.T) { }, } - if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { + if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff) } @@ -990,7 +992,7 @@ func TestSubnetRouteACL(t *testing.T) { }, } - if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { + if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) } } @@ -1603,9 +1605,9 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } for _, tt := range tests { - for _, dbMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} { + for _, polMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} { for _, advertiseDuringUp := range []bool{false, true} { - name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, dbMode) + name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, polMode) t.Run(name, func(t *testing.T) { scenario, err := NewScenario(tt.spec) require.NoErrorf(t, err, "failed to create scenario: %s", err) @@ -1616,7 +1618,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), hsic.WithACLPolicy(tt.pol), - hsic.WithPolicyMode(dbMode), + hsic.WithPolicyMode(polMode), } tsOpts := []tsic.Option{ @@ -2007,7 +2009,7 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) }) - if diff := cmp.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { + if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { t.Fatalf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff) } } @@ -2018,3 +2020,193 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub require.Lenf(t, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes())) require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) } + +// TestSubnetRouteACLFiltering tests that a node can only access subnet routes +// that are explicitly allowed in the ACL. +func TestSubnetRouteACLFiltering(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + // Use router and node users for better clarity + routerUser := "router" + nodeUser := "node" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{routerUser, nodeUser}, + Networks: map[string][]string{ + "usernet1": {routerUser, nodeUser}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + // Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24) + aclPolicyStr := fmt.Sprintf(`{ + "hosts": { + "router": "100.64.0.1/32", + "node": "100.64.0.2/32" + }, + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "router:8000" + ] + }, + { + "action": "accept", + "src": [ + "node" + ], + "dst": [] + } + ] + }`) + + route, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + weburl := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("webservice: %s, %s", webip.String(), weburl) + + // Create ACL policy + aclPolicy := &policyv1.ACLPolicy{} + err = json.Unmarshal([]byte(aclPolicyStr), aclPolicy) + require.NoError(t, err) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithAcceptRoutes(), + }, hsic.WithTestName("routeaclfilter"), + hsic.WithACLPolicy(aclPolicy), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + // Sort clients by ID for consistent order + slices.SortFunc(allClients, func(a, b TailscaleClient) int { + return b.MustIPv4().Compare(a.MustIPv4()) + }) + + // Get the router and node clients + routerClient := allClients[0] + nodeClient := allClients[1] + + aclPolicy.Hosts = policyv1.Hosts{ + routerUser: must.Get(routerClient.MustIPv4().Prefix(32)), + nodeUser: must.Get(nodeClient.MustIPv4().Prefix(32)), + } + aclPolicy.ACLs[1].Destinations = []string{ + route.String() + ":*", + } + + require.NoError(t, headscale.SetPolicy(aclPolicy)) + + // Set up the subnet routes for the router + routes := []string{ + route.String(), // This should be accessible by the client + "10.10.11.0/24", // These should NOT be accessible + "10.10.12.0/24", + } + + routeArg := "--advertise-routes=" + routes[0] + "," + routes[1] + "," + routes[2] + command := []string{ + "tailscale", + "set", + routeArg, + } + + _, _, err = routerClient.Execute(command) + require.NoErrorf(t, err, "failed to advertise routes: %s", err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + // List nodes and verify the router has 3 available routes + nodes, err := headscale.NodesByUser() + require.NoError(t, err) + require.Len(t, nodes, 2) + + // Find the router node + routerNode := nodes[routerUser][0] + nodeNode := nodes[nodeUser][0] + + require.NotNil(t, routerNode, "Router node not found") + require.NotNil(t, nodeNode, "Client node not found") + + // Check that the router has 3 routes available but not approved yet + requireNodeRouteCount(t, routerNode, 3, 0, 0) + requireNodeRouteCount(t, nodeNode, 0, 0, 0) + + // Approve all routes for the router + _, err = headscale.ApproveRoutes( + routerNode.GetId(), + util.MustStringsToPrefixes(routerNode.GetAvailableRoutes()), + ) + require.NoError(t, err) + + // Give some time for the routes to propagate + time.Sleep(5 * time.Second) + + // List nodes and verify the router has 3 available routes + nodes, err = headscale.NodesByUser() + require.NoError(t, err) + require.Len(t, nodes, 2) + + // Find the router node + routerNode = nodes[routerUser][0] + + // Check that the router has 3 routes now approved and available + requireNodeRouteCount(t, routerNode, 3, 3, 3) + + // Now check the client node status + nodeStatus, err := nodeClient.Status() + require.NoError(t, err) + + routerStatus, err := routerClient.Status() + require.NoError(t, err) + + // Check that the node can see the subnet routes from the router + routerPeerStatus := nodeStatus.Peer[routerStatus.Self.PublicKey] + + // The node should only have 1 subnet route + requirePeerSubnetRoutes(t, routerPeerStatus, []netip.Prefix{*route}) + + result, err := nodeClient.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err := nodeClient.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, routerClient.MustIPv4()) + +}