Multi network integration tests (#2464)

This commit is contained in:
Kristoffer Dalby 2025-03-21 11:49:32 +01:00 committed by GitHub
parent 707438f25e
commit 603f3ad490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 2385 additions and 1449 deletions

View File

@ -70,8 +70,9 @@ jobs:
- TestAutoApprovedSubRoute2068 - TestAutoApprovedSubRoute2068
- TestSubnetRouteACL - TestSubnetRouteACL
- TestEnablingExitRoutes - TestEnablingExitRoutes
- TestSubnetRouterMultiNetwork
- TestSubnetRouterMultiNetworkExitNode
- TestHeadscale - TestHeadscale
- TestCreateTailscale
- TestTailscaleNodesJoiningHeadcale - TestTailscaleNodesJoiningHeadcale
- TestSSHOneUserToAll - TestSSHOneUserToAll
- TestSSHMultipleUsersAllToAll - TestSSHMultipleUsersAllToAll

View File

@ -70,8 +70,9 @@ jobs:
- TestAutoApprovedSubRoute2068 - TestAutoApprovedSubRoute2068
- TestSubnetRouteACL - TestSubnetRouteACL
- TestEnablingExitRoutes - TestEnablingExitRoutes
- TestSubnetRouterMultiNetwork
- TestSubnetRouterMultiNetworkExitNode
- TestHeadscale - TestHeadscale
- TestCreateTailscale
- TestTailscaleNodesJoiningHeadcale - TestTailscaleNodesJoiningHeadcale
- TestSSHOneUserToAll - TestSSHOneUserToAll
- TestSSHMultipleUsersAllToAll - TestSSHMultipleUsersAllToAll

View File

@ -165,9 +165,13 @@ func Test_fullMapResponse(t *testing.T) {
), ),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{ AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv4(), tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
}, },
HomeDERP: 0, HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0", LegacyDERPString: "127.3.3.40:0",

View File

@ -2,13 +2,13 @@ package mapper
import ( import (
"fmt" "fmt"
"net/netip"
"time" "time"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -49,14 +49,6 @@ func tailNode(
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
addrs := node.Prefixes() addrs := node.Prefixes()
allowedIPs := append(
[]netip.Prefix{},
addrs...) // we append the node own IP, as it is required by the clients
for _, route := range node.SubnetRoutes() {
allowedIPs = append(allowedIPs, netip.Prefix(route))
}
var derp int var derp int
// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077 // TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077
@ -89,6 +81,10 @@ func tailNode(
} }
tags = lo.Uniq(append(tags, node.ForcedTags...)) tags = lo.Uniq(append(tags, node.ForcedTags...))
allowed := append(node.Prefixes(), primary.PrimaryRoutes(node.ID)...)
allowed = append(allowed, node.ExitRoutes()...)
tsaddr.SortPrefixes(allowed)
tNode := tailcfg.Node{ tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: node.ID.StableID(), StableID: node.ID.StableID(),
@ -104,7 +100,7 @@ func tailNode(
DiscoKey: node.DiscoKey, DiscoKey: node.DiscoKey,
Addresses: addrs, Addresses: addrs,
PrimaryRoutes: primary.PrimaryRoutes(node.ID), PrimaryRoutes: primary.PrimaryRoutes(node.ID),
AllowedIPs: allowedIPs, AllowedIPs: allowed,
Endpoints: node.Endpoints, Endpoints: node.Endpoints,
HomeDERP: derp, HomeDERP: derp,
LegacyDERPString: legacyDERP, LegacyDERPString: legacyDERP,

View File

@ -67,8 +67,6 @@ func TestTailNode(t *testing.T) {
want: &tailcfg.Node{ want: &tailcfg.Node{
Name: "empty", Name: "empty",
StableID: "0", StableID: "0",
Addresses: []netip.Prefix{},
AllowedIPs: []netip.Prefix{},
HomeDERP: 0, HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0", LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}), Hostinfo: hiview(tailcfg.Hostinfo{}),
@ -139,9 +137,13 @@ func TestTailNode(t *testing.T) {
), ),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{ AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv4(), tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
}, },
HomeDERP: 0, HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0", LegacyDERPString: "127.3.3.40:0",
@ -156,10 +158,6 @@ func TestTailNode(t *testing.T) {
Tags: []string{}, Tags: []string{},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
LastSeen: &lastSeen, LastSeen: &lastSeen,
MachineAuthorized: true, MachineAuthorized: true,

View File

@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
xmaps "golang.org/x/exp/maps" xmaps "golang.org/x/exp/maps"
"tailscale.com/net/tsaddr"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
@ -74,18 +75,12 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
// If the current primary is not available, select a new one. // If the current primary is not available, select a new one.
for prefix, nodes := range allPrimaries { for prefix, nodes := range allPrimaries {
if node, ok := pr.primaries[prefix]; ok { if node, ok := pr.primaries[prefix]; ok {
if len(nodes) < 2 {
delete(pr.primaries, prefix)
changed = true
continue
}
// If the current primary is still available, continue. // If the current primary is still available, continue.
if slices.Contains(nodes, node) { if slices.Contains(nodes, node) {
continue continue
} }
} }
if len(nodes) >= 2 { if len(nodes) >= 1 {
pr.primaries[prefix] = nodes[0] pr.primaries[prefix] = nodes[0]
changed = true changed = true
} }
@ -107,12 +102,16 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
return changed return changed
} }
func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bool { // SetRoutes sets the routes for a given Node ID and recalculates the primary routes
// of the headscale.
// It returns true if there was a change in primary routes.
// All exit routes are ignored as they are not used in primary route context.
func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) bool {
pr.mu.Lock() pr.mu.Lock()
defer pr.mu.Unlock() defer pr.mu.Unlock()
// If no routes are being set, remove the node from the routes map. // If no routes are being set, remove the node from the routes map.
if len(prefix) == 0 { if len(prefixes) == 0 {
if _, ok := pr.routes[node]; ok { if _, ok := pr.routes[node]; ok {
delete(pr.routes, node) delete(pr.routes, node)
return pr.updatePrimaryLocked() return pr.updatePrimaryLocked()
@ -121,12 +120,17 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bo
return false return false
} }
if _, ok := pr.routes[node]; !ok { rs := make(set.Set[netip.Prefix], len(prefixes))
pr.routes[node] = make(set.Set[netip.Prefix], len(prefix)) for _, prefix := range prefixes {
if !tsaddr.IsExitRoute(prefix) {
rs.Add(prefix)
}
} }
for _, p := range prefix { if rs.Len() != 0 {
pr.routes[node].Add(p) pr.routes[node] = rs
} else {
delete(pr.routes, node)
} }
return pr.updatePrimaryLocked() return pr.updatePrimaryLocked()
@ -153,6 +157,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
} }
} }
tsaddr.SortPrefixes(routes)
return routes return routes
} }

View File

@ -6,8 +6,10 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/util/set"
) )
// mp is a helper function that wraps netip.MustParsePrefix. // mp is a helper function that wraps netip.MustParsePrefix.
@ -19,18 +21,32 @@ func TestPrimaryRoutes(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
operations func(pr *PrimaryRoutes) bool operations func(pr *PrimaryRoutes) bool
nodeID types.NodeID expectedRoutes map[types.NodeID]set.Set[netip.Prefix]
expectedRoutes []netip.Prefix expectedPrimaries map[netip.Prefix]types.NodeID
expectedIsPrimary map[types.NodeID]bool
expectedChange bool expectedChange bool
// primaries is a map of prefixes to the node that is the primary for that prefix.
primaries map[netip.Prefix]types.NodeID
isPrimary map[types.NodeID]bool
}{ }{
{ {
name: "single-node-registers-single-route", name: "single-node-registers-single-route",
operations: func(pr *PrimaryRoutes) bool { operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1, mp("192.168.1.0/24")) return pr.SetRoutes(1, mp("192.168.1.0/24"))
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 1: {
expectedChange: false, mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: true,
}, },
{ {
name: "multiple-nodes-register-different-routes", name: "multiple-nodes-register-different-routes",
@ -38,19 +54,45 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24")) pr.SetRoutes(1, mp("192.168.1.0/24"))
return pr.SetRoutes(2, mp("192.168.2.0/24")) return pr.SetRoutes(2, mp("192.168.2.0/24"))
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 1: {
expectedChange: false, mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.2.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
mp("192.168.2.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
2: true,
},
expectedChange: true,
}, },
{ {
name: "multiple-nodes-register-overlapping-routes", name: "multiple-nodes-register-overlapping-routes",
operations: func(pr *PrimaryRoutes) bool { operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false pr.SetRoutes(1, mp("192.168.1.0/24")) // true
return pr.SetRoutes(2, mp("192.168.1.0/24")) // true return pr.SetRoutes(2, mp("192.168.1.0/24")) // false
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, 1: {
expectedChange: true, mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: false,
}, },
{ {
name: "node-deregisters-a-route", name: "node-deregisters-a-route",
@ -58,9 +100,10 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24")) pr.SetRoutes(1, mp("192.168.1.0/24"))
return pr.SetRoutes(1) // Deregister by setting no routes return pr.SetRoutes(1) // Deregister by setting no routes
}, },
nodeID: 1,
expectedRoutes: nil, expectedRoutes: nil,
expectedChange: false, expectedPrimaries: nil,
expectedIsPrimary: nil,
expectedChange: true,
}, },
{ {
name: "node-deregisters-one-of-multiple-routes", name: "node-deregisters-one-of-multiple-routes",
@ -68,9 +111,18 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24")) pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24"))
return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 1: {
expectedChange: false, mp("192.168.2.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.2.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: true,
}, },
{ {
name: "node-registers-and-deregisters-routes-in-sequence", name: "node-registers-and-deregisters-routes-in-sequence",
@ -80,18 +132,23 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1) // Deregister by setting no routes pr.SetRoutes(1) // Deregister by setting no routes
return pr.SetRoutes(1, mp("192.168.3.0/24")) return pr.SetRoutes(1, mp("192.168.3.0/24"))
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 1: {
expectedChange: false, mp("192.168.3.0/24"): {},
}, },
{ 2: {
name: "no-change-in-primary-routes", mp("192.168.2.0/24"): {},
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1, mp("192.168.1.0/24"))
}, },
nodeID: 1, },
expectedRoutes: nil, expectedPrimaries: map[netip.Prefix]types.NodeID{
expectedChange: false, mp("192.168.2.0/24"): 2,
mp("192.168.3.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
2: true,
},
expectedChange: true,
}, },
{ {
name: "multiple-nodes-register-same-route", name: "multiple-nodes-register-same-route",
@ -100,22 +157,25 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(2, mp("192.168.1.0/24")) // true pr.SetRoutes(2, mp("192.168.1.0/24")) // true
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, 1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: false, expectedChange: false,
}, },
{
name: "register-multiple-routes-shift-primary-check-old-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: true,
},
{ {
name: "register-multiple-routes-shift-primary-check-primary", name: "register-multiple-routes-shift-primary-check-primary",
operations: func(pr *PrimaryRoutes) bool { operations: func(pr *PrimaryRoutes) bool {
@ -124,20 +184,20 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary return pr.SetRoutes(1) // true, 2 primary
}, },
nodeID: 2, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, 2: {
expectedChange: true, mp("192.168.1.0/24"): {},
}, },
{ 3: {
name: "register-multiple-routes-shift-primary-check-non-primary", mp("192.168.1.0/24"): {},
operations: func(pr *PrimaryRoutes) bool { },
pr.SetRoutes(1, mp("192.168.1.0/24")) // false },
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary expectedPrimaries: map[netip.Prefix]types.NodeID{
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary mp("192.168.1.0/24"): 2,
return pr.SetRoutes(1) // true, 2 primary },
expectedIsPrimary: map[types.NodeID]bool{
2: true,
}, },
nodeID: 3,
expectedRoutes: nil,
expectedChange: true, expectedChange: true,
}, },
{ {
@ -150,8 +210,17 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(2) // true, no primary return pr.SetRoutes(2) // true, no primary
}, },
nodeID: 2, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 3,
},
expectedIsPrimary: map[types.NodeID]bool{
3: true,
},
expectedChange: true, expectedChange: true,
}, },
{ {
@ -165,9 +234,7 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(3) // false, no primary return pr.SetRoutes(3) // false, no primary
}, },
nodeID: 2, expectedChange: true,
expectedRoutes: nil,
expectedChange: false,
}, },
{ {
name: "primary-route-map-is-cleared-up", name: "primary-route-map-is-cleared-up",
@ -179,8 +246,17 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(2) // true, no primary return pr.SetRoutes(2) // true, no primary
}, },
nodeID: 2, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 3,
},
expectedIsPrimary: map[types.NodeID]bool{
3: true,
},
expectedChange: true, expectedChange: true,
}, },
{ {
@ -193,8 +269,23 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
}, },
nodeID: 2, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, 1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
2: true,
},
expectedChange: false, expectedChange: false,
}, },
{ {
@ -207,8 +298,23 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
2: true,
},
expectedChange: false, expectedChange: false,
}, },
{ {
@ -218,15 +324,30 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary pr.SetRoutes(1) // true, 2 primary
pr.SetRoutes(2) // true, no primary pr.SetRoutes(2) // true, 3 primary
pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 3 primary
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 3 primary
pr.SetRoutes(1) // true, 2 primary pr.SetRoutes(1) // true, 3 primary
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 3 primary
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 3,
},
expectedIsPrimary: map[types.NodeID]bool{
3: true,
}, },
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: false, expectedChange: false,
}, },
{ {
@ -235,16 +356,27 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24")) pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24"))
return pr.SetRoutes(2, mp("192.168.1.0/24")) return pr.SetRoutes(2, mp("192.168.1.0/24"))
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, 1: {
expectedChange: true, mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: false,
}, },
{ {
name: "deregister-non-existent-route", name: "deregister-non-existent-route",
operations: func(pr *PrimaryRoutes) bool { operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1) // Deregister by setting no routes return pr.SetRoutes(1) // Deregister by setting no routes
}, },
nodeID: 1,
expectedRoutes: nil, expectedRoutes: nil,
expectedChange: false, expectedChange: false,
}, },
@ -253,17 +385,27 @@ func TestPrimaryRoutes(t *testing.T) {
operations: func(pr *PrimaryRoutes) bool { operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1) return pr.SetRoutes(1)
}, },
nodeID: 1,
expectedRoutes: nil, expectedRoutes: nil,
expectedChange: false, expectedChange: false,
}, },
{ {
name: "deregister-empty-prefix-list", name: "exit-nodes",
operations: func(pr *PrimaryRoutes) bool { operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1) pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("10.0.0.0/16"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("10.0.0.0/16"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
}, },
nodeID: 1,
expectedRoutes: nil,
expectedChange: false, expectedChange: false,
}, },
{ {
@ -284,19 +426,23 @@ func TestPrimaryRoutes(t *testing.T) {
return change1 || change2 return change1 || change2
}, },
nodeID: 1, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
expectedRoutes: nil, 1: {
expectedChange: false, mp("192.168.1.0/24"): {},
}, },
{ 2: {
name: "no-routes-registered", mp("192.168.2.0/24"): {},
operations: func(pr *PrimaryRoutes) bool {
// No operations
return false
}, },
nodeID: 1, },
expectedRoutes: nil, expectedPrimaries: map[netip.Prefix]types.NodeID{
expectedChange: false, mp("192.168.1.0/24"): 1,
mp("192.168.2.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
2: true,
},
expectedChange: true,
}, },
} }
@ -307,9 +453,15 @@ func TestPrimaryRoutes(t *testing.T) {
if change != tt.expectedChange { if change != tt.expectedChange {
t.Errorf("change = %v, want %v", change, tt.expectedChange) t.Errorf("change = %v, want %v", change, tt.expectedChange)
} }
routes := pr.PrimaryRoutes(tt.nodeID) comps := append(util.Comparers, cmpopts.EquateEmpty())
if diff := cmp.Diff(tt.expectedRoutes, routes, util.Comparers...); diff != "" { if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
t.Errorf("PrimaryRoutes() mismatch (-want +got):\n%s", diff) t.Errorf("routes mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
} }
}) })
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -213,7 +214,7 @@ func (node *Node) RequestTags() []string {
} }
func (node *Node) Prefixes() []netip.Prefix { func (node *Node) Prefixes() []netip.Prefix {
addrs := []netip.Prefix{} var addrs []netip.Prefix
for _, nodeAddress := range node.IPs() { for _, nodeAddress := range node.IPs() {
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
addrs = append(addrs, ip) addrs = append(addrs, ip)
@ -222,6 +223,19 @@ func (node *Node) Prefixes() []netip.Prefix {
return addrs return addrs
} }
// ExitRoutes returns a list of both exit routes if the
// node has any exit routes enabled.
// If none are enabled, it will return nil.
func (node *Node) ExitRoutes() []netip.Prefix {
for _, route := range node.SubnetRoutes() {
if tsaddr.IsExitRoute(route) {
return tsaddr.ExitRoutes()
}
}
return nil
}
func (node *Node) IPsAsString() []string { func (node *Node) IPsAsString() []string {
var ret []string var ret []string

View File

@ -57,6 +57,15 @@ func GenerateRandomStringDNSSafe(size int) (string, error) {
return str[:size], nil return str[:size], nil
} }
func MustGenerateRandomStringDNSSafe(size int) string {
hash, err := GenerateRandomStringDNSSafe(size)
if err != nil {
panic(err)
}
return hash
}
func TailNodesToString(nodes []*tailcfg.Node) string { func TailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes)) temp := make([]string, len(nodes))

View File

@ -3,8 +3,12 @@ package util
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/netip"
"net/url" "net/url"
"regexp"
"strconv"
"strings" "strings"
"time"
"tailscale.com/util/cmpver" "tailscale.com/util/cmpver"
) )
@ -46,3 +50,126 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
return loginURL, nil return loginURL, nil
} }
type TraceroutePath struct {
// Hop is the current jump in the total traceroute.
Hop int
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// IP is the IP address of the jump
IP netip.Addr
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
}
type Traceroute struct {
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// IP is the IP address of the target
IP netip.Addr
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Success indicates if the traceroute was successful.
Success bool
// Err contains an error if the traceroute was not successful.
Err error
}
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
func ParseTraceroute(output string) (Traceroute, error) {
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) < 1 {
return Traceroute{}, errors.New("empty traceroute output")
}
// Parse the header line
headerRegex := regexp.MustCompile(`traceroute to ([^ ]+) \(([^)]+)\)`)
headerMatches := headerRegex.FindStringSubmatch(lines[0])
if len(headerMatches) != 3 {
return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0])
}
hostname := headerMatches[1]
ipStr := headerMatches[2]
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err)
}
result := Traceroute{
Hostname: hostname,
IP: ip,
Route: []TraceroutePath{},
Success: false,
}
// Parse each hop line
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i])
if len(matches) == 0 {
continue
}
hop, err := strconv.Atoi(matches[1])
if err != nil {
return Traceroute{}, fmt.Errorf("parsing hop number: %w", err)
}
var hopHostname string
var hopIP netip.Addr
var latencies []time.Duration
// Handle hostname and IP
if matches[2] != "" && matches[3] != "" {
hopHostname = matches[2]
hopIP, err = netip.ParseAddr(matches[3])
if err != nil {
return Traceroute{}, fmt.Errorf("parsing hop IP address %s: %w", matches[3], err)
}
} else if matches[4] == "*" {
hopHostname = "*"
// No IP for timeouts
}
// Parse latencies
for j := 5; j <= 7; j++ {
if matches[j] != "" {
ms, err := strconv.ParseFloat(matches[j], 64)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
}
latencies = append(latencies, time.Duration(ms*float64(time.Millisecond)))
}
}
path := TraceroutePath{
Hop: hop,
Hostname: hopHostname,
IP: hopIP,
Latencies: latencies,
}
result.Route = append(result.Route, path)
// Check if we've reached the target
if hopIP == ip {
result.Success = true
}
}
// If we didn't reach the target, it's unsuccessful
if !result.Success {
result.Err = errors.New("traceroute did not reach target")
}
return result, nil
}

View File

@ -1,6 +1,13 @@
package util package util
import "testing" import (
"errors"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func TestTailscaleVersionNewerOrEqual(t *testing.T) { func TestTailscaleVersionNewerOrEqual(t *testing.T) {
type args struct { type args struct {
@ -178,3 +185,186 @@ Success.`,
}) })
} }
} }
func TestParseTraceroute(t *testing.T) {
tests := []struct {
name string
input string
want Traceroute
wantErr bool
}{
{
name: "simple successful traceroute",
input: `traceroute to 172.24.0.3 (172.24.0.3), 30 hops max, 46 byte packets
1 ts-head-hk0urr.headscale.net (100.64.0.1) 1.135 ms 0.922 ms 0.619 ms
2 172.24.0.3 (172.24.0.3) 0.593 ms 0.549 ms 0.522 ms`,
want: Traceroute{
Hostname: "172.24.0.3",
IP: netip.MustParseAddr("172.24.0.3"),
Route: []TraceroutePath{
{
Hop: 1,
Hostname: "ts-head-hk0urr.headscale.net",
IP: netip.MustParseAddr("100.64.0.1"),
Latencies: []time.Duration{
1135 * time.Microsecond,
922 * time.Microsecond,
619 * time.Microsecond,
},
},
{
Hop: 2,
Hostname: "172.24.0.3",
IP: netip.MustParseAddr("172.24.0.3"),
Latencies: []time.Duration{
593 * time.Microsecond,
549 * time.Microsecond,
522 * time.Microsecond,
},
},
},
Success: true,
Err: nil,
},
wantErr: false,
},
{
name: "traceroute with timeouts",
input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets
1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms
2 * * *
3 isp-gateway.net (10.0.0.1) 15.678 ms 14.789 ms 15.432 ms
4 8.8.8.8 (8.8.8.8) 20.123 ms 19.876 ms 20.345 ms`,
want: Traceroute{
Hostname: "8.8.8.8",
IP: netip.MustParseAddr("8.8.8.8"),
Route: []TraceroutePath{
{
Hop: 1,
Hostname: "router.local",
IP: netip.MustParseAddr("192.168.1.1"),
Latencies: []time.Duration{
1234 * time.Microsecond,
1123 * time.Microsecond,
1121 * time.Microsecond,
},
},
{
Hop: 2,
Hostname: "*",
},
{
Hop: 3,
Hostname: "isp-gateway.net",
IP: netip.MustParseAddr("10.0.0.1"),
Latencies: []time.Duration{
15678 * time.Microsecond,
14789 * time.Microsecond,
15432 * time.Microsecond,
},
},
{
Hop: 4,
Hostname: "8.8.8.8",
IP: netip.MustParseAddr("8.8.8.8"),
Latencies: []time.Duration{
20123 * time.Microsecond,
19876 * time.Microsecond,
20345 * time.Microsecond,
},
},
},
Success: true,
Err: nil,
},
wantErr: false,
},
{
name: "unsuccessful traceroute",
input: `traceroute to 10.0.0.99 (10.0.0.99), 5 hops max, 60 byte packets
1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms
2 * * *
3 * * *
4 * * *
5 * * *`,
want: Traceroute{
Hostname: "10.0.0.99",
IP: netip.MustParseAddr("10.0.0.99"),
Route: []TraceroutePath{
{
Hop: 1,
Hostname: "router.local",
IP: netip.MustParseAddr("192.168.1.1"),
Latencies: []time.Duration{
1234 * time.Microsecond,
1123 * time.Microsecond,
1121 * time.Microsecond,
},
},
{
Hop: 2,
Hostname: "*",
},
{
Hop: 3,
Hostname: "*",
},
{
Hop: 4,
Hostname: "*",
},
{
Hop: 5,
Hostname: "*",
},
},
Success: false,
Err: errors.New("traceroute did not reach target"),
},
wantErr: false,
},
{
name: "empty input",
input: "",
want: Traceroute{},
wantErr: true,
},
{
name: "invalid header",
input: "not a valid traceroute output",
want: Traceroute{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseTraceroute(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseTraceroute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
// Special handling for error field since it can't be directly compared with cmp.Diff
gotErr := got.Err
wantErr := tt.want.Err
got.Err = nil
tt.want.Err = nil
if diff := cmp.Diff(tt.want, got, IPComparer); diff != "" {
t.Errorf("ParseTraceroute() mismatch (-want +got):\n%s", diff)
}
// Now check error field separately
if (gotErr == nil) != (wantErr == nil) {
t.Errorf("Error field: got %v, want %v", gotErr, wantErr)
} else if gotErr != nil && wantErr != nil && gotErr.Error() != wantErr.Error() {
t.Errorf("Error message: got %q, want %q", gotErr.Error(), wantErr.Error())
}
})
}
}

View File

@ -54,15 +54,16 @@ func aclScenario(
clientsPerUser int, clientsPerUser int,
) *Scenario { ) *Scenario {
t.Helper() t.Helper()
scenario, err := NewScenario(dockertestMaxWait())
require.NoError(t, err)
spec := map[string]int{ spec := ScenarioSpec{
"user1": clientsPerUser, NodesPerUser: clientsPerUser,
"user2": clientsPerUser, Users: []string{"user1", "user2"},
} }
err = scenario.CreateHeadscaleEnv(spec, scenario, err := NewScenario(spec)
require.NoError(t, err)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{ []tsic.Option{
// Alpine containers dont have ip6tables set up, which causes // Alpine containers dont have ip6tables set up, which causes
// tailscaled to stop configuring the wgengine, causing it // tailscaled to stop configuring the wgengine, causing it
@ -96,22 +97,24 @@ func aclScenario(
func TestACLHostsInNetMapTable(t *testing.T) { func TestACLHostsInNetMapTable(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 2,
Users: []string{"user1", "user2"},
}
// NOTE: All want cases currently checks the // NOTE: All want cases currently checks the
// total count of expected peers, this would // total count of expected peers, this would
// typically be the client count of the users // typically be the client count of the users
// they can access minus one (them self). // they can access minus one (them self).
tests := map[string]struct { tests := map[string]struct {
users map[string]int users ScenarioSpec
policy policyv1.ACLPolicy policy policyv1.ACLPolicy
want map[string]int want map[string]int
}{ }{
// Test that when we have no ACL, each client netmap has // Test that when we have no ACL, each client netmap has
// the amount of peers of the total amount of clients // the amount of peers of the total amount of clients
"base-acls": { "base-acls": {
users: map[string]int{ users: spec,
"user1": 2,
"user2": 2,
},
policy: policyv1.ACLPolicy{ policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
@ -129,10 +132,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// each other, each node has only the number of pairs from // each other, each node has only the number of pairs from
// their own user. // their own user.
"two-isolated-users": { "two-isolated-users": {
users: map[string]int{ users: spec,
"user1": 2,
"user2": 2,
},
policy: policyv1.ACLPolicy{ policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
@ -155,10 +155,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// are restricted to a single port, nodes are still present // are restricted to a single port, nodes are still present
// in the netmap. // in the netmap.
"two-restricted-present-in-netmap": { "two-restricted-present-in-netmap": {
users: map[string]int{ users: spec,
"user1": 2,
"user2": 2,
},
policy: policyv1.ACLPolicy{ policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
@ -192,10 +189,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// of peers. This will still result in all the peers as we // of peers. This will still result in all the peers as we
// need them present on the other side for the "return path". // need them present on the other side for the "return path".
"two-ns-one-isolated": { "two-ns-one-isolated": {
users: map[string]int{ users: spec,
"user1": 2,
"user2": 2,
},
policy: policyv1.ACLPolicy{ policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
@ -220,10 +214,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
}, },
}, },
"very-large-destination-prefix-1372": { "very-large-destination-prefix-1372": {
users: map[string]int{ users: spec,
"user1": 2,
"user2": 2,
},
policy: policyv1.ACLPolicy{ policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
@ -248,10 +239,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
}, },
}, },
"ipv6-acls-1470": { "ipv6-acls-1470": {
users: map[string]int{ users: spec,
"user1": 2,
"user2": 2,
},
policy: policyv1.ACLPolicy{ policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{ ACLs: []policyv1.ACL{
{ {
@ -269,12 +257,11 @@ func TestACLHostsInNetMapTable(t *testing.T) {
for name, testCase := range tests { for name, testCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait()) caseSpec := testCase.users
scenario, err := NewScenario(caseSpec)
require.NoError(t, err) require.NoError(t, err)
spec := testCase.users err = scenario.CreateHeadscaleEnv(
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithACLPolicy(&testCase.policy), hsic.WithACLPolicy(&testCase.policy),
) )
@ -944,6 +931,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
for name, testCase := range tests { for name, testCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
scenario := aclScenario(t, &testCase.policy, 1) scenario := aclScenario(t, &testCase.policy, 1)
defer scenario.ShutdownAssertNoPanics(t)
test1ip := netip.MustParseAddr("100.64.0.1") test1ip := netip.MustParseAddr("100.64.0.1")
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
@ -1022,16 +1010,16 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv(
"user1": 1,
"user2": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{ []tsic.Option{
// Alpine containers dont have ip6tables set up, which causes // Alpine containers dont have ip6tables set up, which causes
// tailscaled to stop configuring the wgengine, causing it // tailscaled to stop configuring the wgengine, causing it

View File

@ -19,15 +19,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
for _, https := range []bool{true, false} { for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
opts := []hsic.Option{hsic.WithTestName("pingallbyip")} opts := []hsic.Option{hsic.WithTestName("pingallbyip")}
if https { if https {
opts = append(opts, []hsic.Option{ opts = append(opts, []hsic.Option{
@ -35,7 +35,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}...) }...)
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -84,7 +84,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
time.Sleep(5 * time.Minute) time.Sleep(5 * time.Minute)
} }
for userName := range spec { for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userName, true, false) key, err := scenario.CreatePreAuthKey(userName, true, false)
if err != nil { if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
@ -152,16 +152,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{},
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{},
hsic.WithTestName("keyrelognewuser"), hsic.WithTestName("keyrelognewuser"),
hsic.WithTLS(), hsic.WithTLS(),
) )
@ -203,7 +203,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
// Log in all clients as user1, iterating over the spec only returns the // Log in all clients as user1, iterating over the spec only returns the
// clients, not the usernames. // clients, not the usernames.
for userName := range spec { for _, userName := range spec.Users {
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
if err != nil { if err != nil {
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
@ -235,15 +235,15 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
for _, https := range []bool{true, false} { for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
opts := []hsic.Option{hsic.WithTestName("pingallbyip")} opts := []hsic.Option{hsic.WithTestName("pingallbyip")}
if https { if https {
opts = append(opts, []hsic.Option{ opts = append(opts, []hsic.Option{
@ -251,7 +251,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}...) }...)
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -300,7 +300,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
time.Sleep(5 * time.Minute) time.Sleep(5 * time.Minute)
} }
for userName := range spec { for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userName, true, false) key, err := scenario.CreatePreAuthKey(userName, true, false)
if err != nil { if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)

View File

@ -1,93 +1,58 @@
package integration package integration
import ( import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt" "fmt"
"io"
"log"
"net"
"net/http"
"net/http/cookiejar"
"net/netip" "net/netip"
"net/url"
"sort" "sort"
"strconv"
"testing" "testing"
"time" "time"
"maps"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc" "github.com/oauth2-proxy/mockoidc"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const (
dockerContextPath = "../."
hsicOIDCMockHashLength = 6
defaultAccessTTL = 10 * time.Minute
)
var errStatusCodeNotOK = errors.New("status code not OK")
type AuthOIDCScenario struct {
*Scenario
mockOIDC *dockertest.Resource
}
func TestOIDCAuthenticationPingAll(t *testing.T) { func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
// Logins to MockOIDC is served by a queue with a strict order, // Logins to MockOIDC is served by a queue with a strict order,
// if we use more than one node per user, the order of the logins // if we use more than one node per user, the order of the logins
// will not be deterministic and the test will fail. // will not be deterministic and the test will fail.
spec := map[string]int{ spec := ScenarioSpec{
"user1": 1, NodesPerUser: 1,
"user2": 1, Users: []string{"user1", "user2"},
} OIDCUsers: []mockoidc.MockUser{
mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true), oidcMockUser("user1", true),
oidcMockUser("user2", false), oidcMockUser("user2", false),
},
} }
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) scenario, err := NewScenario(spec)
assertNoErrf(t, "failed to run mock OIDC server: %s", err) assertNoErr(t, err)
defer scenario.mockOIDC.Close()
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp", "CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
} }
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnvWithLoginURL(
spec, nil,
hsic.WithTestName("oidcauthping"), hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -126,7 +91,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
Name: "user1", Name: "user1",
Email: "user1@headscale.net", Email: "user1@headscale.net",
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1", ProviderId: scenario.mockOIDC.Issuer() + "/user1",
}, },
{ {
Id: 3, Id: 3,
@ -138,7 +103,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
Name: "user2", Name: "user2",
Email: "", // Unverified Email: "", // Unverified
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user2", ProviderId: scenario.mockOIDC.Issuer() + "/user2",
}, },
} }
@ -158,37 +123,29 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
shortAccessTTL := 5 * time.Minute shortAccessTTL := 5 * time.Minute
baseScenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
assertNoErr(t, err) NodesPerUser: 1,
Users: []string{"user1", "user2"},
baseScenario.pool.MaxWait = 5 * time.Minute OIDCUsers: []mockoidc.MockUser{
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
"user2": 1,
}
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{
oidcMockUser("user1", true), oidcMockUser("user1", true),
oidcMockUser("user2", false), oidcMockUser("user2", false),
}) },
assertNoErrf(t, "failed to run mock OIDC server: %s", err) OIDCAccessTTL: shortAccessTTL,
defer scenario.mockOIDC.Close() }
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret, "HEADSCALE_OIDC_CLIENT_SECRET": scenario.mockOIDC.ClientSecret(),
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", "HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
} }
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnvWithLoginURL(
spec, nil,
hsic.WithTestName("oidcexpirenodes"), hsic.WithTestName("oidcexpirenodes"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
) )
@ -334,45 +291,35 @@ func TestOIDC024UserCreation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
baseScenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
assertNoErr(t, err) NodesPerUser: 1,
scenario := AuthOIDCScenario{
Scenario: baseScenario,
} }
for _, user := range tt.cliUsers {
spec.Users = append(spec.Users, user)
}
for _, user := range tt.oidcUsers {
spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified))
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{}
for _, user := range tt.cliUsers {
spec[user] = 1
}
var mockusers []mockoidc.MockUser
for _, user := range tt.oidcUsers {
mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified))
}
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp", "CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
} }
maps.Copy(oidcMap, tt.config)
for k, v := range tt.config { err = scenario.CreateHeadscaleEnvWithLoginURL(
oidcMap[k] = v nil,
}
err = scenario.CreateHeadscaleEnv(
spec,
hsic.WithTestName("oidcmigration"), hsic.WithTestName("oidcmigration"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -384,7 +331,7 @@ func TestOIDC024UserCreation(t *testing.T) {
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) assertNoErr(t, err)
want := tt.want(oidcConfig.Issuer) want := tt.want(scenario.mockOIDC.Issuer())
listUsers, err := headscale.ListUsers() listUsers, err := headscale.ListUsers()
assertNoErr(t, err) assertNoErr(t, err)
@ -404,41 +351,33 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait()) // Single user with one node for testing PKCE flow
assertNoErr(t, err) spec := ScenarioSpec{
NodesPerUser: 1,
scenario := AuthOIDCScenario{ Users: []string{"user1"},
Scenario: baseScenario, OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
} }
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
// Single user with one node for testing PKCE flow
spec := map[string]int{
"user1": 1,
}
mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true),
}
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"CREDENTIALS_DIRECTORY_TEST": "/tmp", "CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE "HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
} }
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnvWithLoginURL(
spec, nil,
hsic.WithTestName("oidcauthpkce"), hsic.WithTestName("oidcauthpkce"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -464,43 +403,33 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
// Create no nodes and no users // Create no nodes and no users
spec := map[string]int{} scenario, err := NewScenario(ScenarioSpec{
// First login creates the first OIDC user // First login creates the first OIDC user
// Second login logs in the same node, which creates a new node // Second login logs in the same node, which creates a new node
// Third login logs in the same node back into the original user // Third login logs in the same node back into the original user
mockusers := []mockoidc.MockUser{ OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true), oidcMockUser("user1", true),
oidcMockUser("user2", true), oidcMockUser("user2", true),
oidcMockUser("user1", true), oidcMockUser("user1", true),
} },
})
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) assertNoErr(t, err)
assertNoErrf(t, "failed to run mock OIDC server: %s", err) defer scenario.ShutdownAssertNoPanics(t)
// defer scenario.mockOIDC.Close()
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp", "CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
} }
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnvWithLoginURL(
spec, nil,
hsic.WithTestName("oidcauthrelog"), hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -512,7 +441,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, listUsers, 0) assert.Len(t, listUsers, 0)
ts, err := scenario.CreateTailscaleNode("unstable") ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
assertNoErr(t, err) assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint()) u, err := ts.LoginWithURL(headscale.GetEndpoint())
@ -530,7 +459,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
Name: "user1", Name: "user1",
Email: "user1@headscale.net", Email: "user1@headscale.net",
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1", ProviderId: scenario.mockOIDC.Issuer() + "/user1",
}, },
} }
@ -575,14 +504,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
Name: "user1", Name: "user1",
Email: "user1@headscale.net", Email: "user1@headscale.net",
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1", ProviderId: scenario.mockOIDC.Issuer() + "/user1",
}, },
{ {
Id: 2, Id: 2,
Name: "user2", Name: "user2",
Email: "user2@headscale.net", Email: "user2@headscale.net",
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user2", ProviderId: scenario.mockOIDC.Issuer() + "/user2",
}, },
} }
@ -632,14 +561,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
Name: "user1", Name: "user1",
Email: "user1@headscale.net", Email: "user1@headscale.net",
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1", ProviderId: scenario.mockOIDC.Issuer() + "/user1",
}, },
{ {
Id: 2, Id: 2,
Name: "user2", Name: "user2",
Email: "user2@headscale.net", Email: "user2@headscale.net",
Provider: "oidc", Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user2", ProviderId: scenario.mockOIDC.Issuer() + "/user2",
}, },
} }
@ -678,254 +607,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey)
} }
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
users map[string]int,
opts ...hsic.Option,
) error {
headscale, err := s.Headscale(opts...)
if err != nil {
return err
}
err = headscale.WaitForRunning()
if err != nil {
return err
}
for userName, clientCount := range users {
if clientCount != 1 {
// OIDC scenario only supports one client per user.
// This is because the MockOIDC server can only serve login
// requests based on a queue it has been given on startup.
// We currently only populates it with one login request per user.
return fmt.Errorf("client count must be 1 for OIDC scenario.")
}
log.Printf("creating user %s with %d clients", userName, clientCount)
err = s.CreateUser(userName)
if err != nil {
return err
}
err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
if err != nil {
return err
}
err = s.runTailscaleUp(userName, headscale.GetEndpoint())
if err != nil {
return err
}
}
return nil
}
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
port, err := dockertestutil.RandomFreeHostPort()
if err != nil {
log.Fatalf("could not find an open port: %s", err)
}
portNotation := fmt.Sprintf("%d/tcp", port)
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
usersJSON, err := json.Marshal(users)
if err != nil {
return nil, err
}
mockOidcOptions := &dockertest.RunOptions{
Name: hostname,
Cmd: []string{"headscale", "mockoidc"},
ExposedPorts: []string{portNotation},
PortBindings: map[docker.Port][]docker.PortBinding{
docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
},
Networks: []*dockertest.Network{s.Scenario.network},
Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
}
err = s.pool.RemoveContainerByName(hostname)
if err != nil {
return nil, err
}
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
dockertestutil.DockerRestartPolicy); err == nil {
s.mockOIDC = pmockoidc
} else {
return nil, err
}
log.Println("Waiting for headscale mock oidc to be ready for tests")
hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port)
if err := s.pool.Retry(func() error {
oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
httpClient := &http.Client{}
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("headscale mock OIDC tests is not ready: %s\n", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errStatusCodeNotOK
}
return nil
}); err != nil {
return nil, err
}
log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
return &types.OIDCConfig{
Issuer: fmt.Sprintf(
"http://%s/oidc",
net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)),
),
ClientID: "superclient",
ClientSecret: "supersecret",
OnlyStartIfOIDCIsAvailable: true,
}, nil
}
type LoggingRoundTripper struct{}
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
noTls := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}
resp, err := noTls.RoundTrip(req)
if err != nil {
return nil, err
}
log.Printf("---")
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
return resp, nil
}
func (s *AuthOIDCScenario) runTailscaleUp(
userStr, loginServer string,
) error {
log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
tsc := client
user.joinWaitGroup.Go(func() error {
loginURL, err := tsc.LoginWithURL(loginServer)
if err != nil {
log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
}
_, err = doLoginURL(tsc.Hostname(), loginURL)
if err != nil {
return err
}
return nil
})
log.Printf("client %s is ready", client.Hostname())
}
if err := user.joinWaitGroup.Wait(); err != nil {
return err
}
for _, client := range user.Clients {
err := client.WaitForRunning()
if err != nil {
return fmt.Errorf(
"%s tailscale node has not reached running: %w",
client.Hostname(),
err,
)
}
}
return nil
}
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
// doLoginURL visits the given login URL and returns the body as a
// string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
var err error
hc := &http.Client{
Transport: LoggingRoundTripper{},
}
hc.Jar, err = cookiejar.New(nil)
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
}
log.Printf("%s logging in with url", hostname)
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := hc.Do(req)
if err != nil {
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
}
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body)
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("%s failed to read response body: %s", hostname, err)
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
return string(body), nil
}
func (s *AuthOIDCScenario) Shutdown() {
err := s.pool.Purge(s.mockOIDC)
if err != nil {
log.Printf("failed to remove mock oidc container")
}
s.Scenario.Shutdown()
}
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
t.Helper() t.Helper()

View File

@ -1,47 +1,33 @@
package integration package integration
import ( import (
"errors"
"fmt"
"log"
"net/netip" "net/netip"
"net/url"
"strings"
"testing" "testing"
"slices"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
var errParseAuthPage = errors.New("failed to parse auth page")
type AuthWebFlowScenario struct {
*Scenario
}
func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
if err != nil { if err != nil {
t.Fatalf("failed to create scenario: %s", err) t.Fatalf("failed to create scenario: %s", err)
} }
scenario := AuthWebFlowScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec, nil,
hsic.WithTestName("webauthping"), hsic.WithTestName("webauthping"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(), hsic.WithTLS(),
@ -71,20 +57,17 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
assertNoErr(t, err) NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
scenario := AuthWebFlowScenario{
Scenario: baseScenario,
} }
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv(
"user1": len(MustTestVersions), nil,
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
hsic.WithTestName("weblogout"), hsic.WithTestName("weblogout"),
hsic.WithTLS(), hsic.WithTLS(),
) )
@ -137,8 +120,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("all clients logged out") t.Logf("all clients logged out")
for userName := range spec { for _, userName := range spec.Users {
err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) err = scenario.RunTailscaleUpWithURL(userName, headscale.GetEndpoint())
if err != nil { if err != nil {
t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err) t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err)
} }
@ -172,14 +155,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
} }
for _, ip := range ips { for _, ip := range ips {
found := false found := slices.Contains(clientIPs[client], ip)
for _, oldIP := range clientIPs[client] {
if ip == oldIP {
found = true
break
}
}
if !found { if !found {
t.Fatalf( t.Fatalf(
@ -194,122 +170,3 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("all clients IPs are the same") t.Logf("all clients IPs are the same")
} }
func (s *AuthWebFlowScenario) CreateHeadscaleEnv(
users map[string]int,
opts ...hsic.Option,
) error {
headscale, err := s.Headscale(opts...)
if err != nil {
return err
}
err = headscale.WaitForRunning()
if err != nil {
return err
}
for userName, clientCount := range users {
log.Printf("creating user %s with %d clients", userName, clientCount)
err = s.CreateUser(userName)
if err != nil {
return err
}
err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
if err != nil {
return err
}
err = s.runTailscaleUp(userName, headscale.GetEndpoint())
if err != nil {
return err
}
}
return nil
}
func (s *AuthWebFlowScenario) runTailscaleUp(
userStr, loginServer string,
) error {
log.Printf("running tailscale up for user %q", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
c := client
user.joinWaitGroup.Go(func() error {
log.Printf("logging %q into %q", c.Hostname(), loginServer)
loginURL, err := c.LoginWithURL(loginServer)
if err != nil {
log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err)
return err
}
err = s.runHeadscaleRegister(userStr, loginURL)
if err != nil {
log.Printf("failed to register client (%s): %s", c.Hostname(), err)
return err
}
return nil
})
err := client.WaitForRunning()
if err != nil {
log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err)
}
}
if err := user.joinWaitGroup.Wait(); err != nil {
return err
}
for _, client := range user.Clients {
err := client.WaitForRunning()
if err != nil {
return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err)
}
}
return nil
}
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error {
body, err := doLoginURL("web-auth-not-set", loginURL)
if err != nil {
return err
}
// see api.go HTML template
codeSep := strings.Split(string(body), "</code>")
if len(codeSep) != 2 {
return errParseAuthPage
}
keySep := strings.Split(codeSep[0], "key ")
if len(keySep) != 2 {
return errParseAuthPage
}
key := keySep[1]
log.Printf("registering node %s", key)
if headscale, err := s.Headscale(); err == nil {
_, err = headscale.Execute(
[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key},
)
if err != nil {
log.Printf("failed to register node: %s", err)
return err
}
return nil
}
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
}

View File

@ -48,16 +48,15 @@ func TestUserCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"user1": 0,
"user2": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -247,15 +246,15 @@ func TestPreAuthKeyCommand(t *testing.T) {
user := "preauthkeyspace" user := "preauthkeyspace"
count := 3 count := 3
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{user},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak"))
user: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -388,16 +387,15 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
t.Parallel() t.Parallel()
user := "pre-auth-key-without-exp-user" user := "pre-auth-key-without-exp-user"
spec := ScenarioSpec{
Users: []string{user},
}
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp"))
user: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -451,16 +449,15 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
t.Parallel() t.Parallel()
user := "pre-auth-key-reus-ephm-user" user := "pre-auth-key-reus-ephm-user"
spec := ScenarioSpec{
Users: []string{user},
}
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph"))
user: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -530,17 +527,16 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
user1 := "user1" user1 := "user1"
user2 := "user2" user2 := "user2"
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{user1},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user1: 1,
user2: 0,
}
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("clipak"), hsic.WithTestName("clipak"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
@ -551,6 +547,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) assertNoErr(t, err)
err = headscale.CreateUser(user2)
assertNoErr(t, err)
var user2Key v1.PreAuthKey var user2Key v1.PreAuthKey
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -573,10 +572,15 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
) )
assertNoErr(t, err) assertNoErr(t, err)
listNodes, err := headscale.ListNodes()
require.Nil(t, err)
require.Len(t, listNodes, 1)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
assert.Len(t, allClients, 1) require.Len(t, allClients, 1)
client := allClients[0] client := allClients[0]
@ -606,12 +610,11 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String()) t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
} }
listNodes, err := headscale.ListNodes() listNodes, err = headscale.ListNodes()
assert.Nil(t, err) require.Nil(t, err)
assert.Len(t, listNodes, 2) require.Len(t, listNodes, 2)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
assert.Equal(t, "user1", listNodes[0].GetUser().GetName()) assert.Equal(t, user2, listNodes[1].GetUser().GetName())
assert.Equal(t, "user2", listNodes[1].GetUser().GetName())
} }
func TestApiKeyCommand(t *testing.T) { func TestApiKeyCommand(t *testing.T) {
@ -620,16 +623,15 @@ func TestApiKeyCommand(t *testing.T) {
count := 5 count := 5
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"user1": 0,
"user2": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -788,15 +790,15 @@ func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"user1": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -977,15 +979,16 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv(
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithTags([]string{"tag:test"})}, []tsic.Option{tsic.WithTags([]string{"tag:test"})},
hsic.WithTestName("cliadvtags"), hsic.WithTestName("cliadvtags"),
hsic.WithACLPolicy(tt.policy), hsic.WithACLPolicy(tt.policy),
@ -996,7 +999,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
// Test list all nodes after added seconds // Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"]) resultMachines := make([]*v1.Node, spec.NodesPerUser)
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
[]string{ []string{
@ -1029,16 +1032,15 @@ func TestNodeCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"node-user", "other-user"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"node-user": 0,
"other-user": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -1269,15 +1271,15 @@ func TestNodeExpireCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"node-expire-user"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"node-expire-user": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -1395,15 +1397,15 @@ func TestNodeRenameCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"node-rename-command"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"node-rename-command": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -1560,16 +1562,15 @@ func TestNodeMoveCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"old-user", "new-user"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
"old-user": 0,
"new-user": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -1721,16 +1722,15 @@ func TestPolicyCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
}
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("clins"), hsic.WithTestName("clins"),
hsic.WithConfigEnv(map[string]string{ hsic.WithConfigEnv(map[string]string{
@ -1808,16 +1808,16 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("clins"), hsic.WithTestName("clins"),
hsic.WithConfigEnv(map[string]string{ hsic.WithConfigEnv(map[string]string{

View File

@ -24,5 +24,4 @@ type ControlServer interface {
ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error)
GetCert() []byte GetCert() []byte
GetHostname() string GetHostname() string
GetIP() string
} }

View File

@ -31,14 +31,15 @@ func TestDERPVerifyEndpoint(t *testing.T) {
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
assertNoErr(t, err) assertNoErr(t, err)
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
}
derper, err := scenario.CreateDERPServer("head", derper, err := scenario.CreateDERPServer("head",
dsic.WithCACert(certHeadscale), dsic.WithCACert(certHeadscale),
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
@ -65,7 +66,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
}, },
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())}, err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithCACert(derper.GetCert())},
hsic.WithHostname(hostname), hsic.WithHostname(hostname),
hsic.WithPort(headscalePort), hsic.WithPort(headscalePort),
hsic.WithCustomTLS(certHeadscale, keyHeadscale), hsic.WithCustomTLS(certHeadscale, keyHeadscale),

View File

@ -17,16 +17,16 @@ func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns"))
"magicdns1": len(MustTestVersions),
"magicdns2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -87,15 +87,15 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"magicdns1": 1,
"magicdns2": 1,
}
const erPath = "/tmp/extra_records.json" const erPath = "/tmp/extra_records.json"
extraRecords := []tailcfg.DNSRecord{ extraRecords := []tailcfg.DNSRecord{
@ -107,7 +107,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
} }
b, _ := json.Marshal(extraRecords) b, _ := json.Marshal(extraRecords)
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{ err = scenario.CreateHeadscaleEnv([]tsic.Option{
tsic.WithDockerEntrypoint([]string{ tsic.WithDockerEntrypoint([]string{
"/bin/sh", "/bin/sh",
"-c", "-c",
@ -364,16 +364,16 @@ func TestValidateResolvConf(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf))
"resolvconf1": 3,
"resolvconf2": 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()

View File

@ -35,7 +35,7 @@ type DERPServerInContainer struct {
pool *dockertest.Pool pool *dockertest.Pool
container *dockertest.Resource container *dockertest.Resource
network *dockertest.Network networks []*dockertest.Network
stunPort int stunPort int
derpPort int derpPort int
@ -63,22 +63,22 @@ func WithCACert(cert []byte) Option {
// isolating the DERPer, will be created. If a network is // isolating the DERPer, will be created. If a network is
// passed, the DERPer instance will join the given network. // passed, the DERPer instance will join the given network.
func WithOrCreateNetwork(network *dockertest.Network) Option { func WithOrCreateNetwork(network *dockertest.Network) Option {
return func(tsic *DERPServerInContainer) { return func(dsic *DERPServerInContainer) {
if network != nil { if network != nil {
tsic.network = network dsic.networks = append(dsic.networks, network)
return return
} }
network, err := dockertestutil.GetFirstOrCreateNetwork( network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool, dsic.pool,
tsic.hostname+"-network", dsic.hostname+"-network",
) )
if err != nil { if err != nil {
log.Fatalf("failed to create network: %s", err) log.Fatalf("failed to create network: %s", err)
} }
tsic.network = network dsic.networks = append(dsic.networks, network)
} }
} }
@ -107,7 +107,7 @@ func WithExtraHosts(hosts []string) Option {
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
version string, version string,
network *dockertest.Network, networks []*dockertest.Network,
opts ...Option, opts ...Option,
) (*DERPServerInContainer, error) { ) (*DERPServerInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength) hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength)
@ -124,7 +124,7 @@ func New(
version: version, version: version,
hostname: hostname, hostname: hostname,
pool: pool, pool: pool,
network: network, networks: networks,
tlsCert: tlsCert, tlsCert: tlsCert,
tlsKey: tlsKey, tlsKey: tlsKey,
stunPort: 3478, //nolint stunPort: 3478, //nolint
@ -148,7 +148,7 @@ func New(
runOptions := &dockertest.RunOptions{ runOptions := &dockertest.RunOptions{
Name: hostname, Name: hostname,
Networks: []*dockertest.Network{dsic.network}, Networks: dsic.networks,
ExtraHosts: dsic.withExtraHosts, ExtraHosts: dsic.withExtraHosts,
// we currently need to give us some time to inject the certificate further down. // we currently need to give us some time to inject the certificate further down.
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()}, Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},

View File

@ -1,18 +1,12 @@
package integration package integration
import ( import (
"fmt"
"log"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
) )
type ClientsSpec struct { type ClientsSpec struct {
@ -20,21 +14,18 @@ type ClientsSpec struct {
WebsocketDERP int WebsocketDERP int
} }
type EmbeddedDERPServerScenario struct {
*Scenario
tsicNetworks map[string]*dockertest.Network
}
func TestDERPServerScenario(t *testing.T) { func TestDERPServerScenario(t *testing.T) {
spec := map[string]ClientsSpec{ spec := ScenarioSpec{
"user1": { NodesPerUser: 1,
Plain: len(MustTestVersions), Users: []string{"user1", "user2", "user3"},
WebsocketDERP: 0, Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
"usernet3": {"user3"},
}, },
} }
derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) { derpServerScenario(t, spec, false, func(scenario *Scenario) {
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
t.Logf("checking %d clients for websocket connections", len(allClients)) t.Logf("checking %d clients for websocket connections", len(allClients))
@ -52,14 +43,17 @@ func TestDERPServerScenario(t *testing.T) {
} }
func TestDERPServerWebsocketScenario(t *testing.T) { func TestDERPServerWebsocketScenario(t *testing.T) {
spec := map[string]ClientsSpec{ spec := ScenarioSpec{
"user1": { NodesPerUser: 1,
Plain: 0, Users: []string{"user1", "user2", "user3"},
WebsocketDERP: 2, Networks: map[string][]string{
"usernet1": []string{"user1"},
"usernet2": []string{"user2"},
"usernet3": []string{"user3"},
}, },
} }
derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) { derpServerScenario(t, spec, true, func(scenario *Scenario) {
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
t.Logf("checking %d clients for websocket connections", len(allClients)) t.Logf("checking %d clients for websocket connections", len(allClients))
@ -83,23 +77,22 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
//nolint:thelper //nolint:thelper
func derpServerScenario( func derpServerScenario(
t *testing.T, t *testing.T,
spec map[string]ClientsSpec, spec ScenarioSpec,
furtherAssertions ...func(*EmbeddedDERPServerScenario), websocket bool,
furtherAssertions ...func(*Scenario),
) { ) {
IntegrationSkip(t) IntegrationSkip(t)
// t.Parallel() // t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
scenario := EmbeddedDERPServerScenario{
Scenario: baseScenario,
tsicNetworks: map[string]*dockertest.Network{},
}
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec, []tsic.Option{
tsic.WithWebsocketDERP(websocket),
},
hsic.WithTestName("derpserver"), hsic.WithTestName("derpserver"),
hsic.WithExtraPorts([]string{"3478/udp"}), hsic.WithExtraPorts([]string{"3478/udp"}),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
@ -185,182 +178,6 @@ func derpServerScenario(
t.Logf("Run2: %d successful pings out of %d", success, len(allClients)*len(allHostnames)) t.Logf("Run2: %d successful pings out of %d", success, len(allClients)*len(allHostnames))
for _, check := range furtherAssertions { for _, check := range furtherAssertions {
check(&scenario) check(scenario)
} }
} }
func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
users map[string]ClientsSpec,
opts ...hsic.Option,
) error {
hsServer, err := s.Headscale(opts...)
if err != nil {
return err
}
headscaleEndpoint := hsServer.GetEndpoint()
headscaleURL, err := url.Parse(headscaleEndpoint)
if err != nil {
return err
}
headscaleURL.Host = fmt.Sprintf("%s:%s", hsServer.GetHostname(), headscaleURL.Port())
err = hsServer.WaitForRunning()
if err != nil {
return err
}
log.Printf("headscale server ip address: %s", hsServer.GetIP())
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
if err != nil {
return err
}
for userName, clientCount := range users {
err = s.CreateUser(userName)
if err != nil {
return err
}
if clientCount.Plain > 0 {
// Containers that use default DERP config
err = s.CreateTailscaleIsolatedNodesInUser(
hash,
userName,
"all",
clientCount.Plain,
)
if err != nil {
return err
}
}
if clientCount.WebsocketDERP > 0 {
// Containers that use DERP-over-WebSocket
// Note that these clients *must* be built
// from source, which is currently
// only done for HEAD.
err = s.CreateTailscaleIsolatedNodesInUser(
hash,
userName,
tsic.VersionHead,
clientCount.WebsocketDERP,
tsic.WithWebsocketDERP(true),
)
if err != nil {
return err
}
}
key, err := s.CreatePreAuthKey(userName, true, false)
if err != nil {
return err
}
err = s.RunTailscaleUp(userName, headscaleURL.String(), key.GetKey())
if err != nil {
return err
}
}
return nil
}
func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser(
hash string,
userStr string,
requestedVersion string,
count int,
opts ...tsic.Option,
) error {
hsServer, err := s.Headscale()
if err != nil {
return err
}
if user, ok := s.users[userStr]; ok {
for clientN := 0; clientN < count; clientN++ {
networkName := fmt.Sprintf("tsnet-%s-%s-%d",
hash,
userStr,
clientN,
)
network, err := dockertestutil.GetFirstOrCreateNetwork(
s.pool,
networkName,
)
if err != nil {
return fmt.Errorf("failed to create or get %s network: %w", networkName, err)
}
s.tsicNetworks[networkName] = network
err = hsServer.ConnectToNetwork(network)
if err != nil {
return fmt.Errorf("failed to connect headscale to %s network: %w", networkName, err)
}
version := requestedVersion
if requestedVersion == "all" {
version = MustTestVersions[clientN%len(MustTestVersions)]
}
cert := hsServer.GetCert()
opts = append(opts,
tsic.WithCACert(cert),
)
user.createWaitGroup.Go(func() error {
tsClient, err := tsic.New(
s.pool,
version,
network,
opts...,
)
if err != nil {
return fmt.Errorf(
"failed to create tailscale (%s) node: %w",
tsClient.Hostname(),
err,
)
}
err = tsClient.WaitForNeedsLogin()
if err != nil {
return fmt.Errorf(
"failed to wait for tailscaled (%s) to need login: %w",
tsClient.Hostname(),
err,
)
}
s.mu.Lock()
user.Clients[tsClient.Hostname()] = tsClient
s.mu.Unlock()
return nil
})
}
if err := user.createWaitGroup.Wait(); err != nil {
return err
}
return nil
}
return fmt.Errorf("failed to add tailscale nodes: %w", errNoUserAvailable)
}
func (s *EmbeddedDERPServerScenario) Shutdown() {
for _, network := range s.tsicNetworks {
err := s.pool.RemoveNetwork(network)
if err != nil {
log.Printf("failed to remove DERP network %s", network.Network.Name)
}
}
s.Scenario.Shutdown()
}

View File

@ -28,18 +28,17 @@ func TestPingAllByIP(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
MaxWait: dockertestMaxWait(),
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second err = scenario.CreateHeadscaleEnv(
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("pingallbyip"), hsic.WithTestName("pingallbyip"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
@ -71,16 +70,16 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv(
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("pingallbyippubderp"), hsic.WithTestName("pingallbyippubderp"),
) )
@ -121,25 +120,25 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
headscale, err := scenario.Headscale(opts...) headscale, err := scenario.Headscale(opts...)
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
for userName, clientCount := range spec { for _, userName := range spec.Users {
err = scenario.CreateUser(userName) err = scenario.CreateUser(userName)
if err != nil { if err != nil {
t.Fatalf("failed to create user %s: %s", userName, err) t.Fatalf("failed to create user %s: %s", userName, err)
} }
err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
if err != nil { if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
} }
@ -194,15 +193,15 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
headscale, err := scenario.Headscale( headscale, err := scenario.Headscale(
hsic.WithTestName("ephemeral2006"), hsic.WithTestName("ephemeral2006"),
hsic.WithConfigEnv(map[string]string{ hsic.WithConfigEnv(map[string]string{
@ -211,13 +210,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
for userName, clientCount := range spec { for _, userName := range spec.Users {
err = scenario.CreateUser(userName) err = scenario.CreateUser(userName)
if err != nil { if err != nil {
t.Fatalf("failed to create user %s: %s", userName, err) t.Fatalf("failed to create user %s: %s", userName, err)
} }
err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
if err != nil { if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
} }
@ -287,7 +286,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// registered. // registered.
time.Sleep(3 * time.Minute) time.Sleep(3 * time.Minute)
for userName := range spec { for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName) nodes, err := headscale.ListNodes(userName)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -308,16 +307,16 @@ func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname"))
"user3": len(MustTestVersions),
"user4": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname"))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -357,15 +356,16 @@ func TestTaildrop(t *testing.T) {
return err return err
} }
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{},
"taildrop": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{},
hsic.WithTestName("taildrop"), hsic.WithTestName("taildrop"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(), hsic.WithTLS(),
@ -522,23 +522,22 @@ func TestUpdateHostnameFromClient(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
user := "update-hostname-from-client"
hostnames := map[string]string{ hostnames := map[string]string{
"1": "user1-host", "1": "user1-host",
"2": "User2-Host", "2": "User2-Host",
"3": "user3-host", "3": "user3-host",
} }
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErrf(t, "failed to create scenario: %s", err) assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname"))
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("updatehostname"))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -650,15 +649,16 @@ func TestExpireNode(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode"))
"user1": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("expirenode"))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -684,7 +684,7 @@ func TestExpireNode(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
// Assert that we have the original count - self // Assert that we have the original count - self
assert.Len(t, status.Peers(), spec["user1"]-1) assert.Len(t, status.Peers(), spec.NodesPerUser-1)
} }
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
@ -776,15 +776,16 @@ func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online"))
"user1": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online"))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -891,18 +892,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second err = scenario.CreateHeadscaleEnv(
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("pingallbyipmany"), hsic.WithTestName("pingallbyipmany"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
@ -973,18 +972,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second err = scenario.CreateHeadscaleEnv(
// get created? maybe only when many?
spec := map[string]int{
"user1": 1,
"user2": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("deletenocrash"), hsic.WithTestName("deletenocrash"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),

View File

@ -56,7 +56,7 @@ type HeadscaleInContainer struct {
pool *dockertest.Pool pool *dockertest.Pool
container *dockertest.Resource container *dockertest.Resource
network *dockertest.Network networks []*dockertest.Network
pgContainer *dockertest.Resource pgContainer *dockertest.Resource
@ -268,7 +268,7 @@ func WithTimezone(timezone string) Option {
// New returns a new HeadscaleInContainer instance. // New returns a new HeadscaleInContainer instance.
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
network *dockertest.Network, networks []*dockertest.Network,
opts ...Option, opts ...Option,
) (*HeadscaleInContainer, error) { ) (*HeadscaleInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength) hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength)
@ -283,7 +283,7 @@ func New(
port: headscaleDefaultPort, port: headscaleDefaultPort,
pool: pool, pool: pool,
network: network, networks: networks,
env: DefaultConfigEnv(), env: DefaultConfigEnv(),
filesInContainer: []fileInContainer{}, filesInContainer: []fileInContainer{},
@ -315,7 +315,7 @@ func New(
Name: fmt.Sprintf("postgres-%s", hash), Name: fmt.Sprintf("postgres-%s", hash),
Repository: "postgres", Repository: "postgres",
Tag: "latest", Tag: "latest",
Networks: []*dockertest.Network{network}, Networks: networks,
Env: []string{ Env: []string{
"POSTGRES_USER=headscale", "POSTGRES_USER=headscale",
"POSTGRES_PASSWORD=headscale", "POSTGRES_PASSWORD=headscale",
@ -357,7 +357,7 @@ func New(
runOptions := &dockertest.RunOptions{ runOptions := &dockertest.RunOptions{
Name: hsic.hostname, Name: hsic.hostname,
ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...), ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...),
Networks: []*dockertest.Network{network}, Networks: networks,
// Cmd: []string{"headscale", "serve"}, // Cmd: []string{"headscale", "serve"},
// TODO(kradalby): Get rid of this hack, we currently need to give us some // TODO(kradalby): Get rid of this hack, we currently need to give us some
// to inject the headscale configuration further down. // to inject the headscale configuration further down.
@ -630,11 +630,6 @@ func (t *HeadscaleInContainer) Execute(
return stdout, nil return stdout, nil
} }
// GetIP returns the docker container IP as a string.
func (t *HeadscaleInContainer) GetIP() string {
return t.container.GetIPInNetwork(t.network)
}
// GetPort returns the docker container port as a string. // GetPort returns the docker container port as a string.
func (t *HeadscaleInContainer) GetPort() string { func (t *HeadscaleInContainer) GetPort() string {
return fmt.Sprintf("%d", t.port) return fmt.Sprintf("%d", t.port)

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +1,37 @@
package integration package integration
import ( import (
"context"
"crypto/tls"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"net"
"net/http"
"net/http/cookiejar"
"net/netip" "net/netip"
"net/url"
"os" "os"
"sort" "sort"
"strconv"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/dsic" "github.com/juanfont/headscale/integration/dsic"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -26,6 +39,7 @@ import (
xmaps "golang.org/x/exp/maps" xmaps "golang.org/x/exp/maps"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/util/mak"
) )
const ( const (
@ -87,32 +101,135 @@ type Scenario struct {
users map[string]*User users map[string]*User
pool *dockertest.Pool pool *dockertest.Pool
network *dockertest.Network networks map[string]*dockertest.Network
mockOIDC scenarioOIDC
extraServices map[string][]*dockertest.Resource
mu sync.Mutex mu sync.Mutex
spec ScenarioSpec
userToNetwork map[string]*dockertest.Network
}
// ScenarioSpec describes the users, nodes, and network topology to
// set up for a given scenario.
type ScenarioSpec struct {
// Users is a list of usernames that will be created.
// Each created user will get nodes equivalent to NodesPerUser
Users []string
// NodesPerUser is how many nodes should be attached to each user.
NodesPerUser int
// Networks, if set, is the seperate Docker networks that should be
// created and a list of the users that should be placed in those networks.
// If not set, a single network will be created and all users+nodes will be
// added there.
// Please note that Docker networks are not necessarily routable and
// connections between them might fall back to DERP.
Networks map[string][]string
// ExtraService, if set, is additional a map of network to additional
// container services that should be set up. These container services
// typically dont run Tailscale, e.g. web service to test subnet router.
ExtraService map[string][]extraServiceFunc
// Versions is specific list of versions to use for the test.
Versions []string
// OIDCUsers, if populated, will start a Mock OIDC server and populate
// the user login stack with the given users.
// If the NodesPerUser is set, it should align with this list to ensure
// the correct users are logged in.
// This is because the MockOIDC server can only serve login
// requests based on a queue it has been given on startup.
// We currently only populates it with one login request per user.
OIDCUsers []mockoidc.MockUser
OIDCAccessTTL time.Duration
MaxWait time.Duration
}
var TestHashPrefix = "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength)
var TestDefaultNetwork = TestHashPrefix + "-default"
func prefixedNetworkName(name string) string {
return TestHashPrefix + "-" + name
} }
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
// a set of Users and TailscaleClients. // a set of Users and TailscaleClients.
func NewScenario(maxWait time.Duration) (*Scenario, error) { func NewScenario(spec ScenarioSpec) (*Scenario, error) {
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
if err != nil {
return nil, err
}
pool, err := dockertest.NewPool("") pool, err := dockertest.NewPool("")
if err != nil { if err != nil {
return nil, fmt.Errorf("could not connect to docker: %w", err) return nil, fmt.Errorf("could not connect to docker: %w", err)
} }
pool.MaxWait = maxWait if spec.MaxWait == 0 {
pool.MaxWait = dockertestMaxWait()
networkName := fmt.Sprintf("hs-%s", hash) } else {
if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { pool.MaxWait = spec.MaxWait
networkName = overrideNetworkName
} }
network, err := dockertestutil.GetFirstOrCreateNetwork(pool, networkName) s := &Scenario{
controlServers: xsync.NewMapOf[string, ControlServer](),
users: make(map[string]*User),
pool: pool,
spec: spec,
}
var userToNetwork map[string]*dockertest.Network
if spec.Networks != nil || len(spec.Networks) != 0 {
for name, users := range s.spec.Networks {
networkName := TestHashPrefix + "-" + name
network, err := s.AddNetwork(networkName)
if err != nil {
return nil, err
}
for _, user := range users {
if n2, ok := userToNetwork[user]; ok {
return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name)
}
mak.Set(&userToNetwork, user, network)
}
}
} else {
_, err := s.AddNetwork(TestDefaultNetwork)
if err != nil {
return nil, err
}
}
for network, extras := range spec.ExtraService {
for _, extra := range extras {
svc, err := extra(s, network)
if err != nil {
return nil, err
}
mak.Set(&s.extraServices, prefixedNetworkName(network), append(s.extraServices[prefixedNetworkName(network)], svc))
}
}
s.userToNetwork = userToNetwork
if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 {
ttl := defaultAccessTTL
if spec.OIDCAccessTTL != 0 {
ttl = spec.OIDCAccessTTL
}
err = s.runMockOIDC(ttl, spec.OIDCUsers)
if err != nil {
return nil, err
}
}
return s, nil
}
func (s *Scenario) AddNetwork(name string) (*dockertest.Network, error) {
network, err := dockertestutil.GetFirstOrCreateNetwork(s.pool, name)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create or get network: %w", err) return nil, fmt.Errorf("failed to create or get network: %w", err)
} }
@ -120,18 +237,58 @@ func NewScenario(maxWait time.Duration) (*Scenario, error) {
// We run the test suite in a docker container that calls a couple of endpoints for // We run the test suite in a docker container that calls a couple of endpoints for
// readiness checks, this ensures that we can run the tests with individual networks // readiness checks, this ensures that we can run the tests with individual networks
// and have the client reach the different containers // and have the client reach the different containers
err = dockertestutil.AddContainerToNetwork(pool, network, "headscale-test-suite") // TODO(kradalby): Can the test-suite be renamed so we can have multiple?
err = dockertestutil.AddContainerToNetwork(s.pool, network, "headscale-test-suite")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add test suite container to network: %w", err) return nil, fmt.Errorf("failed to add test suite container to network: %w", err)
} }
return &Scenario{ mak.Set(&s.networks, name, network)
controlServers: xsync.NewMapOf[string, ControlServer](),
users: make(map[string]*User),
pool: pool, return network, nil
network: network, }
}, nil
func (s *Scenario) Networks() []*dockertest.Network {
if len(s.networks) == 0 {
panic("Scenario.Networks called with empty network list")
}
return xmaps.Values(s.networks)
}
func (s *Scenario) Network(name string) (*dockertest.Network, error) {
net, ok := s.networks[prefixedNetworkName(name)]
if !ok {
return nil, fmt.Errorf("no network named: %s", name)
}
return net, nil
}
func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
net, ok := s.networks[prefixedNetworkName(name)]
if !ok {
return nil, fmt.Errorf("no network named: %s", name)
}
for _, ipam := range net.Network.IPAM.Config {
pref, err := netip.ParsePrefix(ipam.Subnet)
if err != nil {
return nil, err
}
return &pref, nil
}
return nil, fmt.Errorf("no prefix found in network: %s", name)
}
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
res, ok := s.extraServices[prefixedNetworkName(name)]
if !ok {
return nil, fmt.Errorf("no network named: %s", name)
}
return res, nil
} }
func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
@ -184,14 +341,27 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
} }
} }
if err := s.pool.RemoveNetwork(s.network); err != nil { for _, svcs := range s.extraServices {
log.Printf("failed to remove network: %s", err) for _, svc := range svcs {
err := svc.Close()
if err != nil {
log.Printf("failed to tear down service %q: %s", svc.Container.Name, err)
}
}
} }
// TODO(kradalby): This seem redundant to the previous call if s.mockOIDC.r != nil {
// if err := s.network.Close(); err != nil { s.mockOIDC.r.Close()
// return fmt.Errorf("failed to tear down network: %w", err) if err := s.mockOIDC.r.Close(); err != nil {
// } log.Printf("failed to tear down oidc server: %s", err)
}
}
for _, network := range s.networks {
if err := network.Close(); err != nil {
log.Printf("failed to tear down network: %s", err)
}
}
} }
// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient) // Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient)
@ -235,7 +405,7 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) {
opts = append(opts, hsic.WithPolicyV2()) opts = append(opts, hsic.WithPolicyV2())
} }
headscale, err := hsic.New(s.pool, s.network, opts...) headscale, err := hsic.New(s.pool, s.Networks(), opts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create headscale container: %w", err) return nil, fmt.Errorf("failed to create headscale container: %w", err)
} }
@ -312,7 +482,6 @@ func (s *Scenario) CreateTailscaleNode(
tsClient, err := tsic.New( tsClient, err := tsic.New(
s.pool, s.pool,
version, version,
s.network,
opts..., opts...,
) )
if err != nil { if err != nil {
@ -345,11 +514,15 @@ func (s *Scenario) CreateTailscaleNodesInUser(
) error { ) error {
if user, ok := s.users[userStr]; ok { if user, ok := s.users[userStr]; ok {
var versions []string var versions []string
for i := 0; i < count; i++ { for i := range count {
version := requestedVersion version := requestedVersion
if requestedVersion == "all" { if requestedVersion == "all" {
if s.spec.Versions != nil {
version = s.spec.Versions[i%len(s.spec.Versions)]
} else {
version = MustTestVersions[i%len(MustTestVersions)] version = MustTestVersions[i%len(MustTestVersions)]
} }
}
versions = append(versions, version) versions = append(versions, version)
headscale, err := s.Headscale() headscale, err := s.Headscale()
@ -372,14 +545,12 @@ func (s *Scenario) CreateTailscaleNodesInUser(
tsClient, err := tsic.New( tsClient, err := tsic.New(
s.pool, s.pool,
version, version,
s.network,
opts..., opts...,
) )
s.mu.Unlock() s.mu.Unlock()
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
"failed to create tailscale (%s) node: %w", "failed to create tailscale node: %w",
tsClient.Hostname(),
err, err,
) )
} }
@ -492,11 +663,24 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int) error {
return nil return nil
} }
// CreateHeadscaleEnv is a convenient method returning a complete Headcale func (s *Scenario) CreateHeadscaleEnvWithLoginURL(
// test environment with nodes of all versions, joined to the server with X tsOpts []tsic.Option,
// users. opts ...hsic.Option,
) error {
return s.createHeadscaleEnv(true, tsOpts, opts...)
}
func (s *Scenario) CreateHeadscaleEnv( func (s *Scenario) CreateHeadscaleEnv(
users map[string]int, tsOpts []tsic.Option,
opts ...hsic.Option,
) error {
return s.createHeadscaleEnv(false, tsOpts, opts...)
}
// CreateHeadscaleEnv starts the headscale environment and the clients
// according to the ScenarioSpec passed to the Scenario.
func (s *Scenario) createHeadscaleEnv(
withURL bool,
tsOpts []tsic.Option, tsOpts []tsic.Option,
opts ...hsic.Option, opts ...hsic.Option,
) error { ) error {
@ -505,34 +689,188 @@ func (s *Scenario) CreateHeadscaleEnv(
return err return err
} }
usernames := xmaps.Keys(users) sort.Strings(s.spec.Users)
sort.Strings(usernames) for _, user := range s.spec.Users {
for _, username := range usernames { err = s.CreateUser(user)
clientCount := users[username]
err = s.CreateUser(username)
if err != nil { if err != nil {
return err return err
} }
err = s.CreateTailscaleNodesInUser(username, "all", clientCount, tsOpts...) var opts []tsic.Option
if s.userToNetwork != nil {
opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user]))
} else {
opts = append(tsOpts, tsic.WithNetwork(s.networks[TestDefaultNetwork]))
}
err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...)
if err != nil { if err != nil {
return err return err
} }
key, err := s.CreatePreAuthKey(username, true, false) if withURL {
err = s.RunTailscaleUpWithURL(user, headscale.GetEndpoint())
if err != nil {
return err
}
} else {
key, err := s.CreatePreAuthKey(user, true, false)
if err != nil { if err != nil {
return err return err
} }
err = s.RunTailscaleUp(username, headscale.GetEndpoint(), key.GetKey()) err = s.RunTailscaleUp(user, headscale.GetEndpoint(), key.GetKey())
if err != nil { if err != nil {
return err return err
} }
} }
}
return nil return nil
} }
func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
tsc := client
user.joinWaitGroup.Go(func() error {
loginURL, err := tsc.LoginWithURL(loginServer)
if err != nil {
log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
}
body, err := doLoginURL(tsc.Hostname(), loginURL)
if err != nil {
return err
}
// If the URL is not a OIDC URL, then we need to
// run the register command to fully log in the client.
if !strings.Contains(loginURL.String(), "/oidc/") {
s.runHeadscaleRegister(userStr, body)
}
return nil
})
log.Printf("client %s is ready", client.Hostname())
}
if err := user.joinWaitGroup.Wait(); err != nil {
return err
}
for _, client := range user.Clients {
err := client.WaitForRunning()
if err != nil {
return fmt.Errorf(
"%s tailscale node has not reached running: %w",
client.Hostname(),
err,
)
}
}
return nil
}
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
// doLoginURL visits the given login URL and returns the body as a
// string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
var err error
hc := &http.Client{
Transport: LoggingRoundTripper{},
}
hc.Jar, err = cookiejar.New(nil)
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
}
log.Printf("%s logging in with url", hostname)
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := hc.Do(req)
if err != nil {
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
}
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body)
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("%s failed to read response body: %s", hostname, err)
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
return string(body), nil
}
var errParseAuthPage = errors.New("failed to parse auth page")
func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
// see api.go HTML template
codeSep := strings.Split(string(body), "</code>")
if len(codeSep) != 2 {
return errParseAuthPage
}
keySep := strings.Split(codeSep[0], "key ")
if len(keySep) != 2 {
return errParseAuthPage
}
key := keySep[1]
log.Printf("registering node %s", key)
if headscale, err := s.Headscale(); err == nil {
_, err = headscale.Execute(
[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key},
)
if err != nil {
log.Printf("failed to register node: %s", err)
return err
}
return nil
}
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
}
type LoggingRoundTripper struct{}
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
noTls := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}
resp, err := noTls.RoundTrip(req)
if err != nil {
return nil, err
}
log.Printf("---")
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
return resp, nil
}
// GetIPs returns all netip.Addr of TailscaleClients associated with a User // GetIPs returns all netip.Addr of TailscaleClients associated with a User
// in a Scenario. // in a Scenario.
func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) {
@ -670,7 +1008,7 @@ func (s *Scenario) WaitForTailscaleLogout() error {
// CreateDERPServer creates a new DERP server in a container. // CreateDERPServer creates a new DERP server in a container.
func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) { func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) {
derp, err := dsic.New(s.pool, version, s.network, opts...) derp, err := dsic.New(s.pool, version, s.Networks(), opts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create DERP server: %w", err) return nil, fmt.Errorf("failed to create DERP server: %w", err)
} }
@ -684,3 +1022,216 @@ func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.
return derp, nil return derp, nil
} }
type scenarioOIDC struct {
r *dockertest.Resource
cfg *types.OIDCConfig
}
func (o *scenarioOIDC) Issuer() string {
if o.cfg == nil {
panic("OIDC has not been created")
}
return o.cfg.Issuer
}
func (o *scenarioOIDC) ClientSecret() string {
if o.cfg == nil {
panic("OIDC has not been created")
}
return o.cfg.ClientSecret
}
func (o *scenarioOIDC) ClientID() string {
if o.cfg == nil {
panic("OIDC has not been created")
}
return o.cfg.ClientID
}
const (
dockerContextPath = "../."
hsicOIDCMockHashLength = 6
defaultAccessTTL = 10 * time.Minute
)
var errStatusCodeNotOK = errors.New("status code not OK")
func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) error {
port, err := dockertestutil.RandomFreeHostPort()
if err != nil {
log.Fatalf("could not find an open port: %s", err)
}
portNotation := fmt.Sprintf("%d/tcp", port)
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
usersJSON, err := json.Marshal(users)
if err != nil {
return err
}
mockOidcOptions := &dockertest.RunOptions{
Name: hostname,
Cmd: []string{"headscale", "mockoidc"},
ExposedPorts: []string{portNotation},
PortBindings: map[docker.Port][]docker.PortBinding{
docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
},
Networks: s.Networks(),
Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
}
err = s.pool.RemoveContainerByName(hostname)
if err != nil {
return err
}
s.mockOIDC = scenarioOIDC{}
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
dockertestutil.DockerRestartPolicy); err == nil {
s.mockOIDC.r = pmockoidc
} else {
return err
}
// headscale needs to set up the provider with a specific
// IP addr to ensure we get the correct config from the well-known
// endpoint.
network := s.Networks()[0]
ipAddr := s.mockOIDC.r.GetIPInNetwork(network)
log.Println("Waiting for headscale mock oidc to be ready for tests")
hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port))
if err := s.pool.Retry(func() error {
oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
httpClient := &http.Client{}
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("headscale mock OIDC tests is not ready: %s\n", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errStatusCodeNotOK
}
return nil
}); err != nil {
return err
}
s.mockOIDC.cfg = &types.OIDCConfig{
Issuer: fmt.Sprintf(
"http://%s/oidc",
hostEndpoint,
),
ClientID: "superclient",
ClientSecret: "supersecret",
OnlyStartIfOIDCIsAvailable: true,
}
log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
return nil
}
type extraServiceFunc func(*Scenario, string) (*dockertest.Resource, error)
func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
// port, err := dockertestutil.RandomFreeHostPort()
// if err != nil {
// log.Fatalf("could not find an open port: %s", err)
// }
// portNotation := fmt.Sprintf("%d/tcp", port)
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-webservice-%s", hash)
network, ok := s.networks[prefixedNetworkName(networkName)]
if !ok {
return nil, fmt.Errorf("network does not exist: %s", networkName)
}
webOpts := &dockertest.RunOptions{
Name: hostname,
Cmd: []string{"/bin/sh", "-c", "cd / ; python3 -m http.server --bind :: 80"},
// ExposedPorts: []string{portNotation},
// PortBindings: map[docker.Port][]docker.PortBinding{
// docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
// },
Networks: []*dockertest.Network{network},
Env: []string{},
}
webBOpts := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
}
web, err := s.pool.BuildAndRunWithBuildOptions(
webBOpts,
webOpts,
dockertestutil.DockerRestartPolicy)
if err != nil {
return nil, err
}
// headscale needs to set up the provider with a specific
// IP addr to ensure we get the correct config from the well-known
// endpoint.
// ipAddr := web.GetIPInNetwork(network)
// log.Println("Waiting for headscale mock oidc to be ready for tests")
// hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port))
// if err := s.pool.Retry(func() error {
// oidcConfigURL := fmt.Sprintf("http://%s/etc/hostname", hostEndpoint)
// httpClient := &http.Client{}
// ctx := context.Background()
// req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil)
// resp, err := httpClient.Do(req)
// if err != nil {
// log.Printf("headscale mock OIDC tests is not ready: %s\n", err)
// return err
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return errStatusCodeNotOK
// }
// return nil
// }); err != nil {
// return err
// }
return web, nil
}

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic"
) )
// This file is intended to "test the test framework", by proxy it will also test // This file is intended to "test the test framework", by proxy it will also test
@ -33,7 +34,7 @@ func TestHeadscale(t *testing.T) {
user := "test-space" user := "test-space"
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(ScenarioSpec{})
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
@ -68,38 +69,6 @@ func TestHeadscale(t *testing.T) {
}) })
} }
// If subtests are parallel, then they will start before setup is run.
// This might mean we approach setup slightly wrong, but for now, ignore
// the linter
// nolint:tparallel
func TestCreateTailscale(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "only-create-containers"
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
scenario.users[user] = &User{
Clients: make(map[string]TailscaleClient),
}
t.Run("create-tailscale", func(t *testing.T) {
err := scenario.CreateTailscaleNodesInUser(user, "all", 3)
if err != nil {
t.Fatalf("failed to add tailscale nodes: %s", err)
}
if clients := len(scenario.users[user].Clients); clients != 3 {
t.Fatalf("wrong number of tailscale clients: %d != %d", clients, 3)
}
// TODO(kradalby): Test "all" version logic
})
}
// If subtests are parallel, then they will start before setup is run. // If subtests are parallel, then they will start before setup is run.
// This might mean we approach setup slightly wrong, but for now, ignore // This might mean we approach setup slightly wrong, but for now, ignore
// the linter // the linter
@ -114,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
count := 1 count := 1
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(ScenarioSpec{})
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
@ -142,7 +111,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
}) })
t.Run("create-tailscale", func(t *testing.T) { t.Run("create-tailscale", func(t *testing.T) {
err := scenario.CreateTailscaleNodesInUser(user, "unstable", count) err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
if err != nil { if err != nil {
t.Fatalf("failed to add tailscale nodes: %s", err) t.Fatalf("failed to add tailscale nodes: %s", err)
} }

View File

@ -50,15 +50,15 @@ var retry = func(times int, sleepInterval time.Duration,
func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario { func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario {
t.Helper() t.Helper()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: clientsPerUser,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)
spec := map[string]int{ err = scenario.CreateHeadscaleEnv(
"user1": clientsPerUser,
"user2": clientsPerUser,
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{ []tsic.Option{
tsic.WithSSH(), tsic.WithSSH(),

View File

@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
@ -27,6 +28,9 @@ type TailscaleClient interface {
Up() error Up() error
Down() error Down() error
IPs() ([]netip.Addr, error) IPs() ([]netip.Addr, error)
MustIPs() []netip.Addr
MustIPv4() netip.Addr
MustIPv6() netip.Addr
FQDN() (string, error) FQDN() (string, error)
Status(...bool) (*ipnstate.Status, error) Status(...bool) (*ipnstate.Status, error)
MustStatus() *ipnstate.Status MustStatus() *ipnstate.Status
@ -38,6 +42,7 @@ type TailscaleClient interface {
WaitForPeers(expected int) error WaitForPeers(expected int) error
Ping(hostnameOrIP string, opts ...tsic.PingOption) error Ping(hostnameOrIP string, opts ...tsic.PingOption) error
Curl(url string, opts ...tsic.CurlOption) (string, error) Curl(url string, opts ...tsic.CurlOption) (string, error)
Traceroute(netip.Addr) (util.Traceroute, error)
ID() string ID() string
ReadFile(path string) ([]byte, error) ReadFile(path string) ([]byte, error)

View File

@ -13,6 +13,7 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -81,6 +82,7 @@ type TailscaleInContainer struct {
workdir string workdir string
netfilter string netfilter string
extraLoginArgs []string extraLoginArgs []string
withAcceptRoutes bool
// build options, solely for HEAD // build options, solely for HEAD
buildConfig TailscaleInContainerBuildConfig buildConfig TailscaleInContainerBuildConfig
@ -101,26 +103,10 @@ func WithCACert(cert []byte) Option {
} }
} }
// WithOrCreateNetwork sets the Docker container network to use with // WithNetwork sets the Docker container network to use with
// the Tailscale instance, if the parameter is nil, a new network, // the Tailscale instance.
// isolating the TailscaleClient, will be created. If a network is func WithNetwork(network *dockertest.Network) Option {
// passed, the Tailscale instance will join the given network.
func WithOrCreateNetwork(network *dockertest.Network) Option {
return func(tsic *TailscaleInContainer) { return func(tsic *TailscaleInContainer) {
if network != nil {
tsic.network = network
return
}
network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool,
fmt.Sprintf("%s-network", tsic.hostname),
)
if err != nil {
log.Fatalf("failed to create network: %s", err)
}
tsic.network = network tsic.network = network
} }
} }
@ -212,11 +198,17 @@ func WithExtraLoginArgs(args []string) Option {
} }
} }
// WithAcceptRoutes tells the node to accept incomming routes.
func WithAcceptRoutes() Option {
return func(tsic *TailscaleInContainer) {
tsic.withAcceptRoutes = true
}
}
// New returns a new TailscaleInContainer instance. // New returns a new TailscaleInContainer instance.
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
version string, version string,
network *dockertest.Network,
opts ...Option, opts ...Option,
) (*TailscaleInContainer, error) { ) (*TailscaleInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength) hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength)
@ -231,7 +223,6 @@ func New(
hostname: hostname, hostname: hostname,
pool: pool, pool: pool,
network: network,
withEntrypoint: []string{ withEntrypoint: []string{
"/bin/sh", "/bin/sh",
@ -244,6 +235,10 @@ func New(
opt(tsic) opt(tsic)
} }
if tsic.network == nil {
return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack()))
}
tailscaleOptions := &dockertest.RunOptions{ tailscaleOptions := &dockertest.RunOptions{
Name: hostname, Name: hostname,
Networks: []*dockertest.Network{tsic.network}, Networks: []*dockertest.Network{tsic.network},
@ -442,7 +437,7 @@ func (t *TailscaleInContainer) Login(
"--login-server=" + loginServer, "--login-server=" + loginServer,
"--authkey=" + authKey, "--authkey=" + authKey,
"--hostname=" + t.hostname, "--hostname=" + t.hostname,
"--accept-routes=false", fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes),
} }
if t.extraLoginArgs != nil { if t.extraLoginArgs != nil {
@ -597,6 +592,33 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
return ips, nil return ips, nil
} }
func (t *TailscaleInContainer) MustIPs() []netip.Addr {
ips, err := t.IPs()
if err != nil {
panic(err)
}
return ips
}
func (t *TailscaleInContainer) MustIPv4() netip.Addr {
for _, ip := range t.MustIPs() {
if ip.Is4() {
return ip
}
}
panic("no ipv4 found")
}
func (t *TailscaleInContainer) MustIPv6() netip.Addr {
for _, ip := range t.MustIPs() {
if ip.Is6() {
return ip
}
}
panic("no ipv6 found")
}
// Status returns the ipnstate.Status of the Tailscale instance. // Status returns the ipnstate.Status of the Tailscale instance.
func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) { func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
command := []string{ command := []string{
@ -992,6 +1014,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err
), ),
) )
if err != nil { if err != nil {
log.Printf("command: %v", command)
log.Printf( log.Printf(
"failed to run ping command from %s to %s, err: %s", "failed to run ping command from %s to %s, err: %s",
t.Hostname(), t.Hostname(),
@ -1108,6 +1131,26 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
return result, nil return result, nil
} }
func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) {
command := []string{
"traceroute",
ip.String(),
}
var result util.Traceroute
stdout, stderr, err := t.Execute(command)
if err != nil {
return result, err
}
result, err = util.ParseTraceroute(stdout + stderr)
if err != nil {
return result, err
}
return result, nil
}
// WriteFile save file inside the Tailscale container. // WriteFile save file inside the Tailscale container.
func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)