fix webauth + autoapprove routes (#2528)

* types/node: add helper funcs for node tags

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* types/node: add DebugString method for node

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy/v2: add String func to AutoApprover interface

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy/v2: simplify, use slices.Contains

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy/v2: debug, use nodes.DebugString

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy/v1: fix potential nil pointer in NodeCanApproveRoute

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy/v1: slices.Contains

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration/tsic: fix diff in login commands

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: fix webauth running with wrong scenario

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: move common oidc opts to func

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: require node count, more verbose

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* auth: remove uneffective route approve

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* .github/workflows: fmt

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration/tsic: add id func

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: remove call that might be nil

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: test autoapprovers against web/authkey x group/tag/user

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: unique network id per scenario

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* Revert "integration: move common oidc opts to func"

This reverts commit 7e9d165d4a900c304f1083b665f1a24a26e06e55.

* remove cmd

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: clean docker images between runs in ci

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: run autoapprove test against differnt policy modes

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration/tsic: append, not overrwrite extra login args

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* .github/workflows: remove polv2

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-04-30 08:54:04 +03:00 committed by GitHub
parent 57861507ab
commit f1206328dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 732 additions and 401 deletions

View File

@ -7,6 +7,8 @@ import (
"os" "os"
"sync" "sync"
"slices"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -145,13 +147,7 @@ func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
tags, invalid := pm.pol.TagsOfNode(pm.users, node) tags, invalid := pm.pol.TagsOfNode(pm.users, node)
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy") log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
for _, t := range tags { return slices.Contains(tags, tag)
if t == tag {
return true
}
}
return false
} }
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool { func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
@ -174,7 +170,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefi
} }
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first // approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if ips.Contains(*node.IPv4) { if ips != nil && ips.Contains(*node.IPv4) {
return true return true
} }
} }

View File

@ -7,6 +7,8 @@ import (
"strings" "strings"
"sync" "sync"
"slices"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
@ -174,12 +176,10 @@ func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
defer pm.mu.Unlock() defer pm.mu.Unlock()
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok { if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
for _, nodeAddr := range node.IPs() { if slices.ContainsFunc(node.IPs(), ips.Contains) {
if ips.Contains(nodeAddr) {
return true return true
} }
} }
}
return false return false
} }
@ -196,12 +196,10 @@ func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefi
// where there is an exact entry, e.g. 10.0.0.0/8, then // where there is an exact entry, e.g. 10.0.0.0/8, then
// check and return quickly // check and return quickly
if _, ok := pm.autoApproveMap[route]; ok { if _, ok := pm.autoApproveMap[route]; ok {
for _, nodeAddr := range node.IPs() { if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) {
if pm.autoApproveMap[route].Contains(nodeAddr) {
return true return true
} }
} }
}
// The slow path is that the node tries to approve // The slow path is that the node tries to approve
// 10.0.10.0/24, which is a part of 10.0.0.0/8, then we // 10.0.10.0/24, which is a part of 10.0.0.0/8, then we
@ -220,13 +218,11 @@ func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefi
// Check if prefix is larger (so containing) and then overlaps // Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover // the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
for _, nodeAddr := range node.IPs() { if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) {
if approveAddrs.Contains(nodeAddr) {
return true return true
} }
} }
} }
}
return false return false
} }
@ -279,5 +275,8 @@ func (pm *PolicyManager) DebugString() string {
} }
} }
sb.WriteString("\n\n")
sb.WriteString(pm.nodes.DebugString())
return sb.String() return sb.String()
} }

View File

@ -162,6 +162,10 @@ func (g Group) CanBeAutoApprover() bool {
return true return true
} }
func (g Group) String() string {
return string(g)
}
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
var errs []error var errs []error
@ -235,6 +239,10 @@ func (t Tag) CanBeAutoApprover() bool {
return true return true
} }
func (t Tag) String() string {
return string(t)
}
// Host is a string that represents a hostname. // Host is a string that represents a hostname.
type Host string type Host string
@ -590,6 +598,7 @@ func unmarshalPointer[T any](
type AutoApprover interface { type AutoApprover interface {
CanBeAutoApprover() bool CanBeAutoApprover() bool
UnmarshalJSON([]byte) error UnmarshalJSON([]byte) error
String() string
} }
type AutoApprovers []AutoApprover type AutoApprovers []AutoApprover

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"slices" "slices"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -194,19 +195,26 @@ func (node *Node) IsTagged() bool {
// Currently, this function only handles tags set // Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys) // via CLI ("forced tags" and preauthkeys)
func (node *Node) HasTag(tag string) bool { func (node *Node) HasTag(tag string) bool {
if slices.Contains(node.ForcedTags, tag) { return slices.Contains(node.Tags(), tag)
return true }
}
if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) { func (node *Node) Tags() []string {
return true var tags []string
if node.AuthKey != nil {
tags = append(tags, node.AuthKey.Tags...)
} }
// TODO(kradalby): Figure out how tagging should work // TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags. // and hostinfo.requestedtags.
// Do this in other work. // Do this in other work.
// #2417
return false tags = append(tags, node.ForcedTags...)
sort.Strings(tags)
tags = slices.Compact(tags)
return tags
} }
func (node *Node) RequestTags() []string { func (node *Node) RequestTags() []string {
@ -549,3 +557,25 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
return ret return ret
} }
func (nodes Nodes) DebugString() string {
var sb strings.Builder
sb.WriteString("Nodes:\n")
for _, node := range nodes {
sb.WriteString(node.DebugString())
sb.WriteString("\n")
}
return sb.String()
}
func (node Node) DebugString() string {
var sb strings.Builder
fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID)
fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags())
fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs())
fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes)
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
sb.WriteString("\n")
return sb.String()
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url" "net/url"
"os"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -173,3 +174,15 @@ func ParseTraceroute(output string) (Traceroute, error) {
return result, nil return result, nil
} }
func IsCI() bool {
if _, ok := os.LookupEnv("CI"); ok {
return true
}
if _, ok := os.LookupEnv("GITHUB_RUN_ID"); ok {
return true
}
return false
}

View File

@ -1054,7 +1054,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
// Initially all nodes can reach each other // Initially all nodes can reach each other
for _, client := range all { for _, client := range all {
for _, peer := range all { for _, peer := range all {
if client.ID() == peer.ID() { if client.ContainerID() == peer.ContainerID() {
continue continue
} }

View File

@ -442,7 +442,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, listUsers, 0) assert.Len(t, listUsers, 0)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err) assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint()) u, err := ts.LoginWithURL(headscale.GetEndpoint())

View File

@ -26,7 +26,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
} }
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnvWithLoginURL(
nil, nil,
hsic.WithTestName("webauthping"), hsic.WithTestName("webauthping"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
@ -66,7 +66,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnvWithLoginURL(
nil, nil,
hsic.WithTestName("weblogout"), hsic.WithTestName("weblogout"),
hsic.WithTLS(), hsic.WithTLS(),

View File

@ -6,6 +6,7 @@ import (
"log" "log"
"net" "net"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
) )
@ -105,3 +106,23 @@ func CleanUnreferencedNetworks(pool *dockertest.Pool) error {
return nil return nil
} }
// CleanImagesInCI removes images if running in CI.
func CleanImagesInCI(pool *dockertest.Pool) error {
if !util.IsCI() {
log.Println("Skipping image cleanup outside of CI")
return nil
}
images, err := pool.Client.ListImages(docker.ListImagesOptions{})
if err != nil {
return fmt.Errorf("getting images: %w", err)
}
for _, image := range images {
log.Printf("removing image: %s, %v", image.ID, image.RepoTags)
_ = pool.Client.RemoveImage(image.ID)
}
return nil
}

View File

@ -138,7 +138,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
t.Fatalf("failed to create user %s: %s", userName, err) t.Fatalf("failed to create user %s: %s", userName, err)
} }
err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
if err != nil { if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
} }
@ -216,7 +216,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
t.Fatalf("failed to create user %s: %s", userName, err) t.Fatalf("failed to create user %s: %s", userName, err)
} }
err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
if err != nil { if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
} }

View File

@ -287,9 +287,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 0, 0) requireNodeRouteCount(t, nodes[0], 1, 0, 0)
assertNodeRouteCount(t, nodes[1], 1, 0, 0) requireNodeRouteCount(t, nodes[1], 1, 0, 0)
assertNodeRouteCount(t, nodes[2], 1, 0, 0) requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
// they are not yet enabled. // they are not yet enabled.
@ -319,9 +319,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, nodes[0], 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 0, 0) requireNodeRouteCount(t, nodes[1], 1, 0, 0)
assertNodeRouteCount(t, nodes[2], 1, 0, 0) requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the client has routes from the primary machine and can access // Verify that the client has routes from the primary machine and can access
// the webservice. // the webservice.
@ -375,9 +375,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, nodes[0], 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[1], 1, 1, 0)
assertNodeRouteCount(t, nodes[2], 1, 0, 0) requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1 = subRouter1.MustStatus() srs1 = subRouter1.MustStatus()
@ -431,9 +431,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, nodes[0], 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[1], 1, 1, 0)
assertNodeRouteCount(t, nodes[2], 1, 1, 0) requireNodeRouteCount(t, nodes[2], 1, 1, 0)
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1 = subRouter1.MustStatus() srs1 = subRouter1.MustStatus()
@ -645,9 +645,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, nodes[0], 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[1], 1, 1, 0)
assertNodeRouteCount(t, nodes[2], 1, 0, 0) requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -690,9 +690,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 0, 0) requireNodeRouteCount(t, nodes[0], 1, 0, 0)
assertNodeRouteCount(t, nodes[1], 1, 1, 1) requireNodeRouteCount(t, nodes[1], 1, 1, 1)
assertNodeRouteCount(t, nodes[2], 1, 0, 0) requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -738,9 +738,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 6) assert.Len(t, nodes, 6)
assertNodeRouteCount(t, nodes[0], 1, 1, 0) requireNodeRouteCount(t, nodes[0], 1, 1, 0)
assertNodeRouteCount(t, nodes[1], 1, 1, 1) requireNodeRouteCount(t, nodes[1], 1, 1, 1)
assertNodeRouteCount(t, nodes[2], 1, 0, 0) requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -870,8 +870,8 @@ func TestSubnetRouteACL(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, nodes, 2) require.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 1, 0, 0) requireNodeRouteCount(t, nodes[0], 1, 0, 0)
assertNodeRouteCount(t, nodes[1], 0, 0, 0) requireNodeRouteCount(t, nodes[1], 0, 0, 0)
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
// they are not yet enabled. // they are not yet enabled.
@ -899,8 +899,8 @@ func TestSubnetRouteACL(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, nodes, 2) require.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, nodes[0], 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 0, 0, 0) requireNodeRouteCount(t, nodes[1], 0, 0, 0)
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1, _ := subRouter1.Status() srs1, _ := subRouter1.Status()
@ -1034,8 +1034,8 @@ func TestEnablingExitRoutes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, nodes, 2) require.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 2, 0, 0) requireNodeRouteCount(t, nodes[0], 2, 0, 0)
assertNodeRouteCount(t, nodes[1], 2, 0, 0) requireNodeRouteCount(t, nodes[1], 2, 0, 0)
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
// they are not yet enabled. // they are not yet enabled.
@ -1067,8 +1067,8 @@ func TestEnablingExitRoutes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, nodes, 2) require.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 2, 2, 2) requireNodeRouteCount(t, nodes[0], 2, 2, 2)
assertNodeRouteCount(t, nodes[1], 2, 2, 2) requireNodeRouteCount(t, nodes[1], 2, 2, 2)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -1158,7 +1158,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
nodes, err := headscale.ListNodes() nodes, err := headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 1, 0, 0) requireNodeRouteCount(t, nodes[0], 1, 0, 0)
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
// they are not yet enabled. // they are not yet enabled.
@ -1184,7 +1184,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, nodes[0], 1, 1, 1)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = user2c.Status() status, err = user2c.Status()
@ -1282,7 +1282,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
nodes, err := headscale.ListNodes() nodes, err := headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 2, 0, 0) requireNodeRouteCount(t, nodes[0], 2, 0, 0)
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
// they are not yet enabled. // they are not yet enabled.
@ -1305,7 +1305,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assertNodeRouteCount(t, nodes[0], 2, 2, 2) requireNodeRouteCount(t, nodes[0], 2, 2, 2)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = user2c.Status() status, err = user2c.Status()
@ -1349,6 +1349,15 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node {
for _, node := range nodes {
if node.GetName() == hostname {
return node
}
}
panic("node not found")
}
// TestAutoApproveMultiNetwork tests auto approving of routes // TestAutoApproveMultiNetwork tests auto approving of routes
// by setting up two networks where network1 has three subnet // by setting up two networks where network1 has three subnet
// routers: // routers:
@ -1367,32 +1376,20 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
// - Verify that routes can now be seen by peers. // - Verify that routes can now be seen by peers.
func TestAutoApproveMultiNetwork(t *testing.T) { func TestAutoApproveMultiNetwork(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() bigRoute := netip.MustParsePrefix("10.42.0.0/16")
spec := ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
}
rootRoute := netip.MustParsePrefix("10.42.0.0/16")
subRoute := netip.MustParsePrefix("10.42.7.0/24") subRoute := netip.MustParsePrefix("10.42.7.0/24")
notApprovedRoute := netip.MustParsePrefix("192.168.0.0/24") notApprovedRoute := netip.MustParsePrefix("192.168.0.0/24")
scenario, err := NewScenario(spec) tests := []struct {
require.NoErrorf(t, err, "failed to create scenario: %s", err) name string
defer scenario.ShutdownAssertNoPanics(t) pol *policyv1.ACLPolicy
approver string
pol := &policyv1.ACLPolicy{ spec ScenarioSpec
withURL bool
}{
{
name: "authkey-tag",
pol: &policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
Action: "accept", Action: "accept",
@ -1405,21 +1402,238 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
}, },
AutoApprovers: policyv1.AutoApprovers{ AutoApprovers: policyv1.AutoApprovers{
Routes: map[string][]string{ Routes: map[string][]string{
rootRoute.String(): {"tag:approve"}, bigRoute.String(): {"tag:approve"},
}, },
ExitNode: []string{"tag:approve"}, ExitNode: []string{"tag:approve"},
}, },
},
approver: "tag:approve",
spec: ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
},
},
{
name: "authkey-user",
pol: &policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
AutoApprovers: policyv1.AutoApprovers{
Routes: map[string][]string{
bigRoute.String(): {"user1@"},
},
ExitNode: []string{"user1@"},
},
},
approver: "user1@",
spec: ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
},
},
{
name: "authkey-group",
pol: &policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
Groups: policyv1.Groups{
"group:approve": []string{"user1@"},
},
AutoApprovers: policyv1.AutoApprovers{
Routes: map[string][]string{
bigRoute.String(): {"group:approve"},
},
ExitNode: []string{"group:approve"},
},
},
approver: "group:approve",
spec: ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
},
},
{
name: "webauth-user",
pol: &policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
AutoApprovers: policyv1.AutoApprovers{
Routes: map[string][]string{
bigRoute.String(): {"user1@"},
},
ExitNode: []string{"user1@"},
},
},
approver: "user1@",
spec: ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
},
withURL: true,
},
{
name: "webauth-tag",
pol: &policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
TagOwners: map[string][]string{
"tag:approve": {"user1@"},
},
AutoApprovers: policyv1.AutoApprovers{
Routes: map[string][]string{
bigRoute.String(): {"tag:approve"},
},
ExitNode: []string{"tag:approve"},
},
},
approver: "tag:approve",
spec: ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
},
withURL: true,
},
{
name: "webauth-group",
pol: &policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
Groups: policyv1.Groups{
"group:approve": []string{"user1@"},
},
AutoApprovers: policyv1.AutoApprovers{
Routes: map[string][]string{
bigRoute.String(): {"group:approve"},
},
ExitNode: []string{"group:approve"},
},
},
approver: "group:approve",
spec: ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
},
withURL: true,
},
} }
err = scenario.CreateHeadscaleEnv([]tsic.Option{ for _, tt := range tests {
tsic.WithAcceptRoutes(), for _, dbMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
tsic.WithTags([]string{"tag:approve"}), for _, advertiseDuringUp := range []bool{false, true} {
}, name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, dbMode)
hsic.WithTestName("clienableroute"), t.Run(name, func(t *testing.T) {
scenario, err := NewScenario(tt.spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
opts := []hsic.Option{
hsic.WithTestName("autoapprovemulti"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithACLPolicy(pol), hsic.WithACLPolicy(tt.pol),
hsic.WithPolicyMode(types.PolicyModeDB), hsic.WithPolicyMode(dbMode),
}
tsOpts := []tsic.Option{
tsic.WithAcceptRoutes(),
}
if tt.approver == "tag:approve" {
tsOpts = append(tsOpts,
tsic.WithTags([]string{"tag:approve"}),
)
}
route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)
err = scenario.createHeadscaleEnv(tt.withURL, tsOpts,
opts...,
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -1429,18 +1643,6 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
assert.NotNil(t, headscale)
route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)
// Set the route of usernet1 to be autoapproved
pol.AutoApprovers.Routes[route.String()] = []string{"tag:approve"}
err = headscale.SetPolicy(pol)
require.NoError(t, err)
services, err := scenario.Services("usernet1") services, err := scenario.Services("usernet1")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, services, 1) require.Len(t, services, 1)
@ -1448,6 +1650,51 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
usernet1, err := scenario.Network("usernet1") usernet1, err := scenario.Network("usernet1")
require.NoError(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
assert.NotNil(t, headscale)
if advertiseDuringUp {
tsOpts = append(tsOpts,
tsic.WithExtraLoginArgs([]string{"--advertise-routes=" + route.String()}),
)
}
tsOpts = append(tsOpts, tsic.WithNetwork(usernet1))
// This whole dance is to add a node _after_ all the other nodes
// with an additional tsOpt which advertises the route as part
// of the `tailscale up` command. If we do this as part of the
// scenario creation, it will be added to all nodes and turn
// into a HA node, which isnt something we are testing here.
routerUsernet1, err := scenario.CreateTailscaleNode("head", tsOpts...)
require.NoError(t, err)
defer routerUsernet1.Shutdown()
if tt.withURL {
u, err := routerUsernet1.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
body, err := doLoginURL(routerUsernet1.Hostname(), u)
assertNoErr(t, err)
scenario.runHeadscaleRegister("user1", body)
} else {
pak, err := scenario.CreatePreAuthKey("user1", false, false)
assertNoErr(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key)
assertNoErr(t, err)
}
// extra creation end.
// Set the route of usernet1 to be autoapproved
tt.pol.AutoApprovers.Routes[route.String()] = []string{tt.approver}
err = headscale.SetPolicy(tt.pol)
require.NoError(t, err)
routerUsernet1ID := routerUsernet1.MustID()
web := services[0] web := services[0]
webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1))
weburl := fmt.Sprintf("http://%s/etc/hostname", webip) weburl := fmt.Sprintf("http://%s/etc/hostname", webip)
@ -1464,12 +1711,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// This is ok because the scenario makes users in order, so the three first // This is ok because the scenario makes users in order, so the three first
// nodes, which are subnet routes, will be created first, and the last user // nodes, which are subnet routes, will be created first, and the last user
// will be created with the second. // will be created with the second.
routerUsernet1 := allClients[0]
routerSubRoute := allClients[1] routerSubRoute := allClients[1]
routerExitNode := allClients[2] routerExitNode := allClients[2]
client := allClients[3] client := allClients[3]
if !advertiseDuringUp {
// Advertise the route for the dockersubnet of user1 // Advertise the route for the dockersubnet of user1
command := []string{ command := []string{
"tailscale", "tailscale",
@ -1478,6 +1725,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
} }
_, _, err = routerUsernet1.Execute(command) _, _, err = routerUsernet1.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err) require.NoErrorf(t, err, "failed to advertise route: %s", err)
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -1485,7 +1733,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err := headscale.ListNodes() nodes, err := headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err := client.Status() status, err := client.Status()
@ -1494,7 +1742,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { if peerStatus.ID == routerUsernet1ID.StableID() {
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route})
} else { } else {
@ -1514,8 +1762,8 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
assertTracerouteViaIP(t, tr, routerUsernet1.MustIPv4()) assertTracerouteViaIP(t, tr, routerUsernet1.MustIPv4())
// Remove the auto approval from the policy, any routes already enabled should be allowed. // Remove the auto approval from the policy, any routes already enabled should be allowed.
delete(pol.AutoApprovers.Routes, route.String()) delete(tt.pol.AutoApprovers.Routes, route.String())
err = headscale.SetPolicy(pol) err = headscale.SetPolicy(tt.pol)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -1524,7 +1772,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1533,7 +1781,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { if peerStatus.ID == routerUsernet1ID.StableID() {
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route})
} else { } else {
@ -1554,7 +1802,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Disable the route, making it unavailable since it is no longer auto-approved // Disable the route, making it unavailable since it is no longer auto-approved
_, err = headscale.ApproveRoutes( _, err = headscale.ApproveRoutes(
nodes[0].GetId(), MustFindNode(routerUsernet1.Hostname(), nodes).GetId(),
[]netip.Prefix{}, []netip.Prefix{},
) )
require.NoError(t, err) require.NoError(t, err)
@ -1565,7 +1813,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 0, 0) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1578,8 +1826,8 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Add the route back to the auto approver in the policy, the route should // Add the route back to the auto approver in the policy, the route should
// now become available again. // now become available again.
pol.AutoApprovers.Routes[route.String()] = []string{"tag:approve"} tt.pol.AutoApprovers.Routes[route.String()] = []string{tt.approver}
err = headscale.SetPolicy(pol) err = headscale.SetPolicy(tt.pol)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -1588,7 +1836,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1597,7 +1845,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { if peerStatus.ID == routerUsernet1ID.StableID() {
require.NotNil(t, peerStatus.PrimaryRoutes) require.NotNil(t, peerStatus.PrimaryRoutes)
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route})
@ -1619,7 +1867,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Advertise and validate a subnet of an auto approved route, /24 inside the // Advertise and validate a subnet of an auto approved route, /24 inside the
// auto approved /16. // auto approved /16.
command = []string{ command := []string{
"tailscale", "tailscale",
"set", "set",
"--advertise-routes=" + subRoute.String(), "--advertise-routes=" + subRoute.String(),
@ -1633,8 +1881,8 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 1, 1) requireNodeRouteCount(t, nodes[1], 1, 1, 1)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1643,7 +1891,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { if peerStatus.ID == routerUsernet1ID.StableID() {
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route})
} else if peerStatus.ID == "2" { } else if peerStatus.ID == "2" {
@ -1669,9 +1917,9 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[1], 1, 1, 0)
assertNodeRouteCount(t, nodes[2], 0, 0, 0) requireNodeRouteCount(t, nodes[2], 0, 0, 0)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1680,7 +1928,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { if peerStatus.ID == routerUsernet1ID.StableID() {
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route})
} else { } else {
@ -1701,9 +1949,9 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assertNodeRouteCount(t, nodes[0], 1, 1, 1) requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
assertNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[1], 1, 1, 0)
assertNodeRouteCount(t, nodes[2], 2, 2, 2) requireNodeRouteCount(t, nodes[2], 2, 2, 2)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1712,7 +1960,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { if peerStatus.ID == routerUsernet1ID.StableID() {
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *route)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*route})
} else if peerStatus.ID == "3" { } else if peerStatus.ID == "3" {
@ -1721,6 +1969,10 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutes(t, peerStatus, nil) requirePeerSubnetRoutes(t, peerStatus, nil)
} }
} }
})
}
}
}
} }
func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) { func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) {
@ -1757,9 +2009,9 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
} }
} }
func assertNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) { func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
t.Helper() t.Helper()
assert.Len(t, node.GetAvailableRoutes(), announced) require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
assert.Len(t, node.GetApprovedRoutes(), approved) require.Lenf(t, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
assert.Len(t, node.GetSubnetRoutes(), subnet) require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
} }

View File

@ -109,6 +109,9 @@ type Scenario struct {
spec ScenarioSpec spec ScenarioSpec
userToNetwork map[string]*dockertest.Network userToNetwork map[string]*dockertest.Network
testHashPrefix string
testDefaultNetwork string
} }
// ScenarioSpec describes the users, nodes, and network topology to // ScenarioSpec describes the users, nodes, and network topology to
@ -150,11 +153,8 @@ type ScenarioSpec struct {
MaxWait time.Duration MaxWait time.Duration
} }
var TestHashPrefix = "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) func (s *Scenario) prefixedNetworkName(name string) string {
var TestDefaultNetwork = TestHashPrefix + "-default" return s.testHashPrefix + "-" + name
func prefixedNetworkName(name string) string {
return TestHashPrefix + "-" + name
} }
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
@ -169,6 +169,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
// This might be a no op, but it is worth a try as we sometime // This might be a no op, but it is worth a try as we sometime
// dont clean up nicely after ourselves. // dont clean up nicely after ourselves.
dockertestutil.CleanUnreferencedNetworks(pool) dockertestutil.CleanUnreferencedNetworks(pool)
dockertestutil.CleanImagesInCI(pool)
if spec.MaxWait == 0 { if spec.MaxWait == 0 {
pool.MaxWait = dockertestMaxWait() pool.MaxWait = dockertestMaxWait()
@ -176,18 +177,22 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
pool.MaxWait = spec.MaxWait pool.MaxWait = spec.MaxWait
} }
testHashPrefix := "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength)
s := &Scenario{ s := &Scenario{
controlServers: xsync.NewMapOf[string, ControlServer](), controlServers: xsync.NewMapOf[string, ControlServer](),
users: make(map[string]*User), users: make(map[string]*User),
pool: pool, pool: pool,
spec: spec, spec: spec,
testHashPrefix: testHashPrefix,
testDefaultNetwork: testHashPrefix + "-default",
} }
var userToNetwork map[string]*dockertest.Network var userToNetwork map[string]*dockertest.Network
if spec.Networks != nil || len(spec.Networks) != 0 { if spec.Networks != nil || len(spec.Networks) != 0 {
for name, users := range s.spec.Networks { for name, users := range s.spec.Networks {
networkName := TestHashPrefix + "-" + name networkName := testHashPrefix + "-" + name
network, err := s.AddNetwork(networkName) network, err := s.AddNetwork(networkName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -201,7 +206,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
} }
} }
} else { } else {
_, err := s.AddNetwork(TestDefaultNetwork) _, err := s.AddNetwork(s.testDefaultNetwork)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -213,7 +218,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
mak.Set(&s.extraServices, prefixedNetworkName(network), append(s.extraServices[prefixedNetworkName(network)], svc)) mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc))
} }
} }
@ -261,7 +266,7 @@ func (s *Scenario) Networks() []*dockertest.Network {
} }
func (s *Scenario) Network(name string) (*dockertest.Network, error) { func (s *Scenario) Network(name string) (*dockertest.Network, error) {
net, ok := s.networks[prefixedNetworkName(name)] net, ok := s.networks[s.prefixedNetworkName(name)]
if !ok { if !ok {
return nil, fmt.Errorf("no network named: %s", name) return nil, fmt.Errorf("no network named: %s", name)
} }
@ -270,7 +275,7 @@ func (s *Scenario) Network(name string) (*dockertest.Network, error) {
} }
func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
net, ok := s.networks[prefixedNetworkName(name)] net, ok := s.networks[s.prefixedNetworkName(name)]
if !ok { if !ok {
return nil, fmt.Errorf("no network named: %s", name) return nil, fmt.Errorf("no network named: %s", name)
} }
@ -288,7 +293,7 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
} }
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
res, ok := s.extraServices[prefixedNetworkName(name)] res, ok := s.extraServices[s.prefixedNetworkName(name)]
if !ok { if !ok {
return nil, fmt.Errorf("no network named: %s", name) return nil, fmt.Errorf("no network named: %s", name)
} }
@ -298,6 +303,7 @@ func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
defer dockertestutil.CleanUnreferencedNetworks(s.pool) defer dockertestutil.CleanUnreferencedNetworks(s.pool)
defer dockertestutil.CleanImagesInCI(s.pool)
s.controlServers.Range(func(_ string, control ControlServer) bool { s.controlServers.Range(func(_ string, control ControlServer) bool {
stdoutPath, stderrPath, err := control.Shutdown() stdoutPath, stderrPath, err := control.Shutdown()
@ -493,8 +499,7 @@ func (s *Scenario) CreateTailscaleNode(
) )
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to create tailscale (%s) node: %w", "failed to create tailscale node: %w",
tsClient.Hostname(),
err, err,
) )
} }
@ -707,7 +712,7 @@ func (s *Scenario) createHeadscaleEnv(
if s.userToNetwork != nil { if s.userToNetwork != nil {
opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user])) opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user]))
} else { } else {
opts = append(tsOpts, tsic.WithNetwork(s.networks[TestDefaultNetwork])) opts = append(tsOpts, tsic.WithNetwork(s.networks[s.testDefaultNetwork]))
} }
err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...) err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...)
@ -1181,7 +1186,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
hostname := fmt.Sprintf("hs-webservice-%s", hash) hostname := fmt.Sprintf("hs-webservice-%s", hash)
network, ok := s.networks[prefixedNetworkName(networkName)] network, ok := s.networks[s.prefixedNetworkName(networkName)]
if !ok { if !ok {
return nil, fmt.Errorf("network does not exist: %s", networkName) return nil, fmt.Errorf("network does not exist: %s", networkName)
} }

View File

@ -111,7 +111,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
}) })
t.Run("create-tailscale", func(t *testing.T) { t.Run("create-tailscale", func(t *testing.T) {
err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
if err != nil { if err != nil {
t.Fatalf("failed to add tailscale nodes: %s", err) t.Fatalf("failed to add tailscale nodes: %s", err)
} }

View File

@ -410,7 +410,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
result, _, err := doSSH(t, client, peer) result, _, err := doSSH(t, client, peer)
assertNoErr(t, err) assertNoErr(t, err)
assertContains(t, peer.ID(), strings.ReplaceAll(result, "\n", "")) assertContains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", ""))
} }
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {

View File

@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
@ -43,7 +44,8 @@ type TailscaleClient interface {
Ping(hostnameOrIP string, opts ...tsic.PingOption) error Ping(hostnameOrIP string, opts ...tsic.PingOption) error
Curl(url string, opts ...tsic.CurlOption) (string, error) Curl(url string, opts ...tsic.CurlOption) (string, error)
Traceroute(netip.Addr) (util.Traceroute, error) Traceroute(netip.Addr) (util.Traceroute, error)
ID() string ContainerID() string
MustID() types.NodeID
ReadFile(path string) ([]byte, error) ReadFile(path string) ([]byte, error)
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client // FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client

View File

@ -18,6 +18,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/integrationutil"
@ -194,7 +195,7 @@ func WithBuildTag(tag string) Option {
// as part of the Login function. // as part of the Login function.
func WithExtraLoginArgs(args []string) Option { func WithExtraLoginArgs(args []string) Option {
return func(tsic *TailscaleInContainer) { return func(tsic *TailscaleInContainer) {
tsic.extraLoginArgs = args tsic.extraLoginArgs = append(tsic.extraLoginArgs, args...)
} }
} }
@ -383,7 +384,7 @@ func (t *TailscaleInContainer) Version() string {
// ID returns the Docker container ID of the TailscaleInContainer // ID returns the Docker container ID of the TailscaleInContainer
// instance. // instance.
func (t *TailscaleInContainer) ID() string { func (t *TailscaleInContainer) ContainerID() string {
return t.container.Container.ID return t.container.Container.ID
} }
@ -426,20 +427,21 @@ func (t *TailscaleInContainer) Logs(stdout, stderr io.Writer) error {
) )
} }
// Up runs the login routine on the given Tailscale instance. func (t *TailscaleInContainer) buildLoginCommand(
// This login mechanism uses the authorised key for authentication.
func (t *TailscaleInContainer) Login(
loginServer, authKey string, loginServer, authKey string,
) error { ) []string {
command := []string{ command := []string{
"tailscale", "tailscale",
"up", "up",
"--login-server=" + loginServer, "--login-server=" + loginServer,
"--authkey=" + authKey,
"--hostname=" + t.hostname, "--hostname=" + t.hostname,
fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes), fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes),
} }
if authKey != "" {
command = append(command, "--authkey="+authKey)
}
if t.extraLoginArgs != nil { if t.extraLoginArgs != nil {
command = append(command, t.extraLoginArgs...) command = append(command, t.extraLoginArgs...)
} }
@ -458,6 +460,16 @@ func (t *TailscaleInContainer) Login(
) )
} }
return command
}
// Login runs the login routine on the given Tailscale instance.
// This login mechanism uses the authorised key for authentication.
func (t *TailscaleInContainer) Login(
loginServer, authKey string,
) error {
command := t.buildLoginCommand(loginServer, authKey)
if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil {
return fmt.Errorf( return fmt.Errorf(
"%s failed to join tailscale client (%s): %w", "%s failed to join tailscale client (%s): %w",
@ -475,17 +487,7 @@ func (t *TailscaleInContainer) Login(
func (t *TailscaleInContainer) LoginWithURL( func (t *TailscaleInContainer) LoginWithURL(
loginServer string, loginServer string,
) (loginURL *url.URL, err error) { ) (loginURL *url.URL, err error) {
command := []string{ command := t.buildLoginCommand(loginServer, "")
"tailscale",
"up",
"--login-server=" + loginServer,
"--hostname=" + t.hostname,
"--accept-routes=false",
}
if t.extraLoginArgs != nil {
command = append(command, t.extraLoginArgs...)
}
stdout, stderr, err := t.Execute(command) stdout, stderr, err := t.Execute(command)
if errors.Is(err, errTailscaleNotLoggedIn) { if errors.Is(err, errTailscaleNotLoggedIn) {
@ -646,7 +648,7 @@ func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
return &status, err return &status, err
} }
// Status returns the ipnstate.Status of the Tailscale instance. // MustStatus returns the ipnstate.Status of the Tailscale instance.
func (t *TailscaleInContainer) MustStatus() *ipnstate.Status { func (t *TailscaleInContainer) MustStatus() *ipnstate.Status {
status, err := t.Status() status, err := t.Status()
if err != nil { if err != nil {
@ -656,6 +658,21 @@ func (t *TailscaleInContainer) MustStatus() *ipnstate.Status {
return status return status
} }
// MustID returns the ID of the Tailscale instance.
func (t *TailscaleInContainer) MustID() types.NodeID {
status, err := t.Status()
if err != nil {
panic(err)
}
id, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
if err != nil {
panic(fmt.Sprintf("failed to parse ID: %s", err))
}
return types.NodeID(id)
}
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. // Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
// Only works with Tailscale 1.56 and newer. // Only works with Tailscale 1.56 and newer.
// Panics if version is lower then minimum. // Panics if version is lower then minimum.

View File

@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"os"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -344,22 +343,10 @@ func isSelfClient(client TailscaleClient, addr string) bool {
return false return false
} }
func isCI() bool {
if _, ok := os.LookupEnv("CI"); ok {
return true
}
if _, ok := os.LookupEnv("GITHUB_RUN_ID"); ok {
return true
}
return false
}
func dockertestMaxWait() time.Duration { func dockertestMaxWait() time.Duration {
wait := 120 * time.Second //nolint wait := 120 * time.Second //nolint
if isCI() { if util.IsCI() {
wait = 300 * time.Second //nolint wait = 300 * time.Second //nolint
} }