mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-07 21:02:51 -05:00
policy: fix autogroup:self propagation and optimize cache invalidation (#2807)
This commit is contained in:
@@ -3,12 +3,14 @@ package integration
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
@@ -319,12 +321,14 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
user := status.User[status.Self.UserID].LoginName
|
||||
user := status.User[status.Self.UserID].LoginName
|
||||
|
||||
assert.Len(t, status.Peer, (testCase.want[user]))
|
||||
assert.Len(c, status.Peer, (testCase.want[user]))
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer visibility")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -782,75 +786,87 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
|
||||
|
||||
// test1 can query test3
|
||||
result, err := test1.Curl(test3ip4URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip4URL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test3ip4URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip4URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv4")
|
||||
|
||||
result, err = test1.Curl(test3ip6URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip6URL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test3ip6URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip6URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv6")
|
||||
|
||||
result, err = test1.Curl(test3fqdnURL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3fqdnURL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test3fqdnURL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3fqdnURL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via FQDN")
|
||||
|
||||
// test2 can query test3
|
||||
result, err = test2.Curl(test3ip4URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip4URL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test2.Curl(test3ip4URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip4URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv4")
|
||||
|
||||
result, err = test2.Curl(test3ip6URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip6URL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test2.Curl(test3ip6URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3ip6URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv6")
|
||||
|
||||
result, err = test2.Curl(test3fqdnURL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3fqdnURL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test2.Curl(test3fqdnURL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test3fqdnURL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via FQDN")
|
||||
|
||||
// test3 cannot query test1
|
||||
result, err = test3.Curl(test1ip4URL)
|
||||
result, err := test3.Curl(test1ip4URL)
|
||||
assert.Empty(t, result)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -876,38 +892,44 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
|
||||
// test1 can query test2
|
||||
result, err = test1.Curl(test2ip4URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ip4URL,
|
||||
result,
|
||||
)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test2ip4URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ip4URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4")
|
||||
|
||||
require.NoError(t, err)
|
||||
result, err = test1.Curl(test2ip6URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ip6URL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test2ip6URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ip6URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6")
|
||||
|
||||
result, err = test1.Curl(test2fqdnURL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2fqdnURL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test2fqdnURL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2fqdnURL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN")
|
||||
|
||||
// test2 cannot query test1
|
||||
result, err = test2.Curl(test1ip4URL)
|
||||
@@ -1050,50 +1072,63 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
||||
|
||||
// test1 can query test2
|
||||
result, err := test1.Curl(test2ipURL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ipURL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test2ipURL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ipURL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4")
|
||||
|
||||
result, err = test1.Curl(test2ip6URL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ip6URL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test2ip6URL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2ip6URL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6")
|
||||
|
||||
result, err = test1.Curl(test2fqdnURL)
|
||||
assert.Lenf(
|
||||
t,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2fqdnURL,
|
||||
result,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test1.Curl(test2fqdnURL)
|
||||
assert.NoError(c, err)
|
||||
assert.Lenf(
|
||||
c,
|
||||
result,
|
||||
13,
|
||||
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
|
||||
test2fqdnURL,
|
||||
result,
|
||||
)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN")
|
||||
|
||||
result, err = test2.Curl(test1ipURL)
|
||||
assert.Empty(t, result)
|
||||
require.Error(t, err)
|
||||
// test2 cannot query test1 (negative test case)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test2.Curl(test1ipURL)
|
||||
assert.Error(c, err)
|
||||
assert.Empty(c, result)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv4")
|
||||
|
||||
result, err = test2.Curl(test1ip6URL)
|
||||
assert.Empty(t, result)
|
||||
require.Error(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test2.Curl(test1ip6URL)
|
||||
assert.Error(c, err)
|
||||
assert.Empty(c, result)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv6")
|
||||
|
||||
result, err = test2.Curl(test1fqdnURL)
|
||||
assert.Empty(t, result)
|
||||
require.Error(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := test2.Curl(test1fqdnURL)
|
||||
assert.Error(c, err)
|
||||
assert.Empty(c, result)
|
||||
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via FQDN")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1266,9 +1301,15 @@ func TestACLAutogroupMember(t *testing.T) {
|
||||
|
||||
// Test that untagged nodes can access each other
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
|
||||
var clientIsUntagged bool
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
clientIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0
|
||||
assert.True(c, clientIsUntagged, "Expected client %s to be untagged for autogroup:member test", client.Hostname())
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for client %s to be untagged", client.Hostname())
|
||||
|
||||
if !clientIsUntagged {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1277,9 +1318,15 @@ func TestACLAutogroupMember(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
status, err := peer.Status()
|
||||
require.NoError(t, err)
|
||||
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
|
||||
var peerIsUntagged bool
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := peer.Status()
|
||||
assert.NoError(c, err)
|
||||
peerIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0
|
||||
assert.True(c, peerIsUntagged, "Expected peer %s to be untagged for autogroup:member test", peer.Hostname())
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for peer %s to be untagged", peer.Hostname())
|
||||
|
||||
if !peerIsUntagged {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1468,21 +1515,23 @@ func TestACLAutogroupTagged(t *testing.T) {
|
||||
|
||||
// Explicitly verify tags on tagged nodes
|
||||
for _, client := range taggedClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
|
||||
require.Positive(t, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
|
||||
t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
assert.NotNil(c, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
|
||||
assert.Positive(c, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for tags to be applied to tagged nodes")
|
||||
}
|
||||
|
||||
// Verify untagged nodes have no tags
|
||||
for _, client := range untaggedClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
if status.Self.Tags != nil {
|
||||
require.Equal(t, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
|
||||
}
|
||||
t.Logf("Untagged node %s has no tags", client.Hostname())
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
if status.Self.Tags != nil {
|
||||
assert.Equal(c, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting to verify untagged nodes have no tags")
|
||||
}
|
||||
|
||||
// Test that tagged nodes can communicate with each other
|
||||
@@ -1603,9 +1652,11 @@ func TestACLAutogroupSelf(t *testing.T) {
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s (user1) to %s (user1)", client.Hostname(), fqdn)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Len(t, result, 13)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, result, 13)
|
||||
}, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1622,9 +1673,11 @@ func TestACLAutogroupSelf(t *testing.T) {
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s (user2) to %s (user2)", client.Hostname(), fqdn)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Len(t, result, 13)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, result, 13)
|
||||
}, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1657,3 +1710,388 @@ func TestACLAutogroupSelf(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLPolicyPropagationOverTime(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 2,
|
||||
Users: []string{"user1", "user2"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
[]tsic.Option{
|
||||
// Install iptables to enable packet filtering for ACL tests.
|
||||
// Packet filters are essential for testing autogroup:self and other ACL policies.
|
||||
tsic.WithDockerEntrypoint([]string{
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"/bin/sleep 3 ; apk add python3 curl iptables ip6tables ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
|
||||
}),
|
||||
tsic.WithDockerWorkdir("/"),
|
||||
},
|
||||
hsic.WithTestName("aclpropagation"),
|
||||
hsic.WithPolicyMode(types.PolicyModeDB),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
require.NoError(t, err)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
allClients := append(user1Clients, user2Clients...)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Define the four policies we'll cycle through
|
||||
allowAllPolicy := &policyv2.Policy{
|
||||
ACLs: []policyv2.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []policyv2.Alias{wildcard()},
|
||||
Destinations: []policyv2.AliasWithPorts{
|
||||
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
autogroupSelfPolicy := &policyv2.Policy{
|
||||
ACLs: []policyv2.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)},
|
||||
Destinations: []policyv2.AliasWithPorts{
|
||||
aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
user1ToUser2Policy := &policyv2.Policy{
|
||||
ACLs: []policyv2.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []policyv2.Alias{usernamep("user1@")},
|
||||
Destinations: []policyv2.AliasWithPorts{
|
||||
aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Run through the policy cycle 5 times
|
||||
for i := range 5 {
|
||||
iteration := i + 1 // range 5 gives 0-4, we want 1-5 for logging
|
||||
t.Logf("=== Iteration %d/5 ===", iteration)
|
||||
|
||||
// Phase 1: Allow all policy
|
||||
t.Logf("Iteration %d: Setting allow-all policy", iteration)
|
||||
err = headscale.SetPolicy(allowAllPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for peer lists to sync with allow-all policy
|
||||
t.Logf("Iteration %d: Phase 1 - Waiting for peer lists to sync with allow-all policy", iteration)
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
require.NoError(t, err, "iteration %d: Phase 1 - failed to sync after allow-all policy", iteration)
|
||||
|
||||
// Test all-to-all connectivity after state is settled
|
||||
t.Logf("Iteration %d: Phase 1 - Testing all-to-all connectivity", iteration)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, client := range allClients {
|
||||
for _, peer := range allClients {
|
||||
if client.ContainerID() == peer.ContainerID() {
|
||||
continue
|
||||
}
|
||||
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(ct, err, "iteration %d: %s should reach %s with allow-all policy", iteration, client.Hostname(), fqdn)
|
||||
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn)
|
||||
}
|
||||
}
|
||||
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 1 - all connectivity tests with allow-all policy", iteration)
|
||||
|
||||
// Phase 2: Autogroup:self policy (only same user can access)
|
||||
t.Logf("Iteration %d: Phase 2 - Setting autogroup:self policy", iteration)
|
||||
err = headscale.SetPolicy(autogroupSelfPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for peer lists to sync with autogroup:self - ensures cross-user peers are removed
|
||||
t.Logf("Iteration %d: Phase 2 - Waiting for peer lists to sync with autogroup:self", iteration)
|
||||
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
|
||||
require.NoError(t, err, "iteration %d: Phase 2 - failed to sync after autogroup:self policy", iteration)
|
||||
|
||||
// Test ALL connectivity (positive and negative) in one block after state is settled
|
||||
t.Logf("Iteration %d: Phase 2 - Testing all connectivity with autogroup:self", iteration)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Positive: user1 can access user1's nodes
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user1Clients {
|
||||
if client.ContainerID() == peer.ContainerID() {
|
||||
continue
|
||||
}
|
||||
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
|
||||
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// Positive: user2 can access user2's nodes
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user2Clients {
|
||||
if client.ContainerID() == peer.ContainerID() {
|
||||
continue
|
||||
}
|
||||
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(ct, err, "iteration %d: user2 %s should reach user2's node %s", iteration, client.Hostname(), fqdn)
|
||||
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn)
|
||||
}
|
||||
}
|
||||
|
||||
// Negative: user1 cannot access user2's nodes
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.Error(ct, err, "iteration %d: user1 %s should NOT reach user2's node %s with autogroup:self", iteration, client.Hostname(), fqdn)
|
||||
assert.Empty(ct, result, "iteration %d: user1 %s->user2 %s should fail", iteration, client.Hostname(), fqdn)
|
||||
}
|
||||
}
|
||||
|
||||
// Negative: user2 cannot access user1's nodes
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
|
||||
assert.Empty(ct, result, "iteration %d: user2->user1 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
|
||||
}
|
||||
}
|
||||
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2 - all connectivity tests with autogroup:self", iteration)
|
||||
|
||||
// Phase 2b: Add a new node to user1 and validate policy propagation
|
||||
t.Logf("Iteration %d: Phase 2b - Adding new node to user1 during autogroup:self policy", iteration)
|
||||
|
||||
// Add a new node with the same options as the initial setup
|
||||
// Get the network to use (scenario uses first network in list)
|
||||
networks := scenario.Networks()
|
||||
require.NotEmpty(t, networks, "scenario should have at least one network")
|
||||
|
||||
newClient := scenario.MustAddAndLoginClient(t, "user1", "all", headscale,
|
||||
tsic.WithNetfilter("off"),
|
||||
tsic.WithDockerEntrypoint([]string{
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
|
||||
}),
|
||||
tsic.WithDockerWorkdir("/"),
|
||||
tsic.WithNetwork(networks[0]),
|
||||
)
|
||||
t.Logf("Iteration %d: Phase 2b - Added and logged in new node %s", iteration, newClient.Hostname())
|
||||
|
||||
// Wait for peer lists to sync after new node addition (now 3 user1 nodes, still autogroup:self)
|
||||
t.Logf("Iteration %d: Phase 2b - Waiting for peer lists to sync after new node addition", iteration)
|
||||
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
|
||||
require.NoError(t, err, "iteration %d: Phase 2b - failed to sync after new node addition", iteration)
|
||||
|
||||
// Test ALL connectivity (positive and negative) in one block after state is settled
|
||||
t.Logf("Iteration %d: Phase 2b - Testing all connectivity after new node addition", iteration)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Re-fetch client list to ensure latest state
|
||||
user1ClientsWithNew, err := scenario.ListTailscaleClients("user1")
|
||||
assert.NoError(ct, err, "iteration %d: failed to list user1 clients", iteration)
|
||||
assert.Len(ct, user1ClientsWithNew, 3, "iteration %d: user1 should have 3 nodes", iteration)
|
||||
|
||||
// Positive: all user1 nodes can access each other
|
||||
for _, client := range user1ClientsWithNew {
|
||||
for _, peer := range user1ClientsWithNew {
|
||||
if client.ContainerID() == peer.ContainerID() {
|
||||
continue
|
||||
}
|
||||
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
|
||||
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// Negative: user1 nodes cannot access user2's nodes
|
||||
for _, client := range user1ClientsWithNew {
|
||||
for _, peer := range user2Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.Error(ct, err, "iteration %d: user1 node %s should NOT reach user2 node %s", iteration, client.Hostname(), peer.Hostname())
|
||||
assert.Empty(ct, result, "iteration %d: user1->user2 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
|
||||
}
|
||||
}
|
||||
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - all connectivity tests after new node addition", iteration)
|
||||
|
||||
// Delete the newly added node before Phase 3
|
||||
t.Logf("Iteration %d: Phase 2b - Deleting the newly added node from user1", iteration)
|
||||
|
||||
// Get the node list and find the newest node (highest ID)
|
||||
var nodeList []*v1.Node
|
||||
var nodeToDeleteID uint64
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
nodeList, err = headscale.ListNodes("user1")
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, nodeList, 3, "should have 3 user1 nodes before deletion")
|
||||
|
||||
// Find the node with the highest ID (the newest one)
|
||||
for _, node := range nodeList {
|
||||
if node.GetId() > nodeToDeleteID {
|
||||
nodeToDeleteID = node.GetId()
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - listing nodes before deletion", iteration)
|
||||
|
||||
// Delete the node via headscale helper
|
||||
t.Logf("Iteration %d: Phase 2b - Deleting node ID %d from headscale", iteration, nodeToDeleteID)
|
||||
err = headscale.DeleteNode(nodeToDeleteID)
|
||||
require.NoError(t, err, "iteration %d: failed to delete node %d", iteration, nodeToDeleteID)
|
||||
|
||||
// Remove the deleted client from the scenario's user.Clients map
|
||||
// This is necessary for WaitForTailscaleSyncPerUser to calculate correct peer counts
|
||||
t.Logf("Iteration %d: Phase 2b - Removing deleted client from scenario", iteration)
|
||||
for clientName, client := range scenario.users["user1"].Clients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if nodeID == nodeToDeleteID {
|
||||
delete(scenario.users["user1"].Clients, clientName)
|
||||
t.Logf("Iteration %d: Phase 2b - Removed client %s (node ID %d) from scenario", iteration, clientName, nodeToDeleteID)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the node has been deleted
|
||||
t.Logf("Iteration %d: Phase 2b - Verifying node deletion (expecting 2 user1 nodes)", iteration)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
nodeListAfter, err := headscale.ListNodes("user1")
|
||||
assert.NoError(ct, err, "failed to list nodes after deletion")
|
||||
assert.Len(ct, nodeListAfter, 2, "iteration %d: should have 2 user1 nodes after deletion, got %d", iteration, len(nodeListAfter))
|
||||
}, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - node should be deleted", iteration)
|
||||
|
||||
// Wait for sync after deletion to ensure peer counts are correct
|
||||
// Use WaitForTailscaleSyncPerUser because autogroup:self is still active,
|
||||
// so nodes only see same-user peers, not all nodes
|
||||
t.Logf("Iteration %d: Phase 2b - Waiting for sync after node deletion (with autogroup:self)", iteration)
|
||||
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
|
||||
require.NoError(t, err, "iteration %d: failed to sync after node deletion", iteration)
|
||||
|
||||
// Refresh client lists after deletion to ensure we don't reference the deleted node
|
||||
user1Clients, err = scenario.ListTailscaleClients("user1")
|
||||
require.NoError(t, err, "iteration %d: failed to refresh user1 client list after deletion", iteration)
|
||||
user2Clients, err = scenario.ListTailscaleClients("user2")
|
||||
require.NoError(t, err, "iteration %d: failed to refresh user2 client list after deletion", iteration)
|
||||
// Create NEW slice instead of appending to old allClients which still has deleted client
|
||||
allClients = make([]TailscaleClient, 0, len(user1Clients)+len(user2Clients))
|
||||
allClients = append(allClients, user1Clients...)
|
||||
allClients = append(allClients, user2Clients...)
|
||||
|
||||
t.Logf("Iteration %d: Phase 2b completed - New node added, validated, and removed successfully", iteration)
|
||||
|
||||
// Phase 3: User1 can access user2 but not reverse
|
||||
t.Logf("Iteration %d: Phase 3 - Setting user1->user2 directional policy", iteration)
|
||||
err = headscale.SetPolicy(user1ToUser2Policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: Cannot use WaitForTailscaleSync() here because directional policy means
|
||||
// user2 nodes don't see user1 nodes in their peer list (asymmetric visibility).
|
||||
// The EventuallyWithT block below will handle waiting for policy propagation.
|
||||
|
||||
// Test ALL connectivity (positive and negative) in one block after policy settles
|
||||
t.Logf("Iteration %d: Phase 3 - Testing all connectivity with directional policy", iteration)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Positive: user1 can access user2's nodes
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user2 node %s", iteration, client.Hostname(), peer.Hostname())
|
||||
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// Negative: user2 cannot access user1's nodes
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
|
||||
continue
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
result, err := client.Curl(url)
|
||||
assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
|
||||
assert.Empty(ct, result, "iteration %d: user2->user1 from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
|
||||
}
|
||||
}
|
||||
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 3 - all connectivity tests with directional policy", iteration)
|
||||
|
||||
t.Logf("=== Iteration %d/5 completed successfully - All 3 phases passed ===", iteration)
|
||||
}
|
||||
|
||||
t.Log("All 5 iterations completed successfully - ACL propagation is working correctly")
|
||||
}
|
||||
|
||||
@@ -74,14 +74,21 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
clientIPs[client] = ips
|
||||
}
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
var listNodes []*v1.Node
|
||||
var nodeCountBeforeLogout int
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, len(allClients))
|
||||
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
}
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSetWithCollect(c, node)
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
|
||||
|
||||
nodeCountBeforeLogout = len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
for _, client := range allClients {
|
||||
err := client.Logout()
|
||||
@@ -188,11 +195,16 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
}
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, nodeCountBeforeLogout)
|
||||
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSetWithCollect(c, node)
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for node list after relogin")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -238,9 +250,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
var listNodes []*v1.Node
|
||||
var nodeCountBeforeLogout int
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, len(allClients))
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
|
||||
|
||||
nodeCountBeforeLogout = len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
for _, client := range allClients {
|
||||
@@ -371,9 +390,16 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
var listNodes []*v1.Node
|
||||
var nodeCountBeforeLogout int
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, len(allClients))
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
|
||||
|
||||
nodeCountBeforeLogout = len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
for _, client := range allClients {
|
||||
|
||||
@@ -901,15 +901,18 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
||||
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
|
||||
time.Sleep(2 * time.Minute)
|
||||
|
||||
st, err := ts.Status()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "NeedsLogin", st.BackendState)
|
||||
var newUrl *url.URL
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
st, err := ts.Status()
|
||||
assert.NoError(c, err)
|
||||
assert.Equal(c, "NeedsLogin", st.BackendState)
|
||||
|
||||
// get new AuthURL from daemon
|
||||
newUrl, err := url.Parse(st.AuthURL)
|
||||
require.NoError(t, err)
|
||||
// get new AuthURL from daemon
|
||||
newUrl, err = url.Parse(st.AuthURL)
|
||||
assert.NoError(c, err)
|
||||
|
||||
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
|
||||
assert.NotEqual(c, u.String(), st.AuthURL, "AuthURL should change")
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for registration cache to expire and status to reflect NeedsLogin")
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), newUrl)
|
||||
require.NoError(t, err)
|
||||
@@ -943,9 +946,11 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
}
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listNodes, 1)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, 1)
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
|
||||
}
|
||||
|
||||
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ type ControlServer interface {
|
||||
CreateUser(user string) (*v1.User, error)
|
||||
CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
|
||||
ListNodes(users ...string) ([]*v1.Node, error)
|
||||
DeleteNode(nodeID uint64) error
|
||||
NodesByUser() (map[string][]*v1.Node, error)
|
||||
NodesByName() (map[string]*v1.Node, error)
|
||||
ListUsers() ([]*v1.User, error)
|
||||
@@ -38,4 +39,5 @@ type ControlServer interface {
|
||||
PrimaryRoutes() (*routes.DebugRoutes, error)
|
||||
DebugBatcher() (*hscontrol.DebugBatcherInfo, error)
|
||||
DebugNodeStore() (map[types.NodeID]types.Node, error)
|
||||
DebugFilter() ([]tailcfg.FilterRule, error)
|
||||
}
|
||||
|
||||
@@ -541,8 +541,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
|
||||
// update hostnames using the up command
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
status := client.MustStatus()
|
||||
|
||||
command := []string{
|
||||
"tailscale",
|
||||
@@ -642,8 +641,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
}, 60*time.Second, 2*time.Second)
|
||||
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
status := client.MustStatus()
|
||||
|
||||
command := []string{
|
||||
"tailscale",
|
||||
@@ -773,26 +771,25 @@ func TestExpireNode(t *testing.T) {
|
||||
|
||||
// Verify that the expired node has been marked in all peers list.
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
if client.Hostname() == node.GetName() {
|
||||
continue
|
||||
}
|
||||
|
||||
if client.Hostname() != node.GetName() {
|
||||
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Ensures that the node is present, and that it is expired.
|
||||
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
|
||||
requireNotNil(t, peerStatus.Expired)
|
||||
assert.NotNil(t, peerStatus.KeyExpiry)
|
||||
peerStatus, ok := status.Peer[expiredNodeKey]
|
||||
assert.True(c, ok, "expired node key should be present in peer list")
|
||||
|
||||
if ok {
|
||||
assert.NotNil(c, peerStatus.Expired)
|
||||
assert.NotNil(c, peerStatus.KeyExpiry)
|
||||
|
||||
t.Logf(
|
||||
"node %q should have a key expire before %s, was %s",
|
||||
peerStatus.HostName,
|
||||
now.String(),
|
||||
peerStatus.KeyExpiry,
|
||||
)
|
||||
if peerStatus.KeyExpiry != nil {
|
||||
assert.Truef(
|
||||
t,
|
||||
c,
|
||||
peerStatus.KeyExpiry.Before(now),
|
||||
"node %q should have a key expire before %s, was %s",
|
||||
peerStatus.HostName,
|
||||
@@ -802,7 +799,7 @@ func TestExpireNode(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Truef(
|
||||
t,
|
||||
c,
|
||||
peerStatus.Expired,
|
||||
"node %q should be expired, expired is %v",
|
||||
peerStatus.HostName,
|
||||
@@ -811,24 +808,14 @@ func TestExpireNode(t *testing.T) {
|
||||
|
||||
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
|
||||
if !strings.Contains(stderr, "node key has expired") {
|
||||
t.Errorf(
|
||||
c.Errorf(
|
||||
"expected to be unable to ping expired host %q from %q",
|
||||
node.GetName(),
|
||||
client.Hostname(),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
|
||||
}
|
||||
} else {
|
||||
if status.Self.KeyExpiry != nil {
|
||||
assert.Truef(t, status.Self.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", status.Self.HostName, now.String(), status.Self.KeyExpiry)
|
||||
}
|
||||
|
||||
// NeedsLogin means that the node has understood that it is no longer
|
||||
// valid.
|
||||
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expired node status to propagate")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -866,11 +853,13 @@ func TestNodeOnlineStatus(t *testing.T) {
|
||||
t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Assert that we have the original count - self
|
||||
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
|
||||
// Assert that we have the original count - self
|
||||
assert.Len(c, status.Peers(), len(MustTestVersions)-1)
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer count")
|
||||
}
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
|
||||
@@ -507,6 +507,11 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
|
||||
assert.NotNil(t, node.GetLastSeen())
|
||||
}
|
||||
|
||||
func assertLastSeenSetWithCollect(c *assert.CollectT, node *v1.Node) {
|
||||
assert.NotNil(c, node)
|
||||
assert.NotNil(c, node.GetLastSeen())
|
||||
}
|
||||
|
||||
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
||||
// are in the logged-out state (NeedsLogin).
|
||||
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
||||
@@ -633,50 +638,50 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||
|
||||
t.Logf("Checking netmap of %q", client.Hostname())
|
||||
|
||||
netmap, err := client.Netmap()
|
||||
if err != nil {
|
||||
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
netmap, err := client.Netmap()
|
||||
assert.NoError(c, err, "getting netmap for %q", client.Hostname())
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
|
||||
|
||||
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
|
||||
|
||||
for _, peer := range netmap.Peers {
|
||||
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
|
||||
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
|
||||
|
||||
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||
if hi := peer.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||
|
||||
// Netinfo is not always set
|
||||
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
|
||||
if ni := hi.NetInfo(); ni.Valid() {
|
||||
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
|
||||
}
|
||||
assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||
assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
|
||||
assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
|
||||
|
||||
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
|
||||
}
|
||||
assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||
assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||
assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
|
||||
|
||||
for _, peer := range netmap.Peers {
|
||||
assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
|
||||
assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
|
||||
|
||||
assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||
if hi := peer.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||
|
||||
// Netinfo is not always set
|
||||
// assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
|
||||
if ni := hi.NetInfo(); ni.Valid() {
|
||||
assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname())
|
||||
}
|
||||
|
||||
// assertValidStatus validates that a client's status has all required fields for proper operation.
|
||||
@@ -920,3 +925,125 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
|
||||
EmailVerified: emailVerified,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserByName retrieves a user by name from the headscale server.
|
||||
// This is a common pattern used when creating preauth keys or managing users.
|
||||
func GetUserByName(headscale ControlServer, username string) (*v1.User, error) {
|
||||
users, err := headscale.ListUsers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if u.GetName() == username {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("user %s not found", username)
|
||||
}
|
||||
|
||||
// FindNewClient finds a client that is in the new list but not in the original list.
|
||||
// This is useful when dynamically adding nodes during tests and needing to identify
|
||||
// which client was just added.
|
||||
func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) {
|
||||
for _, client := range updated {
|
||||
isOriginal := false
|
||||
for _, origClient := range original {
|
||||
if client.Hostname() == origClient.Hostname() {
|
||||
isOriginal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isOriginal {
|
||||
return client, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no new client found")
|
||||
}
|
||||
|
||||
// AddAndLoginClient adds a new tailscale client to a user and logs it in.
|
||||
// This combines the common pattern of:
|
||||
// 1. Creating a new node
|
||||
// 2. Finding the new node in the client list
|
||||
// 3. Getting the user to create a preauth key
|
||||
// 4. Logging in the new node
|
||||
func (s *Scenario) AddAndLoginClient(
|
||||
t *testing.T,
|
||||
username string,
|
||||
version string,
|
||||
headscale ControlServer,
|
||||
tsOpts ...tsic.Option,
|
||||
) (TailscaleClient, error) {
|
||||
t.Helper()
|
||||
|
||||
// Get the original client list
|
||||
originalClients, err := s.ListTailscaleClients(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list original clients: %w", err)
|
||||
}
|
||||
|
||||
// Create the new node
|
||||
err = s.CreateTailscaleNodesInUser(username, version, 1, tsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tailscale node: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the new node to appear in the client list
|
||||
var newClient TailscaleClient
|
||||
|
||||
_, err = backoff.Retry(t.Context(), func() (struct{}, error) {
|
||||
updatedClients, err := s.ListTailscaleClients(username)
|
||||
if err != nil {
|
||||
return struct{}{}, fmt.Errorf("failed to list updated clients: %w", err)
|
||||
}
|
||||
|
||||
if len(updatedClients) != len(originalClients)+1 {
|
||||
return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients))
|
||||
}
|
||||
|
||||
newClient, err = FindNewClient(originalClients, updatedClients)
|
||||
if err != nil {
|
||||
return struct{}{}, fmt.Errorf("failed to find new client: %w", err)
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
}, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("timeout waiting for new client: %w", err)
|
||||
}
|
||||
|
||||
// Get the user and create preauth key
|
||||
user, err := GetUserByName(headscale, username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
authKey, err := s.CreatePreAuthKey(user.GetId(), true, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create preauth key: %w", err)
|
||||
}
|
||||
|
||||
// Login the new client
|
||||
err = newClient.Login(headscale.GetEndpoint(), authKey.GetKey())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to login new client: %w", err)
|
||||
}
|
||||
|
||||
return newClient, nil
|
||||
}
|
||||
|
||||
// MustAddAndLoginClient is like AddAndLoginClient but fails the test on error.
|
||||
func (s *Scenario) MustAddAndLoginClient(
|
||||
t *testing.T,
|
||||
username string,
|
||||
version string,
|
||||
headscale ControlServer,
|
||||
tsOpts ...tsic.Option,
|
||||
) TailscaleClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...)
|
||||
require.NoError(t, err)
|
||||
return client
|
||||
}
|
||||
|
||||
@@ -1082,6 +1082,30 @@ func (t *HeadscaleInContainer) ListNodes(
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error {
|
||||
command := []string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"delete",
|
||||
"--identifier",
|
||||
fmt.Sprintf("%d", nodeID),
|
||||
"--output",
|
||||
"json",
|
||||
"--force",
|
||||
}
|
||||
|
||||
_, _, err := dockertestutil.ExecuteCommand(
|
||||
t.container,
|
||||
command,
|
||||
[]string{},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute delete node command: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
|
||||
nodes, err := t.ListNodes()
|
||||
if err != nil {
|
||||
@@ -1397,3 +1421,38 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er
|
||||
|
||||
return nodeStore, nil
|
||||
}
|
||||
|
||||
// DebugFilter fetches the current filter rules from the debug endpoint.
|
||||
func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) {
|
||||
// Execute curl inside the container to access the debug endpoint locally
|
||||
command := []string{
|
||||
"curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/filter",
|
||||
}
|
||||
|
||||
result, err := t.Execute(command)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching filter from debug endpoint: %w", err)
|
||||
}
|
||||
|
||||
var filterRules []tailcfg.FilterRule
|
||||
if err := json.Unmarshal([]byte(result), &filterRules); err != nil {
|
||||
return nil, fmt.Errorf("decoding filter response: %w", err)
|
||||
}
|
||||
|
||||
return filterRules, nil
|
||||
}
|
||||
|
||||
// DebugPolicy fetches the current policy from the debug endpoint.
|
||||
func (t *HeadscaleInContainer) DebugPolicy() (string, error) {
|
||||
// Execute curl inside the container to access the debug endpoint locally
|
||||
command := []string{
|
||||
"curl", "-s", "http://localhost:9090/debug/policy",
|
||||
}
|
||||
|
||||
result, err := t.Execute(command)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetching policy from debug endpoint: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -1358,16 +1358,8 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
|
||||
// Sort nodes by ID
|
||||
sort.SliceStable(allClients, func(i, j int) bool {
|
||||
statusI, err := allClients[i].Status()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusJ, err := allClients[j].Status()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusI := allClients[i].MustStatus()
|
||||
statusJ := allClients[j].MustStatus()
|
||||
return statusI.Self.ID < statusJ.Self.ID
|
||||
})
|
||||
|
||||
@@ -1475,9 +1467,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])})
|
||||
}, 5*time.Second, 200*time.Millisecond, "Verifying client can see subnet routes from router")
|
||||
|
||||
clientNm, err := client.Netmap()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for packet filter updates to propagate to client netmap
|
||||
wantClientFilter := []filter.Match{
|
||||
{
|
||||
IPProto: views.SliceOf([]ipproto.Proto{
|
||||
@@ -1503,13 +1493,16 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
|
||||
}
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientNm, err := client.Netmap()
|
||||
assert.NoError(c, err)
|
||||
|
||||
subnetNm, err := subRouter1.Netmap()
|
||||
require.NoError(t, err)
|
||||
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
assert.Fail(c, fmt.Sprintf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff))
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for client packet filter to update")
|
||||
|
||||
// Wait for packet filter updates to propagate to subnet router netmap
|
||||
wantSubnetFilter := []filter.Match{
|
||||
{
|
||||
IPProto: views.SliceOf([]ipproto.Proto{
|
||||
@@ -1553,9 +1546,14 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
|
||||
}
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
subnetNm, err := subRouter1.Netmap()
|
||||
assert.NoError(c, err)
|
||||
|
||||
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
assert.Fail(c, fmt.Sprintf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff))
|
||||
}
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for subnet router packet filter to update")
|
||||
}
|
||||
|
||||
// TestEnablingExitRoutes tests enabling exit routes for clients.
|
||||
@@ -1592,12 +1590,16 @@ func TestEnablingExitRoutes(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
nodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
var nodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 2, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[1], 2, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 2, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 2, 0, 0)
|
||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for route advertisements to propagate")
|
||||
|
||||
// Verify that no routes has been sent to the client,
|
||||
// they are not yet enabled.
|
||||
|
||||
@@ -693,6 +693,35 @@ func (s *Scenario) WaitForTailscaleSync() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// WaitForTailscaleSyncPerUser blocks execution until each TailscaleClient has the expected
|
||||
// number of peers for its user. This is useful for policies like autogroup:self where nodes
|
||||
// only see same-user peers, not all nodes in the network.
|
||||
func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Duration) error {
|
||||
var allErrors []error
|
||||
|
||||
for _, user := range s.users {
|
||||
// Calculate expected peer count: number of nodes in this user minus 1 (self)
|
||||
expectedPeers := len(user.Clients) - 1
|
||||
|
||||
for _, client := range user.Clients {
|
||||
c := client
|
||||
expectedCount := expectedPeers
|
||||
user.syncWaitGroup.Go(func() error {
|
||||
return c.WaitForPeers(expectedCount, timeout, retryInterval)
|
||||
})
|
||||
}
|
||||
if err := user.syncWaitGroup.Wait(); err != nil {
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
return multierr.New(allErrors...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports
|
||||
// to have all other TailscaleClients present in their netmap.NetworkMap.
|
||||
func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, retryInterval time.Duration) error {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"tailscale.com/net/netcheck"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
// nolint
|
||||
@@ -36,6 +37,7 @@ type TailscaleClient interface {
|
||||
MustIPv4() netip.Addr
|
||||
MustIPv6() netip.Addr
|
||||
FQDN() (string, error)
|
||||
MustFQDN() string
|
||||
Status(...bool) (*ipnstate.Status, error)
|
||||
MustStatus() *ipnstate.Status
|
||||
Netmap() (*netmap.NetworkMap, error)
|
||||
@@ -52,6 +54,7 @@ type TailscaleClient interface {
|
||||
ContainerID() string
|
||||
MustID() types.NodeID
|
||||
ReadFile(path string) ([]byte, error)
|
||||
PacketFilter() ([]filter.Match, error)
|
||||
|
||||
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client
|
||||
// and a bool indicating if the clients online count and peer count is equal.
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
@@ -32,6 +33,7 @@ import (
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/multierr"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -597,28 +599,39 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
|
||||
return t.ips, nil
|
||||
}
|
||||
|
||||
ips := make([]netip.Addr, 0)
|
||||
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"ip",
|
||||
}
|
||||
|
||||
result, _, err := t.Execute(command)
|
||||
if err != nil {
|
||||
return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err)
|
||||
}
|
||||
|
||||
for address := range strings.SplitSeq(result, "\n") {
|
||||
address = strings.TrimSuffix(address, "\n")
|
||||
if len(address) < 1 {
|
||||
continue
|
||||
// Retry with exponential backoff to handle eventual consistency
|
||||
ips, err := backoff.Retry(context.Background(), func() ([]netip.Addr, error) {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"ip",
|
||||
}
|
||||
ip, err := netip.ParseAddr(address)
|
||||
|
||||
result, _, err := t.Execute(command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("%s failed to get IPs: %w", t.hostname, err)
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
|
||||
ips := make([]netip.Addr, 0)
|
||||
for address := range strings.SplitSeq(result, "\n") {
|
||||
address = strings.TrimSuffix(address, "\n")
|
||||
if len(address) < 1 {
|
||||
continue
|
||||
}
|
||||
ip, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse IP %s: %w", address, err)
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
@@ -629,7 +642,6 @@ func (t *TailscaleInContainer) MustIPs() []netip.Addr {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
||||
@@ -646,16 +658,15 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Addr{}, errors.New("no IPv4 address found")
|
||||
return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname)
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) MustIPv4() netip.Addr {
|
||||
for _, ip := range t.MustIPs() {
|
||||
if ip.Is4() {
|
||||
return ip
|
||||
}
|
||||
ip, err := t.IPv4()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
panic("no ipv4 found")
|
||||
return ip
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) MustIPv6() netip.Addr {
|
||||
@@ -900,12 +911,33 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
|
||||
return t.fqdn, nil
|
||||
}
|
||||
|
||||
status, err := t.Status()
|
||||
// Retry with exponential backoff to handle eventual consistency
|
||||
fqdn, err := backoff.Retry(context.Background(), func() (string, error) {
|
||||
status, err := t.Status()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Self.DNSName == "" {
|
||||
return "", fmt.Errorf("FQDN not yet available")
|
||||
}
|
||||
|
||||
return status.Self.DNSName, nil
|
||||
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get FQDN: %w", err)
|
||||
return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err)
|
||||
}
|
||||
|
||||
return status.Self.DNSName, nil
|
||||
return fqdn, nil
|
||||
}
|
||||
|
||||
// MustFQDN returns the FQDN as a string of the Tailscale instance, panicking on error.
|
||||
func (t *TailscaleInContainer) MustFQDN() string {
|
||||
fqdn, err := t.FQDN()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fqdn
|
||||
}
|
||||
|
||||
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client
|
||||
@@ -1353,3 +1385,18 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
|
||||
|
||||
return &p.Persist.PrivateNodeKey, nil
|
||||
}
|
||||
|
||||
// PacketFilter returns the current packet filter rules from the client's network map.
|
||||
// This is useful for verifying that policy changes have propagated to the client.
|
||||
func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) {
|
||||
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
|
||||
return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version)
|
||||
}
|
||||
|
||||
nm, err := t.Netmap()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get netmap: %w", err)
|
||||
}
|
||||
|
||||
return nm.PacketFilter, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user