From 603f3ad4902e11decba0c7ea9d156e96253b186e Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 21 Mar 2025 11:49:32 +0100 Subject: [PATCH] Multi network integration tests (#2464) --- .../workflows/test-integration-policyv2.yaml | 3 +- .github/workflows/test-integration.yaml | 3 +- hscontrol/mapper/mapper_test.go | 6 +- hscontrol/mapper/tail.go | 16 +- hscontrol/mapper/tail_test.go | 12 +- hscontrol/routes/primary.go | 31 +- hscontrol/routes/primary_test.go | 356 +++++--- hscontrol/types/node.go | 16 +- hscontrol/util/string.go | 9 + hscontrol/util/util.go | 127 +++ hscontrol/util/util_test.go | 192 ++++- integration/acl_test.go | 72 +- integration/auth_key_test.go | 48 +- integration/auth_oidc_test.go | 503 +++--------- integration/auth_web_flow_test.go | 183 +---- integration/cli_test.go | 194 ++--- integration/control.go | 1 - integration/derp_verify_endpoint_test.go | 13 +- integration/dns_test.go | 42 +- integration/dsic/dsic.go | 18 +- integration/embedded_derp_test.go | 231 +----- integration/general_test.go | 167 ++-- integration/hsic/hsic.go | 17 +- integration/route_test.go | 772 +++++++++++++++--- integration/scenario.go | 655 +++++++++++++-- integration/scenario_test.go | 39 +- integration/ssh_test.go | 14 +- integration/tailscale.go | 5 + integration/tsic/tsic.go | 89 +- 29 files changed, 2385 insertions(+), 1449 deletions(-) diff --git a/.github/workflows/test-integration-policyv2.yaml b/.github/workflows/test-integration-policyv2.yaml index 73015603..3959c67a 100644 --- a/.github/workflows/test-integration-policyv2.yaml +++ b/.github/workflows/test-integration-policyv2.yaml @@ -70,8 +70,9 @@ jobs: - TestAutoApprovedSubRoute2068 - TestSubnetRouteACL - TestEnablingExitRoutes + - TestSubnetRouterMultiNetwork + - TestSubnetRouterMultiNetworkExitNode - TestHeadscale - - TestCreateTailscale - TestTailscaleNodesJoiningHeadcale - TestSSHOneUserToAll - TestSSHMultipleUsersAllToAll diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 2898b4ba..ff20fbc3 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -70,8 +70,9 @@ jobs: - TestAutoApprovedSubRoute2068 - TestSubnetRouteACL - TestEnablingExitRoutes + - TestSubnetRouterMultiNetwork + - TestSubnetRouterMultiNetworkExitNode - TestHeadscale - - TestCreateTailscale - TestTailscaleNodesJoiningHeadcale - TestSSHOneUserToAll - TestSSHMultipleUsersAllToAll diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 0fc797a7..ced0c9f4 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -165,9 +165,13 @@ func Test_fullMapResponse(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.1/32"), tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("100.64.0.1/32"), + tsaddr.AllIPv6(), + }, + PrimaryRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), }, HomeDERP: 0, LegacyDERPString: "127.3.3.40:0", diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 9e3ff4cf..32905345 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -2,13 +2,13 @@ package mapper import ( "fmt" - "net/netip" "time" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "github.com/samber/lo" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -49,14 +49,6 @@ func tailNode( ) (*tailcfg.Node, error) { addrs := node.Prefixes() - allowedIPs := append( - []netip.Prefix{}, - addrs...) // we append the node own IP, as it is required by the clients - - for _, route := range node.SubnetRoutes() { - allowedIPs = append(allowedIPs, netip.Prefix(route)) - } - var derp int // TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077 @@ -89,6 +81,10 @@ func tailNode( } tags = lo.Uniq(append(tags, node.ForcedTags...)) + allowed := append(node.Prefixes(), primary.PrimaryRoutes(node.ID)...) + allowed = append(allowed, node.ExitRoutes()...) + tsaddr.SortPrefixes(allowed) + tNode := tailcfg.Node{ ID: tailcfg.NodeID(node.ID), // this is the actual ID StableID: node.ID.StableID(), @@ -104,7 +100,7 @@ func tailNode( DiscoKey: node.DiscoKey, Addresses: addrs, PrimaryRoutes: primary.PrimaryRoutes(node.ID), - AllowedIPs: allowedIPs, + AllowedIPs: allowed, Endpoints: node.Endpoints, HomeDERP: derp, LegacyDERPString: legacyDERP, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 919ea43c..9722df2e 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -67,8 +67,6 @@ func TestTailNode(t *testing.T) { want: &tailcfg.Node{ Name: "empty", StableID: "0", - Addresses: []netip.Prefix{}, - AllowedIPs: []netip.Prefix{}, HomeDERP: 0, LegacyDERPString: "127.3.3.40:0", Hostinfo: hiview(tailcfg.Hostinfo{}), @@ -139,9 +137,13 @@ func TestTailNode(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.1/32"), tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("100.64.0.1/32"), + tsaddr.AllIPv6(), + }, + PrimaryRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), }, HomeDERP: 0, LegacyDERPString: "127.3.3.40:0", @@ -156,10 +158,6 @@ func TestTailNode(t *testing.T) { Tags: []string{}, - PrimaryRoutes: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/24"), - }, - LastSeen: &lastSeen, MachineAuthorized: true, diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 317bf450..67eb8d1f 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" xmaps "golang.org/x/exp/maps" + "tailscale.com/net/tsaddr" "tailscale.com/util/set" ) @@ -74,18 +75,12 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { // If the current primary is not available, select a new one. for prefix, nodes := range allPrimaries { if node, ok := pr.primaries[prefix]; ok { - if len(nodes) < 2 { - delete(pr.primaries, prefix) - changed = true - continue - } - // If the current primary is still available, continue. if slices.Contains(nodes, node) { continue } } - if len(nodes) >= 2 { + if len(nodes) >= 1 { pr.primaries[prefix] = nodes[0] changed = true } @@ -107,12 +102,16 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { return changed } -func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bool { +// SetRoutes sets the routes for a given Node ID and recalculates the primary routes +// of the headscale. +// It returns true if there was a change in primary routes. +// All exit routes are ignored as they are not used in primary route context. +func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) bool { pr.mu.Lock() defer pr.mu.Unlock() // If no routes are being set, remove the node from the routes map. - if len(prefix) == 0 { + if len(prefixes) == 0 { if _, ok := pr.routes[node]; ok { delete(pr.routes, node) return pr.updatePrimaryLocked() @@ -121,12 +120,17 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bo return false } - if _, ok := pr.routes[node]; !ok { - pr.routes[node] = make(set.Set[netip.Prefix], len(prefix)) + rs := make(set.Set[netip.Prefix], len(prefixes)) + for _, prefix := range prefixes { + if !tsaddr.IsExitRoute(prefix) { + rs.Add(prefix) + } } - for _, p := range prefix { - pr.routes[node].Add(p) + if rs.Len() != 0 { + pr.routes[node] = rs + } else { + delete(pr.routes, node) } return pr.updatePrimaryLocked() @@ -153,6 +157,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } } + tsaddr.SortPrefixes(routes) return routes } diff --git a/hscontrol/routes/primary_test.go b/hscontrol/routes/primary_test.go index c58337c0..7a9767b2 100644 --- a/hscontrol/routes/primary_test.go +++ b/hscontrol/routes/primary_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "tailscale.com/util/set" ) // mp is a helper function that wraps netip.MustParsePrefix. @@ -17,20 +19,34 @@ func mp(prefix string) netip.Prefix { func TestPrimaryRoutes(t *testing.T) { tests := []struct { - name string - operations func(pr *PrimaryRoutes) bool - nodeID types.NodeID - expectedRoutes []netip.Prefix - expectedChange bool + name string + operations func(pr *PrimaryRoutes) bool + expectedRoutes map[types.NodeID]set.Set[netip.Prefix] + expectedPrimaries map[netip.Prefix]types.NodeID + expectedIsPrimary map[types.NodeID]bool + expectedChange bool + + // primaries is a map of prefixes to the node that is the primary for that prefix. + primaries map[netip.Prefix]types.NodeID + isPrimary map[types.NodeID]bool }{ { name: "single-node-registers-single-route", operations: func(pr *PrimaryRoutes) bool { return pr.SetRoutes(1, mp("192.168.1.0/24")) }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: true, }, { name: "multiple-nodes-register-different-routes", @@ -38,19 +54,45 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) return pr.SetRoutes(2, mp("192.168.2.0/24")) }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.2.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + mp("192.168.2.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + expectedChange: true, }, { name: "multiple-nodes-register-overlapping-routes", operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - return pr.SetRoutes(2, mp("192.168.1.0/24")) // true + pr.SetRoutes(1, mp("192.168.1.0/24")) // true + return pr.SetRoutes(2, mp("192.168.1.0/24")) // false }, - nodeID: 1, - expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, - expectedChange: true, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, }, { name: "node-deregisters-a-route", @@ -58,9 +100,10 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) return pr.SetRoutes(1) // Deregister by setting no routes }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, + expectedRoutes: nil, + expectedPrimaries: nil, + expectedIsPrimary: nil, + expectedChange: true, }, { name: "node-deregisters-one-of-multiple-routes", @@ -68,9 +111,18 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24")) return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.2.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.2.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: true, }, { name: "node-registers-and-deregisters-routes-in-sequence", @@ -80,18 +132,23 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1) // Deregister by setting no routes return pr.SetRoutes(1, mp("192.168.3.0/24")) }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, - }, - { - name: "no-change-in-primary-routes", - operations: func(pr *PrimaryRoutes) bool { - return pr.SetRoutes(1, mp("192.168.1.0/24")) + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.3.0/24"): {}, + }, + 2: { + mp("192.168.2.0/24"): {}, + }, }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.2.0/24"): 2, + mp("192.168.3.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + expectedChange: true, }, { name: "multiple-nodes-register-same-route", @@ -100,21 +157,24 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(2, mp("192.168.1.0/24")) // true return pr.SetRoutes(3, mp("192.168.1.0/24")) // false }, - nodeID: 1, - expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, - expectedChange: false, - }, - { - name: "register-multiple-routes-shift-primary-check-old-primary", - operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary - pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary - return pr.SetRoutes(1) // true, 2 primary + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: true, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, }, { name: "register-multiple-routes-shift-primary-check-primary", @@ -124,20 +184,20 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary return pr.SetRoutes(1) // true, 2 primary }, - nodeID: 2, - expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, - expectedChange: true, - }, - { - name: "register-multiple-routes-shift-primary-check-non-primary", - operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary - pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary - return pr.SetRoutes(1) // true, 2 primary + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 2: true, }, - nodeID: 3, - expectedRoutes: nil, expectedChange: true, }, { @@ -150,8 +210,17 @@ func TestPrimaryRoutes(t *testing.T) { return pr.SetRoutes(2) // true, no primary }, - nodeID: 2, - expectedRoutes: nil, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 3, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 3: true, + }, expectedChange: true, }, { @@ -165,9 +234,7 @@ func TestPrimaryRoutes(t *testing.T) { return pr.SetRoutes(3) // false, no primary }, - nodeID: 2, - expectedRoutes: nil, - expectedChange: false, + expectedChange: true, }, { name: "primary-route-map-is-cleared-up", @@ -179,8 +246,17 @@ func TestPrimaryRoutes(t *testing.T) { return pr.SetRoutes(2) // true, no primary }, - nodeID: 2, - expectedRoutes: nil, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 3, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 3: true, + }, expectedChange: true, }, { @@ -193,8 +269,23 @@ func TestPrimaryRoutes(t *testing.T) { return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary }, - nodeID: 2, - expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 2: true, + }, expectedChange: false, }, { @@ -207,8 +298,23 @@ func TestPrimaryRoutes(t *testing.T) { return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary }, - nodeID: 1, - expectedRoutes: nil, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 2: true, + }, expectedChange: false, }, { @@ -218,15 +324,30 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary pr.SetRoutes(1) // true, 2 primary - pr.SetRoutes(2) // true, no primary - pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 1 primary - pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary - pr.SetRoutes(1) // true, 2 primary + pr.SetRoutes(2) // true, 3 primary + pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 3 primary + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 3 primary + pr.SetRoutes(1) // true, 3 primary - return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary + return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 3 primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 3, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 3: true, }, - nodeID: 2, - expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, expectedChange: false, }, { @@ -235,16 +356,27 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24")) return pr.SetRoutes(2, mp("192.168.1.0/24")) }, - nodeID: 1, - expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, - expectedChange: true, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, }, { name: "deregister-non-existent-route", operations: func(pr *PrimaryRoutes) bool { return pr.SetRoutes(1) // Deregister by setting no routes }, - nodeID: 1, expectedRoutes: nil, expectedChange: false, }, @@ -253,17 +385,27 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { return pr.SetRoutes(1) }, - nodeID: 1, expectedRoutes: nil, expectedChange: false, }, { - name: "deregister-empty-prefix-list", + name: "exit-nodes", operations: func(pr *PrimaryRoutes) bool { - return pr.SetRoutes(1) + pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) + pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) + return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("10.0.0.0/16"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("10.0.0.0/16"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, }, - nodeID: 1, - expectedRoutes: nil, expectedChange: false, }, { @@ -284,19 +426,23 @@ func TestPrimaryRoutes(t *testing.T) { return change1 || change2 }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, - }, - { - name: "no-routes-registered", - operations: func(pr *PrimaryRoutes) bool { - // No operations - return false + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.2.0/24"): {}, + }, }, - nodeID: 1, - expectedRoutes: nil, - expectedChange: false, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + mp("192.168.2.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + expectedChange: true, }, } @@ -307,9 +453,15 @@ func TestPrimaryRoutes(t *testing.T) { if change != tt.expectedChange { t.Errorf("change = %v, want %v", change, tt.expectedChange) } - routes := pr.PrimaryRoutes(tt.nodeID) - if diff := cmp.Diff(tt.expectedRoutes, routes, util.Comparers...); diff != "" { - t.Errorf("PrimaryRoutes() mismatch (-want +got):\n%s", diff) + comps := append(util.Comparers, cmpopts.EquateEmpty()) + if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { + t.Errorf("routes mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { + t.Errorf("primaries mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { + t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) } }) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index e506a2c5..767ccdff 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -14,6 +14,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -213,7 +214,7 @@ func (node *Node) RequestTags() []string { } func (node *Node) Prefixes() []netip.Prefix { - addrs := []netip.Prefix{} + var addrs []netip.Prefix for _, nodeAddress := range node.IPs() { ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) addrs = append(addrs, ip) @@ -222,6 +223,19 @@ func (node *Node) Prefixes() []netip.Prefix { return addrs } +// ExitRoutes returns a list of both exit routes if the +// node has any exit routes enabled. +// If none are enabled, it will return nil. +func (node *Node) ExitRoutes() []netip.Prefix { + for _, route := range node.SubnetRoutes() { + if tsaddr.IsExitRoute(route) { + return tsaddr.ExitRoutes() + } + } + + return nil +} + func (node *Node) IPsAsString() []string { var ret []string diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index a9e7ca96..624d8bc0 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -57,6 +57,15 @@ func GenerateRandomStringDNSSafe(size int) (string, error) { return str[:size], nil } +func MustGenerateRandomStringDNSSafe(size int) string { + hash, err := GenerateRandomStringDNSSafe(size) + if err != nil { + panic(err) + } + + return hash +} + func TailNodesToString(nodes []*tailcfg.Node) string { temp := make([]string, len(nodes)) diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 569af354..a41ee6f8 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -3,8 +3,12 @@ package util import ( "errors" "fmt" + "net/netip" "net/url" + "regexp" + "strconv" "strings" + "time" "tailscale.com/util/cmpver" ) @@ -46,3 +50,126 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { return loginURL, nil } + +type TraceroutePath struct { + // Hop is the current jump in the total traceroute. + Hop int + + // Hostname is the resolved hostname or IP address identifying the jump + Hostname string + + // IP is the IP address of the jump + IP netip.Addr + + // Latencies is a list of the latencies for this jump + Latencies []time.Duration +} + +type Traceroute struct { + // Hostname is the resolved hostname or IP address identifying the target + Hostname string + + // IP is the IP address of the target + IP netip.Addr + + // Route is the path taken to reach the target if successful. The list is ordered by the path taken. + Route []TraceroutePath + + // Success indicates if the traceroute was successful. + Success bool + + // Err contains an error if the traceroute was not successful. + Err error +} + +// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct +func ParseTraceroute(output string) (Traceroute, error) { + lines := strings.Split(strings.TrimSpace(output), "\n") + if len(lines) < 1 { + return Traceroute{}, errors.New("empty traceroute output") + } + + // Parse the header line + headerRegex := regexp.MustCompile(`traceroute to ([^ ]+) \(([^)]+)\)`) + headerMatches := headerRegex.FindStringSubmatch(lines[0]) + if len(headerMatches) != 3 { + return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) + } + + hostname := headerMatches[1] + ipStr := headerMatches[2] + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) + } + + result := Traceroute{ + Hostname: hostname, + IP: ip, + Route: []TraceroutePath{}, + Success: false, + } + + // Parse each hop line + hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) + + for i := 1; i < len(lines); i++ { + matches := hopRegex.FindStringSubmatch(lines[i]) + if len(matches) == 0 { + continue + } + + hop, err := strconv.Atoi(matches[1]) + if err != nil { + return Traceroute{}, fmt.Errorf("parsing hop number: %w", err) + } + + var hopHostname string + var hopIP netip.Addr + var latencies []time.Duration + + // Handle hostname and IP + if matches[2] != "" && matches[3] != "" { + hopHostname = matches[2] + hopIP, err = netip.ParseAddr(matches[3]) + if err != nil { + return Traceroute{}, fmt.Errorf("parsing hop IP address %s: %w", matches[3], err) + } + } else if matches[4] == "*" { + hopHostname = "*" + // No IP for timeouts + } + + // Parse latencies + for j := 5; j <= 7; j++ { + if matches[j] != "" { + ms, err := strconv.ParseFloat(matches[j], 64) + if err != nil { + return Traceroute{}, fmt.Errorf("parsing latency: %w", err) + } + latencies = append(latencies, time.Duration(ms*float64(time.Millisecond))) + } + } + + path := TraceroutePath{ + Hop: hop, + Hostname: hopHostname, + IP: hopIP, + Latencies: latencies, + } + + result.Route = append(result.Route, path) + + // Check if we've reached the target + if hopIP == ip { + result.Success = true + } + } + + // If we didn't reach the target, it's unsuccessful + if !result.Success { + result.Err = errors.New("traceroute did not reach target") + } + + return result, nil +} diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 1e331fe2..b1a18610 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1,6 +1,13 @@ package util -import "testing" +import ( + "errors" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) func TestTailscaleVersionNewerOrEqual(t *testing.T) { type args struct { @@ -178,3 +185,186 @@ Success.`, }) } } + +func TestParseTraceroute(t *testing.T) { + tests := []struct { + name string + input string + want Traceroute + wantErr bool + }{ + { + name: "simple successful traceroute", + input: `traceroute to 172.24.0.3 (172.24.0.3), 30 hops max, 46 byte packets + 1 ts-head-hk0urr.headscale.net (100.64.0.1) 1.135 ms 0.922 ms 0.619 ms + 2 172.24.0.3 (172.24.0.3) 0.593 ms 0.549 ms 0.522 ms`, + want: Traceroute{ + Hostname: "172.24.0.3", + IP: netip.MustParseAddr("172.24.0.3"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "ts-head-hk0urr.headscale.net", + IP: netip.MustParseAddr("100.64.0.1"), + Latencies: []time.Duration{ + 1135 * time.Microsecond, + 922 * time.Microsecond, + 619 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "172.24.0.3", + IP: netip.MustParseAddr("172.24.0.3"), + Latencies: []time.Duration{ + 593 * time.Microsecond, + 549 * time.Microsecond, + 522 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "traceroute with timeouts", + input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets + 1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms + 2 * * * + 3 isp-gateway.net (10.0.0.1) 15.678 ms 14.789 ms 15.432 ms + 4 8.8.8.8 (8.8.8.8) 20.123 ms 19.876 ms 20.345 ms`, + want: Traceroute{ + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "router.local", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + 1121 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "*", + }, + { + Hop: 3, + Hostname: "isp-gateway.net", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 15678 * time.Microsecond, + 14789 * time.Microsecond, + 15432 * time.Microsecond, + }, + }, + { + Hop: 4, + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Latencies: []time.Duration{ + 20123 * time.Microsecond, + 19876 * time.Microsecond, + 20345 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "unsuccessful traceroute", + input: `traceroute to 10.0.0.99 (10.0.0.99), 5 hops max, 60 byte packets + 1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms + 2 * * * + 3 * * * + 4 * * * + 5 * * *`, + want: Traceroute{ + Hostname: "10.0.0.99", + IP: netip.MustParseAddr("10.0.0.99"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "router.local", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + 1121 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "*", + }, + { + Hop: 3, + Hostname: "*", + }, + { + Hop: 4, + Hostname: "*", + }, + { + Hop: 5, + Hostname: "*", + }, + }, + Success: false, + Err: errors.New("traceroute did not reach target"), + }, + wantErr: false, + }, + { + name: "empty input", + input: "", + want: Traceroute{}, + wantErr: true, + }, + { + name: "invalid header", + input: "not a valid traceroute output", + want: Traceroute{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseTraceroute(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseTraceroute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + // Special handling for error field since it can't be directly compared with cmp.Diff + gotErr := got.Err + wantErr := tt.want.Err + got.Err = nil + tt.want.Err = nil + + if diff := cmp.Diff(tt.want, got, IPComparer); diff != "" { + t.Errorf("ParseTraceroute() mismatch (-want +got):\n%s", diff) + } + + // Now check error field separately + if (gotErr == nil) != (wantErr == nil) { + t.Errorf("Error field: got %v, want %v", gotErr, wantErr) + } else if gotErr != nil && wantErr != nil && gotErr.Error() != wantErr.Error() { + t.Errorf("Error message: got %q, want %q", gotErr.Error(), wantErr.Error()) + } + }) + } +} diff --git a/integration/acl_test.go b/integration/acl_test.go index fefd75c0..d1bf0342 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -54,15 +54,16 @@ func aclScenario( clientsPerUser int, ) *Scenario { t.Helper() - scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) - spec := map[string]int{ - "user1": clientsPerUser, - "user2": clientsPerUser, + spec := ScenarioSpec{ + NodesPerUser: clientsPerUser, + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, + scenario, err := NewScenario(spec) + require.NoError(t, err) + + err = scenario.CreateHeadscaleEnv( []tsic.Option{ // Alpine containers dont have ip6tables set up, which causes // tailscaled to stop configuring the wgengine, causing it @@ -96,22 +97,24 @@ func aclScenario( func TestACLHostsInNetMapTable(t *testing.T) { IntegrationSkip(t) + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1", "user2"}, + } + // NOTE: All want cases currently checks the // total count of expected peers, this would // typically be the client count of the users // they can access minus one (them self). tests := map[string]struct { - users map[string]int + users ScenarioSpec policy policyv1.ACLPolicy want map[string]int }{ // Test that when we have no ACL, each client netmap has // the amount of peers of the total amount of clients "base-acls": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, + users: spec, policy: policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -129,10 +132,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { // each other, each node has only the number of pairs from // their own user. "two-isolated-users": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, + users: spec, policy: policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -155,10 +155,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { // are restricted to a single port, nodes are still present // in the netmap. "two-restricted-present-in-netmap": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, + users: spec, policy: policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -192,10 +189,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { // of peers. This will still result in all the peers as we // need them present on the other side for the "return path". "two-ns-one-isolated": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, + users: spec, policy: policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -220,10 +214,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { }, }, "very-large-destination-prefix-1372": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, + users: spec, policy: policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -248,10 +239,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { }, }, "ipv6-acls-1470": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, + users: spec, policy: policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -269,12 +257,11 @@ func TestACLHostsInNetMapTable(t *testing.T) { for name, testCase := range tests { t.Run(name, func(t *testing.T) { - scenario, err := NewScenario(dockertestMaxWait()) + caseSpec := testCase.users + scenario, err := NewScenario(caseSpec) require.NoError(t, err) - spec := testCase.users - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithACLPolicy(&testCase.policy), ) @@ -944,6 +931,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { for name, testCase := range tests { t.Run(name, func(t *testing.T) { scenario := aclScenario(t, &testCase.policy, 1) + defer scenario.ShutdownAssertNoPanics(t) test1ip := netip.MustParseAddr("100.64.0.1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") @@ -1022,16 +1010,16 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 1, - "user2": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{ // Alpine containers dont have ip6tables set up, which causes // tailscaled to stop configuring the wgengine, causing it diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index a2bda02a..9d219fca 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -19,15 +19,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { for _, https := range []bool{true, false} { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - opts := []hsic.Option{hsic.WithTestName("pingallbyip")} if https { opts = append(opts, []hsic.Option{ @@ -35,7 +35,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { }...) } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -84,7 +84,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { time.Sleep(5 * time.Minute) } - for userName := range spec { + for _, userName := range spec.Users { key, err := scenario.CreatePreAuthKey(userName, true, false) if err != nil { t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) @@ -152,16 +152,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("keyrelognewuser"), hsic.WithTLS(), ) @@ -203,7 +203,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // Log in all clients as user1, iterating over the spec only returns the // clients, not the usernames. - for userName := range spec { + for _, userName := range spec.Users { err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) if err != nil { t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) @@ -235,15 +235,15 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { for _, https := range []bool{true, false} { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - opts := []hsic.Option{hsic.WithTestName("pingallbyip")} if https { opts = append(opts, []hsic.Option{ @@ -251,7 +251,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { }...) } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -300,7 +300,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { time.Sleep(5 * time.Minute) } - for userName := range spec { + for _, userName := range spec.Users { key, err := scenario.CreatePreAuthKey(userName, true, false) if err != nil { t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index a76220d8..c86138a8 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1,93 +1,58 @@ package integration import ( - "context" - "crypto/tls" - "encoding/json" - "errors" "fmt" - "io" - "log" - "net" - "net/http" - "net/http/cookiejar" "net/netip" - "net/url" "sort" - "strconv" "testing" "time" + "maps" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" "github.com/oauth2-proxy/mockoidc" - "github.com/ory/dockertest/v3" - "github.com/ory/dockertest/v3/docker" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) -const ( - dockerContextPath = "../." - hsicOIDCMockHashLength = 6 - defaultAccessTTL = 10 * time.Minute -) - -var errStatusCodeNotOK = errors.New("status code not OK") - -type AuthOIDCScenario struct { - *Scenario - - mockOIDC *dockertest.Resource -} - func TestOIDCAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, - } - defer scenario.ShutdownAssertNoPanics(t) - // Logins to MockOIDC is served by a queue with a strict order, // if we use more than one node per user, the order of the logins // will not be deterministic and the test will fail. - spec := map[string]int{ - "user1": 1, - "user2": 1, + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }, } - mockusers := []mockoidc.MockUser{ - oidcMockUser("user1", true), - oidcMockUser("user2", false), - } + scenario, err := NewScenario(spec) + assertNoErr(t, err) - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) - defer scenario.mockOIDC.Close() + defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } - err = scenario.CreateHeadscaleEnv( - spec, + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, hsic.WithTestName("oidcauthping"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) assertNoErrHeadscaleEnv(t, err) @@ -126,7 +91,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { Name: "user1", Email: "user1@headscale.net", Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user1", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", }, { Id: 3, @@ -138,7 +103,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { Name: "user2", Email: "", // Unverified Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user2", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", }, } @@ -158,37 +123,29 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { shortAccessTTL := 5 * time.Minute - baseScenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - - baseScenario.pool.MaxWait = 5 * time.Minute - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }, + OIDCAccessTTL: shortAccessTTL, } + + scenario, err := NewScenario(spec) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 1, - "user2": 1, - } - - oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ - oidcMockUser("user1", true), - oidcMockUser("user2", false), - }) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) - defer scenario.mockOIDC.Close() - oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, - "HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret, + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "HEADSCALE_OIDC_CLIENT_SECRET": scenario.mockOIDC.ClientSecret(), "HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", } - err = scenario.CreateHeadscaleEnv( - spec, + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, hsic.WithTestName("oidcexpirenodes"), hsic.WithConfigEnv(oidcMap), ) @@ -334,45 +291,35 @@ func TestOIDC024UserCreation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - baseScenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, + spec := ScenarioSpec{ + NodesPerUser: 1, } + for _, user := range tt.cliUsers { + spec.Users = append(spec.Users, user) + } + + for _, user := range tt.oidcUsers { + spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified)) + } + + scenario, err := NewScenario(spec) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{} - for _, user := range tt.cliUsers { - spec[user] = 1 - } - - var mockusers []mockoidc.MockUser - for _, user := range tt.oidcUsers { - mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) - } - - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) - defer scenario.mockOIDC.Close() - oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } + maps.Copy(oidcMap, tt.config) - for k, v := range tt.config { - oidcMap[k] = v - } - - err = scenario.CreateHeadscaleEnv( - spec, + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, hsic.WithTestName("oidcmigration"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) assertNoErrHeadscaleEnv(t, err) @@ -384,7 +331,7 @@ func TestOIDC024UserCreation(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - want := tt.want(oidcConfig.Issuer) + want := tt.want(scenario.mockOIDC.Issuer()) listUsers, err := headscale.ListUsers() assertNoErr(t, err) @@ -404,41 +351,33 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, + // Single user with one node for testing PKCE flow + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, } + + scenario, err := NewScenario(spec) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - // Single user with one node for testing PKCE flow - spec := map[string]int{ - "user1": 1, - } - - mockusers := []mockoidc.MockUser{ - oidcMockUser("user1", true), - } - - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) - defer scenario.mockOIDC.Close() - oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE } - err = scenario.CreateHeadscaleEnv( - spec, + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, hsic.WithTestName("oidcauthpkce"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) assertNoErrHeadscaleEnv(t, err) @@ -464,43 +403,33 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario(dockertestMaxWait()) + // Create no nodes and no users + scenario, err := NewScenario(ScenarioSpec{ + // First login creates the first OIDC user + // Second login logs in the same node, which creates a new node + // Third login logs in the same node back into the original user + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", true), + oidcMockUser("user1", true), + }, + }) assertNoErr(t, err) - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, - } defer scenario.ShutdownAssertNoPanics(t) - // Create no nodes and no users - spec := map[string]int{} - - // First login creates the first OIDC user - // Second login logs in the same node, which creates a new node - // Third login logs in the same node back into the original user - mockusers := []mockoidc.MockUser{ - oidcMockUser("user1", true), - oidcMockUser("user2", true), - oidcMockUser("user1", true), - } - - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) - // defer scenario.mockOIDC.Close() - oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } - err = scenario.CreateHeadscaleEnv( - spec, + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, hsic.WithTestName("oidcauthrelog"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), hsic.WithEmbeddedDERPServerOnly(), ) assertNoErrHeadscaleEnv(t, err) @@ -512,7 +441,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { assertNoErr(t, err) assert.Len(t, listUsers, 0) - ts, err := scenario.CreateTailscaleNode("unstable") + ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) assertNoErr(t, err) u, err := ts.LoginWithURL(headscale.GetEndpoint()) @@ -530,7 +459,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { Name: "user1", Email: "user1@headscale.net", Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user1", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", }, } @@ -575,14 +504,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { Name: "user1", Email: "user1@headscale.net", Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user1", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", }, { Id: 2, Name: "user2", Email: "user2@headscale.net", Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user2", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", }, } @@ -632,14 +561,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { Name: "user1", Email: "user1@headscale.net", Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user1", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", }, { Id: 2, Name: "user2", Email: "user2@headscale.net", Provider: "oidc", - ProviderId: oidcConfig.Issuer + "/user2", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", }, } @@ -678,254 +607,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) } -func (s *AuthOIDCScenario) CreateHeadscaleEnv( - users map[string]int, - opts ...hsic.Option, -) error { - headscale, err := s.Headscale(opts...) - if err != nil { - return err - } - - err = headscale.WaitForRunning() - if err != nil { - return err - } - - for userName, clientCount := range users { - if clientCount != 1 { - // OIDC scenario only supports one client per user. - // This is because the MockOIDC server can only serve login - // requests based on a queue it has been given on startup. - // We currently only populates it with one login request per user. - return fmt.Errorf("client count must be 1 for OIDC scenario.") - } - log.Printf("creating user %s with %d clients", userName, clientCount) - err = s.CreateUser(userName) - if err != nil { - return err - } - - err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) - if err != nil { - return err - } - - err = s.runTailscaleUp(userName, headscale.GetEndpoint()) - if err != nil { - return err - } - } - - return nil -} - -func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { - port, err := dockertestutil.RandomFreeHostPort() - if err != nil { - log.Fatalf("could not find an open port: %s", err) - } - portNotation := fmt.Sprintf("%d/tcp", port) - - hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) - - hostname := fmt.Sprintf("hs-oidcmock-%s", hash) - - usersJSON, err := json.Marshal(users) - if err != nil { - return nil, err - } - - mockOidcOptions := &dockertest.RunOptions{ - Name: hostname, - Cmd: []string{"headscale", "mockoidc"}, - ExposedPorts: []string{portNotation}, - PortBindings: map[docker.Port][]docker.PortBinding{ - docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, - }, - Networks: []*dockertest.Network{s.Scenario.network}, - Env: []string{ - fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), - fmt.Sprintf("MOCKOIDC_PORT=%d", port), - "MOCKOIDC_CLIENT_ID=superclient", - "MOCKOIDC_CLIENT_SECRET=supersecret", - fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), - fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), - }, - } - - headscaleBuildOptions := &dockertest.BuildOptions{ - Dockerfile: hsic.IntegrationTestDockerFileName, - ContextDir: dockerContextPath, - } - - err = s.pool.RemoveContainerByName(hostname) - if err != nil { - return nil, err - } - - if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( - headscaleBuildOptions, - mockOidcOptions, - dockertestutil.DockerRestartPolicy); err == nil { - s.mockOIDC = pmockoidc - } else { - return nil, err - } - - log.Println("Waiting for headscale mock oidc to be ready for tests") - hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port) - - if err := s.pool.Retry(func() error { - oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) - httpClient := &http.Client{} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf("headscale mock OIDC tests is not ready: %s\n", err) - - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errStatusCodeNotOK - } - - return nil - }); err != nil { - return nil, err - } - - log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) - - return &types.OIDCConfig{ - Issuer: fmt.Sprintf( - "http://%s/oidc", - net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)), - ), - ClientID: "superclient", - ClientSecret: "supersecret", - OnlyStartIfOIDCIsAvailable: true, - }, nil -} - -type LoggingRoundTripper struct{} - -func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - noTls := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint - } - resp, err := noTls.RoundTrip(req) - if err != nil { - return nil, err - } - - log.Printf("---") - log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) - log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) - - return resp, nil -} - -func (s *AuthOIDCScenario) runTailscaleUp( - userStr, loginServer string, -) error { - log.Printf("running tailscale up for user %s", userStr) - if user, ok := s.users[userStr]; ok { - for _, client := range user.Clients { - tsc := client - user.joinWaitGroup.Go(func() error { - loginURL, err := tsc.LoginWithURL(loginServer) - if err != nil { - log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) - } - - _, err = doLoginURL(tsc.Hostname(), loginURL) - if err != nil { - return err - } - - return nil - }) - - log.Printf("client %s is ready", client.Hostname()) - } - - if err := user.joinWaitGroup.Wait(); err != nil { - return err - } - - for _, client := range user.Clients { - err := client.WaitForRunning() - if err != nil { - return fmt.Errorf( - "%s tailscale node has not reached running: %w", - client.Hostname(), - err, - ) - } - } - - return nil - } - - return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) -} - -// doLoginURL visits the given login URL and returns the body as a -// string. -func doLoginURL(hostname string, loginURL *url.URL) (string, error) { - log.Printf("%s login url: %s\n", hostname, loginURL.String()) - - var err error - hc := &http.Client{ - Transport: LoggingRoundTripper{}, - } - hc.Jar, err = cookiejar.New(nil) - if err != nil { - return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err) - } - - log.Printf("%s logging in with url", hostname) - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := hc.Do(req) - if err != nil { - return "", fmt.Errorf("%s failed to send http request: %w", hostname, err) - } - - log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - log.Printf("body: %s", body) - - return "", fmt.Errorf("%s response code of login request was %w", hostname, err) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("%s failed to read response body: %s", hostname, err) - - return "", fmt.Errorf("%s failed to read response body: %w", hostname, err) - } - - return string(body), nil -} - -func (s *AuthOIDCScenario) Shutdown() { - err := s.pool.Purge(s.mockOIDC) - if err != nil { - log.Printf("failed to remove mock oidc container") - } - - s.Scenario.Shutdown() -} - func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { t.Helper() diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index acc96cec..034ad5ae 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,47 +1,33 @@ package integration import ( - "errors" - "fmt" - "log" "net/netip" - "net/url" - "strings" "testing" + "slices" + "github.com/juanfont/headscale/integration/hsic" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var errParseAuthPage = errors.New("failed to parse auth page") - -type AuthWebFlowScenario struct { - *Scenario -} - func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() - baseScenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) if err != nil { t.Fatalf("failed to create scenario: %s", err) } - - scenario := AuthWebFlowScenario{ - Scenario: baseScenario, - } defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - err = scenario.CreateHeadscaleEnv( - spec, + nil, hsic.WithTestName("webauthping"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), @@ -71,20 +57,17 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - - scenario := AuthWebFlowScenario{ - Scenario: baseScenario, + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, } + + scenario, err := NewScenario(spec) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( + nil, hsic.WithTestName("weblogout"), hsic.WithTLS(), ) @@ -137,8 +120,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients logged out") - for userName := range spec { - err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) + for _, userName := range spec.Users { + err = scenario.RunTailscaleUpWithURL(userName, headscale.GetEndpoint()) if err != nil { t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err) } @@ -172,14 +155,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } for _, ip := range ips { - found := false - for _, oldIP := range clientIPs[client] { - if ip == oldIP { - found = true - - break - } - } + found := slices.Contains(clientIPs[client], ip) if !found { t.Fatalf( @@ -194,122 +170,3 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients IPs are the same") } - -func (s *AuthWebFlowScenario) CreateHeadscaleEnv( - users map[string]int, - opts ...hsic.Option, -) error { - headscale, err := s.Headscale(opts...) - if err != nil { - return err - } - - err = headscale.WaitForRunning() - if err != nil { - return err - } - - for userName, clientCount := range users { - log.Printf("creating user %s with %d clients", userName, clientCount) - err = s.CreateUser(userName) - if err != nil { - return err - } - - err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) - if err != nil { - return err - } - - err = s.runTailscaleUp(userName, headscale.GetEndpoint()) - if err != nil { - return err - } - } - - return nil -} - -func (s *AuthWebFlowScenario) runTailscaleUp( - userStr, loginServer string, -) error { - log.Printf("running tailscale up for user %q", userStr) - if user, ok := s.users[userStr]; ok { - for _, client := range user.Clients { - c := client - user.joinWaitGroup.Go(func() error { - log.Printf("logging %q into %q", c.Hostname(), loginServer) - loginURL, err := c.LoginWithURL(loginServer) - if err != nil { - log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err) - - return err - } - - err = s.runHeadscaleRegister(userStr, loginURL) - if err != nil { - log.Printf("failed to register client (%s): %s", c.Hostname(), err) - - return err - } - - return nil - }) - - err := client.WaitForRunning() - if err != nil { - log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err) - } - } - - if err := user.joinWaitGroup.Wait(); err != nil { - return err - } - - for _, client := range user.Clients { - err := client.WaitForRunning() - if err != nil { - return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err) - } - } - - return nil - } - - return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) -} - -func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error { - body, err := doLoginURL("web-auth-not-set", loginURL) - if err != nil { - return err - } - - // see api.go HTML template - codeSep := strings.Split(string(body), "") - if len(codeSep) != 2 { - return errParseAuthPage - } - - keySep := strings.Split(codeSep[0], "key ") - if len(keySep) != 2 { - return errParseAuthPage - } - key := keySep[1] - log.Printf("registering node %s", key) - - if headscale, err := s.Headscale(); err == nil { - _, err = headscale.Execute( - []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, - ) - if err != nil { - log.Printf("failed to register node: %s", err) - - return err - } - - return nil - } - - return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) -} diff --git a/integration/cli_test.go b/integration/cli_test.go index 2f23e8f6..85b20702 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -48,16 +48,15 @@ func TestUserCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 0, - "user2": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -247,15 +246,15 @@ func TestPreAuthKeyCommand(t *testing.T) { user := "preauthkeyspace" count := 3 - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{user}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -388,16 +387,15 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { t.Parallel() user := "pre-auth-key-without-exp-user" + spec := ScenarioSpec{ + Users: []string{user}, + } - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -451,16 +449,15 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { t.Parallel() user := "pre-auth-key-reus-ephm-user" + spec := ScenarioSpec{ + Users: []string{user}, + } - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -530,17 +527,16 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { user1 := "user1" user2 := "user2" - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{user1}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user1: 1, - user2: 0, - } - err = scenario.CreateHeadscaleEnv( - spec, []tsic.Option{}, hsic.WithTestName("clipak"), hsic.WithEmbeddedDERPServerOnly(), @@ -551,6 +547,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) + err = headscale.CreateUser(user2) + assertNoErr(t, err) + var user2Key v1.PreAuthKey err = executeAndUnmarshal( @@ -573,10 +572,15 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { ) assertNoErr(t, err) + listNodes, err := headscale.ListNodes() + require.Nil(t, err) + require.Len(t, listNodes, 1) + assert.Equal(t, user1, listNodes[0].GetUser().GetName()) + allClients, err := scenario.ListTailscaleClients() assertNoErrListClients(t, err) - assert.Len(t, allClients, 1) + require.Len(t, allClients, 1) client := allClients[0] @@ -606,12 +610,11 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String()) } - listNodes, err := headscale.ListNodes() - assert.Nil(t, err) - assert.Len(t, listNodes, 2) - - assert.Equal(t, "user1", listNodes[0].GetUser().GetName()) - assert.Equal(t, "user2", listNodes[1].GetUser().GetName()) + listNodes, err = headscale.ListNodes() + require.Nil(t, err) + require.Len(t, listNodes, 2) + assert.Equal(t, user1, listNodes[0].GetUser().GetName()) + assert.Equal(t, user2, listNodes[1].GetUser().GetName()) } func TestApiKeyCommand(t *testing.T) { @@ -620,16 +623,15 @@ func TestApiKeyCommand(t *testing.T) { count := 5 - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 0, - "user2": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -788,15 +790,15 @@ func TestNodeTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -977,15 +979,16 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(tt.policy), @@ -996,7 +999,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { assertNoErr(t, err) // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) + resultMachines := make([]*v1.Node, spec.NodesPerUser) err = executeAndUnmarshal( headscale, []string{ @@ -1029,16 +1032,15 @@ func TestNodeCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"node-user", "other-user"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "node-user": 0, - "other-user": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -1269,15 +1271,15 @@ func TestNodeExpireCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"node-expire-user"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "node-expire-user": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -1395,15 +1397,15 @@ func TestNodeRenameCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"node-rename-command"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "node-rename-command": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -1560,16 +1562,15 @@ func TestNodeMoveCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"old-user", "new-user"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "old-user": 0, - "new-user": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) assertNoErr(t, err) headscale, err := scenario.Headscale() @@ -1721,16 +1722,15 @@ func TestPolicyCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 0, - } - err = scenario.CreateHeadscaleEnv( - spec, []tsic.Option{}, hsic.WithTestName("clins"), hsic.WithConfigEnv(map[string]string{ @@ -1808,16 +1808,16 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": 1, - } - err = scenario.CreateHeadscaleEnv( - spec, []tsic.Option{}, hsic.WithTestName("clins"), hsic.WithConfigEnv(map[string]string{ diff --git a/integration/control.go b/integration/control.go index e1ad2a7e..2109b99d 100644 --- a/integration/control.go +++ b/integration/control.go @@ -24,5 +24,4 @@ type ControlServer interface { ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) GetCert() []byte GetHostname() string - GetIP() string } diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index bc7a0a7d..20ed4872 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -31,14 +31,15 @@ func TestDERPVerifyEndpoint(t *testing.T) { certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) assertNoErr(t, err) - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - } - derper, err := scenario.CreateDERPServer("head", dsic.WithCACert(certHeadscale), dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), @@ -65,7 +66,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { }, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())}, + err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithCACert(derper.GetCert())}, hsic.WithHostname(hostname), hsic.WithPort(headscalePort), hsic.WithCustomTLS(certHeadscale, keyHeadscale), diff --git a/integration/dns_test.go b/integration/dns_test.go index 1a8b69aa..9bd171f9 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -17,16 +17,16 @@ func TestResolveMagicDNS(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "magicdns1": len(MustTestVersions), - "magicdns2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns")) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -87,15 +87,15 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "magicdns1": 1, - "magicdns2": 1, - } - const erPath = "/tmp/extra_records.json" extraRecords := []tailcfg.DNSRecord{ @@ -107,7 +107,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { } b, _ := json.Marshal(extraRecords) - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{ + err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithDockerEntrypoint([]string{ "/bin/sh", "-c", @@ -364,16 +364,16 @@ func TestValidateResolvConf(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "resolvconf1": 3, - "resolvconf2": 3, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf)) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf)) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index a3dee180..9c5a3320 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -35,7 +35,7 @@ type DERPServerInContainer struct { pool *dockertest.Pool container *dockertest.Resource - network *dockertest.Network + networks []*dockertest.Network stunPort int derpPort int @@ -63,22 +63,22 @@ func WithCACert(cert []byte) Option { // isolating the DERPer, will be created. If a network is // passed, the DERPer instance will join the given network. func WithOrCreateNetwork(network *dockertest.Network) Option { - return func(tsic *DERPServerInContainer) { + return func(dsic *DERPServerInContainer) { if network != nil { - tsic.network = network + dsic.networks = append(dsic.networks, network) return } network, err := dockertestutil.GetFirstOrCreateNetwork( - tsic.pool, - tsic.hostname+"-network", + dsic.pool, + dsic.hostname+"-network", ) if err != nil { log.Fatalf("failed to create network: %s", err) } - tsic.network = network + dsic.networks = append(dsic.networks, network) } } @@ -107,7 +107,7 @@ func WithExtraHosts(hosts []string) Option { func New( pool *dockertest.Pool, version string, - network *dockertest.Network, + networks []*dockertest.Network, opts ...Option, ) (*DERPServerInContainer, error) { hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength) @@ -124,7 +124,7 @@ func New( version: version, hostname: hostname, pool: pool, - network: network, + networks: networks, tlsCert: tlsCert, tlsKey: tlsKey, stunPort: 3478, //nolint @@ -148,7 +148,7 @@ func New( runOptions := &dockertest.RunOptions{ Name: hostname, - Networks: []*dockertest.Network{dsic.network}, + Networks: dsic.networks, ExtraHosts: dsic.withExtraHosts, // we currently need to give us some time to inject the certificate further down. Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()}, diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index e17bbacb..0d930186 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -1,18 +1,12 @@ package integration import ( - "fmt" - "log" - "net/url" "strings" "testing" "time" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" - "github.com/ory/dockertest/v3" ) type ClientsSpec struct { @@ -20,21 +14,18 @@ type ClientsSpec struct { WebsocketDERP int } -type EmbeddedDERPServerScenario struct { - *Scenario - - tsicNetworks map[string]*dockertest.Network -} - func TestDERPServerScenario(t *testing.T) { - spec := map[string]ClientsSpec{ - "user1": { - Plain: len(MustTestVersions), - WebsocketDERP: 0, + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2", "user3"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + "usernet3": {"user3"}, }, } - derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) { + derpServerScenario(t, spec, false, func(scenario *Scenario) { allClients, err := scenario.ListTailscaleClients() assertNoErrListClients(t, err) t.Logf("checking %d clients for websocket connections", len(allClients)) @@ -52,14 +43,17 @@ func TestDERPServerScenario(t *testing.T) { } func TestDERPServerWebsocketScenario(t *testing.T) { - spec := map[string]ClientsSpec{ - "user1": { - Plain: 0, - WebsocketDERP: 2, + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2", "user3"}, + Networks: map[string][]string{ + "usernet1": []string{"user1"}, + "usernet2": []string{"user2"}, + "usernet3": []string{"user3"}, }, } - derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) { + derpServerScenario(t, spec, true, func(scenario *Scenario) { allClients, err := scenario.ListTailscaleClients() assertNoErrListClients(t, err) t.Logf("checking %d clients for websocket connections", len(allClients)) @@ -83,23 +77,22 @@ func TestDERPServerWebsocketScenario(t *testing.T) { //nolint:thelper func derpServerScenario( t *testing.T, - spec map[string]ClientsSpec, - furtherAssertions ...func(*EmbeddedDERPServerScenario), + spec ScenarioSpec, + websocket bool, + furtherAssertions ...func(*Scenario), ) { IntegrationSkip(t) // t.Parallel() - baseScenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(spec) assertNoErr(t, err) - scenario := EmbeddedDERPServerScenario{ - Scenario: baseScenario, - tsicNetworks: map[string]*dockertest.Network{}, - } defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( - spec, + []tsic.Option{ + tsic.WithWebsocketDERP(websocket), + }, hsic.WithTestName("derpserver"), hsic.WithExtraPorts([]string{"3478/udp"}), hsic.WithEmbeddedDERPServerOnly(), @@ -185,182 +178,6 @@ func derpServerScenario( t.Logf("Run2: %d successful pings out of %d", success, len(allClients)*len(allHostnames)) for _, check := range furtherAssertions { - check(&scenario) + check(scenario) } } - -func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv( - users map[string]ClientsSpec, - opts ...hsic.Option, -) error { - hsServer, err := s.Headscale(opts...) - if err != nil { - return err - } - - headscaleEndpoint := hsServer.GetEndpoint() - headscaleURL, err := url.Parse(headscaleEndpoint) - if err != nil { - return err - } - - headscaleURL.Host = fmt.Sprintf("%s:%s", hsServer.GetHostname(), headscaleURL.Port()) - - err = hsServer.WaitForRunning() - if err != nil { - return err - } - log.Printf("headscale server ip address: %s", hsServer.GetIP()) - - hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) - if err != nil { - return err - } - - for userName, clientCount := range users { - err = s.CreateUser(userName) - if err != nil { - return err - } - - if clientCount.Plain > 0 { - // Containers that use default DERP config - err = s.CreateTailscaleIsolatedNodesInUser( - hash, - userName, - "all", - clientCount.Plain, - ) - if err != nil { - return err - } - } - - if clientCount.WebsocketDERP > 0 { - // Containers that use DERP-over-WebSocket - // Note that these clients *must* be built - // from source, which is currently - // only done for HEAD. - err = s.CreateTailscaleIsolatedNodesInUser( - hash, - userName, - tsic.VersionHead, - clientCount.WebsocketDERP, - tsic.WithWebsocketDERP(true), - ) - if err != nil { - return err - } - } - - key, err := s.CreatePreAuthKey(userName, true, false) - if err != nil { - return err - } - - err = s.RunTailscaleUp(userName, headscaleURL.String(), key.GetKey()) - if err != nil { - return err - } - } - - return nil -} - -func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser( - hash string, - userStr string, - requestedVersion string, - count int, - opts ...tsic.Option, -) error { - hsServer, err := s.Headscale() - if err != nil { - return err - } - - if user, ok := s.users[userStr]; ok { - for clientN := 0; clientN < count; clientN++ { - networkName := fmt.Sprintf("tsnet-%s-%s-%d", - hash, - userStr, - clientN, - ) - network, err := dockertestutil.GetFirstOrCreateNetwork( - s.pool, - networkName, - ) - if err != nil { - return fmt.Errorf("failed to create or get %s network: %w", networkName, err) - } - - s.tsicNetworks[networkName] = network - - err = hsServer.ConnectToNetwork(network) - if err != nil { - return fmt.Errorf("failed to connect headscale to %s network: %w", networkName, err) - } - - version := requestedVersion - if requestedVersion == "all" { - version = MustTestVersions[clientN%len(MustTestVersions)] - } - - cert := hsServer.GetCert() - - opts = append(opts, - tsic.WithCACert(cert), - ) - - user.createWaitGroup.Go(func() error { - tsClient, err := tsic.New( - s.pool, - version, - network, - opts..., - ) - if err != nil { - return fmt.Errorf( - "failed to create tailscale (%s) node: %w", - tsClient.Hostname(), - err, - ) - } - - err = tsClient.WaitForNeedsLogin() - if err != nil { - return fmt.Errorf( - "failed to wait for tailscaled (%s) to need login: %w", - tsClient.Hostname(), - err, - ) - } - - s.mu.Lock() - user.Clients[tsClient.Hostname()] = tsClient - s.mu.Unlock() - - return nil - }) - } - - if err := user.createWaitGroup.Wait(); err != nil { - return err - } - - return nil - } - - return fmt.Errorf("failed to add tailscale nodes: %w", errNoUserAvailable) -} - -func (s *EmbeddedDERPServerScenario) Shutdown() { - for _, network := range s.tsicNetworks { - err := s.pool.RemoveNetwork(network) - if err != nil { - log.Printf("failed to remove DERP network %s", network.Network.Name) - } - } - - s.Scenario.Shutdown() -} diff --git a/integration/general_test.go b/integration/general_test.go index d6d9e7e1..0b55f0b7 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -28,18 +28,17 @@ func TestPingAllByIP(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + MaxWait: dockertestMaxWait(), + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - // TODO(kradalby): it does not look like the user thing works, only second - // get created? maybe only when many? - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithTestName("pingallbyip"), hsic.WithEmbeddedDERPServerOnly(), @@ -71,16 +70,16 @@ func TestPingAllByIPPublicDERP(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithTestName("pingallbyippubderp"), ) @@ -121,25 +120,25 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - headscale, err := scenario.Headscale(opts...) assertNoErrHeadscaleEnv(t, err) - for userName, clientCount := range spec { + for _, userName := range spec.Users { err = scenario.CreateUser(userName) if err != nil { t.Fatalf("failed to create user %s: %s", userName, err) } - err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) + err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) if err != nil { t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) } @@ -194,15 +193,15 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - headscale, err := scenario.Headscale( hsic.WithTestName("ephemeral2006"), hsic.WithConfigEnv(map[string]string{ @@ -211,13 +210,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { ) assertNoErrHeadscaleEnv(t, err) - for userName, clientCount := range spec { + for _, userName := range spec.Users { err = scenario.CreateUser(userName) if err != nil { t.Fatalf("failed to create user %s: %s", userName, err) } - err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) + err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) if err != nil { t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) } @@ -287,7 +286,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { // registered. time.Sleep(3 * time.Minute) - for userName := range spec { + for _, userName := range spec.Users { nodes, err := headscale.ListNodes(userName) if err != nil { log.Error(). @@ -308,16 +307,16 @@ func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user3": len(MustTestVersions), - "user4": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname")) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -357,15 +356,16 @@ func TestTaildrop(t *testing.T) { return err } - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "taildrop": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("taildrop"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), @@ -522,23 +522,22 @@ func TestUpdateHostnameFromClient(t *testing.T) { IntegrationSkip(t) t.Parallel() - user := "update-hostname-from-client" - hostnames := map[string]string{ "1": "user1-host", "2": "User2-Host", "3": "user3-host", } - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErrf(t, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 3, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("updatehostname")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname")) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -650,15 +649,16 @@ func TestExpireNode(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("expirenode")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode")) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -684,7 +684,7 @@ func TestExpireNode(t *testing.T) { assertNoErr(t, err) // Assert that we have the original count - self - assert.Len(t, status.Peers(), spec["user1"]-1) + assert.Len(t, status.Peers(), spec.NodesPerUser-1) } headscale, err := scenario.Headscale() @@ -776,15 +776,16 @@ func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "user1": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online")) + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online")) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -891,18 +892,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - // TODO(kradalby): it does not look like the user thing works, only second - // get created? maybe only when many? - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithTestName("pingallbyipmany"), hsic.WithEmbeddedDERPServerOnly(), @@ -973,18 +972,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) - // TODO(kradalby): it does not look like the user thing works, only second - // get created? maybe only when many? - spec := map[string]int{ - "user1": 1, - "user2": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithTestName("deletenocrash"), hsic.WithEmbeddedDERPServerOnly(), diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index fedf220e..1b976f4a 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -56,7 +56,7 @@ type HeadscaleInContainer struct { pool *dockertest.Pool container *dockertest.Resource - network *dockertest.Network + networks []*dockertest.Network pgContainer *dockertest.Resource @@ -268,7 +268,7 @@ func WithTimezone(timezone string) Option { // New returns a new HeadscaleInContainer instance. func New( pool *dockertest.Pool, - network *dockertest.Network, + networks []*dockertest.Network, opts ...Option, ) (*HeadscaleInContainer, error) { hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength) @@ -282,8 +282,8 @@ func New( hostname: hostname, port: headscaleDefaultPort, - pool: pool, - network: network, + pool: pool, + networks: networks, env: DefaultConfigEnv(), filesInContainer: []fileInContainer{}, @@ -315,7 +315,7 @@ func New( Name: fmt.Sprintf("postgres-%s", hash), Repository: "postgres", Tag: "latest", - Networks: []*dockertest.Network{network}, + Networks: networks, Env: []string{ "POSTGRES_USER=headscale", "POSTGRES_PASSWORD=headscale", @@ -357,7 +357,7 @@ func New( runOptions := &dockertest.RunOptions{ Name: hsic.hostname, ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...), - Networks: []*dockertest.Network{network}, + Networks: networks, // Cmd: []string{"headscale", "serve"}, // TODO(kradalby): Get rid of this hack, we currently need to give us some // to inject the headscale configuration further down. @@ -630,11 +630,6 @@ func (t *HeadscaleInContainer) Execute( return stdout, nil } -// GetIP returns the docker container IP as a string. -func (t *HeadscaleInContainer) GetIP() string { - return t.container.GetIPInNetwork(t.network) -} - // GetPort returns the docker container port as a string. func (t *HeadscaleInContainer) GetPort() string { return fmt.Sprintf("%d", t.port) diff --git a/integration/route_test.go b/integration/route_test.go index e92a4c37..04f9073e 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1,12 +1,16 @@ package integration import ( + "fmt" "net/netip" "sort" "testing" "time" + "slices" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -18,6 +22,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/types/ipproto" "tailscale.com/types/views" + "tailscale.com/util/slicesx" "tailscale.com/wgengine/filter" ) @@ -29,17 +34,18 @@ func TestEnablingRoutes(t *testing.T) { IntegrationSkip(t) t.Parallel() - user := "user6" + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1"}, + } - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(spec) require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 3, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute")) + err = scenario.CreateHeadscaleEnv( + []tsic.Option{tsic.WithAcceptRoutes()}, + hsic.WithTestName("clienableroute")) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -123,26 +129,10 @@ func TestEnablingRoutes(t *testing.T) { for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] - assert.Nil(t, peerStatus.PrimaryRoutes) + assert.NotNil(t, peerStatus.PrimaryRoutes) assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3) - - if peerStatus.AllowedIPs.Len() > 2 { - peerRoute := peerStatus.AllowedIPs.At(2) - - // id starts at 1, we created routes with 0 index - assert.Equalf( - t, - expectedRoutes[string(peerStatus.ID)], - peerRoute.String(), - "expected route %s to be present on peer %s (%s) in %s (%s) status", - expectedRoutes[string(peerStatus.ID)], - peerStatus.HostName, - peerStatus.ID, - client.Hostname(), - client.ID(), - ) - } + requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) } } @@ -187,13 +177,12 @@ func TestEnablingRoutes(t *testing.T) { for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] - assert.Nil(t, peerStatus.PrimaryRoutes) if peerStatus.ID == "1" { - assertPeerSubnetRoutes(t, peerStatus, nil) + requirePeerSubnetRoutes(t, peerStatus, nil) } else if peerStatus.ID == "2" { - assertPeerSubnetRoutes(t, peerStatus, nil) + requirePeerSubnetRoutes(t, peerStatus, nil) } else { - assertPeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")}) + requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")}) } } } @@ -203,17 +192,27 @@ func TestHASubnetRouterFailover(t *testing.T) { IntegrationSkip(t) t.Parallel() - user := "user9" + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + } - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(spec) require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 4, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, + err = scenario.CreateHeadscaleEnv( + []tsic.Option{tsic.WithAcceptRoutes()}, hsic.WithTestName("clienableroute"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), @@ -229,11 +228,22 @@ func TestHASubnetRouterFailover(t *testing.T) { headscale, err := scenario.Headscale() assertNoErrGetHeadscale(t, err) - expectedRoutes := map[string]string{ - "1": "10.0.0.0/24", - "2": "10.0.0.0/24", - "3": "10.0.0.0/24", - } + prefp, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + pref := *prefp + t.Logf("usernet1 prefix: %s", pref.String()) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + weburl := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("webservice: %s, %s", webip.String(), weburl) // Sort nodes by ID sort.SliceStable(allClients, func(i, j int) bool { @@ -243,6 +253,9 @@ func TestHASubnetRouterFailover(t *testing.T) { return statusI.Self.ID < statusJ.Self.ID }) + // This is ok because the scenario makes users in order, so the three first + // nodes, which are subnet routes, will be created first, and the last user + // will be created with the second. subRouter1 := allClients[0] subRouter2 := allClients[1] subRouter3 := allClients[2] @@ -255,28 +268,23 @@ func TestHASubnetRouterFailover(t *testing.T) { // ID 2 will be standby // ID 3 will be standby for _, client := range allClients[:3] { - status, err := client.Status() - require.NoError(t, err) - - if route, ok := expectedRoutes[string(status.Self.ID)]; ok { - command := []string{ - "tailscale", - "set", - "--advertise-routes=" + route, - } - _, _, err = client.Execute(command) - require.NoErrorf(t, err, "failed to advertise route: %s", err) - } else { - t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID) + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + pref.String(), } + _, _, err = client.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) } err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + time.Sleep(3 * time.Second) + nodes, err := headscale.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 4) + assert.Len(t, nodes, 6) assertNodeRouteCount(t, nodes[0], 1, 0, 0) assertNodeRouteCount(t, nodes[1], 1, 0, 0) @@ -292,28 +300,30 @@ func TestHASubnetRouterFailover(t *testing.T) { peerStatus := status.Peer[peerKey] assert.Nil(t, peerStatus.PrimaryRoutes) - assertPeerSubnetRoutes(t, peerStatus, nil) + requirePeerSubnetRoutes(t, peerStatus, nil) } } - // Enable all routes - for _, node := range nodes { - _, err := headscale.ApproveRoutes( - node.GetId(), - util.MustStringsToPrefixes(node.GetAvailableRoutes()), - ) - require.NoError(t, err) - } + // Enable route on node 1 + t.Logf("Enabling route on subnet router 1, no HA") + _, err = headscale.ApproveRoutes( + 1, + []netip.Prefix{pref}, + ) + require.NoError(t, err) + + time.Sleep(3 * time.Second) nodes, err = headscale.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 4) + assert.Len(t, nodes, 6) assertNodeRouteCount(t, nodes[0], 1, 1, 1) - assertNodeRouteCount(t, nodes[1], 1, 1, 1) - assertNodeRouteCount(t, nodes[2], 1, 1, 1) + assertNodeRouteCount(t, nodes[1], 1, 0, 0) + assertNodeRouteCount(t, nodes[2], 1, 0, 0) - // Verify that the client has routes from the primary machine + // Verify that the client has routes from the primary machine and can access + // the webservice. srs1 := subRouter1.MustStatus() srs2 := subRouter2.MustStatus() srs3 := subRouter3.MustStatus() @@ -331,11 +341,135 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Nil(t, srs3PeerStatus.PrimaryRoutes) require.NotNil(t, srs1PeerStatus.PrimaryRoutes) + requirePeerSubnetRoutes(t, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, nil) + + t.Logf("got list: %v, want in: %v", srs1PeerStatus.PrimaryRoutes.AsSlice(), pref) assert.Contains(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + pref, ) + t.Logf("Validating access via subnetrouter(%s) to %s, no HA", subRouter1.MustIPv4().String(), webip.String()) + result, err := client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err := client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + + // Enable route on node 2, now we will have a HA subnet router + t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby") + _, err = headscale.ApproveRoutes( + 2, + []netip.Prefix{pref}, + ) + require.NoError(t, err) + + time.Sleep(3 * time.Second) + + nodes, err = headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 6) + + assertNodeRouteCount(t, nodes[0], 1, 1, 1) + assertNodeRouteCount(t, nodes[1], 1, 1, 1) + assertNodeRouteCount(t, nodes[2], 1, 0, 0) + + // Verify that the client has routes from the primary machine + srs1 = subRouter1.MustStatus() + srs2 = subRouter2.MustStatus() + srs3 = subRouter3.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up") + assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up") + assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up") + + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + require.NotNil(t, srs1PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutes(t, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, nil) + + t.Logf("got list: %v, want in: %v", srs1PeerStatus.PrimaryRoutes.AsSlice(), pref) + assert.Contains(t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + + t.Logf("Validating access via subnetrouter(%s) to %s, 2 is standby", subRouter1.MustIPv4().String(), webip.String()) + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + + // Enable route on node 3, now we will have a second standby and all will + // be enabled. + t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby") + _, err = headscale.ApproveRoutes( + 3, + []netip.Prefix{pref}, + ) + require.NoError(t, err) + + time.Sleep(3 * time.Second) + + nodes, err = headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 6) + + assertNodeRouteCount(t, nodes[0], 1, 1, 1) + assertNodeRouteCount(t, nodes[1], 1, 1, 1) + assertNodeRouteCount(t, nodes[2], 1, 1, 1) + + // Verify that the client has routes from the primary machine + srs1 = subRouter1.MustStatus() + srs2 = subRouter2.MustStatus() + srs3 = subRouter3.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up") + assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up") + assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up") + + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + require.NotNil(t, srs1PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutes(t, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, nil) + + t.Logf("got list: %v, want in: %v", srs1PeerStatus.PrimaryRoutes.AsSlice(), pref) + assert.Contains(t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + // Take down the current primary t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname()) t.Logf("expecting r2 (%s) to take over as primary", subRouter2.Hostname()) @@ -359,12 +493,24 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NotNil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + requirePeerSubnetRoutes(t, srs1PeerStatus, nil) + requirePeerSubnetRoutes(t, srs2PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutes(t, srs3PeerStatus, nil) + assert.Contains( t, srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + pref, ) + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter2.MustIPv4()) + // Take down subnet router 2, leaving none available t.Logf("taking down subnet router r2 (%s)", subRouter2.Hostname()) t.Logf("expecting no primary, r3 available, but no HA so no primary") @@ -390,7 +536,19 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) - assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + require.NotNil(t, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutes(t, srs1PeerStatus, nil) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, []netip.Prefix{pref}) + + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter3.MustIPv4()) // Bring up subnet router 1, making the route available from there. t.Logf("bringing up subnet router r1 (%s)", subRouter1.Hostname()) @@ -412,16 +570,28 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down") assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available") - assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) - assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + require.NotNil(t, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutes(t, srs1PeerStatus, nil) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, []netip.Prefix{pref}) assert.Contains( t, - srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + srs3PeerStatus.PrimaryRoutes.AsSlice(), + pref, ) + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter3.MustIPv4()) + // Bring up subnet router 2, should result in no change. t.Logf("bringing up subnet router r2 (%s)", subRouter2.Hostname()) t.Logf("all online, expecting r1 (%s) to still be primary (no flapping)", subRouter1.Hostname()) @@ -442,30 +612,86 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up") assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up") + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + require.NotNil(t, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutes(t, srs1PeerStatus, nil) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, []netip.Prefix{pref}) + + assert.Contains( + t, + srs3PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter3.MustIPv4()) + + t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname()) + t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname()) + _, err = headscale.ApproveRoutes(nodes[2].GetId(), []netip.Prefix{}) + + time.Sleep(5 * time.Second) + + nodes, err = headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 6) + + assertNodeRouteCount(t, nodes[0], 1, 1, 1) + assertNodeRouteCount(t, nodes[1], 1, 1, 1) + assertNodeRouteCount(t, nodes[2], 1, 0, 0) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + require.NoError(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + require.NotNil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + requirePeerSubnetRoutes(t, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutes(t, srs2PeerStatus, nil) + requirePeerSubnetRoutes(t, srs3PeerStatus, nil) + assert.Contains( t, srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + pref, ) + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + // Disable the route of subnet router 1, making it failover to 2 t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname()) - t.Logf("expecting route to failover to r2 (%s), which is still available with r3", subRouter2.Hostname()) + t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname()) _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{}) time.Sleep(5 * time.Second) nodes, err = headscale.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 4) + assert.Len(t, nodes, 6) assertNodeRouteCount(t, nodes[0], 1, 0, 0) assertNodeRouteCount(t, nodes[1], 1, 1, 1) - assertNodeRouteCount(t, nodes[2], 1, 1, 1) + assertNodeRouteCount(t, nodes[2], 1, 0, 0) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -476,15 +702,27 @@ func TestHASubnetRouterFailover(t *testing.T) { srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] assert.Nil(t, srs1PeerStatus.PrimaryRoutes) - assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) + require.NotNil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs3PeerStatus.PrimaryRoutes) + requirePeerSubnetRoutes(t, srs1PeerStatus, nil) + requirePeerSubnetRoutes(t, srs2PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutes(t, srs3PeerStatus, nil) + assert.Contains( t, srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + pref, ) + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter2.MustIPv4()) + // enable the route of subnet router 1, no change expected t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname()) t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname()) @@ -497,11 +735,11 @@ func TestHASubnetRouterFailover(t *testing.T) { nodes, err = headscale.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 4) + assert.Len(t, nodes, 6) assertNodeRouteCount(t, nodes[0], 1, 1, 1) assertNodeRouteCount(t, nodes[1], 1, 1, 1) - assertNodeRouteCount(t, nodes[2], 1, 1, 1) + assertNodeRouteCount(t, nodes[2], 1, 0, 0) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -518,8 +756,16 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Contains( t, srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + pref, ) + + result, err = client.Curl(weburl) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err = client.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, subRouter2.MustIPv4()) } func TestEnableDisableAutoApprovedRoute(t *testing.T) { @@ -528,17 +774,19 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { expectedRoutes := "172.0.0.0/24" - user := "user2" + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + } - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(spec) require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithTags([]string{"tag:approve"}), + tsic.WithAcceptRoutes(), + }, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( &policyv1.ACLPolicy{ ACLs: []policyv1.ACL{ { @@ -548,7 +796,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { }, }, TagOwners: map[string][]string{ - "tag:approve": {user}, + "tag:approve": {"user1"}, }, AutoApprovers: policyv1.AutoApprovers{ Routes: map[string][]string{ @@ -627,15 +875,19 @@ func TestAutoApprovedSubRoute2068(t *testing.T) { user := "user1" - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{user}, + } + + scenario, err := NewScenario(spec) require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithTags([]string{"tag:approve"}), + tsic.WithAcceptRoutes(), + }, hsic.WithTestName("clienableroute"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), @@ -698,15 +950,18 @@ func TestSubnetRouteACL(t *testing.T) { user := "user4" - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{user}, + } + + scenario, err := NewScenario(spec) require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 2, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithAcceptRoutes(), + }, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( &policyv1.ACLPolicy{ Groups: policyv1.Groups{ "group:admins": {user}, @@ -799,7 +1054,7 @@ func TestSubnetRouteACL(t *testing.T) { peerStatus := status.Peer[peerKey] assert.Nil(t, peerStatus.PrimaryRoutes) - assertPeerSubnetRoutes(t, peerStatus, nil) + requirePeerSubnetRoutes(t, peerStatus, nil) } } @@ -826,7 +1081,7 @@ func TestSubnetRouteACL(t *testing.T) { srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] - assertPeerSubnetRoutes(t, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])}) + requirePeerSubnetRoutes(t, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])}) clientNm, err := client.Netmap() require.NoError(t, err) @@ -920,15 +1175,16 @@ func TestEnablingExitRoutes(t *testing.T) { user := "user2" - scenario, err := NewScenario(dockertestMaxWait()) + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{user}, + } + + scenario, err := NewScenario(spec) assertNoErrf(t, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - user: 2, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{ + err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithExtraLoginArgs([]string{"--advertise-exit-node"}), }, hsic.WithTestName("clienableroute")) assertNoErrHeadscaleEnv(t, err) @@ -1003,11 +1259,286 @@ func TestEnablingExitRoutes(t *testing.T) { } } -// assertPeerSubnetRoutes asserts that the peer has the expected subnet routes. -func assertPeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) { +// TestSubnetRouterMultiNetwork is an evolution of the subnet router test. +// This test will set up multiple docker networks and use two isolated tailscale +// clients and a service available in one of the networks to validate that a +// subnet router is working as expected. +func TestSubnetRouterMultiNetwork(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithAcceptRoutes()}, + hsic.WithTestName("clienableroute"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + assert.NotNil(t, headscale) + + pref, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + + var user1c, user2c TailscaleClient + + for _, c := range allClients { + s := c.MustStatus() + if s.User[s.Self.UserID].LoginName == "user1@test.no" { + user1c = c + } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { + user2c = c + } + } + require.NotNil(t, user1c) + require.NotNil(t, user2c) + + // Advertise the route for the dockersubnet of user1 + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + pref.String(), + } + _, _, err = user1c.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + nodes, err := headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 2) + assertNodeRouteCount(t, nodes[0], 1, 0, 0) + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + status, err := user1c.Status() + require.NoError(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutes(t, peerStatus, nil) + } + + // Enable route + _, err = headscale.ApproveRoutes( + nodes[0].Id, + []netip.Prefix{*pref}, + ) + require.NoError(t, err) + + time.Sleep(5 * time.Second) + + nodes, err = headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 2) + assertNodeRouteCount(t, nodes[0], 1, 1, 1) + + // Verify that the routes have been sent to the client. + status, err = user2c.Status() + require.NoError(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref) + requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref}) + } + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + + url := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("url from %s to %s", user2c.Hostname(), url) + + result, err := user2c.Curl(url) + require.NoError(t, err) + assert.Len(t, result, 13) + + tr, err := user2c.Traceroute(webip) + require.NoError(t, err) + assertTracerouteViaIP(t, tr, user1c.MustIPv4()) +} + +// TestSubnetRouterMultiNetworkExitNode +func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, + hsic.WithTestName("clienableroute"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + assert.NotNil(t, headscale) + + var user1c, user2c TailscaleClient + + for _, c := range allClients { + s := c.MustStatus() + if s.User[s.Self.UserID].LoginName == "user1@test.no" { + user1c = c + } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { + user2c = c + } + } + require.NotNil(t, user1c) + require.NotNil(t, user2c) + + // Advertise the exit nodes for the dockersubnet of user1 + command := []string{ + "tailscale", + "set", + "--advertise-exit-node", + } + _, _, err = user1c.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + nodes, err := headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 2) + assertNodeRouteCount(t, nodes[0], 2, 0, 0) + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + status, err := user1c.Status() + require.NoError(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutes(t, peerStatus, nil) + } + + // Enable route + _, err = headscale.ApproveRoutes( + nodes[0].Id, + []netip.Prefix{tsaddr.AllIPv4()}, + ) + require.NoError(t, err) + + time.Sleep(5 * time.Second) + + nodes, err = headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 2) + assertNodeRouteCount(t, nodes[0], 2, 2, 2) + + // Verify that the routes have been sent to the client. + status, err = user2c.Status() + require.NoError(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } + + // Tell user2c to use user1c as an exit node. + command = []string{ + "tailscale", + "set", + "--exit-node", + user1c.Hostname(), + } + _, _, err = user2c.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + + // We cant mess to much with ip forwarding in containers so + // we settle for a simple ping here. + // Direct is false since we use internal DERP which means we + // cant discover a direct path between docker networks. + err = user2c.Ping(webip.String(), + tsic.WithPingUntilDirect(false), + tsic.WithPingCount(1), + tsic.WithPingTimeout(7*time.Second), + ) + require.NoError(t, err) +} + +func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) { + t.Helper() + + require.NotNil(t, tr) + require.True(t, tr.Success) + require.NoError(t, tr.Err) + require.NotEmpty(t, tr.Route) + require.Equal(t, tr.Route[0].IP, ip) +} + +// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes. +func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) { t.Helper() if status.AllowedIPs.Len() <= 2 && len(expected) != 0 { - t.Errorf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected) + t.Fatalf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected) return } @@ -1015,10 +1546,15 @@ func assertPeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected expected = []netip.Prefix{} } - got := status.AllowedIPs.AsSlice()[2:] + got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool { + if tsaddr.IsExitRoute(p) { + return true + } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) + }) - if diff := cmp.Diff(expected, got, util.PrefixComparer); diff != "" { - t.Errorf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff) + if diff := cmp.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { + t.Fatalf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff) } } diff --git a/integration/scenario.go b/integration/scenario.go index 1cdc8f5d..e0cbdc21 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -1,24 +1,37 @@ package integration import ( + "context" + "crypto/tls" + "encoding/json" "errors" "fmt" + "io" "log" + "net" + "net/http" + "net/http/cookiejar" "net/netip" + "net/url" "os" "sort" + "strconv" + "strings" "sync" "testing" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/capver" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dsic" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" "github.com/puzpuzpuz/xsync/v3" "github.com/samber/lo" "github.com/stretchr/testify/assert" @@ -26,6 +39,7 @@ import ( xmaps "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" "tailscale.com/envknob" + "tailscale.com/util/mak" ) const ( @@ -86,33 +100,136 @@ type Scenario struct { users map[string]*User - pool *dockertest.Pool - network *dockertest.Network + pool *dockertest.Pool + networks map[string]*dockertest.Network + mockOIDC scenarioOIDC + extraServices map[string][]*dockertest.Resource mu sync.Mutex + + spec ScenarioSpec + userToNetwork map[string]*dockertest.Network +} + +// ScenarioSpec describes the users, nodes, and network topology to +// set up for a given scenario. +type ScenarioSpec struct { + // Users is a list of usernames that will be created. + // Each created user will get nodes equivalent to NodesPerUser + Users []string + + // NodesPerUser is how many nodes should be attached to each user. + NodesPerUser int + + // Networks, if set, is the seperate Docker networks that should be + // created and a list of the users that should be placed in those networks. + // If not set, a single network will be created and all users+nodes will be + // added there. + // Please note that Docker networks are not necessarily routable and + // connections between them might fall back to DERP. + Networks map[string][]string + + // ExtraService, if set, is additional a map of network to additional + // container services that should be set up. These container services + // typically dont run Tailscale, e.g. web service to test subnet router. + ExtraService map[string][]extraServiceFunc + + // Versions is specific list of versions to use for the test. + Versions []string + + // OIDCUsers, if populated, will start a Mock OIDC server and populate + // the user login stack with the given users. + // If the NodesPerUser is set, it should align with this list to ensure + // the correct users are logged in. + // This is because the MockOIDC server can only serve login + // requests based on a queue it has been given on startup. + // We currently only populates it with one login request per user. + OIDCUsers []mockoidc.MockUser + OIDCAccessTTL time.Duration + + MaxWait time.Duration +} + +var TestHashPrefix = "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) +var TestDefaultNetwork = TestHashPrefix + "-default" + +func prefixedNetworkName(name string) string { + return TestHashPrefix + "-" + name } // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // a set of Users and TailscaleClients. -func NewScenario(maxWait time.Duration) (*Scenario, error) { - hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) - if err != nil { - return nil, err - } - +func NewScenario(spec ScenarioSpec) (*Scenario, error) { pool, err := dockertest.NewPool("") if err != nil { return nil, fmt.Errorf("could not connect to docker: %w", err) } - pool.MaxWait = maxWait - - networkName := fmt.Sprintf("hs-%s", hash) - if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { - networkName = overrideNetworkName + if spec.MaxWait == 0 { + pool.MaxWait = dockertestMaxWait() + } else { + pool.MaxWait = spec.MaxWait } - network, err := dockertestutil.GetFirstOrCreateNetwork(pool, networkName) + s := &Scenario{ + controlServers: xsync.NewMapOf[string, ControlServer](), + users: make(map[string]*User), + + pool: pool, + spec: spec, + } + + var userToNetwork map[string]*dockertest.Network + if spec.Networks != nil || len(spec.Networks) != 0 { + for name, users := range s.spec.Networks { + networkName := TestHashPrefix + "-" + name + network, err := s.AddNetwork(networkName) + if err != nil { + return nil, err + } + + for _, user := range users { + if n2, ok := userToNetwork[user]; ok { + return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) + } + mak.Set(&userToNetwork, user, network) + } + } + } else { + _, err := s.AddNetwork(TestDefaultNetwork) + if err != nil { + return nil, err + } + } + + for network, extras := range spec.ExtraService { + for _, extra := range extras { + svc, err := extra(s, network) + if err != nil { + return nil, err + } + mak.Set(&s.extraServices, prefixedNetworkName(network), append(s.extraServices[prefixedNetworkName(network)], svc)) + } + } + + s.userToNetwork = userToNetwork + + if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 { + ttl := defaultAccessTTL + if spec.OIDCAccessTTL != 0 { + ttl = spec.OIDCAccessTTL + } + err = s.runMockOIDC(ttl, spec.OIDCUsers) + if err != nil { + return nil, err + } + } + + return s, nil +} + +func (s *Scenario) AddNetwork(name string) (*dockertest.Network, error) { + network, err := dockertestutil.GetFirstOrCreateNetwork(s.pool, name) if err != nil { return nil, fmt.Errorf("failed to create or get network: %w", err) } @@ -120,18 +237,58 @@ func NewScenario(maxWait time.Duration) (*Scenario, error) { // We run the test suite in a docker container that calls a couple of endpoints for // readiness checks, this ensures that we can run the tests with individual networks // and have the client reach the different containers - err = dockertestutil.AddContainerToNetwork(pool, network, "headscale-test-suite") + // TODO(kradalby): Can the test-suite be renamed so we can have multiple? + err = dockertestutil.AddContainerToNetwork(s.pool, network, "headscale-test-suite") if err != nil { return nil, fmt.Errorf("failed to add test suite container to network: %w", err) } - return &Scenario{ - controlServers: xsync.NewMapOf[string, ControlServer](), - users: make(map[string]*User), + mak.Set(&s.networks, name, network) - pool: pool, - network: network, - }, nil + return network, nil +} + +func (s *Scenario) Networks() []*dockertest.Network { + if len(s.networks) == 0 { + panic("Scenario.Networks called with empty network list") + } + return xmaps.Values(s.networks) +} + +func (s *Scenario) Network(name string) (*dockertest.Network, error) { + net, ok := s.networks[prefixedNetworkName(name)] + if !ok { + return nil, fmt.Errorf("no network named: %s", name) + } + + return net, nil +} + +func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { + net, ok := s.networks[prefixedNetworkName(name)] + if !ok { + return nil, fmt.Errorf("no network named: %s", name) + } + + for _, ipam := range net.Network.IPAM.Config { + pref, err := netip.ParsePrefix(ipam.Subnet) + if err != nil { + return nil, err + } + + return &pref, nil + } + + return nil, fmt.Errorf("no prefix found in network: %s", name) +} + +func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { + res, ok := s.extraServices[prefixedNetworkName(name)] + if !ok { + return nil, fmt.Errorf("no network named: %s", name) + } + + return res, nil } func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { @@ -184,14 +341,27 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { } } - if err := s.pool.RemoveNetwork(s.network); err != nil { - log.Printf("failed to remove network: %s", err) + for _, svcs := range s.extraServices { + for _, svc := range svcs { + err := svc.Close() + if err != nil { + log.Printf("failed to tear down service %q: %s", svc.Container.Name, err) + } + } } - // TODO(kradalby): This seem redundant to the previous call - // if err := s.network.Close(); err != nil { - // return fmt.Errorf("failed to tear down network: %w", err) - // } + if s.mockOIDC.r != nil { + s.mockOIDC.r.Close() + if err := s.mockOIDC.r.Close(); err != nil { + log.Printf("failed to tear down oidc server: %s", err) + } + } + + for _, network := range s.networks { + if err := network.Close(); err != nil { + log.Printf("failed to tear down network: %s", err) + } + } } // Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient) @@ -235,7 +405,7 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { opts = append(opts, hsic.WithPolicyV2()) } - headscale, err := hsic.New(s.pool, s.network, opts...) + headscale, err := hsic.New(s.pool, s.Networks(), opts...) if err != nil { return nil, fmt.Errorf("failed to create headscale container: %w", err) } @@ -312,7 +482,6 @@ func (s *Scenario) CreateTailscaleNode( tsClient, err := tsic.New( s.pool, version, - s.network, opts..., ) if err != nil { @@ -345,10 +514,14 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) error { if user, ok := s.users[userStr]; ok { var versions []string - for i := 0; i < count; i++ { + for i := range count { version := requestedVersion if requestedVersion == "all" { - version = MustTestVersions[i%len(MustTestVersions)] + if s.spec.Versions != nil { + version = s.spec.Versions[i%len(s.spec.Versions)] + } else { + version = MustTestVersions[i%len(MustTestVersions)] + } } versions = append(versions, version) @@ -372,14 +545,12 @@ func (s *Scenario) CreateTailscaleNodesInUser( tsClient, err := tsic.New( s.pool, version, - s.network, opts..., ) s.mu.Unlock() if err != nil { return fmt.Errorf( - "failed to create tailscale (%s) node: %w", - tsClient.Hostname(), + "failed to create tailscale node: %w", err, ) } @@ -492,11 +663,24 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int) error { return nil } -// CreateHeadscaleEnv is a convenient method returning a complete Headcale -// test environment with nodes of all versions, joined to the server with X -// users. +func (s *Scenario) CreateHeadscaleEnvWithLoginURL( + tsOpts []tsic.Option, + opts ...hsic.Option, +) error { + return s.createHeadscaleEnv(true, tsOpts, opts...) +} + func (s *Scenario) CreateHeadscaleEnv( - users map[string]int, + tsOpts []tsic.Option, + opts ...hsic.Option, +) error { + return s.createHeadscaleEnv(false, tsOpts, opts...) +} + +// CreateHeadscaleEnv starts the headscale environment and the clients +// according to the ScenarioSpec passed to the Scenario. +func (s *Scenario) createHeadscaleEnv( + withURL bool, tsOpts []tsic.Option, opts ...hsic.Option, ) error { @@ -505,34 +689,188 @@ func (s *Scenario) CreateHeadscaleEnv( return err } - usernames := xmaps.Keys(users) - sort.Strings(usernames) - for _, username := range usernames { - clientCount := users[username] - err = s.CreateUser(username) + sort.Strings(s.spec.Users) + for _, user := range s.spec.Users { + err = s.CreateUser(user) if err != nil { return err } - err = s.CreateTailscaleNodesInUser(username, "all", clientCount, tsOpts...) + var opts []tsic.Option + if s.userToNetwork != nil { + opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user])) + } else { + opts = append(tsOpts, tsic.WithNetwork(s.networks[TestDefaultNetwork])) + } + + err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...) if err != nil { return err } - key, err := s.CreatePreAuthKey(username, true, false) - if err != nil { - return err - } + if withURL { + err = s.RunTailscaleUpWithURL(user, headscale.GetEndpoint()) + if err != nil { + return err + } + } else { + key, err := s.CreatePreAuthKey(user, true, false) + if err != nil { + return err + } - err = s.RunTailscaleUp(username, headscale.GetEndpoint(), key.GetKey()) - if err != nil { - return err + err = s.RunTailscaleUp(user, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + return err + } } } return nil } +func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { + log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { + for _, client := range user.Clients { + tsc := client + user.joinWaitGroup.Go(func() error { + loginURL, err := tsc.LoginWithURL(loginServer) + if err != nil { + log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) + } + + body, err := doLoginURL(tsc.Hostname(), loginURL) + if err != nil { + return err + } + + // If the URL is not a OIDC URL, then we need to + // run the register command to fully log in the client. + if !strings.Contains(loginURL.String(), "/oidc/") { + s.runHeadscaleRegister(userStr, body) + } + + return nil + }) + + log.Printf("client %s is ready", client.Hostname()) + } + + if err := user.joinWaitGroup.Wait(); err != nil { + return err + } + + for _, client := range user.Clients { + err := client.WaitForRunning() + if err != nil { + return fmt.Errorf( + "%s tailscale node has not reached running: %w", + client.Hostname(), + err, + ) + } + } + + return nil + } + + return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) +} + +// doLoginURL visits the given login URL and returns the body as a +// string. +func doLoginURL(hostname string, loginURL *url.URL) (string, error) { + log.Printf("%s login url: %s\n", hostname, loginURL.String()) + + var err error + hc := &http.Client{ + Transport: LoggingRoundTripper{}, + } + hc.Jar, err = cookiejar.New(nil) + if err != nil { + return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err) + } + + log.Printf("%s logging in with url", hostname) + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + resp, err := hc.Do(req) + if err != nil { + return "", fmt.Errorf("%s failed to send http request: %w", hostname, err) + } + + log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Printf("body: %s", body) + + return "", fmt.Errorf("%s response code of login request was %w", hostname, err) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("%s failed to read response body: %s", hostname, err) + + return "", fmt.Errorf("%s failed to read response body: %w", hostname, err) + } + + return string(body), nil +} + +var errParseAuthPage = errors.New("failed to parse auth page") + +func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { + // see api.go HTML template + codeSep := strings.Split(string(body), "") + if len(codeSep) != 2 { + return errParseAuthPage + } + + keySep := strings.Split(codeSep[0], "key ") + if len(keySep) != 2 { + return errParseAuthPage + } + key := keySep[1] + log.Printf("registering node %s", key) + + if headscale, err := s.Headscale(); err == nil { + _, err = headscale.Execute( + []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, + ) + if err != nil { + log.Printf("failed to register node: %s", err) + + return err + } + + return nil + } + + return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) +} + +type LoggingRoundTripper struct{} + +func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + noTls := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + } + resp, err := noTls.RoundTrip(req) + if err != nil { + return nil, err + } + + log.Printf("---") + log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) + log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) + + return resp, nil +} + // GetIPs returns all netip.Addr of TailscaleClients associated with a User // in a Scenario. func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { @@ -670,7 +1008,7 @@ func (s *Scenario) WaitForTailscaleLogout() error { // CreateDERPServer creates a new DERP server in a container. func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) { - derp, err := dsic.New(s.pool, version, s.network, opts...) + derp, err := dsic.New(s.pool, version, s.Networks(), opts...) if err != nil { return nil, fmt.Errorf("failed to create DERP server: %w", err) } @@ -684,3 +1022,216 @@ func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic. return derp, nil } + +type scenarioOIDC struct { + r *dockertest.Resource + cfg *types.OIDCConfig +} + +func (o *scenarioOIDC) Issuer() string { + if o.cfg == nil { + panic("OIDC has not been created") + } + + return o.cfg.Issuer +} + +func (o *scenarioOIDC) ClientSecret() string { + if o.cfg == nil { + panic("OIDC has not been created") + } + + return o.cfg.ClientSecret +} + +func (o *scenarioOIDC) ClientID() string { + if o.cfg == nil { + panic("OIDC has not been created") + } + + return o.cfg.ClientID +} + +const ( + dockerContextPath = "../." + hsicOIDCMockHashLength = 6 + defaultAccessTTL = 10 * time.Minute +) + +var errStatusCodeNotOK = errors.New("status code not OK") + +func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) error { + port, err := dockertestutil.RandomFreeHostPort() + if err != nil { + log.Fatalf("could not find an open port: %s", err) + } + portNotation := fmt.Sprintf("%d/tcp", port) + + hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) + + hostname := fmt.Sprintf("hs-oidcmock-%s", hash) + + usersJSON, err := json.Marshal(users) + if err != nil { + return err + } + + mockOidcOptions := &dockertest.RunOptions{ + Name: hostname, + Cmd: []string{"headscale", "mockoidc"}, + ExposedPorts: []string{portNotation}, + PortBindings: map[docker.Port][]docker.PortBinding{ + docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, + }, + Networks: s.Networks(), + Env: []string{ + fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), + fmt.Sprintf("MOCKOIDC_PORT=%d", port), + "MOCKOIDC_CLIENT_ID=superclient", + "MOCKOIDC_CLIENT_SECRET=supersecret", + fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), + fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), + }, + } + + headscaleBuildOptions := &dockertest.BuildOptions{ + Dockerfile: hsic.IntegrationTestDockerFileName, + ContextDir: dockerContextPath, + } + + err = s.pool.RemoveContainerByName(hostname) + if err != nil { + return err + } + + s.mockOIDC = scenarioOIDC{} + + if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( + headscaleBuildOptions, + mockOidcOptions, + dockertestutil.DockerRestartPolicy); err == nil { + s.mockOIDC.r = pmockoidc + } else { + return err + } + + // headscale needs to set up the provider with a specific + // IP addr to ensure we get the correct config from the well-known + // endpoint. + network := s.Networks()[0] + ipAddr := s.mockOIDC.r.GetIPInNetwork(network) + + log.Println("Waiting for headscale mock oidc to be ready for tests") + hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) + + if err := s.pool.Retry(func() error { + oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) + httpClient := &http.Client{} + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) + resp, err := httpClient.Do(req) + if err != nil { + log.Printf("headscale mock OIDC tests is not ready: %s\n", err) + + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errStatusCodeNotOK + } + + return nil + }); err != nil { + return err + } + + s.mockOIDC.cfg = &types.OIDCConfig{ + Issuer: fmt.Sprintf( + "http://%s/oidc", + hostEndpoint, + ), + ClientID: "superclient", + ClientSecret: "supersecret", + OnlyStartIfOIDCIsAvailable: true, + } + + log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) + + return nil +} + +type extraServiceFunc func(*Scenario, string) (*dockertest.Resource, error) + +func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { + // port, err := dockertestutil.RandomFreeHostPort() + // if err != nil { + // log.Fatalf("could not find an open port: %s", err) + // } + // portNotation := fmt.Sprintf("%d/tcp", port) + + hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) + + hostname := fmt.Sprintf("hs-webservice-%s", hash) + + network, ok := s.networks[prefixedNetworkName(networkName)] + if !ok { + return nil, fmt.Errorf("network does not exist: %s", networkName) + } + + webOpts := &dockertest.RunOptions{ + Name: hostname, + Cmd: []string{"/bin/sh", "-c", "cd / ; python3 -m http.server --bind :: 80"}, + // ExposedPorts: []string{portNotation}, + // PortBindings: map[docker.Port][]docker.PortBinding{ + // docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, + // }, + Networks: []*dockertest.Network{network}, + Env: []string{}, + } + + webBOpts := &dockertest.BuildOptions{ + Dockerfile: hsic.IntegrationTestDockerFileName, + ContextDir: dockerContextPath, + } + + web, err := s.pool.BuildAndRunWithBuildOptions( + webBOpts, + webOpts, + dockertestutil.DockerRestartPolicy) + if err != nil { + return nil, err + } + + // headscale needs to set up the provider with a specific + // IP addr to ensure we get the correct config from the well-known + // endpoint. + // ipAddr := web.GetIPInNetwork(network) + + // log.Println("Waiting for headscale mock oidc to be ready for tests") + // hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) + + // if err := s.pool.Retry(func() error { + // oidcConfigURL := fmt.Sprintf("http://%s/etc/hostname", hostEndpoint) + // httpClient := &http.Client{} + // ctx := context.Background() + // req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) + // resp, err := httpClient.Do(req) + // if err != nil { + // log.Printf("headscale mock OIDC tests is not ready: %s\n", err) + + // return err + // } + // defer resp.Body.Close() + + // if resp.StatusCode != http.StatusOK { + // return errStatusCodeNotOK + // } + + // return nil + // }); err != nil { + // return err + // } + + return web, nil +} diff --git a/integration/scenario_test.go b/integration/scenario_test.go index aec6cb5c..7f34fa77 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/juanfont/headscale/integration/dockertestutil" + "github.com/juanfont/headscale/integration/tsic" ) // This file is intended to "test the test framework", by proxy it will also test @@ -33,7 +34,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(ScenarioSpec{}) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -68,38 +69,6 @@ func TestHeadscale(t *testing.T) { }) } -// If subtests are parallel, then they will start before setup is run. -// This might mean we approach setup slightly wrong, but for now, ignore -// the linter -// nolint:tparallel -func TestCreateTailscale(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - user := "only-create-containers" - - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - scenario.users[user] = &User{ - Clients: make(map[string]TailscaleClient), - } - - t.Run("create-tailscale", func(t *testing.T) { - err := scenario.CreateTailscaleNodesInUser(user, "all", 3) - if err != nil { - t.Fatalf("failed to add tailscale nodes: %s", err) - } - - if clients := len(scenario.users[user].Clients); clients != 3 { - t.Fatalf("wrong number of tailscale clients: %d != %d", clients, 3) - } - - // TODO(kradalby): Test "all" version logic - }) -} - // If subtests are parallel, then they will start before setup is run. // This might mean we approach setup slightly wrong, but for now, ignore // the linter @@ -114,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 - scenario, err := NewScenario(dockertestMaxWait()) + scenario, err := NewScenario(ScenarioSpec{}) assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -142,7 +111,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { }) t.Run("create-tailscale", func(t *testing.T) { - err := scenario.CreateTailscaleNodesInUser(user, "unstable", count) + err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) if err != nil { t.Fatalf("failed to add tailscale nodes: %s", err) } diff --git a/integration/ssh_test.go b/integration/ssh_test.go index ade119d3..d9983f65 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -50,15 +50,15 @@ var retry = func(times int, sleepInterval time.Duration, func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario { t.Helper() - scenario, err := NewScenario(dockertestMaxWait()) + + spec := ScenarioSpec{ + NodesPerUser: clientsPerUser, + Users: []string{"user1", "user2"}, + } + scenario, err := NewScenario(spec) assertNoErr(t, err) - spec := map[string]int{ - "user1": clientsPerUser, - "user2": clientsPerUser, - } - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{ tsic.WithSSH(), diff --git a/integration/tailscale.go b/integration/tailscale.go index 9ab6e1e2..552fc759 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -5,6 +5,7 @@ import ( "net/netip" "net/url" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" "tailscale.com/ipn/ipnstate" @@ -27,6 +28,9 @@ type TailscaleClient interface { Up() error Down() error IPs() ([]netip.Addr, error) + MustIPs() []netip.Addr + MustIPv4() netip.Addr + MustIPv6() netip.Addr FQDN() (string, error) Status(...bool) (*ipnstate.Status, error) MustStatus() *ipnstate.Status @@ -38,6 +42,7 @@ type TailscaleClient interface { WaitForPeers(expected int) error Ping(hostnameOrIP string, opts ...tsic.PingOption) error Curl(url string, opts ...tsic.CurlOption) (string, error) + Traceroute(netip.Addr) (util.Traceroute, error) ID() string ReadFile(path string) ([]byte, error) diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index b501dc1a..0c8ba734 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -13,6 +13,7 @@ import ( "net/url" "os" "reflect" + "runtime/debug" "strconv" "strings" "time" @@ -81,6 +82,7 @@ type TailscaleInContainer struct { workdir string netfilter string extraLoginArgs []string + withAcceptRoutes bool // build options, solely for HEAD buildConfig TailscaleInContainerBuildConfig @@ -101,26 +103,10 @@ func WithCACert(cert []byte) Option { } } -// WithOrCreateNetwork sets the Docker container network to use with -// the Tailscale instance, if the parameter is nil, a new network, -// isolating the TailscaleClient, will be created. If a network is -// passed, the Tailscale instance will join the given network. -func WithOrCreateNetwork(network *dockertest.Network) Option { +// WithNetwork sets the Docker container network to use with +// the Tailscale instance. +func WithNetwork(network *dockertest.Network) Option { return func(tsic *TailscaleInContainer) { - if network != nil { - tsic.network = network - - return - } - - network, err := dockertestutil.GetFirstOrCreateNetwork( - tsic.pool, - fmt.Sprintf("%s-network", tsic.hostname), - ) - if err != nil { - log.Fatalf("failed to create network: %s", err) - } - tsic.network = network } } @@ -212,11 +198,17 @@ func WithExtraLoginArgs(args []string) Option { } } +// WithAcceptRoutes tells the node to accept incomming routes. +func WithAcceptRoutes() Option { + return func(tsic *TailscaleInContainer) { + tsic.withAcceptRoutes = true + } +} + // New returns a new TailscaleInContainer instance. func New( pool *dockertest.Pool, version string, - network *dockertest.Network, opts ...Option, ) (*TailscaleInContainer, error) { hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength) @@ -230,8 +222,7 @@ func New( version: version, hostname: hostname, - pool: pool, - network: network, + pool: pool, withEntrypoint: []string{ "/bin/sh", @@ -244,6 +235,10 @@ func New( opt(tsic) } + if tsic.network == nil { + return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) + } + tailscaleOptions := &dockertest.RunOptions{ Name: hostname, Networks: []*dockertest.Network{tsic.network}, @@ -442,7 +437,7 @@ func (t *TailscaleInContainer) Login( "--login-server=" + loginServer, "--authkey=" + authKey, "--hostname=" + t.hostname, - "--accept-routes=false", + fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes), } if t.extraLoginArgs != nil { @@ -597,6 +592,33 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { return ips, nil } +func (t *TailscaleInContainer) MustIPs() []netip.Addr { + ips, err := t.IPs() + if err != nil { + panic(err) + } + + return ips +} + +func (t *TailscaleInContainer) MustIPv4() netip.Addr { + for _, ip := range t.MustIPs() { + if ip.Is4() { + return ip + } + } + panic("no ipv4 found") +} + +func (t *TailscaleInContainer) MustIPv6() netip.Addr { + for _, ip := range t.MustIPs() { + if ip.Is6() { + return ip + } + } + panic("no ipv6 found") +} + // Status returns the ipnstate.Status of the Tailscale instance. func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) { command := []string{ @@ -992,6 +1014,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err ), ) if err != nil { + log.Printf("command: %v", command) log.Printf( "failed to run ping command from %s to %s, err: %s", t.Hostname(), @@ -1108,6 +1131,26 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err return result, nil } +func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) { + command := []string{ + "traceroute", + ip.String(), + } + + var result util.Traceroute + stdout, stderr, err := t.Execute(command) + if err != nil { + return result, err + } + + result, err = util.ParseTraceroute(stdout + stderr) + if err != nil { + return result, err + } + + return result, nil +} + // WriteFile save file inside the Tailscale container. func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)