stability and race conditions in auth and node store (#2781)

This PR addresses some consistency issues that was introduced or discovered with the nodestore.

nodestore:
Now returns the node that is being put or updated when it is finished. This closes a race condition where when we read it back, we do not necessarily get the node with the given change and it ensures we get all the other updates from that batch write.

auth:
Authentication paths have been unified and simplified. It removes a lot of bad branches and ensures we only do the minimal work.
A comprehensive auth test set has been created so we do not have to run integration tests to validate auth and it has allowed us to generate test cases for all the branches we currently know of.

integration:
added a lot more tooling and checks to validate that nodes reach the expected state when they come up and down. Standardised between the different auth models. A lot of this is to support or detect issues in the changes to nodestore (races) and auth (inconsistencies after login and reaching correct state)

This PR was assisted, particularly tests, by claude code.
This commit is contained in:
Kristoffer Dalby
2025-10-16 12:17:43 +02:00
committed by GitHub
parent 881a6b9227
commit fddc7117e4
34 changed files with 7408 additions and 1876 deletions

View File

@@ -10,19 +10,15 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/types/key"
@@ -38,7 +34,7 @@ func TestPingAllByIP(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -48,16 +44,16 @@ func TestPingAllByIP(t *testing.T) {
hsic.WithTLS(),
hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
@@ -80,7 +76,7 @@ func TestPingAllByIP(t *testing.T) {
// Get headscale instance for batcher debug check
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Test our DebugBatcher functionality
t.Logf("Testing DebugBatcher functionality...")
@@ -99,23 +95,23 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("pingallbyippubderp"),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -148,11 +144,11 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
headscale, err := scenario.Headscale(opts...)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
for _, userName := range spec.Users {
user, err := scenario.CreateUser(userName)
@@ -177,13 +173,13 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -200,7 +196,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
}
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
requireNoErrLogout(t, err)
t.Logf("all clients logged out")
@@ -222,7 +218,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
headscale, err := scenario.Headscale(
@@ -231,7 +227,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s",
}),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
for _, userName := range spec.Users {
user, err := scenario.CreateUser(userName)
@@ -256,13 +252,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -344,22 +340,22 @@ func TestPingAllByHostname(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
success := pingAllHelper(t, allClients, allHostnames)
@@ -379,7 +375,7 @@ func TestTaildrop(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
@@ -387,17 +383,17 @@ func TestTaildrop(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// This will essentially fetch and cache all the FQDNs
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range allClients {
if !strings.Contains(client.Hostname(), "head") {
@@ -498,7 +494,7 @@ func TestTaildrop(t *testing.T) {
)
result, _, err := client.Execute(command)
assertNoErrf(t, "failed to execute command to ls taildrop: %s", err)
require.NoErrorf(t, err, "failed to execute command to ls taildrop")
log.Printf("Result for %s: %s\n", peer.Hostname(), result)
if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result {
@@ -528,25 +524,25 @@ func TestUpdateHostnameFromClient(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErrf(t, "failed to create scenario: %s", err)
require.NoErrorf(t, err, "failed to create scenario")
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
// update hostnames using the up command
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
command := []string{
"tailscale",
@@ -554,11 +550,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
"--hostname=" + hostnames[string(status.Self.ID)],
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to set hostname: %s", err)
require.NoErrorf(t, err, "failed to set hostname")
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
@@ -597,7 +593,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
"--identifier",
strconv.FormatUint(node.GetId(), 10),
})
assertNoErr(t, err)
require.NoError(t, err)
}
// Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
@@ -643,7 +639,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
command := []string{
"tailscale",
@@ -651,11 +647,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
"--hostname=" + hostnames[string(status.Self.ID)] + "NEW",
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to set hostname: %s", err)
require.NoErrorf(t, err, "failed to set hostname")
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
@@ -696,20 +692,20 @@ func TestExpireNode(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -731,22 +727,22 @@ func TestExpireNode(t *testing.T) {
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// TODO(kradalby): This is Headscale specific and would not play nicely
// with other implementations of the ControlServer interface
result, err := headscale.Execute([]string{
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
})
assertNoErr(t, err)
require.NoError(t, err)
var node v1.Node
err = json.Unmarshal([]byte(result), &node)
assertNoErr(t, err)
require.NoError(t, err)
var expiredNodeKey key.NodePublic
err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey()))
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
@@ -773,14 +769,14 @@ func TestExpireNode(t *testing.T) {
// Verify that the expired node has been marked in all peers list.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
if client.Hostname() != node.GetName() {
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
// Ensures that the node is present, and that it is expired.
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
assertNotNil(t, peerStatus.Expired)
requireNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry)
t.Logf(
@@ -840,20 +836,20 @@ func TestNodeOnlineStatus(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -866,14 +862,14 @@ func TestNodeOnlineStatus(t *testing.T) {
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Duration is chosen arbitrarily, 10m is reported in #1561
testDuration := 12 * time.Minute
@@ -963,7 +959,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -973,16 +969,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
hsic.WithDERPAsIP(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -992,7 +988,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
// Get headscale instance for batcher debug checks
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Initial check: all nodes should be connected to batcher
// Extract node IDs for validation
@@ -1000,7 +996,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
require.NoError(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second)
@@ -1072,7 +1068,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -1081,16 +1077,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -1100,7 +1096,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Test list all nodes after added otherUser
var nodeList []v1.Node
@@ -1170,159 +1166,3 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
assert.True(t, nodeListAfter[0].GetOnline())
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
}
// NodeSystemStatus represents the online status of a node across different systems
type NodeSystemStatus struct {
Batcher bool
BatcherConnCount int
MapResponses bool
NodeStore bool
}
// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format(TimestampFormat), message)
var prevReport string
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get batcher state
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
// Get map responses
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate node counts first
expectedCount := len(expectedNodes)
assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch")
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
// Check that we have map responses for expected nodes
mapResponseCount := len(mapResponses)
assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch")
// Build status map for each node
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
// Initialize all expected nodes
for _, nodeID := range expectedNodes {
nodeStatus[nodeID] = NodeSystemStatus{}
}
// Check batcher state
for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes {
nodeID := types.MustParseNodeID(nodeIDStr)
if status, exists := nodeStatus[nodeID]; exists {
status.Batcher = nodeInfo.Connected
status.BatcherConnCount = nodeInfo.ActiveConnections
nodeStatus[nodeID] = status
}
}
// Check map responses using buildExpectedOnlineMap
onlineFromMaps := make(map[types.NodeID]bool)
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
for nodeID := range nodeStatus {
NODE_STATUS:
for id, peerMap := range onlineMap {
if id == nodeID {
continue
}
online := peerMap[nodeID]
// If the node is offline in any map response, we consider it offline
if !online {
onlineFromMaps[nodeID] = false
continue NODE_STATUS
}
onlineFromMaps[nodeID] = true
}
}
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
// Update status with map response data
for nodeID, online := range onlineFromMaps {
if status, exists := nodeStatus[nodeID]; exists {
status.MapResponses = online
nodeStatus[nodeID] = status
}
}
// Check nodestore state
for nodeID, node := range nodeStore {
if status, exists := nodeStatus[nodeID]; exists {
// Check if node is online in nodestore
status.NodeStore = node.IsOnline != nil && *node.IsOnline
nodeStatus[nodeID] = status
}
}
// Verify all systems show nodes in expected state and report failures
allMatch := true
var failureReport strings.Builder
ids := types.NodeIDs(maps.Keys(nodeStatus))
slices.Sort(ids)
for _, nodeID := range ids {
status := nodeStatus[nodeID]
systemsMatch := (status.Batcher == expectedOnline) &&
(status.MapResponses == expectedOnline) &&
(status.NodeStore == expectedOnline)
if !systemsMatch {
allMatch = false
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr))
failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher))
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses))
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore))
}
}
if !allMatch {
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
t.Log("Diff between reports:")
t.Logf("Prev report: \n%s\n", prevReport)
t.Logf("New report: \n%s\n", failureReport.String())
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
prevReport = failureReport.String()
}
failureReport.WriteString("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
assert.Fail(c, failureReport.String())
}
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr))
}, timeout, 2*time.Second, message)
endTime := time.Now()
duration := endTime.Sub(startTime)
t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message)
}