Multi network integration tests (#2464)

This commit is contained in:
Kristoffer Dalby 2025-03-21 11:49:32 +01:00 committed by GitHub
parent 707438f25e
commit 603f3ad490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 2385 additions and 1449 deletions

View File

@ -70,8 +70,9 @@ jobs:
- TestAutoApprovedSubRoute2068
- TestSubnetRouteACL
- TestEnablingExitRoutes
- TestSubnetRouterMultiNetwork
- TestSubnetRouterMultiNetworkExitNode
- TestHeadscale
- TestCreateTailscale
- TestTailscaleNodesJoiningHeadcale
- TestSSHOneUserToAll
- TestSSHMultipleUsersAllToAll

View File

@ -70,8 +70,9 @@ jobs:
- TestAutoApprovedSubRoute2068
- TestSubnetRouteACL
- TestEnablingExitRoutes
- TestSubnetRouterMultiNetwork
- TestSubnetRouterMultiNetworkExitNode
- TestHeadscale
- TestCreateTailscale
- TestTailscaleNodesJoiningHeadcale
- TestSSHOneUserToAll
- TestSSHMultipleUsersAllToAll

View File

@ -165,9 +165,13 @@ func Test_fullMapResponse(t *testing.T) {
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",

View File

@ -2,13 +2,13 @@ package mapper
import (
"fmt"
"net/netip"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
@ -49,14 +49,6 @@ func tailNode(
) (*tailcfg.Node, error) {
addrs := node.Prefixes()
allowedIPs := append(
[]netip.Prefix{},
addrs...) // we append the node own IP, as it is required by the clients
for _, route := range node.SubnetRoutes() {
allowedIPs = append(allowedIPs, netip.Prefix(route))
}
var derp int
// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077
@ -89,6 +81,10 @@ func tailNode(
}
tags = lo.Uniq(append(tags, node.ForcedTags...))
allowed := append(node.Prefixes(), primary.PrimaryRoutes(node.ID)...)
allowed = append(allowed, node.ExitRoutes()...)
tsaddr.SortPrefixes(allowed)
tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: node.ID.StableID(),
@ -104,7 +100,7 @@ func tailNode(
DiscoKey: node.DiscoKey,
Addresses: addrs,
PrimaryRoutes: primary.PrimaryRoutes(node.ID),
AllowedIPs: allowedIPs,
AllowedIPs: allowed,
Endpoints: node.Endpoints,
HomeDERP: derp,
LegacyDERPString: legacyDERP,

View File

@ -67,8 +67,6 @@ func TestTailNode(t *testing.T) {
want: &tailcfg.Node{
Name: "empty",
StableID: "0",
Addresses: []netip.Prefix{},
AllowedIPs: []netip.Prefix{},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
@ -139,9 +137,13 @@ func TestTailNode(t *testing.T) {
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
@ -156,10 +158,6 @@ func TestTailNode(t *testing.T) {
Tags: []string{},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
LastSeen: &lastSeen,
MachineAuthorized: true,

View File

@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
xmaps "golang.org/x/exp/maps"
"tailscale.com/net/tsaddr"
"tailscale.com/util/set"
)
@ -74,18 +75,12 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
// If the current primary is not available, select a new one.
for prefix, nodes := range allPrimaries {
if node, ok := pr.primaries[prefix]; ok {
if len(nodes) < 2 {
delete(pr.primaries, prefix)
changed = true
continue
}
// If the current primary is still available, continue.
if slices.Contains(nodes, node) {
continue
}
}
if len(nodes) >= 2 {
if len(nodes) >= 1 {
pr.primaries[prefix] = nodes[0]
changed = true
}
@ -107,12 +102,16 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
return changed
}
func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bool {
// SetRoutes sets the routes for a given Node ID and recalculates the primary routes
// of the headscale.
// It returns true if there was a change in primary routes.
// All exit routes are ignored as they are not used in primary route context.
func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) bool {
pr.mu.Lock()
defer pr.mu.Unlock()
// If no routes are being set, remove the node from the routes map.
if len(prefix) == 0 {
if len(prefixes) == 0 {
if _, ok := pr.routes[node]; ok {
delete(pr.routes, node)
return pr.updatePrimaryLocked()
@ -121,12 +120,17 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bo
return false
}
if _, ok := pr.routes[node]; !ok {
pr.routes[node] = make(set.Set[netip.Prefix], len(prefix))
rs := make(set.Set[netip.Prefix], len(prefixes))
for _, prefix := range prefixes {
if !tsaddr.IsExitRoute(prefix) {
rs.Add(prefix)
}
}
for _, p := range prefix {
pr.routes[node].Add(p)
if rs.Len() != 0 {
pr.routes[node] = rs
} else {
delete(pr.routes, node)
}
return pr.updatePrimaryLocked()
@ -153,6 +157,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
}
}
tsaddr.SortPrefixes(routes)
return routes
}

View File

@ -6,8 +6,10 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/util/set"
)
// mp is a helper function that wraps netip.MustParsePrefix.
@ -17,20 +19,34 @@ func mp(prefix string) netip.Prefix {
func TestPrimaryRoutes(t *testing.T) {
tests := []struct {
name string
operations func(pr *PrimaryRoutes) bool
nodeID types.NodeID
expectedRoutes []netip.Prefix
expectedChange bool
name string
operations func(pr *PrimaryRoutes) bool
expectedRoutes map[types.NodeID]set.Set[netip.Prefix]
expectedPrimaries map[netip.Prefix]types.NodeID
expectedIsPrimary map[types.NodeID]bool
expectedChange bool
// primaries is a map of prefixes to the node that is the primary for that prefix.
primaries map[netip.Prefix]types.NodeID
isPrimary map[types.NodeID]bool
}{
{
name: "single-node-registers-single-route",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1, mp("192.168.1.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: true,
},
{
name: "multiple-nodes-register-different-routes",
@ -38,19 +54,45 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24"))
return pr.SetRoutes(2, mp("192.168.2.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.2.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
mp("192.168.2.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
2: true,
},
expectedChange: true,
},
{
name: "multiple-nodes-register-overlapping-routes",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
return pr.SetRoutes(2, mp("192.168.1.0/24")) // true
pr.SetRoutes(1, mp("192.168.1.0/24")) // true
return pr.SetRoutes(2, mp("192.168.1.0/24")) // false
},
nodeID: 1,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: true,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: false,
},
{
name: "node-deregisters-a-route",
@ -58,9 +100,10 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24"))
return pr.SetRoutes(1) // Deregister by setting no routes
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
expectedRoutes: nil,
expectedPrimaries: nil,
expectedIsPrimary: nil,
expectedChange: true,
},
{
name: "node-deregisters-one-of-multiple-routes",
@ -68,9 +111,18 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24"))
return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.2.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.2.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: true,
},
{
name: "node-registers-and-deregisters-routes-in-sequence",
@ -80,18 +132,23 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1) // Deregister by setting no routes
return pr.SetRoutes(1, mp("192.168.3.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "no-change-in-primary-routes",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1, mp("192.168.1.0/24"))
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.3.0/24"): {},
},
2: {
mp("192.168.2.0/24"): {},
},
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.2.0/24"): 2,
mp("192.168.3.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
2: true,
},
expectedChange: true,
},
{
name: "multiple-nodes-register-same-route",
@ -100,21 +157,24 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
},
nodeID: 1,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: false,
},
{
name: "register-multiple-routes-shift-primary-check-old-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: true,
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: false,
},
{
name: "register-multiple-routes-shift-primary-check-primary",
@ -124,20 +184,20 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
},
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: true,
},
{
name: "register-multiple-routes-shift-primary-check-non-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
2: true,
},
nodeID: 3,
expectedRoutes: nil,
expectedChange: true,
},
{
@ -150,8 +210,17 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(2) // true, no primary
},
nodeID: 2,
expectedRoutes: nil,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 3,
},
expectedIsPrimary: map[types.NodeID]bool{
3: true,
},
expectedChange: true,
},
{
@ -165,9 +234,7 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(3) // false, no primary
},
nodeID: 2,
expectedRoutes: nil,
expectedChange: false,
expectedChange: true,
},
{
name: "primary-route-map-is-cleared-up",
@ -179,8 +246,17 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(2) // true, no primary
},
nodeID: 2,
expectedRoutes: nil,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 3,
},
expectedIsPrimary: map[types.NodeID]bool{
3: true,
},
expectedChange: true,
},
{
@ -193,8 +269,23 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
},
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
2: true,
},
expectedChange: false,
},
{
@ -207,8 +298,23 @@ func TestPrimaryRoutes(t *testing.T) {
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
},
nodeID: 1,
expectedRoutes: nil,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
2: true,
},
expectedChange: false,
},
{
@ -218,15 +324,30 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
pr.SetRoutes(2) // true, no primary
pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(1) // true, 2 primary
pr.SetRoutes(2) // true, 3 primary
pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 3 primary
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 3 primary
pr.SetRoutes(1) // true, 3 primary
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 3 primary
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
3: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 3,
},
expectedIsPrimary: map[types.NodeID]bool{
3: true,
},
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: false,
},
{
@ -235,16 +356,27 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24"))
return pr.SetRoutes(2, mp("192.168.1.0/24"))
},
nodeID: 1,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: true,
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.1.0/24"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
expectedChange: false,
},
{
name: "deregister-non-existent-route",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1) // Deregister by setting no routes
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
@ -253,17 +385,27 @@ func TestPrimaryRoutes(t *testing.T) {
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1)
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "deregister-empty-prefix-list",
name: "exit-nodes",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1)
pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("10.0.0.0/16"): {},
},
},
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("10.0.0.0/16"): 1,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
@ -284,19 +426,23 @@ func TestPrimaryRoutes(t *testing.T) {
return change1 || change2
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "no-routes-registered",
operations: func(pr *PrimaryRoutes) bool {
// No operations
return false
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
1: {
mp("192.168.1.0/24"): {},
},
2: {
mp("192.168.2.0/24"): {},
},
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
expectedPrimaries: map[netip.Prefix]types.NodeID{
mp("192.168.1.0/24"): 1,
mp("192.168.2.0/24"): 2,
},
expectedIsPrimary: map[types.NodeID]bool{
1: true,
2: true,
},
expectedChange: true,
},
}
@ -307,9 +453,15 @@ func TestPrimaryRoutes(t *testing.T) {
if change != tt.expectedChange {
t.Errorf("change = %v, want %v", change, tt.expectedChange)
}
routes := pr.PrimaryRoutes(tt.nodeID)
if diff := cmp.Diff(tt.expectedRoutes, routes, util.Comparers...); diff != "" {
t.Errorf("PrimaryRoutes() mismatch (-want +got):\n%s", diff)
comps := append(util.Comparers, cmpopts.EquateEmpty())
if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
t.Errorf("routes mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
}
})
}

View File

@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -213,7 +214,7 @@ func (node *Node) RequestTags() []string {
}
func (node *Node) Prefixes() []netip.Prefix {
addrs := []netip.Prefix{}
var addrs []netip.Prefix
for _, nodeAddress := range node.IPs() {
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
addrs = append(addrs, ip)
@ -222,6 +223,19 @@ func (node *Node) Prefixes() []netip.Prefix {
return addrs
}
// ExitRoutes returns a list of both exit routes if the
// node has any exit routes enabled.
// If none are enabled, it will return nil.
func (node *Node) ExitRoutes() []netip.Prefix {
for _, route := range node.SubnetRoutes() {
if tsaddr.IsExitRoute(route) {
return tsaddr.ExitRoutes()
}
}
return nil
}
func (node *Node) IPsAsString() []string {
var ret []string

View File

@ -57,6 +57,15 @@ func GenerateRandomStringDNSSafe(size int) (string, error) {
return str[:size], nil
}
func MustGenerateRandomStringDNSSafe(size int) string {
hash, err := GenerateRandomStringDNSSafe(size)
if err != nil {
panic(err)
}
return hash
}
func TailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))

View File

@ -3,8 +3,12 @@ package util
import (
"errors"
"fmt"
"net/netip"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"tailscale.com/util/cmpver"
)
@ -46,3 +50,126 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
return loginURL, nil
}
type TraceroutePath struct {
// Hop is the current jump in the total traceroute.
Hop int
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// IP is the IP address of the jump
IP netip.Addr
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
}
type Traceroute struct {
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// IP is the IP address of the target
IP netip.Addr
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Success indicates if the traceroute was successful.
Success bool
// Err contains an error if the traceroute was not successful.
Err error
}
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
func ParseTraceroute(output string) (Traceroute, error) {
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) < 1 {
return Traceroute{}, errors.New("empty traceroute output")
}
// Parse the header line
headerRegex := regexp.MustCompile(`traceroute to ([^ ]+) \(([^)]+)\)`)
headerMatches := headerRegex.FindStringSubmatch(lines[0])
if len(headerMatches) != 3 {
return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0])
}
hostname := headerMatches[1]
ipStr := headerMatches[2]
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err)
}
result := Traceroute{
Hostname: hostname,
IP: ip,
Route: []TraceroutePath{},
Success: false,
}
// Parse each hop line
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i])
if len(matches) == 0 {
continue
}
hop, err := strconv.Atoi(matches[1])
if err != nil {
return Traceroute{}, fmt.Errorf("parsing hop number: %w", err)
}
var hopHostname string
var hopIP netip.Addr
var latencies []time.Duration
// Handle hostname and IP
if matches[2] != "" && matches[3] != "" {
hopHostname = matches[2]
hopIP, err = netip.ParseAddr(matches[3])
if err != nil {
return Traceroute{}, fmt.Errorf("parsing hop IP address %s: %w", matches[3], err)
}
} else if matches[4] == "*" {
hopHostname = "*"
// No IP for timeouts
}
// Parse latencies
for j := 5; j <= 7; j++ {
if matches[j] != "" {
ms, err := strconv.ParseFloat(matches[j], 64)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
}
latencies = append(latencies, time.Duration(ms*float64(time.Millisecond)))
}
}
path := TraceroutePath{
Hop: hop,
Hostname: hopHostname,
IP: hopIP,
Latencies: latencies,
}
result.Route = append(result.Route, path)
// Check if we've reached the target
if hopIP == ip {
result.Success = true
}
}
// If we didn't reach the target, it's unsuccessful
if !result.Success {
result.Err = errors.New("traceroute did not reach target")
}
return result, nil
}

View File

@ -1,6 +1,13 @@
package util
import "testing"
import (
"errors"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func TestTailscaleVersionNewerOrEqual(t *testing.T) {
type args struct {
@ -178,3 +185,186 @@ Success.`,
})
}
}
func TestParseTraceroute(t *testing.T) {
tests := []struct {
name string
input string
want Traceroute
wantErr bool
}{
{
name: "simple successful traceroute",
input: `traceroute to 172.24.0.3 (172.24.0.3), 30 hops max, 46 byte packets
1 ts-head-hk0urr.headscale.net (100.64.0.1) 1.135 ms 0.922 ms 0.619 ms
2 172.24.0.3 (172.24.0.3) 0.593 ms 0.549 ms 0.522 ms`,
want: Traceroute{
Hostname: "172.24.0.3",
IP: netip.MustParseAddr("172.24.0.3"),
Route: []TraceroutePath{
{
Hop: 1,
Hostname: "ts-head-hk0urr.headscale.net",
IP: netip.MustParseAddr("100.64.0.1"),
Latencies: []time.Duration{
1135 * time.Microsecond,
922 * time.Microsecond,
619 * time.Microsecond,
},
},
{
Hop: 2,
Hostname: "172.24.0.3",
IP: netip.MustParseAddr("172.24.0.3"),
Latencies: []time.Duration{
593 * time.Microsecond,
549 * time.Microsecond,
522 * time.Microsecond,
},
},
},
Success: true,
Err: nil,
},
wantErr: false,
},
{
name: "traceroute with timeouts",
input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets
1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms
2 * * *
3 isp-gateway.net (10.0.0.1) 15.678 ms 14.789 ms 15.432 ms
4 8.8.8.8 (8.8.8.8) 20.123 ms 19.876 ms 20.345 ms`,
want: Traceroute{
Hostname: "8.8.8.8",
IP: netip.MustParseAddr("8.8.8.8"),
Route: []TraceroutePath{
{
Hop: 1,
Hostname: "router.local",
IP: netip.MustParseAddr("192.168.1.1"),
Latencies: []time.Duration{
1234 * time.Microsecond,
1123 * time.Microsecond,
1121 * time.Microsecond,
},
},
{
Hop: 2,
Hostname: "*",
},
{
Hop: 3,
Hostname: "isp-gateway.net",
IP: netip.MustParseAddr("10.0.0.1"),
Latencies: []time.Duration{
15678 * time.Microsecond,
14789 * time.Microsecond,
15432 * time.Microsecond,
},
},
{
Hop: 4,
Hostname: "8.8.8.8",
IP: netip.MustParseAddr("8.8.8.8"),
Latencies: []time.Duration{
20123 * time.Microsecond,
19876 * time.Microsecond,
20345 * time.Microsecond,
},
},
},
Success: true,
Err: nil,
},
wantErr: false,
},
{
name: "unsuccessful traceroute",
input: `traceroute to 10.0.0.99 (10.0.0.99), 5 hops max, 60 byte packets
1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms
2 * * *
3 * * *
4 * * *
5 * * *`,
want: Traceroute{
Hostname: "10.0.0.99",
IP: netip.MustParseAddr("10.0.0.99"),
Route: []TraceroutePath{
{
Hop: 1,
Hostname: "router.local",
IP: netip.MustParseAddr("192.168.1.1"),
Latencies: []time.Duration{
1234 * time.Microsecond,
1123 * time.Microsecond,
1121 * time.Microsecond,
},
},
{
Hop: 2,
Hostname: "*",
},
{
Hop: 3,
Hostname: "*",
},
{
Hop: 4,
Hostname: "*",
},
{
Hop: 5,
Hostname: "*",
},
},
Success: false,
Err: errors.New("traceroute did not reach target"),
},
wantErr: false,
},
{
name: "empty input",
input: "",
want: Traceroute{},
wantErr: true,
},
{
name: "invalid header",
input: "not a valid traceroute output",
want: Traceroute{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseTraceroute(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseTraceroute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
// Special handling for error field since it can't be directly compared with cmp.Diff
gotErr := got.Err
wantErr := tt.want.Err
got.Err = nil
tt.want.Err = nil
if diff := cmp.Diff(tt.want, got, IPComparer); diff != "" {
t.Errorf("ParseTraceroute() mismatch (-want +got):\n%s", diff)
}
// Now check error field separately
if (gotErr == nil) != (wantErr == nil) {
t.Errorf("Error field: got %v, want %v", gotErr, wantErr)
} else if gotErr != nil && wantErr != nil && gotErr.Error() != wantErr.Error() {
t.Errorf("Error message: got %q, want %q", gotErr.Error(), wantErr.Error())
}
})
}
}

View File

@ -54,15 +54,16 @@ func aclScenario(
clientsPerUser int,
) *Scenario {
t.Helper()
scenario, err := NewScenario(dockertestMaxWait())
require.NoError(t, err)
spec := map[string]int{
"user1": clientsPerUser,
"user2": clientsPerUser,
spec := ScenarioSpec{
NodesPerUser: clientsPerUser,
Users: []string{"user1", "user2"},
}
err = scenario.CreateHeadscaleEnv(spec,
scenario, err := NewScenario(spec)
require.NoError(t, err)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
// Alpine containers dont have ip6tables set up, which causes
// tailscaled to stop configuring the wgengine, causing it
@ -96,22 +97,24 @@ func aclScenario(
func TestACLHostsInNetMapTable(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 2,
Users: []string{"user1", "user2"},
}
// NOTE: All want cases currently checks the
// total count of expected peers, this would
// typically be the client count of the users
// they can access minus one (them self).
tests := map[string]struct {
users map[string]int
users ScenarioSpec
policy policyv1.ACLPolicy
want map[string]int
}{
// Test that when we have no ACL, each client netmap has
// the amount of peers of the total amount of clients
"base-acls": {
users: map[string]int{
"user1": 2,
"user2": 2,
},
users: spec,
policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
@ -129,10 +132,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// each other, each node has only the number of pairs from
// their own user.
"two-isolated-users": {
users: map[string]int{
"user1": 2,
"user2": 2,
},
users: spec,
policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
@ -155,10 +155,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// are restricted to a single port, nodes are still present
// in the netmap.
"two-restricted-present-in-netmap": {
users: map[string]int{
"user1": 2,
"user2": 2,
},
users: spec,
policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
@ -192,10 +189,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// of peers. This will still result in all the peers as we
// need them present on the other side for the "return path".
"two-ns-one-isolated": {
users: map[string]int{
"user1": 2,
"user2": 2,
},
users: spec,
policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
@ -220,10 +214,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
},
},
"very-large-destination-prefix-1372": {
users: map[string]int{
"user1": 2,
"user2": 2,
},
users: spec,
policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
@ -248,10 +239,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
},
},
"ipv6-acls-1470": {
users: map[string]int{
"user1": 2,
"user2": 2,
},
users: spec,
policy: policyv1.ACLPolicy{
ACLs: []policyv1.ACL{
{
@ -269,12 +257,11 @@ func TestACLHostsInNetMapTable(t *testing.T) {
for name, testCase := range tests {
t.Run(name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
caseSpec := testCase.users
scenario, err := NewScenario(caseSpec)
require.NoError(t, err)
spec := testCase.users
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithACLPolicy(&testCase.policy),
)
@ -944,6 +931,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
for name, testCase := range tests {
t.Run(name, func(t *testing.T) {
scenario := aclScenario(t, &testCase.policy, 1)
defer scenario.ShutdownAssertNoPanics(t)
test1ip := netip.MustParseAddr("100.64.0.1")
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
@ -1022,16 +1010,16 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
"user2": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
// Alpine containers dont have ip6tables set up, which causes
// tailscaled to stop configuring the wgengine, causing it

View File

@ -19,15 +19,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
opts := []hsic.Option{hsic.WithTestName("pingallbyip")}
if https {
opts = append(opts, []hsic.Option{
@ -35,7 +35,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}...)
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -84,7 +84,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
time.Sleep(5 * time.Minute)
}
for userName := range spec {
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userName, true, false)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
@ -152,16 +152,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{},
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
hsic.WithTestName("keyrelognewuser"),
hsic.WithTLS(),
)
@ -203,7 +203,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
// Log in all clients as user1, iterating over the spec only returns the
// clients, not the usernames.
for userName := range spec {
for _, userName := range spec.Users {
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
if err != nil {
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
@ -235,15 +235,15 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
opts := []hsic.Option{hsic.WithTestName("pingallbyip")}
if https {
opts = append(opts, []hsic.Option{
@ -251,7 +251,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}...)
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -300,7 +300,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
time.Sleep(5 * time.Minute)
}
for userName := range spec {
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userName, true, false)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)

View File

@ -1,93 +1,58 @@
package integration
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/cookiejar"
"net/netip"
"net/url"
"sort"
"strconv"
"testing"
"time"
"maps"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
)
const (
dockerContextPath = "../."
hsicOIDCMockHashLength = 6
defaultAccessTTL = 10 * time.Minute
)
var errStatusCodeNotOK = errors.New("status code not OK")
type AuthOIDCScenario struct {
*Scenario
mockOIDC *dockertest.Resource
}
func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
// Logins to MockOIDC is served by a queue with a strict order,
// if we use more than one node per user, the order of the logins
// will not be deterministic and the test will fail.
spec := map[string]int{
"user1": 1,
"user2": 1,
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
oidcMockUser("user2", false),
},
}
mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true),
oidcMockUser("user2", false),
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
err = scenario.CreateHeadscaleEnv(
spec,
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
)
assertNoErrHeadscaleEnv(t, err)
@ -126,7 +91,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 3,
@ -138,7 +103,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
Name: "user2",
Email: "", // Unverified
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user2",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
@ -158,37 +123,29 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
shortAccessTTL := 5 * time.Minute
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
baseScenario.pool.MaxWait = 5 * time.Minute
scenario := AuthOIDCScenario{
Scenario: baseScenario,
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
oidcMockUser("user2", false),
},
OIDCAccessTTL: shortAccessTTL,
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
"user2": 1,
}
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{
oidcMockUser("user1", true),
oidcMockUser("user2", false),
})
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"HEADSCALE_OIDC_CLIENT_SECRET": scenario.mockOIDC.ClientSecret(),
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
}
err = scenario.CreateHeadscaleEnv(
spec,
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcexpirenodes"),
hsic.WithConfigEnv(oidcMap),
)
@ -334,45 +291,35 @@ func TestOIDC024UserCreation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
spec := ScenarioSpec{
NodesPerUser: 1,
}
for _, user := range tt.cliUsers {
spec.Users = append(spec.Users, user)
}
for _, user := range tt.oidcUsers {
spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified))
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{}
for _, user := range tt.cliUsers {
spec[user] = 1
}
var mockusers []mockoidc.MockUser
for _, user := range tt.oidcUsers {
mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified))
}
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
maps.Copy(oidcMap, tt.config)
for k, v := range tt.config {
oidcMap[k] = v
}
err = scenario.CreateHeadscaleEnv(
spec,
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcmigration"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
)
assertNoErrHeadscaleEnv(t, err)
@ -384,7 +331,7 @@ func TestOIDC024UserCreation(t *testing.T) {
headscale, err := scenario.Headscale()
assertNoErr(t, err)
want := tt.want(oidcConfig.Issuer)
want := tt.want(scenario.mockOIDC.Issuer())
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
@ -404,41 +351,33 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
// Single user with one node for testing PKCE flow
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1"},
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
// Single user with one node for testing PKCE flow
spec := map[string]int{
"user1": 1,
}
mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true),
}
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
}
err = scenario.CreateHeadscaleEnv(
spec,
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthpkce"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
)
assertNoErrHeadscaleEnv(t, err)
@ -464,43 +403,33 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
// Create no nodes and no users
scenario, err := NewScenario(ScenarioSpec{
// First login creates the first OIDC user
// Second login logs in the same node, which creates a new node
// Third login logs in the same node back into the original user
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
oidcMockUser("user2", true),
oidcMockUser("user1", true),
},
})
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
// Create no nodes and no users
spec := map[string]int{}
// First login creates the first OIDC user
// Second login logs in the same node, which creates a new node
// Third login logs in the same node back into the original user
mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true),
oidcMockUser("user2", true),
oidcMockUser("user1", true),
}
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
// defer scenario.mockOIDC.Close()
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
err = scenario.CreateHeadscaleEnv(
spec,
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
@ -512,7 +441,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assertNoErr(t, err)
assert.Len(t, listUsers, 0)
ts, err := scenario.CreateTailscaleNode("unstable")
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
@ -530,7 +459,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
@ -575,14 +504,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 2,
Name: "user2",
Email: "user2@headscale.net",
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user2",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
@ -632,14 +561,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user1",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 2,
Name: "user2",
Email: "user2@headscale.net",
Provider: "oidc",
ProviderId: oidcConfig.Issuer + "/user2",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
@ -678,254 +607,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey)
}
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
users map[string]int,
opts ...hsic.Option,
) error {
headscale, err := s.Headscale(opts...)
if err != nil {
return err
}
err = headscale.WaitForRunning()
if err != nil {
return err
}
for userName, clientCount := range users {
if clientCount != 1 {
// OIDC scenario only supports one client per user.
// This is because the MockOIDC server can only serve login
// requests based on a queue it has been given on startup.
// We currently only populates it with one login request per user.
return fmt.Errorf("client count must be 1 for OIDC scenario.")
}
log.Printf("creating user %s with %d clients", userName, clientCount)
err = s.CreateUser(userName)
if err != nil {
return err
}
err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
if err != nil {
return err
}
err = s.runTailscaleUp(userName, headscale.GetEndpoint())
if err != nil {
return err
}
}
return nil
}
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
port, err := dockertestutil.RandomFreeHostPort()
if err != nil {
log.Fatalf("could not find an open port: %s", err)
}
portNotation := fmt.Sprintf("%d/tcp", port)
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
usersJSON, err := json.Marshal(users)
if err != nil {
return nil, err
}
mockOidcOptions := &dockertest.RunOptions{
Name: hostname,
Cmd: []string{"headscale", "mockoidc"},
ExposedPorts: []string{portNotation},
PortBindings: map[docker.Port][]docker.PortBinding{
docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
},
Networks: []*dockertest.Network{s.Scenario.network},
Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
}
err = s.pool.RemoveContainerByName(hostname)
if err != nil {
return nil, err
}
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
dockertestutil.DockerRestartPolicy); err == nil {
s.mockOIDC = pmockoidc
} else {
return nil, err
}
log.Println("Waiting for headscale mock oidc to be ready for tests")
hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port)
if err := s.pool.Retry(func() error {
oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
httpClient := &http.Client{}
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("headscale mock OIDC tests is not ready: %s\n", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errStatusCodeNotOK
}
return nil
}); err != nil {
return nil, err
}
log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
return &types.OIDCConfig{
Issuer: fmt.Sprintf(
"http://%s/oidc",
net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)),
),
ClientID: "superclient",
ClientSecret: "supersecret",
OnlyStartIfOIDCIsAvailable: true,
}, nil
}
type LoggingRoundTripper struct{}
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
noTls := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}
resp, err := noTls.RoundTrip(req)
if err != nil {
return nil, err
}
log.Printf("---")
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
return resp, nil
}
func (s *AuthOIDCScenario) runTailscaleUp(
userStr, loginServer string,
) error {
log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
tsc := client
user.joinWaitGroup.Go(func() error {
loginURL, err := tsc.LoginWithURL(loginServer)
if err != nil {
log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
}
_, err = doLoginURL(tsc.Hostname(), loginURL)
if err != nil {
return err
}
return nil
})
log.Printf("client %s is ready", client.Hostname())
}
if err := user.joinWaitGroup.Wait(); err != nil {
return err
}
for _, client := range user.Clients {
err := client.WaitForRunning()
if err != nil {
return fmt.Errorf(
"%s tailscale node has not reached running: %w",
client.Hostname(),
err,
)
}
}
return nil
}
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
// doLoginURL visits the given login URL and returns the body as a
// string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
var err error
hc := &http.Client{
Transport: LoggingRoundTripper{},
}
hc.Jar, err = cookiejar.New(nil)
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
}
log.Printf("%s logging in with url", hostname)
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := hc.Do(req)
if err != nil {
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
}
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body)
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("%s failed to read response body: %s", hostname, err)
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
return string(body), nil
}
func (s *AuthOIDCScenario) Shutdown() {
err := s.pool.Purge(s.mockOIDC)
if err != nil {
log.Printf("failed to remove mock oidc container")
}
s.Scenario.Shutdown()
}
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
t.Helper()

View File

@ -1,47 +1,33 @@
package integration
import (
"errors"
"fmt"
"log"
"net/netip"
"net/url"
"strings"
"testing"
"slices"
"github.com/juanfont/headscale/integration/hsic"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var errParseAuthPage = errors.New("failed to parse auth page")
type AuthWebFlowScenario struct {
*Scenario
}
func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
if err != nil {
t.Fatalf("failed to create scenario: %s", err)
}
scenario := AuthWebFlowScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(
spec,
nil,
hsic.WithTestName("webauthping"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
@ -71,20 +57,17 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthWebFlowScenario{
Scenario: baseScenario,
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
nil,
hsic.WithTestName("weblogout"),
hsic.WithTLS(),
)
@ -137,8 +120,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("all clients logged out")
for userName := range spec {
err = scenario.runTailscaleUp(userName, headscale.GetEndpoint())
for _, userName := range spec.Users {
err = scenario.RunTailscaleUpWithURL(userName, headscale.GetEndpoint())
if err != nil {
t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err)
}
@ -172,14 +155,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
}
for _, ip := range ips {
found := false
for _, oldIP := range clientIPs[client] {
if ip == oldIP {
found = true
break
}
}
found := slices.Contains(clientIPs[client], ip)
if !found {
t.Fatalf(
@ -194,122 +170,3 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("all clients IPs are the same")
}
func (s *AuthWebFlowScenario) CreateHeadscaleEnv(
users map[string]int,
opts ...hsic.Option,
) error {
headscale, err := s.Headscale(opts...)
if err != nil {
return err
}
err = headscale.WaitForRunning()
if err != nil {
return err
}
for userName, clientCount := range users {
log.Printf("creating user %s with %d clients", userName, clientCount)
err = s.CreateUser(userName)
if err != nil {
return err
}
err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
if err != nil {
return err
}
err = s.runTailscaleUp(userName, headscale.GetEndpoint())
if err != nil {
return err
}
}
return nil
}
func (s *AuthWebFlowScenario) runTailscaleUp(
userStr, loginServer string,
) error {
log.Printf("running tailscale up for user %q", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
c := client
user.joinWaitGroup.Go(func() error {
log.Printf("logging %q into %q", c.Hostname(), loginServer)
loginURL, err := c.LoginWithURL(loginServer)
if err != nil {
log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err)
return err
}
err = s.runHeadscaleRegister(userStr, loginURL)
if err != nil {
log.Printf("failed to register client (%s): %s", c.Hostname(), err)
return err
}
return nil
})
err := client.WaitForRunning()
if err != nil {
log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err)
}
}
if err := user.joinWaitGroup.Wait(); err != nil {
return err
}
for _, client := range user.Clients {
err := client.WaitForRunning()
if err != nil {
return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err)
}
}
return nil
}
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error {
body, err := doLoginURL("web-auth-not-set", loginURL)
if err != nil {
return err
}
// see api.go HTML template
codeSep := strings.Split(string(body), "</code>")
if len(codeSep) != 2 {
return errParseAuthPage
}
keySep := strings.Split(codeSep[0], "key ")
if len(keySep) != 2 {
return errParseAuthPage
}
key := keySep[1]
log.Printf("registering node %s", key)
if headscale, err := s.Headscale(); err == nil {
_, err = headscale.Execute(
[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key},
)
if err != nil {
log.Printf("failed to register node: %s", err)
return err
}
return nil
}
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
}

View File

@ -48,16 +48,15 @@ func TestUserCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
"user2": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -247,15 +246,15 @@ func TestPreAuthKeyCommand(t *testing.T) {
user := "preauthkeyspace"
count := 3
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{user},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -388,16 +387,15 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
t.Parallel()
user := "pre-auth-key-without-exp-user"
spec := ScenarioSpec{
Users: []string{user},
}
scenario, err := NewScenario(dockertestMaxWait())
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -451,16 +449,15 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
t.Parallel()
user := "pre-auth-key-reus-ephm-user"
spec := ScenarioSpec{
Users: []string{user},
}
scenario, err := NewScenario(dockertestMaxWait())
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -530,17 +527,16 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
user1 := "user1"
user2 := "user2"
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{user1},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user1: 1,
user2: 0,
}
err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{},
hsic.WithTestName("clipak"),
hsic.WithEmbeddedDERPServerOnly(),
@ -551,6 +547,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
headscale, err := scenario.Headscale()
assertNoErr(t, err)
err = headscale.CreateUser(user2)
assertNoErr(t, err)
var user2Key v1.PreAuthKey
err = executeAndUnmarshal(
@ -573,10 +572,15 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
)
assertNoErr(t, err)
listNodes, err := headscale.ListNodes()
require.Nil(t, err)
require.Len(t, listNodes, 1)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
assert.Len(t, allClients, 1)
require.Len(t, allClients, 1)
client := allClients[0]
@ -606,12 +610,11 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
}
listNodes, err := headscale.ListNodes()
assert.Nil(t, err)
assert.Len(t, listNodes, 2)
assert.Equal(t, "user1", listNodes[0].GetUser().GetName())
assert.Equal(t, "user2", listNodes[1].GetUser().GetName())
listNodes, err = headscale.ListNodes()
require.Nil(t, err)
require.Len(t, listNodes, 2)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
assert.Equal(t, user2, listNodes[1].GetUser().GetName())
}
func TestApiKeyCommand(t *testing.T) {
@ -620,16 +623,15 @@ func TestApiKeyCommand(t *testing.T) {
count := 5
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
"user2": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -788,15 +790,15 @@ func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -977,15 +979,16 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{tsic.WithTags([]string{"tag:test"})},
hsic.WithTestName("cliadvtags"),
hsic.WithACLPolicy(tt.policy),
@ -996,7 +999,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
assertNoErr(t, err)
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"])
resultMachines := make([]*v1.Node, spec.NodesPerUser)
err = executeAndUnmarshal(
headscale,
[]string{
@ -1029,16 +1032,15 @@ func TestNodeCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"node-user", "other-user"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"node-user": 0,
"other-user": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -1269,15 +1271,15 @@ func TestNodeExpireCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"node-expire-user"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"node-expire-user": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -1395,15 +1397,15 @@ func TestNodeRenameCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"node-rename-command"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"node-rename-command": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -1560,16 +1562,15 @@ func TestNodeMoveCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"old-user", "new-user"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"old-user": 0,
"new-user": 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
@ -1721,16 +1722,15 @@ func TestPolicyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
}
err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{},
hsic.WithTestName("clins"),
hsic.WithConfigEnv(map[string]string{
@ -1808,16 +1808,16 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{},
hsic.WithTestName("clins"),
hsic.WithConfigEnv(map[string]string{

View File

@ -24,5 +24,4 @@ type ControlServer interface {
ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error)
GetCert() []byte
GetHostname() string
GetIP() string
}

View File

@ -31,14 +31,15 @@ func TestDERPVerifyEndpoint(t *testing.T) {
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
assertNoErr(t, err)
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
}
derper, err := scenario.CreateDERPServer("head",
dsic.WithCACert(certHeadscale),
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
@ -65,7 +66,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
},
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())},
err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithCACert(derper.GetCert())},
hsic.WithHostname(hostname),
hsic.WithPort(headscalePort),
hsic.WithCustomTLS(certHeadscale, keyHeadscale),

View File

@ -17,16 +17,16 @@ func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"magicdns1": len(MustTestVersions),
"magicdns2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -87,15 +87,15 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"magicdns1": 1,
"magicdns2": 1,
}
const erPath = "/tmp/extra_records.json"
extraRecords := []tailcfg.DNSRecord{
@ -107,7 +107,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
}
b, _ := json.Marshal(extraRecords)
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{
err = scenario.CreateHeadscaleEnv([]tsic.Option{
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
@ -364,16 +364,16 @@ func TestValidateResolvConf(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"resolvconf1": 3,
"resolvconf2": 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()

View File

@ -35,7 +35,7 @@ type DERPServerInContainer struct {
pool *dockertest.Pool
container *dockertest.Resource
network *dockertest.Network
networks []*dockertest.Network
stunPort int
derpPort int
@ -63,22 +63,22 @@ func WithCACert(cert []byte) Option {
// isolating the DERPer, will be created. If a network is
// passed, the DERPer instance will join the given network.
func WithOrCreateNetwork(network *dockertest.Network) Option {
return func(tsic *DERPServerInContainer) {
return func(dsic *DERPServerInContainer) {
if network != nil {
tsic.network = network
dsic.networks = append(dsic.networks, network)
return
}
network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool,
tsic.hostname+"-network",
dsic.pool,
dsic.hostname+"-network",
)
if err != nil {
log.Fatalf("failed to create network: %s", err)
}
tsic.network = network
dsic.networks = append(dsic.networks, network)
}
}
@ -107,7 +107,7 @@ func WithExtraHosts(hosts []string) Option {
func New(
pool *dockertest.Pool,
version string,
network *dockertest.Network,
networks []*dockertest.Network,
opts ...Option,
) (*DERPServerInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength)
@ -124,7 +124,7 @@ func New(
version: version,
hostname: hostname,
pool: pool,
network: network,
networks: networks,
tlsCert: tlsCert,
tlsKey: tlsKey,
stunPort: 3478, //nolint
@ -148,7 +148,7 @@ func New(
runOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{dsic.network},
Networks: dsic.networks,
ExtraHosts: dsic.withExtraHosts,
// we currently need to give us some time to inject the certificate further down.
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},

View File

@ -1,18 +1,12 @@
package integration
import (
"fmt"
"log"
"net/url"
"strings"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
)
type ClientsSpec struct {
@ -20,21 +14,18 @@ type ClientsSpec struct {
WebsocketDERP int
}
type EmbeddedDERPServerScenario struct {
*Scenario
tsicNetworks map[string]*dockertest.Network
}
func TestDERPServerScenario(t *testing.T) {
spec := map[string]ClientsSpec{
"user1": {
Plain: len(MustTestVersions),
WebsocketDERP: 0,
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2", "user3"},
Networks: map[string][]string{
"usernet1": {"user1"},
"usernet2": {"user2"},
"usernet3": {"user3"},
},
}
derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) {
derpServerScenario(t, spec, false, func(scenario *Scenario) {
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
t.Logf("checking %d clients for websocket connections", len(allClients))
@ -52,14 +43,17 @@ func TestDERPServerScenario(t *testing.T) {
}
func TestDERPServerWebsocketScenario(t *testing.T) {
spec := map[string]ClientsSpec{
"user1": {
Plain: 0,
WebsocketDERP: 2,
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2", "user3"},
Networks: map[string][]string{
"usernet1": []string{"user1"},
"usernet2": []string{"user2"},
"usernet3": []string{"user3"},
},
}
derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) {
derpServerScenario(t, spec, true, func(scenario *Scenario) {
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
t.Logf("checking %d clients for websocket connections", len(allClients))
@ -83,23 +77,22 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
//nolint:thelper
func derpServerScenario(
t *testing.T,
spec map[string]ClientsSpec,
furtherAssertions ...func(*EmbeddedDERPServerScenario),
spec ScenarioSpec,
websocket bool,
furtherAssertions ...func(*Scenario),
) {
IntegrationSkip(t)
// t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
scenario, err := NewScenario(spec)
assertNoErr(t, err)
scenario := EmbeddedDERPServerScenario{
Scenario: baseScenario,
tsicNetworks: map[string]*dockertest.Network{},
}
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
spec,
[]tsic.Option{
tsic.WithWebsocketDERP(websocket),
},
hsic.WithTestName("derpserver"),
hsic.WithExtraPorts([]string{"3478/udp"}),
hsic.WithEmbeddedDERPServerOnly(),
@ -185,182 +178,6 @@ func derpServerScenario(
t.Logf("Run2: %d successful pings out of %d", success, len(allClients)*len(allHostnames))
for _, check := range furtherAssertions {
check(&scenario)
check(scenario)
}
}
func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
users map[string]ClientsSpec,
opts ...hsic.Option,
) error {
hsServer, err := s.Headscale(opts...)
if err != nil {
return err
}
headscaleEndpoint := hsServer.GetEndpoint()
headscaleURL, err := url.Parse(headscaleEndpoint)
if err != nil {
return err
}
headscaleURL.Host = fmt.Sprintf("%s:%s", hsServer.GetHostname(), headscaleURL.Port())
err = hsServer.WaitForRunning()
if err != nil {
return err
}
log.Printf("headscale server ip address: %s", hsServer.GetIP())
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
if err != nil {
return err
}
for userName, clientCount := range users {
err = s.CreateUser(userName)
if err != nil {
return err
}
if clientCount.Plain > 0 {
// Containers that use default DERP config
err = s.CreateTailscaleIsolatedNodesInUser(
hash,
userName,
"all",
clientCount.Plain,
)
if err != nil {
return err
}
}
if clientCount.WebsocketDERP > 0 {
// Containers that use DERP-over-WebSocket
// Note that these clients *must* be built
// from source, which is currently
// only done for HEAD.
err = s.CreateTailscaleIsolatedNodesInUser(
hash,
userName,
tsic.VersionHead,
clientCount.WebsocketDERP,
tsic.WithWebsocketDERP(true),
)
if err != nil {
return err
}
}
key, err := s.CreatePreAuthKey(userName, true, false)
if err != nil {
return err
}
err = s.RunTailscaleUp(userName, headscaleURL.String(), key.GetKey())
if err != nil {
return err
}
}
return nil
}
func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser(
hash string,
userStr string,
requestedVersion string,
count int,
opts ...tsic.Option,
) error {
hsServer, err := s.Headscale()
if err != nil {
return err
}
if user, ok := s.users[userStr]; ok {
for clientN := 0; clientN < count; clientN++ {
networkName := fmt.Sprintf("tsnet-%s-%s-%d",
hash,
userStr,
clientN,
)
network, err := dockertestutil.GetFirstOrCreateNetwork(
s.pool,
networkName,
)
if err != nil {
return fmt.Errorf("failed to create or get %s network: %w", networkName, err)
}
s.tsicNetworks[networkName] = network
err = hsServer.ConnectToNetwork(network)
if err != nil {
return fmt.Errorf("failed to connect headscale to %s network: %w", networkName, err)
}
version := requestedVersion
if requestedVersion == "all" {
version = MustTestVersions[clientN%len(MustTestVersions)]
}
cert := hsServer.GetCert()
opts = append(opts,
tsic.WithCACert(cert),
)
user.createWaitGroup.Go(func() error {
tsClient, err := tsic.New(
s.pool,
version,
network,
opts...,
)
if err != nil {
return fmt.Errorf(
"failed to create tailscale (%s) node: %w",
tsClient.Hostname(),
err,
)
}
err = tsClient.WaitForNeedsLogin()
if err != nil {
return fmt.Errorf(
"failed to wait for tailscaled (%s) to need login: %w",
tsClient.Hostname(),
err,
)
}
s.mu.Lock()
user.Clients[tsClient.Hostname()] = tsClient
s.mu.Unlock()
return nil
})
}
if err := user.createWaitGroup.Wait(); err != nil {
return err
}
return nil
}
return fmt.Errorf("failed to add tailscale nodes: %w", errNoUserAvailable)
}
func (s *EmbeddedDERPServerScenario) Shutdown() {
for _, network := range s.tsicNetworks {
err := s.pool.RemoveNetwork(network)
if err != nil {
log.Printf("failed to remove DERP network %s", network.Network.Name)
}
}
s.Scenario.Shutdown()
}

View File

@ -28,18 +28,17 @@ func TestPingAllByIP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
MaxWait: dockertestMaxWait(),
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("pingallbyip"),
hsic.WithEmbeddedDERPServerOnly(),
@ -71,16 +70,16 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("pingallbyippubderp"),
)
@ -121,25 +120,25 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
headscale, err := scenario.Headscale(opts...)
assertNoErrHeadscaleEnv(t, err)
for userName, clientCount := range spec {
for _, userName := range spec.Users {
err = scenario.CreateUser(userName)
if err != nil {
t.Fatalf("failed to create user %s: %s", userName, err)
}
err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...)
err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
}
@ -194,15 +193,15 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
headscale, err := scenario.Headscale(
hsic.WithTestName("ephemeral2006"),
hsic.WithConfigEnv(map[string]string{
@ -211,13 +210,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
)
assertNoErrHeadscaleEnv(t, err)
for userName, clientCount := range spec {
for _, userName := range spec.Users {
err = scenario.CreateUser(userName)
if err != nil {
t.Fatalf("failed to create user %s: %s", userName, err)
}
err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...)
err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
}
@ -287,7 +286,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// registered.
time.Sleep(3 * time.Minute)
for userName := range spec {
for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName)
if err != nil {
log.Error().
@ -308,16 +307,16 @@ func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user3": len(MustTestVersions),
"user4": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -357,15 +356,16 @@ func TestTaildrop(t *testing.T) {
return err
}
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"taildrop": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{},
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
hsic.WithTestName("taildrop"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
@ -522,23 +522,22 @@ func TestUpdateHostnameFromClient(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "update-hostname-from-client"
hostnames := map[string]string{
"1": "user1-host",
"2": "User2-Host",
"3": "user3-host",
}
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 3,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("updatehostname"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -650,15 +649,16 @@ func TestExpireNode(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("expirenode"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -684,7 +684,7 @@ func TestExpireNode(t *testing.T) {
assertNoErr(t, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), spec["user1"]-1)
assert.Len(t, status.Peers(), spec.NodesPerUser-1)
}
headscale, err := scenario.Headscale()
@ -776,15 +776,16 @@ func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online"))
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -891,18 +892,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("pingallbyipmany"),
hsic.WithEmbeddedDERPServerOnly(),
@ -973,18 +972,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": 1,
"user2": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("deletenocrash"),
hsic.WithEmbeddedDERPServerOnly(),

View File

@ -56,7 +56,7 @@ type HeadscaleInContainer struct {
pool *dockertest.Pool
container *dockertest.Resource
network *dockertest.Network
networks []*dockertest.Network
pgContainer *dockertest.Resource
@ -268,7 +268,7 @@ func WithTimezone(timezone string) Option {
// New returns a new HeadscaleInContainer instance.
func New(
pool *dockertest.Pool,
network *dockertest.Network,
networks []*dockertest.Network,
opts ...Option,
) (*HeadscaleInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength)
@ -282,8 +282,8 @@ func New(
hostname: hostname,
port: headscaleDefaultPort,
pool: pool,
network: network,
pool: pool,
networks: networks,
env: DefaultConfigEnv(),
filesInContainer: []fileInContainer{},
@ -315,7 +315,7 @@ func New(
Name: fmt.Sprintf("postgres-%s", hash),
Repository: "postgres",
Tag: "latest",
Networks: []*dockertest.Network{network},
Networks: networks,
Env: []string{
"POSTGRES_USER=headscale",
"POSTGRES_PASSWORD=headscale",
@ -357,7 +357,7 @@ func New(
runOptions := &dockertest.RunOptions{
Name: hsic.hostname,
ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...),
Networks: []*dockertest.Network{network},
Networks: networks,
// Cmd: []string{"headscale", "serve"},
// TODO(kradalby): Get rid of this hack, we currently need to give us some
// to inject the headscale configuration further down.
@ -630,11 +630,6 @@ func (t *HeadscaleInContainer) Execute(
return stdout, nil
}
// GetIP returns the docker container IP as a string.
func (t *HeadscaleInContainer) GetIP() string {
return t.container.GetIPInNetwork(t.network)
}
// GetPort returns the docker container port as a string.
func (t *HeadscaleInContainer) GetPort() string {
return fmt.Sprintf("%d", t.port)

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +1,37 @@
package integration
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/cookiejar"
"net/netip"
"net/url"
"os"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/dsic"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/puzpuzpuz/xsync/v3"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
@ -26,6 +39,7 @@ import (
xmaps "golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
"tailscale.com/envknob"
"tailscale.com/util/mak"
)
const (
@ -86,33 +100,136 @@ type Scenario struct {
users map[string]*User
pool *dockertest.Pool
network *dockertest.Network
pool *dockertest.Pool
networks map[string]*dockertest.Network
mockOIDC scenarioOIDC
extraServices map[string][]*dockertest.Resource
mu sync.Mutex
spec ScenarioSpec
userToNetwork map[string]*dockertest.Network
}
// ScenarioSpec describes the users, nodes, and network topology to
// set up for a given scenario.
type ScenarioSpec struct {
// Users is a list of usernames that will be created.
// Each created user will get nodes equivalent to NodesPerUser
Users []string
// NodesPerUser is how many nodes should be attached to each user.
NodesPerUser int
// Networks, if set, is the seperate Docker networks that should be
// created and a list of the users that should be placed in those networks.
// If not set, a single network will be created and all users+nodes will be
// added there.
// Please note that Docker networks are not necessarily routable and
// connections between them might fall back to DERP.
Networks map[string][]string
// ExtraService, if set, is additional a map of network to additional
// container services that should be set up. These container services
// typically dont run Tailscale, e.g. web service to test subnet router.
ExtraService map[string][]extraServiceFunc
// Versions is specific list of versions to use for the test.
Versions []string
// OIDCUsers, if populated, will start a Mock OIDC server and populate
// the user login stack with the given users.
// If the NodesPerUser is set, it should align with this list to ensure
// the correct users are logged in.
// This is because the MockOIDC server can only serve login
// requests based on a queue it has been given on startup.
// We currently only populates it with one login request per user.
OIDCUsers []mockoidc.MockUser
OIDCAccessTTL time.Duration
MaxWait time.Duration
}
var TestHashPrefix = "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength)
var TestDefaultNetwork = TestHashPrefix + "-default"
func prefixedNetworkName(name string) string {
return TestHashPrefix + "-" + name
}
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
// a set of Users and TailscaleClients.
func NewScenario(maxWait time.Duration) (*Scenario, error) {
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
if err != nil {
return nil, err
}
func NewScenario(spec ScenarioSpec) (*Scenario, error) {
pool, err := dockertest.NewPool("")
if err != nil {
return nil, fmt.Errorf("could not connect to docker: %w", err)
}
pool.MaxWait = maxWait
networkName := fmt.Sprintf("hs-%s", hash)
if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" {
networkName = overrideNetworkName
if spec.MaxWait == 0 {
pool.MaxWait = dockertestMaxWait()
} else {
pool.MaxWait = spec.MaxWait
}
network, err := dockertestutil.GetFirstOrCreateNetwork(pool, networkName)
s := &Scenario{
controlServers: xsync.NewMapOf[string, ControlServer](),
users: make(map[string]*User),
pool: pool,
spec: spec,
}
var userToNetwork map[string]*dockertest.Network
if spec.Networks != nil || len(spec.Networks) != 0 {
for name, users := range s.spec.Networks {
networkName := TestHashPrefix + "-" + name
network, err := s.AddNetwork(networkName)
if err != nil {
return nil, err
}
for _, user := range users {
if n2, ok := userToNetwork[user]; ok {
return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name)
}
mak.Set(&userToNetwork, user, network)
}
}
} else {
_, err := s.AddNetwork(TestDefaultNetwork)
if err != nil {
return nil, err
}
}
for network, extras := range spec.ExtraService {
for _, extra := range extras {
svc, err := extra(s, network)
if err != nil {
return nil, err
}
mak.Set(&s.extraServices, prefixedNetworkName(network), append(s.extraServices[prefixedNetworkName(network)], svc))
}
}
s.userToNetwork = userToNetwork
if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 {
ttl := defaultAccessTTL
if spec.OIDCAccessTTL != 0 {
ttl = spec.OIDCAccessTTL
}
err = s.runMockOIDC(ttl, spec.OIDCUsers)
if err != nil {
return nil, err
}
}
return s, nil
}
func (s *Scenario) AddNetwork(name string) (*dockertest.Network, error) {
network, err := dockertestutil.GetFirstOrCreateNetwork(s.pool, name)
if err != nil {
return nil, fmt.Errorf("failed to create or get network: %w", err)
}
@ -120,18 +237,58 @@ func NewScenario(maxWait time.Duration) (*Scenario, error) {
// We run the test suite in a docker container that calls a couple of endpoints for
// readiness checks, this ensures that we can run the tests with individual networks
// and have the client reach the different containers
err = dockertestutil.AddContainerToNetwork(pool, network, "headscale-test-suite")
// TODO(kradalby): Can the test-suite be renamed so we can have multiple?
err = dockertestutil.AddContainerToNetwork(s.pool, network, "headscale-test-suite")
if err != nil {
return nil, fmt.Errorf("failed to add test suite container to network: %w", err)
}
return &Scenario{
controlServers: xsync.NewMapOf[string, ControlServer](),
users: make(map[string]*User),
mak.Set(&s.networks, name, network)
pool: pool,
network: network,
}, nil
return network, nil
}
func (s *Scenario) Networks() []*dockertest.Network {
if len(s.networks) == 0 {
panic("Scenario.Networks called with empty network list")
}
return xmaps.Values(s.networks)
}
func (s *Scenario) Network(name string) (*dockertest.Network, error) {
net, ok := s.networks[prefixedNetworkName(name)]
if !ok {
return nil, fmt.Errorf("no network named: %s", name)
}
return net, nil
}
func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
net, ok := s.networks[prefixedNetworkName(name)]
if !ok {
return nil, fmt.Errorf("no network named: %s", name)
}
for _, ipam := range net.Network.IPAM.Config {
pref, err := netip.ParsePrefix(ipam.Subnet)
if err != nil {
return nil, err
}
return &pref, nil
}
return nil, fmt.Errorf("no prefix found in network: %s", name)
}
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
res, ok := s.extraServices[prefixedNetworkName(name)]
if !ok {
return nil, fmt.Errorf("no network named: %s", name)
}
return res, nil
}
func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
@ -184,14 +341,27 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
}
}
if err := s.pool.RemoveNetwork(s.network); err != nil {
log.Printf("failed to remove network: %s", err)
for _, svcs := range s.extraServices {
for _, svc := range svcs {
err := svc.Close()
if err != nil {
log.Printf("failed to tear down service %q: %s", svc.Container.Name, err)
}
}
}
// TODO(kradalby): This seem redundant to the previous call
// if err := s.network.Close(); err != nil {
// return fmt.Errorf("failed to tear down network: %w", err)
// }
if s.mockOIDC.r != nil {
s.mockOIDC.r.Close()
if err := s.mockOIDC.r.Close(); err != nil {
log.Printf("failed to tear down oidc server: %s", err)
}
}
for _, network := range s.networks {
if err := network.Close(); err != nil {
log.Printf("failed to tear down network: %s", err)
}
}
}
// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient)
@ -235,7 +405,7 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) {
opts = append(opts, hsic.WithPolicyV2())
}
headscale, err := hsic.New(s.pool, s.network, opts...)
headscale, err := hsic.New(s.pool, s.Networks(), opts...)
if err != nil {
return nil, fmt.Errorf("failed to create headscale container: %w", err)
}
@ -312,7 +482,6 @@ func (s *Scenario) CreateTailscaleNode(
tsClient, err := tsic.New(
s.pool,
version,
s.network,
opts...,
)
if err != nil {
@ -345,10 +514,14 @@ func (s *Scenario) CreateTailscaleNodesInUser(
) error {
if user, ok := s.users[userStr]; ok {
var versions []string
for i := 0; i < count; i++ {
for i := range count {
version := requestedVersion
if requestedVersion == "all" {
version = MustTestVersions[i%len(MustTestVersions)]
if s.spec.Versions != nil {
version = s.spec.Versions[i%len(s.spec.Versions)]
} else {
version = MustTestVersions[i%len(MustTestVersions)]
}
}
versions = append(versions, version)
@ -372,14 +545,12 @@ func (s *Scenario) CreateTailscaleNodesInUser(
tsClient, err := tsic.New(
s.pool,
version,
s.network,
opts...,
)
s.mu.Unlock()
if err != nil {
return fmt.Errorf(
"failed to create tailscale (%s) node: %w",
tsClient.Hostname(),
"failed to create tailscale node: %w",
err,
)
}
@ -492,11 +663,24 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int) error {
return nil
}
// CreateHeadscaleEnv is a convenient method returning a complete Headcale
// test environment with nodes of all versions, joined to the server with X
// users.
func (s *Scenario) CreateHeadscaleEnvWithLoginURL(
tsOpts []tsic.Option,
opts ...hsic.Option,
) error {
return s.createHeadscaleEnv(true, tsOpts, opts...)
}
func (s *Scenario) CreateHeadscaleEnv(
users map[string]int,
tsOpts []tsic.Option,
opts ...hsic.Option,
) error {
return s.createHeadscaleEnv(false, tsOpts, opts...)
}
// CreateHeadscaleEnv starts the headscale environment and the clients
// according to the ScenarioSpec passed to the Scenario.
func (s *Scenario) createHeadscaleEnv(
withURL bool,
tsOpts []tsic.Option,
opts ...hsic.Option,
) error {
@ -505,34 +689,188 @@ func (s *Scenario) CreateHeadscaleEnv(
return err
}
usernames := xmaps.Keys(users)
sort.Strings(usernames)
for _, username := range usernames {
clientCount := users[username]
err = s.CreateUser(username)
sort.Strings(s.spec.Users)
for _, user := range s.spec.Users {
err = s.CreateUser(user)
if err != nil {
return err
}
err = s.CreateTailscaleNodesInUser(username, "all", clientCount, tsOpts...)
var opts []tsic.Option
if s.userToNetwork != nil {
opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user]))
} else {
opts = append(tsOpts, tsic.WithNetwork(s.networks[TestDefaultNetwork]))
}
err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...)
if err != nil {
return err
}
key, err := s.CreatePreAuthKey(username, true, false)
if err != nil {
return err
}
if withURL {
err = s.RunTailscaleUpWithURL(user, headscale.GetEndpoint())
if err != nil {
return err
}
} else {
key, err := s.CreatePreAuthKey(user, true, false)
if err != nil {
return err
}
err = s.RunTailscaleUp(username, headscale.GetEndpoint(), key.GetKey())
if err != nil {
return err
err = s.RunTailscaleUp(user, headscale.GetEndpoint(), key.GetKey())
if err != nil {
return err
}
}
}
return nil
}
func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
tsc := client
user.joinWaitGroup.Go(func() error {
loginURL, err := tsc.LoginWithURL(loginServer)
if err != nil {
log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
}
body, err := doLoginURL(tsc.Hostname(), loginURL)
if err != nil {
return err
}
// If the URL is not a OIDC URL, then we need to
// run the register command to fully log in the client.
if !strings.Contains(loginURL.String(), "/oidc/") {
s.runHeadscaleRegister(userStr, body)
}
return nil
})
log.Printf("client %s is ready", client.Hostname())
}
if err := user.joinWaitGroup.Wait(); err != nil {
return err
}
for _, client := range user.Clients {
err := client.WaitForRunning()
if err != nil {
return fmt.Errorf(
"%s tailscale node has not reached running: %w",
client.Hostname(),
err,
)
}
}
return nil
}
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
// doLoginURL visits the given login URL and returns the body as a
// string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
var err error
hc := &http.Client{
Transport: LoggingRoundTripper{},
}
hc.Jar, err = cookiejar.New(nil)
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
}
log.Printf("%s logging in with url", hostname)
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := hc.Do(req)
if err != nil {
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
}
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body)
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("%s failed to read response body: %s", hostname, err)
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
return string(body), nil
}
var errParseAuthPage = errors.New("failed to parse auth page")
func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
// see api.go HTML template
codeSep := strings.Split(string(body), "</code>")
if len(codeSep) != 2 {
return errParseAuthPage
}
keySep := strings.Split(codeSep[0], "key ")
if len(keySep) != 2 {
return errParseAuthPage
}
key := keySep[1]
log.Printf("registering node %s", key)
if headscale, err := s.Headscale(); err == nil {
_, err = headscale.Execute(
[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key},
)
if err != nil {
log.Printf("failed to register node: %s", err)
return err
}
return nil
}
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
}
type LoggingRoundTripper struct{}
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
noTls := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}
resp, err := noTls.RoundTrip(req)
if err != nil {
return nil, err
}
log.Printf("---")
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
return resp, nil
}
// GetIPs returns all netip.Addr of TailscaleClients associated with a User
// in a Scenario.
func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) {
@ -670,7 +1008,7 @@ func (s *Scenario) WaitForTailscaleLogout() error {
// CreateDERPServer creates a new DERP server in a container.
func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) {
derp, err := dsic.New(s.pool, version, s.network, opts...)
derp, err := dsic.New(s.pool, version, s.Networks(), opts...)
if err != nil {
return nil, fmt.Errorf("failed to create DERP server: %w", err)
}
@ -684,3 +1022,216 @@ func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.
return derp, nil
}
type scenarioOIDC struct {
r *dockertest.Resource
cfg *types.OIDCConfig
}
func (o *scenarioOIDC) Issuer() string {
if o.cfg == nil {
panic("OIDC has not been created")
}
return o.cfg.Issuer
}
func (o *scenarioOIDC) ClientSecret() string {
if o.cfg == nil {
panic("OIDC has not been created")
}
return o.cfg.ClientSecret
}
func (o *scenarioOIDC) ClientID() string {
if o.cfg == nil {
panic("OIDC has not been created")
}
return o.cfg.ClientID
}
const (
dockerContextPath = "../."
hsicOIDCMockHashLength = 6
defaultAccessTTL = 10 * time.Minute
)
var errStatusCodeNotOK = errors.New("status code not OK")
func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) error {
port, err := dockertestutil.RandomFreeHostPort()
if err != nil {
log.Fatalf("could not find an open port: %s", err)
}
portNotation := fmt.Sprintf("%d/tcp", port)
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
usersJSON, err := json.Marshal(users)
if err != nil {
return err
}
mockOidcOptions := &dockertest.RunOptions{
Name: hostname,
Cmd: []string{"headscale", "mockoidc"},
ExposedPorts: []string{portNotation},
PortBindings: map[docker.Port][]docker.PortBinding{
docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
},
Networks: s.Networks(),
Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
}
err = s.pool.RemoveContainerByName(hostname)
if err != nil {
return err
}
s.mockOIDC = scenarioOIDC{}
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
dockertestutil.DockerRestartPolicy); err == nil {
s.mockOIDC.r = pmockoidc
} else {
return err
}
// headscale needs to set up the provider with a specific
// IP addr to ensure we get the correct config from the well-known
// endpoint.
network := s.Networks()[0]
ipAddr := s.mockOIDC.r.GetIPInNetwork(network)
log.Println("Waiting for headscale mock oidc to be ready for tests")
hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port))
if err := s.pool.Retry(func() error {
oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
httpClient := &http.Client{}
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("headscale mock OIDC tests is not ready: %s\n", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errStatusCodeNotOK
}
return nil
}); err != nil {
return err
}
s.mockOIDC.cfg = &types.OIDCConfig{
Issuer: fmt.Sprintf(
"http://%s/oidc",
hostEndpoint,
),
ClientID: "superclient",
ClientSecret: "supersecret",
OnlyStartIfOIDCIsAvailable: true,
}
log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
return nil
}
type extraServiceFunc func(*Scenario, string) (*dockertest.Resource, error)
func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
// port, err := dockertestutil.RandomFreeHostPort()
// if err != nil {
// log.Fatalf("could not find an open port: %s", err)
// }
// portNotation := fmt.Sprintf("%d/tcp", port)
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-webservice-%s", hash)
network, ok := s.networks[prefixedNetworkName(networkName)]
if !ok {
return nil, fmt.Errorf("network does not exist: %s", networkName)
}
webOpts := &dockertest.RunOptions{
Name: hostname,
Cmd: []string{"/bin/sh", "-c", "cd / ; python3 -m http.server --bind :: 80"},
// ExposedPorts: []string{portNotation},
// PortBindings: map[docker.Port][]docker.PortBinding{
// docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
// },
Networks: []*dockertest.Network{network},
Env: []string{},
}
webBOpts := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
}
web, err := s.pool.BuildAndRunWithBuildOptions(
webBOpts,
webOpts,
dockertestutil.DockerRestartPolicy)
if err != nil {
return nil, err
}
// headscale needs to set up the provider with a specific
// IP addr to ensure we get the correct config from the well-known
// endpoint.
// ipAddr := web.GetIPInNetwork(network)
// log.Println("Waiting for headscale mock oidc to be ready for tests")
// hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port))
// if err := s.pool.Retry(func() error {
// oidcConfigURL := fmt.Sprintf("http://%s/etc/hostname", hostEndpoint)
// httpClient := &http.Client{}
// ctx := context.Background()
// req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil)
// resp, err := httpClient.Do(req)
// if err != nil {
// log.Printf("headscale mock OIDC tests is not ready: %s\n", err)
// return err
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return errStatusCodeNotOK
// }
// return nil
// }); err != nil {
// return err
// }
return web, nil
}

View File

@ -4,6 +4,7 @@ import (
"testing"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic"
)
// This file is intended to "test the test framework", by proxy it will also test
@ -33,7 +34,7 @@ func TestHeadscale(t *testing.T) {
user := "test-space"
scenario, err := NewScenario(dockertestMaxWait())
scenario, err := NewScenario(ScenarioSpec{})
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -68,38 +69,6 @@ func TestHeadscale(t *testing.T) {
})
}
// If subtests are parallel, then they will start before setup is run.
// This might mean we approach setup slightly wrong, but for now, ignore
// the linter
// nolint:tparallel
func TestCreateTailscale(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "only-create-containers"
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
scenario.users[user] = &User{
Clients: make(map[string]TailscaleClient),
}
t.Run("create-tailscale", func(t *testing.T) {
err := scenario.CreateTailscaleNodesInUser(user, "all", 3)
if err != nil {
t.Fatalf("failed to add tailscale nodes: %s", err)
}
if clients := len(scenario.users[user].Clients); clients != 3 {
t.Fatalf("wrong number of tailscale clients: %d != %d", clients, 3)
}
// TODO(kradalby): Test "all" version logic
})
}
// If subtests are parallel, then they will start before setup is run.
// This might mean we approach setup slightly wrong, but for now, ignore
// the linter
@ -114,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
count := 1
scenario, err := NewScenario(dockertestMaxWait())
scenario, err := NewScenario(ScenarioSpec{})
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -142,7 +111,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
})
t.Run("create-tailscale", func(t *testing.T) {
err := scenario.CreateTailscaleNodesInUser(user, "unstable", count)
err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[TestDefaultNetwork]))
if err != nil {
t.Fatalf("failed to add tailscale nodes: %s", err)
}

View File

@ -50,15 +50,15 @@ var retry = func(times int, sleepInterval time.Duration,
func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario {
t.Helper()
scenario, err := NewScenario(dockertestMaxWait())
spec := ScenarioSpec{
NodesPerUser: clientsPerUser,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
spec := map[string]int{
"user1": clientsPerUser,
"user2": clientsPerUser,
}
err = scenario.CreateHeadscaleEnv(spec,
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
tsic.WithSSH(),

View File

@ -5,6 +5,7 @@ import (
"net/netip"
"net/url"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic"
"tailscale.com/ipn/ipnstate"
@ -27,6 +28,9 @@ type TailscaleClient interface {
Up() error
Down() error
IPs() ([]netip.Addr, error)
MustIPs() []netip.Addr
MustIPv4() netip.Addr
MustIPv6() netip.Addr
FQDN() (string, error)
Status(...bool) (*ipnstate.Status, error)
MustStatus() *ipnstate.Status
@ -38,6 +42,7 @@ type TailscaleClient interface {
WaitForPeers(expected int) error
Ping(hostnameOrIP string, opts ...tsic.PingOption) error
Curl(url string, opts ...tsic.CurlOption) (string, error)
Traceroute(netip.Addr) (util.Traceroute, error)
ID() string
ReadFile(path string) ([]byte, error)

View File

@ -13,6 +13,7 @@ import (
"net/url"
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"
@ -81,6 +82,7 @@ type TailscaleInContainer struct {
workdir string
netfilter string
extraLoginArgs []string
withAcceptRoutes bool
// build options, solely for HEAD
buildConfig TailscaleInContainerBuildConfig
@ -101,26 +103,10 @@ func WithCACert(cert []byte) Option {
}
}
// WithOrCreateNetwork sets the Docker container network to use with
// the Tailscale instance, if the parameter is nil, a new network,
// isolating the TailscaleClient, will be created. If a network is
// passed, the Tailscale instance will join the given network.
func WithOrCreateNetwork(network *dockertest.Network) Option {
// WithNetwork sets the Docker container network to use with
// the Tailscale instance.
func WithNetwork(network *dockertest.Network) Option {
return func(tsic *TailscaleInContainer) {
if network != nil {
tsic.network = network
return
}
network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool,
fmt.Sprintf("%s-network", tsic.hostname),
)
if err != nil {
log.Fatalf("failed to create network: %s", err)
}
tsic.network = network
}
}
@ -212,11 +198,17 @@ func WithExtraLoginArgs(args []string) Option {
}
}
// WithAcceptRoutes tells the node to accept incomming routes.
func WithAcceptRoutes() Option {
return func(tsic *TailscaleInContainer) {
tsic.withAcceptRoutes = true
}
}
// New returns a new TailscaleInContainer instance.
func New(
pool *dockertest.Pool,
version string,
network *dockertest.Network,
opts ...Option,
) (*TailscaleInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength)
@ -230,8 +222,7 @@ func New(
version: version,
hostname: hostname,
pool: pool,
network: network,
pool: pool,
withEntrypoint: []string{
"/bin/sh",
@ -244,6 +235,10 @@ func New(
opt(tsic)
}
if tsic.network == nil {
return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack()))
}
tailscaleOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{tsic.network},
@ -442,7 +437,7 @@ func (t *TailscaleInContainer) Login(
"--login-server=" + loginServer,
"--authkey=" + authKey,
"--hostname=" + t.hostname,
"--accept-routes=false",
fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes),
}
if t.extraLoginArgs != nil {
@ -597,6 +592,33 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
return ips, nil
}
func (t *TailscaleInContainer) MustIPs() []netip.Addr {
ips, err := t.IPs()
if err != nil {
panic(err)
}
return ips
}
func (t *TailscaleInContainer) MustIPv4() netip.Addr {
for _, ip := range t.MustIPs() {
if ip.Is4() {
return ip
}
}
panic("no ipv4 found")
}
func (t *TailscaleInContainer) MustIPv6() netip.Addr {
for _, ip := range t.MustIPs() {
if ip.Is6() {
return ip
}
}
panic("no ipv6 found")
}
// Status returns the ipnstate.Status of the Tailscale instance.
func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
command := []string{
@ -992,6 +1014,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err
),
)
if err != nil {
log.Printf("command: %v", command)
log.Printf(
"failed to run ping command from %s to %s, err: %s",
t.Hostname(),
@ -1108,6 +1131,26 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
return result, nil
}
func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) {
command := []string{
"traceroute",
ip.String(),
}
var result util.Traceroute
stdout, stderr, err := t.Execute(command)
if err != nil {
return result, err
}
result, err = util.ParseTraceroute(stdout + stderr)
if err != nil {
return result, err
}
return result, nil
}
// WriteFile save file inside the Tailscale container.
func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)