From c6d7b512bd3c8059a863db214f263877487ab83a Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 10 Jul 2025 23:38:55 +0200 Subject: [PATCH] integration: replace time.Sleep with assert.EventuallyWithT (#2680) --- .github/workflows/docs-deploy.yml | 3 +- .../workflows/integration-test-template.yml | 2 +- .github/workflows/lint.yml | 6 +- cmd/headscale/cli/debug.go | 2 +- cmd/headscale/cli/mockoidc.go | 3 +- cmd/headscale/cli/nodes.go | 27 +-- cmd/headscale/cli/users.go | 37 ++-- cmd/hi/cleanup.go | 12 +- cmd/hi/docker.go | 41 ++-- cmd/hi/tar_utils.go | 9 +- flake.nix | 1 + hscontrol/auth.go | 6 +- hscontrol/capver/capver.go | 3 +- hscontrol/capver/capver_generated.go | 25 ++- hscontrol/db/db.go | 5 +- hscontrol/derp/server/derp_server.go | 2 +- hscontrol/dns/extrarecords.go | 1 - hscontrol/grpcv1.go | 5 +- hscontrol/handlers.go | 3 +- hscontrol/mapper/mapper.go | 2 +- hscontrol/mapper/mapper_test.go | 6 +- hscontrol/mapper/tail.go | 2 +- hscontrol/metrics.go | 1 + hscontrol/notifier/notifier.go | 6 +- hscontrol/notifier/notifier_test.go | 15 +- hscontrol/oidc.go | 12 +- hscontrol/policy/matcher/matcher.go | 5 +- hscontrol/policy/pm.go | 1 - hscontrol/policy/policy.go | 3 +- hscontrol/policy/policy_test.go | 4 +- hscontrol/policy/v2/filter.go | 7 +- hscontrol/policy/v2/policy.go | 10 +- hscontrol/policy/v2/policy_test.go | 2 +- hscontrol/policy/v2/types.go | 86 +++++---- hscontrol/policy/v2/types_test.go | 22 ++- hscontrol/policy/v2/utils_test.go | 8 +- hscontrol/routes/primary.go | 1 + hscontrol/state/state.go | 4 +- hscontrol/tailsql.go | 4 +- hscontrol/templates/apple.go | 12 +- hscontrol/templates/windows.go | 4 +- hscontrol/types/common.go | 1 + hscontrol/types/config.go | 5 +- hscontrol/types/config_test.go | 1 + hscontrol/types/node.go | 18 +- hscontrol/types/node_test.go | 2 +- hscontrol/types/preauth_key.go | 2 +- hscontrol/types/preauth_key_test.go | 4 +- hscontrol/types/users.go | 6 +- hscontrol/types/version.go | 6 +- hscontrol/util/dns.go | 11 +- hscontrol/util/log.go | 2 +- hscontrol/util/net.go | 1 + hscontrol/util/util.go | 40 ++-- integration/acl_test.go | 3 - integration/auth_key_test.go | 15 +- integration/auth_oidc_test.go | 70 ++++--- integration/auth_web_flow_test.go | 8 +- integration/cli_test.go | 133 ++++++------- integration/derp_verify_endpoint_test.go | 3 +- integration/dns_test.go | 16 +- integration/dockertestutil/config.go | 13 +- integration/dockertestutil/execute.go | 6 +- integration/dsic/dsic.go | 3 +- integration/embedded_derp_test.go | 11 +- integration/general_test.go | 177 +++++++++++------- integration/hsic/hsic.go | 48 ++--- integration/route_test.go | 27 +-- integration/scenario.go | 15 +- integration/scenario_test.go | 2 - integration/ssh_test.go | 72 ++++--- integration/tsic/tsic.go | 20 +- integration/utils.go | 6 +- 73 files changed, 584 insertions(+), 573 deletions(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 15637069..7d06b6a6 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -48,5 +48,4 @@ jobs: - name: Deploy stable docs from tag if: startsWith(github.ref, 'refs/tags/v') # This assumes that only newer tags are pushed - run: - mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest + run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest diff --git a/.github/workflows/integration-test-template.yml b/.github/workflows/integration-test-template.yml index 1c621192..939451d4 100644 --- a/.github/workflows/integration-test-template.yml +++ b/.github/workflows/integration-test-template.yml @@ -75,7 +75,7 @@ jobs: # Some of the jobs might still require manual restart as they are really # slow and this will cause them to eventually be killed by Github actions. attempt_delay: 300000 # 5 min - attempt_limit: 3 + attempt_limit: 2 command: | nix develop --command -- hi run "^${{ inputs.test }}$" \ --timeout=120m \ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 49334233..1e06f4de 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -36,8 +36,7 @@ jobs: - name: golangci-lint if: steps.changed-files.outputs.files == 'true' - run: - nix develop --command -- golangci-lint run + run: nix develop --command -- golangci-lint run --new-from-rev=${{github.event.pull_request.base.sha}} --format=colored-line-number @@ -75,8 +74,7 @@ jobs: - name: Prettify code if: steps.changed-files.outputs.files == 'true' - run: - nix develop --command -- prettier --no-error-on-unmatched-pattern + run: nix develop --command -- prettier --no-error-on-unmatched-pattern --ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html} proto-lint: diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 41b46fb0..8ce5f237 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -117,7 +117,7 @@ var createNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()), + "Cannot create node: "+status.Convert(err).Message(), output, ) } diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 309ad67d..9969f7c6 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -2,6 +2,7 @@ package cli import ( "encoding/json" + "errors" "fmt" "net" "net/http" @@ -68,7 +69,7 @@ func mockOIDC() error { userStr := os.Getenv("MOCKOIDC_USERS") if userStr == "" { - return fmt.Errorf("MOCKOIDC_USERS not defined") + return errors.New("MOCKOIDC_USERS not defined") } var users []mockoidc.MockUser diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 00d803b2..fb49f4a3 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -184,7 +184,7 @@ var listNodesCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), + "Cannot get nodes: "+status.Convert(err).Message(), output, ) } @@ -398,10 +398,7 @@ var deleteNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error getting node node: %s", - status.Convert(err).Message(), - ), + "Error getting node node: "+status.Convert(err).Message(), output, ) @@ -437,10 +434,7 @@ var deleteNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error deleting node: %s", - status.Convert(err).Message(), - ), + "Error deleting node: "+status.Convert(err).Message(), output, ) @@ -498,10 +492,7 @@ var moveNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error getting node: %s", - status.Convert(err).Message(), - ), + "Error getting node: "+status.Convert(err).Message(), output, ) @@ -517,10 +508,7 @@ var moveNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error moving node: %s", - status.Convert(err).Message(), - ), + "Error moving node: "+status.Convert(err).Message(), output, ) @@ -567,10 +555,7 @@ be assigned to nodes.`, if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error backfilling IPs: %s", - status.Convert(err).Message(), - ), + "Error backfilling IPs: "+status.Convert(err).Message(), output, ) diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index b5f1bc49..c482299c 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/url" + "strconv" survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -27,10 +28,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { err := errors.New("--name or --identifier flag is required") ErrorOutput( err, - fmt.Sprintf( - "Cannot rename user: %s", - status.Convert(err).Message(), - ), + "Cannot rename user: "+status.Convert(err).Message(), "", ) } @@ -114,10 +112,7 @@ var createUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot create user: %s", - status.Convert(err).Message(), - ), + "Cannot create user: "+status.Convert(err).Message(), output, ) } @@ -147,16 +142,16 @@ var destroyUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } if len(users.GetUsers()) != 1 { - err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID") + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } @@ -185,10 +180,7 @@ var destroyUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot destroy user: %s", - status.Convert(err).Message(), - ), + "Cannot destroy user: "+status.Convert(err).Message(), output, ) } @@ -233,7 +225,7 @@ var listUsersCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()), + "Cannot get users: "+status.Convert(err).Message(), output, ) } @@ -247,7 +239,7 @@ var listUsersCmd = &cobra.Command{ tableData = append( tableData, []string{ - fmt.Sprintf("%d", user.GetId()), + strconv.FormatUint(user.GetId(), 10), user.GetDisplayName(), user.GetName(), user.GetEmail(), @@ -287,16 +279,16 @@ var renameUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } if len(users.GetUsers()) != 1 { - err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID") + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } @@ -312,10 +304,7 @@ var renameUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot rename user: %s", - status.Convert(err).Message(), - ), + "Cannot rename user: "+status.Convert(err).Message(), output, ) } diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 080266d8..fd78c66f 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -66,7 +66,7 @@ func killTestContainers(ctx context.Context) error { if cont.State == "running" { _ = cli.ContainerKill(ctx, cont.ID, "KILL") } - + // Then remove the container with retry logic if removeContainerWithRetry(ctx, cli, cont.ID) { removed++ @@ -87,25 +87,25 @@ func killTestContainers(ctx context.Context) error { func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool { maxRetries := 3 baseDelay := 100 * time.Millisecond - - for attempt := 0; attempt < maxRetries; attempt++ { + + for attempt := range maxRetries { err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{ Force: true, }) if err == nil { return true } - + // If this is the last attempt, don't wait if attempt == maxRetries-1 { break } - + // Wait with exponential backoff delay := baseDelay * time.Duration(1< diff --git a/hscontrol/auth.go b/hscontrol/auth.go index f9de67e7..986bbabc 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode( return nil, nil } - } n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry) @@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, changed, err := h.state.HandleNodeFromPreAuthKey( regReq, machineKey, @@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } - if perr, ok := err.(types.PAKError); ok { + var perr types.PAKError + if errors.As(err, &perr) { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } + return nil, err } diff --git a/hscontrol/capver/capver.go b/hscontrol/capver/capver.go index 7ad5074d..347ec981 100644 --- a/hscontrol/capver/capver.go +++ b/hscontrol/capver/capver.go @@ -1,11 +1,10 @@ package capver import ( + "slices" "sort" "strings" - "slices" - xmaps "golang.org/x/exp/maps" "tailscale.com/tailcfg" "tailscale.com/util/set" diff --git a/hscontrol/capver/capver_generated.go b/hscontrol/capver/capver_generated.go index f192fad4..687e3d51 100644 --- a/hscontrol/capver/capver_generated.go +++ b/hscontrol/capver/capver_generated.go @@ -1,6 +1,6 @@ package capver -//Generated DO NOT EDIT +// Generated DO NOT EDIT import "tailscale.com/tailcfg" @@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ "v1.82.5": 115, } - var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ - 87: "v1.60.0", - 88: "v1.62.0", - 90: "v1.64.0", - 95: "v1.66.0", - 97: "v1.68.0", - 102: "v1.70.0", - 104: "v1.72.0", - 106: "v1.74.0", - 109: "v1.78.0", - 113: "v1.80.0", - 115: "v1.82.0", + 87: "v1.60.0", + 88: "v1.62.0", + 90: "v1.64.0", + 95: "v1.66.0", + 97: "v1.68.0", + 102: "v1.70.0", + 104: "v1.72.0", + 106: "v1.74.0", + 109: "v1.78.0", + 113: "v1.80.0", + 115: "v1.82.0", } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 56d7860b..abda802c 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -764,13 +764,13 @@ AND auth_key_id NOT IN ( // Drop all indexes first to avoid conflicts indexesToDrop := []string{ "idx_users_deleted_at", - "idx_provider_identifier", + "idx_provider_identifier", "idx_name_provider_identifier", "idx_name_no_provider_identifier", "idx_api_keys_prefix", "idx_policies_deleted_at", } - + for _, index := range indexesToDrop { _ = tx.Exec("DROP INDEX IF EXISTS " + index).Error } @@ -927,6 +927,7 @@ AND auth_key_id NOT IN ( } log.Info().Msg("Schema recreation completed successfully") + return nil }, Rollback: func(db *gorm.DB) error { return nil }, diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index ae7bf03e..fee395f1 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { Avoid: false, Nodes: []*tailcfg.DERPNode{ { - Name: fmt.Sprintf("%d", d.cfg.ServerRegionID), + Name: strconv.Itoa(d.cfg.ServerRegionID), RegionID: d.cfg.ServerRegionID, HostName: host, DERPPort: port, diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 6ea3aa35..82b3078b 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() { return struct{}{}, nil }, backoff.WithBackOff(backoff.NewExponentialBackOff())) - if err != nil { log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete") continue diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index e098b766..7df4c92e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode( api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } - ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) + ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname) + api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID) + + ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname) api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) log.Trace(). diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index f32aea96..590541b0 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -32,7 +32,7 @@ const ( reservedResponseHeaderSize = 4 ) -// httpError logs an error and sends an HTTP error response with the given +// httpError logs an error and sends an HTTP error response with the given. func httpError(w http.ResponseWriter, err error) { var herr HTTPError if errors.As(err, &herr) { @@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest( resp := &tailcfg.DERPAdmitClientResponse{ Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), } + return json.NewEncoder(writer).Encode(resp) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 49a99351..553658f5 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types. } // ListNodes queries the database for either all nodes if no parameters are given -// or for the given nodes if at least one node ID is given as parameter +// or for the given nodes if at least one node ID is given as parameter. func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { nodes, err := m.state.ListNodes(nodeIDs...) if err != nil { diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 71b9e4b9..b5747c2b 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) { } } -// mockState is a mock implementation that provides the required methods +// mockState is a mock implementation that provides the required methods. type mockState struct { polMan policy.PolicyManager derpMap *tailcfg.DERPMap @@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ } } } + return filtered, nil } // Return all peers except the node itself @@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ filtered = append(filtered, peer) } } + return filtered, nil } @@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { } } } + return filtered, nil } + return m.nodes, nil } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 9b58ad34..9729301d 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -11,7 +11,7 @@ import ( "tailscale.com/types/views" ) -// NodeCanHaveTagChecker is an interface for checking if a node can have a tag +// NodeCanHaveTagChecker is an interface for checking if a node can have a tag. type NodeCanHaveTagChecker interface { NodeCanHaveTag(node types.NodeView, tag string) bool } diff --git a/hscontrol/metrics.go b/hscontrol/metrics.go index cb01838c..ef427afb 100644 --- a/hscontrol/metrics.go +++ b/hscontrol/metrics.go @@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) { } n, err := r.ResponseWriter.Write(b) r.written += int64(n) + return n, err } diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 2e6b9b0b..6bd990c7 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier { n.b = b go b.doWork() + return n } @@ -72,7 +73,7 @@ func (n *Notifier) Close() { n.nodes = make(map[types.NodeID]chan<- types.StateUpdate) } -// safeCloseChannel closes a channel and panic recovers if already closed +// safeCloseChannel closes a channel and panic recovers if already closed. func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) { defer func() { if r := recover(); r != nil { @@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool { if val, ok := n.connected.Load(nodeID); ok { return val } + return false } @@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { return false } -// LikelyConnectedMap returns a thread safe map of connected nodes +// LikelyConnectedMap returns a thread safe map of connected nodes. func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { return n.connected } diff --git a/hscontrol/notifier/notifier_test.go b/hscontrol/notifier/notifier_test.go index 9654cfc8..c3e96a8d 100644 --- a/hscontrol/notifier/notifier_test.go +++ b/hscontrol/notifier/notifier_test.go @@ -1,17 +1,15 @@ package notifier import ( - "context" "fmt" "math/rand" "net/netip" + "slices" "sort" "sync" "testing" "time" - "slices" - "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) { defer n.RemoveNode(1, ch) for _, u := range tt.updates { - n.NotifyAll(context.Background(), u) + n.NotifyAll(t.Context(), u) } n.b.flush() @@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) { // TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected // Multiple goroutines calling AddNode and RemoveNode cause panics when trying to // close a channel that was already closed, which can happen when a node changes -// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting +// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting. func TestIsLikelyConnectedRaceCondition(t *testing.T) { // mock config for the notifier cfg := &types.Config{ @@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) { for range iterations { // Simulate race by having some goroutines check IsLikelyConnected // while others add/remove the node - if routineID%3 == 0 { + switch routineID % 3 { + case 0: // This goroutine checks connection status isConnected := notifier.IsLikelyConnected(nodeID) if isConnected != true && isConnected != false { errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected) } - } else if routineID%3 == 1 { + case 1: // This goroutine removes the node notifier.RemoveNode(nodeID, updateChan) - } else { + default: // This goroutine adds the node back notifier.AddNode(nodeID, updateChan) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 1f08adf8..5f1935e5 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -84,11 +84,8 @@ func NewAuthProviderOIDC( ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, Endpoint: oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf( - "%s/oidc/callback", - strings.TrimSuffix(serverURL, "/"), - ), - Scopes: cfg.Scope, + RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback", + Scopes: cfg.Scope, } registrationCache := zcache.New[string, RegistrationInfo]( @@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler( req *http.Request, ) { vars := mux.Vars(req) - registrationIdStr, _ := vars["registration_id"] + registrationIdStr := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render @@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } oauth2Token, err := a.getOauth2Token(req.Context(), code, state) - if err != nil { httpError(writer, err) return @@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Neither node nor machine key was found in the state cache meaning // that we could not reauth nor register the node. httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + return } @@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token( if err != nil { return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) } + return oauth2Token, err } diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index d246d5e2..aac5a5f3 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -2,9 +2,8 @@ package matcher import ( "net/netip" - "strings" - "slices" + "strings" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" @@ -28,6 +27,7 @@ func (m Match) DebugString() string { for _, prefix := range m.dests.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } + return sb.String() } @@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { for _, rule := range rules { matches = append(matches, MatchFromFilterRule(rule)) } + return matches } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index cfeb65a1..3a59b25f 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -4,7 +4,6 @@ import ( "net/netip" "github.com/juanfont/headscale/hscontrol/policy/matcher" - policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 4efd1e01..5a9103e5 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -5,7 +5,6 @@ import ( "slices" "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/samber/lo" @@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf // AutoApproveRoutes approves any route that can be autoapproved from // the nodes perspective according to the given policy. // It reports true if any routes were approved. -// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes +// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes. func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { if pm == nil { return false diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9f2f7573..f19ac3d3 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -7,9 +7,8 @@ import ( "testing" "time" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) { } } } + func TestReduceRoutes(t *testing.T) { type args struct { node *types.Node diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 1825926f..9d838e56 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -13,9 +13,7 @@ import ( "tailscale.com/types/views" ) -var ( - ErrInvalidAction = errors.New("invalid action") -) +var ErrInvalidAction = errors.New("invalid action") // compileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. @@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules( var destPorts []tailcfg.NetPortRange for _, dest := range acl.Destinations { - ips, err := dest.Alias.Resolve(pol, users, nodes) + ips, err := dest.Resolve(pol, users, nodes) if err != nil { log.Trace().Err(err).Msgf("resolving destination ips") } @@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string { for _, pref := range ips.Prefixes() { out = append(out, pref.String()) } + return out } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index cbc34215..2f4be34e 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -4,19 +4,17 @@ import ( "encoding/json" "fmt" "net/netip" + "slices" "strings" "sync" "github.com/juanfont/headscale/hscontrol/policy/matcher" - - "slices" - "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/util/deephash" "tailscale.com/types/views" + "tailscale.com/util/deephash" ) type PolicyManager struct { @@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { pm.mu.Lock() defer pm.mu.Unlock() + return pm.filter, pm.matchers } @@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() pm.users = users + return pm.updateLocked() } @@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes + return pm.updateLocked() } @@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // cannot just lookup in the prefix map and have to check // if there is a "parent" prefix available. for prefix, approveAddrs := range pm.autoApproveMap { - // Check if prefix is larger (so containing) and then overlaps // the route to see if the node can approve a subset of an autoapprover if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index b3540e63..a91831ad 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -1,10 +1,10 @@ package v2 import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "testing" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/require" "gorm.io/gorm" diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 550287c2..c38d1991 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -6,9 +6,9 @@ import ( "errors" "fmt" "net/netip" - "strings" - "slices" + "strconv" + "strings" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { // Check if it's the wildcard port range if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 { - return json.Marshal(fmt.Sprintf("%s:*", alias)) + return json.Marshal(alias + ":*") } // Otherwise, format as "alias:ports" var ports []string for _, port := range a.Ports { if port.First == port.Last { - ports = append(ports, fmt.Sprintf("%d", port.First)) + ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) } else { ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last)) } @@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error { if err := u.Validate(); err != nil { return err } + return nil } @@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. return buildIPSetMultiErr(&ips, errs) } -// Group is a special string which is always prefixed with `group:` +// Group is a special string which is always prefixed with `group:`. type Group string func (g Group) Validate() error { @@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error { if err := g.Validate(); err != nil { return err } + return nil } @@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod return buildIPSetMultiErr(&ips, errs) } -// Tag is a special string which is always prefixed with `tag:` +// Tag is a special string which is always prefixed with `tag:`. type Tag string func (t Tag) Validate() error { @@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error { if err := t.Validate(); err != nil { return err } + return nil } @@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error { if err := h.Validate(); err != nil { return err } + return nil } @@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error { } *p = Prefix(addrPref) + return nil } @@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error { return err } *p = Prefix(pref) + return nil } @@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { if err := p.Validate(); err != nil { return err } + return nil } @@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild } } -// AutoGroup is a special string which is always prefixed with `autogroup:` +// AutoGroup is a special string which is always prefixed with `autogroup:`. type AutoGroup string const ( @@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { if err := ag.Validate(); err != nil { return err } + return nil } @@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } - if err := ve.Alias.Validate(); err != nil { + if err := ve.Validate(); err != nil { return err } default: return fmt.Errorf("type %T not supported", vs) } + return nil } @@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { return err } ve.Alias = ptr + return nil } @@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { for i, alias := range aliases { (*a)[i] = alias.Alias } + return nil } @@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I return ips, multierr.New(append(errs, err)...) } -// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer +// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer. func unmarshalPointer[T any]( b []byte, parseFunc func(string) (T, error), @@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { for i, autoApprover := range autoApprovers { (*aa)[i] = autoApprover.AutoApprover } + return nil } @@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { return err } ve.AutoApprover = ptr + return nil } @@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { return err } ve.Owner = ptr + return nil } @@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error { for i, owner := range owners { (*o)[i] = owner.Owner } + return nil } @@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) { case isGroup(s): return ptr.To(Group(s)), nil } + return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: - user (containing an "@") - group (starting with "group:") @@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { (*g)[group] = usernames } + return nil } @@ -1252,7 +1269,7 @@ type Policy struct { // We use the default JSON marshalling behavior provided by the Go runtime. var ( - // TODO(kradalby): Add these checks for tagOwners and autoApprovers + // TODO(kradalby): Add these checks for tagOwners and autoApprovers. autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged} autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} @@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) } if !slices.Contains(autogroupForSrc, *src) { @@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) } if !slices.Contains(autogroupForSSHSrc, *src) { @@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error { } if dst.Is(AutoGroupInternet) { - return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) } if !slices.Contains(autogroupForSSHDst, *dst) { @@ -1360,14 +1377,14 @@ func (p *Policy) validate() error { for _, acl := range p.ACLs { for _, src := range acl.Sources { - switch src.(type) { + switch src := src.(type) { case *Host: - h := src.(*Host) + h := src if !p.Hosts.exist(*h) { errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) } case *AutoGroup: - ag := src.(*AutoGroup) + ag := src if err := validateAutogroupSupported(ag); err != nil { errs = append(errs, err) @@ -1379,12 +1396,12 @@ func (p *Policy) validate() error { continue } case *Group: - g := src.(*Group) + g := src if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := src.(*Tag) + tagOwner := src if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1440,9 +1457,9 @@ func (p *Policy) validate() error { } for _, src := range ssh.Sources { - switch src.(type) { + switch src := src.(type) { case *AutoGroup: - ag := src.(*AutoGroup) + ag := src if err := validateAutogroupSupported(ag); err != nil { errs = append(errs, err) @@ -1454,21 +1471,21 @@ func (p *Policy) validate() error { continue } case *Group: - g := src.(*Group) + g := src if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := src.(*Tag) + tagOwner := src if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } } } for _, dst := range ssh.Destinations { - switch dst.(type) { + switch dst := dst.(type) { case *AutoGroup: - ag := dst.(*AutoGroup) + ag := dst if err := validateAutogroupSupported(ag); err != nil { errs = append(errs, err) continue @@ -1479,7 +1496,7 @@ func (p *Policy) validate() error { continue } case *Tag: - tagOwner := dst.(*Tag) + tagOwner := dst if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1489,9 +1506,9 @@ func (p *Policy) validate() error { for _, tagOwners := range p.TagOwners { for _, tagOwner := range tagOwners { - switch tagOwner.(type) { + switch tagOwner := tagOwner.(type) { case *Group: - g := tagOwner.(*Group) + g := tagOwner if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } @@ -1501,14 +1518,14 @@ func (p *Policy) validate() error { for _, approvers := range p.AutoApprovers.Routes { for _, approver := range approvers { - switch approver.(type) { + switch approver := approver.(type) { case *Group: - g := approver.(*Group) + g := approver if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := approver.(*Tag) + tagOwner := approver if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1517,14 +1534,14 @@ func (p *Policy) validate() error { } for _, approver := range p.AutoApprovers.ExitNode { - switch approver.(type) { + switch approver := approver.(type) { case *Group: - g := approver.(*Group) + g := approver if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := approver.(*Tag) + tagOwner := approver if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1536,6 +1553,7 @@ func (p *Policy) validate() error { } p.validated = true + return nil } @@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { ) } } + return nil } @@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { ) } } + return nil } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 8cddfeba..4aca150e 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -5,13 +5,13 @@ import ( "net/netip" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go4.org/netipx" @@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) { // Marshal the policy to JSON marshalled, err := json.MarshalIndent(policy, "", " ") require.NoError(t, err) - + // Make sure all expected fields are present in the JSON jsonString := string(marshalled) assert.Contains(t, jsonString, "group:example") @@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) { assert.Contains(t, jsonString, "accept") assert.Contains(t, jsonString, "tcp") assert.Contains(t, jsonString, "80") - + // Unmarshal back to verify round trip var roundTripped Policy err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) - + // Compare the original and round-tripped policies - cmps := append(util.Comparers, + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { return x == y }), cmpopts.IgnoreUnexported(Policy{}), cmpopts.EquateEmpty(), ) - + if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" { t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) } @@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) { }, } - cmps := append(util.Comparers, + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { return x == y }), cmpopts.IgnoreUnexported(Policy{}), ) - + // For round-trip testing, we'll normalize the policies before comparing for _, tt := range tests { @@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) { } else if !strings.Contains(err.Error(), tt.wantErr) { t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr) } + return // Skip the rest of the test if we expected an error } @@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) { if err != nil { t.Fatalf("round-trip unmarshalling: %v", err) } - + // Add EquateEmpty to handle nil vs empty maps/slices - roundTripCmps := append(cmps, + roundTripCmps := append(cmps, cmpopts.EquateEmpty(), cmpopts.IgnoreUnexported(Policy{}), ) @@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet { builder.AddPrefix(mp(p)) } ipSet, _ := builder.IPSet() + return ipSet } diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index d1645071..2084b22f 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) { expected []tailcfg.PortRange err string }{ - {"80", []tailcfg.PortRange{{80, 80}}, ""}, - {"80-90", []tailcfg.PortRange{{80, 90}}, ""}, - {"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""}, - {"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""}, + {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""}, + {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""}, + {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""}, + {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""}, {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, {"80-", nil, "invalid port range format"}, {"-90", nil, "invalid port range format"}, diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 67eb8d1f..f65d9122 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } tsaddr.SortPrefixes(routes) + return routes } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 0d8a2a8e..b754e594 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) { if err != nil { return types.NodeView{}, err } + return node.View(), nil } @@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er if err != nil { return types.NodeView{}, err } + return node.View(), nil } @@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey( if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) { nodeToRegister.Expiry = ®Req.Expiry } else if !regReq.Expiry.IsZero() { - // If client is sending an expired time (e.g., after logout), + // If client is sending an expired time (e.g., after logout), // don't set expiry so the node won't be considered expired log.Debug(). Time("requested_expiry", regReq.Expiry). diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 82e82d78..1a949173 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -2,6 +2,7 @@ package hscontrol import ( "context" + "errors" "fmt" "net/http" "os" @@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. certDomains := tsNode.CertDomains() if len(certDomains) == 0 { - return fmt.Errorf("no cert domains available for HTTPS") + return errors.New("no cert domains available for HTTPS") } base := "https://" + certDomains[0] go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s logf("TailSQL started") <-ctx.Done() logf("TailSQL shutting down...") + return tsNode.Close() } diff --git a/hscontrol/templates/apple.go b/hscontrol/templates/apple.go index 99b1cc8e..84928ed5 100644 --- a/hscontrol/templates/apple.go +++ b/hscontrol/templates/apple.go @@ -62,7 +62,7 @@ func Apple(url string) *elem.Element { ), elem.Pre(nil, elem.Code(nil, - elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)), + elem.Text("tailscale login --login-server "+url), ), ), headerTwo("GUI"), @@ -143,10 +143,7 @@ func Apple(url string) *elem.Element { elem.Code( nil, elem.Text( - fmt.Sprintf( - `defaults write io.tailscale.ipn.macos ControlURL %s`, - url, - ), + "defaults write io.tailscale.ipn.macos ControlURL "+url, ), ), ), @@ -155,10 +152,7 @@ func Apple(url string) *elem.Element { elem.Code( nil, elem.Text( - fmt.Sprintf( - `defaults write io.tailscale.ipn.macsys ControlURL %s`, - url, - ), + "defaults write io.tailscale.ipn.macsys ControlURL "+url, ), ), ), diff --git a/hscontrol/templates/windows.go b/hscontrol/templates/windows.go index 680d6655..ecf7d77c 100644 --- a/hscontrol/templates/windows.go +++ b/hscontrol/templates/windows.go @@ -1,8 +1,6 @@ package templates import ( - "fmt" - "github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go/attrs" ) @@ -31,7 +29,7 @@ func Windows(url string) *elem.Element { ), elem.Pre(nil, elem.Code(nil, - elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)), + elem.Text("tailscale login --login-server "+url), ), ), ), diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 69c298b9..51e11757 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID { if err != nil { panic(err) } + return rid } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 03c1e7ea..1e35303e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error { log.Warn().Msg("No config file found, using defaults") return nil } + return fmt.Errorf("fatal error reading config file: %w", err) } @@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) { } if prefix4 == nil && prefix6 == nil { - return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") } allocStr := viper.GetString("prefixes.allocation") @@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error { s := len(serverDomainParts) b := len(baseDomainParts) - for i := range len(baseDomainParts) { + for i := range baseDomainParts { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { return nil } diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 7ae3db59..6b9fc2ef 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) { assert.Equal(t, "trace", viper.GetString("log.level")) assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4")) assert.False(t, viper.GetBool("database.sqlite.write_ahead_log")) + return nil, nil }, want: nil, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 11383950..32f0274c 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -28,8 +28,10 @@ var ( ErrNodeUserHasNoName = errors.New("node user has no name") ) -type NodeID uint64 -type NodeIDs []NodeID +type ( + NodeID uint64 + NodeIDs []NodeID +) func (n NodeIDs) Len() int { return len(n) } func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } @@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool { return true } } + return false } @@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool { // and therefore should not be treated as a // user owned device. // Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) +// via CLI ("forced tags" and preauthkeys). func (node *Node) IsTagged() bool { if len(node.ForcedTags) > 0 { return true @@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool { // HasTag reports if a node has a given tag. // Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) +// via CLI ("forced tags" and preauthkeys). func (node *Node) HasTag(tag string) bool { return slices.Contains(node.Tags(), tag) } @@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string { sb.WriteString(node.DebugString()) sb.WriteString("\n") } + return sb.String() } @@ -590,6 +594,7 @@ func (node Node) DebugString() string { fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes()) sb.WriteString("\n") + return sb.String() } @@ -689,7 +694,7 @@ func (v NodeView) Tags() []string { // and therefore should not be treated as a // user owned device. // Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) +// via CLI ("forced tags" and preauthkeys). func (v NodeView) IsTagged() bool { if !v.Valid() { return false @@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC // GetFQDN returns the fully qualified domain name for the node. func (v NodeView) GetFQDN(baseDomain string) (string, error) { if !v.Valid() { - return "", fmt.Errorf("failed to create valid FQDN: node view is invalid") + return "", errors.New("failed to create valid FQDN: node view is invalid") } return v.ж.GetFQDN(baseDomain) } @@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string { } return v.ж.IPsAsString() } - diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index c7261587..f6d1d027 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -2,7 +2,6 @@ package types import ( "fmt" - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "strings" "testing" @@ -10,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" "tailscale.com/types/key" diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 51c474eb..e47666ff 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -11,7 +11,7 @@ import ( type PAKError string func (e PAKError) Error() string { return string(e) } -func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) } +func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) } // PreAuthKey describes a pre-authorization key usable in a particular user. type PreAuthKey struct { diff --git a/hscontrol/types/preauth_key_test.go b/hscontrol/types/preauth_key_test.go index 3f7eb269..4ab1c717 100644 --- a/hscontrol/types/preauth_key_test.go +++ b/hscontrol/types/preauth_key_test.go @@ -1,6 +1,7 @@ package types import ( + "errors" "testing" "time" @@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) { if err == nil { t.Errorf("expected error but got none") } else { - httpErr, ok := err.(PAKError) + var httpErr PAKError + ok := errors.As(err, &httpErr) if !ok { t.Errorf("expected HTTPError but got %T", err) } else { diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 6cd2c41a..69377b95 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string { // - Remove empty path segments // - For non-URL identifiers, it joins non-empty segments with a single slash // - Returns empty string for identifiers with only slashes -// - Normalize URL schemes to lowercase +// - Normalize URL schemes to lowercase. func CleanIdentifier(identifier string) string { if identifier == "" { return identifier @@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, part) } } - + if len(cleanParts) == 0 { u.Path = "" } else { @@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string { } // Ensure scheme is lowercase u.Scheme = strings.ToLower(u.Scheme) + return u.String() } @@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string { if len(cleanParts) == 0 { return "" } + return strings.Join(cleanParts, "/") } diff --git a/hscontrol/types/version.go b/hscontrol/types/version.go index e84087fb..7fe23250 100644 --- a/hscontrol/types/version.go +++ b/hscontrol/types/version.go @@ -1,4 +1,6 @@ package types -var Version = "dev" -var GitCommitHash = "dev" +var ( + Version = "dev" + GitCommitHash = "dev" +) diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index 3a08fc3a..65194720 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "regexp" + "strconv" "strings" "unicode" @@ -21,8 +22,10 @@ const ( LabelHostnameLength = 63 ) -var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") -var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") +var ( + invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") + invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") +) var ErrInvalidUserName = errors.New("invalid user name") @@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.) rdnsSlice := []string{} for i := lastOctet - 1; i >= 0; i-- { - rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i])) + rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10)) } rdnsSlice = append(rdnsSlice, "in-addr.arpa.") rdnsBase := strings.Join(rdnsSlice, ".") @@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) { prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".") - return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix)) + return dnsname.ToFQDN(prefix + ".ip6.arpa") } var fqdns []dnsname.FQDN diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go index 12f646b1..936b374c 100644 --- a/hscontrol/util/log.go +++ b/hscontrol/util/log.go @@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq "rowsAffected": rowsAffected, } - if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) { + if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) { l.Logger.Error().Err(err).Fields(fields).Msgf("") return } diff --git a/hscontrol/util/net.go b/hscontrol/util/net.go index 0d6b4412..e28bb00b 100644 --- a/hscontrol/util/net.go +++ b/hscontrol/util/net.go @@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet { internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16")) theInternetSet, _ := internetBuilder.IPSet() + return theInternetSet }) diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 4f6660be..a44a6e97 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { } type TraceroutePath struct { - // Hop is the current jump in the total traceroute. - Hop int + // Hop is the current jump in the total traceroute. + Hop int - // Hostname is the resolved hostname or IP address identifying the jump - Hostname string + // Hostname is the resolved hostname or IP address identifying the jump + Hostname string - // IP is the IP address of the jump - IP netip.Addr + // IP is the IP address of the jump + IP netip.Addr - // Latencies is a list of the latencies for this jump - Latencies []time.Duration + // Latencies is a list of the latencies for this jump + Latencies []time.Duration } type Traceroute struct { - // Hostname is the resolved hostname or IP address identifying the target - Hostname string + // Hostname is the resolved hostname or IP address identifying the target + Hostname string - // IP is the IP address of the target - IP netip.Addr + // IP is the IP address of the target + IP netip.Addr - // Route is the path taken to reach the target if successful. The list is ordered by the path taken. - Route []TraceroutePath + // Route is the path taken to reach the target if successful. The list is ordered by the path taken. + Route []TraceroutePath - // Success indicates if the traceroute was successful. - Success bool + // Success indicates if the traceroute was successful. + Success bool - // Err contains an error if the traceroute was not successful. - Err error + // Err contains an error if the traceroute was not successful. + Err error } -// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct +// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct. func ParseTraceroute(output string) (Traceroute, error) { lines := strings.Split(strings.TrimSpace(output), "\n") if len(lines) < 1 { @@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Parse each hop line - hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) + hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?") for i := 1; i < len(lines); i++ { matches := hopRegex.FindStringSubmatch(lines[i]) diff --git a/integration/acl_test.go b/integration/acl_test.go index 193b6669..3aef521e 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1077,7 +1077,6 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1213,7 +1212,6 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { func TestACLAutogroupMember(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := aclScenario(t, &policyv2.Policy{ @@ -1271,7 +1269,6 @@ func TestACLAutogroupMember(t *testing.T) { func TestACLAutogroupTagged(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := aclScenario(t, &policyv2.Policy{ diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index d54ff593..ac69a6f5 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -3,12 +3,11 @@ package integration import ( "fmt" "net/netip" + "slices" "strconv" "testing" "time" - "slices" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -19,7 +18,6 @@ import ( func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() for _, https := range []bool{true, false} { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { @@ -66,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -161,12 +159,11 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } }) } - } func assertLastSeenSet(t *testing.T, node *v1.Node) { assert.NotNil(t, node) - assert.NotNil(t, node.LastSeen) + assert.NotNil(t, node.GetLastSeen()) } // This test will first log in two sets of nodes to two sets of users, then @@ -175,7 +172,6 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) { // still has nodes, but they are not connected. func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -204,7 +200,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -259,7 +255,6 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { IntegrationSkip(t) - t.Parallel() for _, https := range []bool{true, false} { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { @@ -303,7 +298,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 53c74577..d118b643 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1,14 +1,12 @@ package integration import ( - "fmt" + "maps" "net/netip" "sort" "testing" "time" - "maps" - "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -21,7 +19,6 @@ import ( func TestOIDCAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Logins to MockOIDC is served by a queue with a strict order, // if we use more than one node per user, the order of the logins @@ -119,7 +116,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { // This test is really flaky. func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { IntegrationSkip(t) - t.Parallel() shortAccessTTL := 5 * time.Minute @@ -174,9 +170,13 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { // of safety reasons) before checking if the clients have logged out. // The Wait function can't do it itself as it has an upper bound of 1 // min. - time.Sleep(shortAccessTTL + 10*time.Second) - - assertTailscaleNodesLogout(t, allClients) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + } + }, shortAccessTTL+10*time.Second, 5*time.Second) } func TestOIDC024UserCreation(t *testing.T) { @@ -295,9 +295,7 @@ func TestOIDC024UserCreation(t *testing.T) { spec := ScenarioSpec{ NodesPerUser: 1, } - for _, user := range tt.cliUsers { - spec.Users = append(spec.Users, user) - } + spec.Users = append(spec.Users, tt.cliUsers...) for _, user := range tt.oidcUsers { spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified)) @@ -350,7 +348,6 @@ func TestOIDC024UserCreation(t *testing.T) { func TestOIDCAuthenticationWithPKCE(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Single user with one node for testing PKCE flow spec := ScenarioSpec{ @@ -402,7 +399,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { func TestOIDCReloginSameNodeNewUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Create no nodes and no users scenario, err := NewScenario(ScenarioSpec{ @@ -440,7 +436,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assertNoErr(t, err) - assert.Len(t, listUsers, 0) + assert.Empty(t, listUsers) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) assertNoErr(t, err) @@ -482,7 +478,13 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { err = ts.Logout() assertNoErr(t, err) - time.Sleep(5 * time.Second) + // Wait for logout to complete and then do second logout + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + }, 5*time.Second, 1*time.Second) // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it @@ -530,16 +532,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Machine key is the same as the "machine" has not changed, // but Node key is not as it is a new node - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) - assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) - assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) + assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) + assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) // Log out user2, and log into user1, no new node should be created, // the node should now "become" node1 again err = ts.Logout() assertNoErr(t, err) - time.Sleep(5 * time.Second) + // Wait for logout to complete and then do second logout + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + }, 5*time.Second, 1*time.Second) // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it @@ -588,24 +596,24 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Validate that the machine we had when we logged in the first time, has the same // machine key, but a different ID than the newly logged in version of the same // machine. - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) - assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey) - assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id) - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) - assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id) - assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) + assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) + assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) + assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) + assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) // Even tho we are logging in again with the same user, the previous key has been expired // and a new one has been generated. The node entry in the database should be the same // as the user + machinekey still matches. - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey) - assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey) - assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) + assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) + assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) // The "logged back in" machine should have the same machinekey but a different nodekey // than the version logged in with a different user. - assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey) - assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) + assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) + assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) } func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { @@ -623,7 +631,7 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { return mockoidc.MockUser{ Subject: username, PreferredUsername: username, - Email: fmt.Sprintf("%s@headscale.net", username), + Email: username + "@headscale.net", EmailVerified: emailVerified, } } diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 64cace7b..83413e0d 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -2,9 +2,8 @@ package integration import ( "net/netip" - "testing" - "slices" + "testing" "github.com/juanfont/headscale/integration/hsic" "github.com/samber/lo" @@ -55,7 +54,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -95,7 +93,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -140,7 +138,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) listNodes, err = headscale.ListNodes() - require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + require.Len(t, listNodes, nodeCountBeforeLogout) t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) for _, client := range allClients { diff --git a/integration/cli_test.go b/integration/cli_test.go index 2cff0500..fd9c49a7 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -18,8 +18,8 @@ import ( "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tailscale.com/tailcfg" "golang.org/x/exp/slices" + "tailscale.com/tailcfg" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -30,7 +30,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul err = json.Unmarshal([]byte(str), result) if err != nil { - return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str) + return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str) } return nil @@ -48,7 +48,6 @@ func sortWithID[T GRPCSortable](a, b T) int { func TestUserCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"user1", "user2"}, @@ -184,7 +183,7 @@ func TestUserCommand(t *testing.T) { "--identifier=1", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterIDDelete []*v1.User @@ -222,7 +221,7 @@ func TestUserCommand(t *testing.T) { "--name=newname", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterNameDelete []v1.User @@ -238,12 +237,11 @@ func TestUserCommand(t *testing.T) { ) assertNoErr(t, err) - require.Len(t, listAfterNameDelete, 0) + require.Empty(t, listAfterNameDelete) } func TestPreAuthKeyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "preauthkeyspace" count := 3 @@ -347,7 +345,7 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) + assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags()) } // Test key expiry @@ -386,7 +384,6 @@ func TestPreAuthKeyCommand(t *testing.T) { func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "pre-auth-key-without-exp-user" spec := ScenarioSpec{ @@ -448,7 +445,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "pre-auth-key-reus-ephm-user" spec := ScenarioSpec{ @@ -524,7 +520,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() user1 := "user1" user2 := "user2" @@ -575,7 +570,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { assertNoErr(t, err) listNodes, err := headscale.ListNodes() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, listNodes, 1) assert.Equal(t, user1, listNodes[0].GetUser().GetName()) @@ -613,7 +608,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { } listNodes, err = headscale.ListNodes() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, listNodes, 2) assert.Equal(t, user1, listNodes[0].GetUser().GetName()) assert.Equal(t, user2, listNodes[1].GetUser().GetName()) @@ -621,7 +616,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { func TestApiKeyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() count := 5 @@ -653,7 +647,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -672,7 +666,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAPIKeys, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listedAPIKeys, 5) @@ -728,7 +722,7 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - assert.Nil(t, err) + assert.NoError(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } @@ -744,7 +738,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAfterExpireAPIKeys, ) - assert.Nil(t, err) + assert.NoError(t, err) for index := range listedAfterExpireAPIKeys { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { @@ -770,7 +764,7 @@ func TestApiKeyCommand(t *testing.T) { "--prefix", listedAPIKeys[0].GetPrefix(), }) - assert.Nil(t, err) + assert.NoError(t, err) var listedAPIKeysAfterDelete []v1.ApiKey err = executeAndUnmarshal(headscale, @@ -783,14 +777,13 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAPIKeysAfterDelete, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listedAPIKeysAfterDelete, 4) } func TestNodeTagCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"user1"}, @@ -811,7 +804,7 @@ func TestNodeTagCommand(t *testing.T) { types.MustRegistrationID().String(), } nodes := make([]*v1.Node, len(regIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -829,7 +822,7 @@ func TestNodeTagCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -847,7 +840,7 @@ func TestNodeTagCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) nodes[index] = &node } @@ -866,7 +859,7 @@ func TestNodeTagCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) @@ -894,7 +887,7 @@ func TestNodeTagCommand(t *testing.T) { }, &resultMachines, ) - assert.Nil(t, err) + assert.NoError(t, err) found := false for _, node := range resultMachines { if node.GetForcedTags() != nil { @@ -905,19 +898,15 @@ func TestNodeTagCommand(t *testing.T) { } } } - assert.Equal( + assert.True( t, - true, found, "should find a node with the tag 'tag:test' in the list of nodes", ) } - - func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() tests := []struct { name string @@ -1024,7 +1013,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { }, &resultMachines, ) - assert.Nil(t, err) + assert.NoError(t, err) found := false for _, node := range resultMachines { if tags := node.GetValidTags(); tags != nil { @@ -1043,7 +1032,6 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { func TestNodeCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"node-user", "other-user"}, @@ -1067,7 +1055,7 @@ func TestNodeCommand(t *testing.T) { types.MustRegistrationID().String(), } nodes := make([]*v1.Node, len(regIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1085,7 +1073,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1103,7 +1091,7 @@ func TestNodeCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) nodes[index] = &node } @@ -1123,7 +1111,7 @@ func TestNodeCommand(t *testing.T) { }, &listAll, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAll, 5) @@ -1144,7 +1132,7 @@ func TestNodeCommand(t *testing.T) { types.MustRegistrationID().String(), } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range otherUserRegIDs { _, err := headscale.Execute( @@ -1162,7 +1150,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1180,7 +1168,7 @@ func TestNodeCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) otherUserMachines[index] = &node } @@ -1200,7 +1188,7 @@ func TestNodeCommand(t *testing.T) { }, &listAllWithotherUser, ) - assert.Nil(t, err) + assert.NoError(t, err) // All nodes, nodes + otherUser assert.Len(t, listAllWithotherUser, 7) @@ -1226,7 +1214,7 @@ func TestNodeCommand(t *testing.T) { }, &listOnlyotherUserMachineUser, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listOnlyotherUserMachineUser, 2) @@ -1258,7 +1246,7 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - assert.Nil(t, err) + assert.NoError(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node @@ -1275,14 +1263,13 @@ func TestNodeCommand(t *testing.T) { }, &listOnlyMachineUserAfterDelete, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listOnlyMachineUserAfterDelete, 4) } func TestNodeExpireCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"node-expire-user"}, @@ -1323,7 +1310,7 @@ func TestNodeExpireCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1341,7 +1328,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) nodes[index] = &node } @@ -1360,7 +1347,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &listAll, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAll, 5) @@ -1377,10 +1364,10 @@ func TestNodeExpireCommand(t *testing.T) { "nodes", "expire", "--identifier", - fmt.Sprintf("%d", listAll[idx].GetId()), + strconv.FormatUint(listAll[idx].GetId(), 10), }, ) - assert.Nil(t, err) + assert.NoError(t, err) } var listAllAfterExpiry []v1.Node @@ -1395,7 +1382,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &listAllAfterExpiry, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAllAfterExpiry, 5) @@ -1408,7 +1395,6 @@ func TestNodeExpireCommand(t *testing.T) { func TestNodeRenameCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"node-rename-command"}, @@ -1432,7 +1418,7 @@ func TestNodeRenameCommand(t *testing.T) { types.MustRegistrationID().String(), } nodes := make([]*v1.Node, len(regIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1487,7 +1473,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAll, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAll, 5) @@ -1504,11 +1490,11 @@ func TestNodeRenameCommand(t *testing.T) { "nodes", "rename", "--identifier", - fmt.Sprintf("%d", listAll[idx].GetId()), + strconv.FormatUint(listAll[idx].GetId(), 10), fmt.Sprintf("newnode-%d", idx+1), }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Contains(t, res, "Node renamed") } @@ -1525,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAllAfterRename, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAllAfterRename, 5) @@ -1542,7 +1528,7 @@ func TestNodeRenameCommand(t *testing.T) { "nodes", "rename", "--identifier", - fmt.Sprintf("%d", listAll[4].GetId()), + strconv.FormatUint(listAll[4].GetId(), 10), strings.Repeat("t", 64), }, ) @@ -1560,7 +1546,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAllAfterRenameAttempt, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAllAfterRenameAttempt, 5) @@ -1573,7 +1559,6 @@ func TestNodeRenameCommand(t *testing.T) { func TestNodeMoveCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"old-user", "new-user"}, @@ -1610,7 +1595,7 @@ func TestNodeMoveCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1628,13 +1613,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint64(1), node.GetId()) assert.Equal(t, "nomad-node", node.GetName()) - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Equal(t, "old-user", node.GetUser().GetName()) - nodeID := fmt.Sprintf("%d", node.GetId()) + nodeID := strconv.FormatUint(node.GetId(), 10) err = executeAndUnmarshal( headscale, @@ -1651,9 +1636,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, node.GetUser().GetName(), "new-user") + assert.Equal(t, "new-user", node.GetUser().GetName()) var allNodes []v1.Node err = executeAndUnmarshal( @@ -1667,13 +1652,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &allNodes, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, allNodes, 1) assert.Equal(t, allNodes[0].GetId(), node.GetId()) assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) - assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") + assert.Equal(t, "new-user", allNodes[0].GetUser().GetName()) _, err = headscale.Execute( []string{ @@ -1693,7 +1678,7 @@ func TestNodeMoveCommand(t *testing.T) { err, "user not found", ) - assert.Equal(t, node.GetUser().GetName(), "new-user") + assert.Equal(t, "new-user", node.GetUser().GetName()) err = executeAndUnmarshal( headscale, @@ -1710,9 +1695,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Equal(t, "old-user", node.GetUser().GetName()) err = executeAndUnmarshal( headscale, @@ -1729,14 +1714,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Equal(t, "old-user", node.GetUser().GetName()) } func TestPolicyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"user1"}, @@ -1817,7 +1801,6 @@ func TestPolicyCommand(t *testing.T) { func TestPolicyBrokenConfigCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 23879d56..4a5e52ae 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -1,7 +1,6 @@ package integration import ( - "context" "fmt" "net" "strconv" @@ -104,7 +103,7 @@ func DERPVerify( defer c.Close() var result error - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(t.Context()); err != nil { result = fmt.Errorf("client Connect: %w", err) } if m, err := c.Recv(); err != nil { diff --git a/integration/dns_test.go b/integration/dns_test.go index ef6c479b..456895cc 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -15,7 +15,6 @@ import ( func TestResolveMagicDNS(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -49,7 +48,7 @@ func TestResolveMagicDNS(t *testing.T) { // It is safe to ignore this error as we handled it when caching it peerFQDN, _ := peer.FQDN() - assert.Equal(t, fmt.Sprintf("%s.headscale.net.", peer.Hostname()), peerFQDN) + assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN) command := []string{ "tailscale", @@ -85,7 +84,6 @@ func TestResolveMagicDNS(t *testing.T) { func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -222,12 +220,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { _, err = hs.Execute([]string{"rm", erPath}) assertNoErr(t, err) - time.Sleep(2 * time.Second) - // The same paths should still be available as it is not cleared on delete. - for _, client := range allClients { - assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9") - } + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"}) + assert.NoError(ct, err) + assert.Contains(ct, result, "9.9.9.9") + } + }, 10*time.Second, 1*time.Second) // Write a new file, the backoff mechanism should make the filewatcher pick it up // again. diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index f8bbde5f..dc8391d7 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -33,26 +33,27 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { } // GenerateRunID creates a unique run identifier with timestamp and random hash. -// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3) +// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3). func GenerateRunID() string { now := time.Now() timestamp := now.Format("20060102-150405") - + // Add a short random hash to ensure uniqueness randomHash := util.MustGenerateRandomStringDNSSafe(6) + return fmt.Sprintf("%s-%s", timestamp, randomHash) } // ExtractRunIDFromContainerName extracts the run ID from container name. -// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH" +// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH". func ExtractRunIDFromContainerName(containerName string) string { parts := strings.Split(containerName, "-") if len(parts) >= 3 { // Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH) return strings.Join(parts[len(parts)-3:], "-") } - - panic(fmt.Sprintf("unexpected container name format: %s", containerName)) + + panic("unexpected container name format: " + containerName) } // IsRunningInContainer checks if the current process is running inside a Docker container. @@ -62,4 +63,4 @@ func IsRunningInContainer() bool { // This could be improved with more robust detection if needed _, err := os.Stat("/.dockerenv") return err == nil -} \ No newline at end of file +} diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index e77b7cb8..e4b39efb 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -30,7 +30,7 @@ func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption { }) } -// buffer is a goroutine safe bytes.buffer +// buffer is a goroutine safe bytes.buffer. type buffer struct { store bytes.Buffer mutex sync.Mutex @@ -58,8 +58,8 @@ func ExecuteCommand( env []string, options ...ExecuteCommandOption, ) (string, string, error) { - var stdout = buffer{} - var stderr = buffer{} + stdout := buffer{} + stderr := buffer{} execConfig := ExecuteCommandConfig{ timeout: dockerExecuteTimeout, diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index 857a5def..dd6c6978 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -159,7 +159,6 @@ func New( }, } - if dsic.workdir != "" { runOptions.WorkingDir = dsic.workdir } @@ -192,7 +191,7 @@ func New( } // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(runOptions, "derp") - + container, err = pool.BuildAndRunWithBuildOptions( buildOptions, runOptions, diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index ca4e8a14..b1d947cd 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -2,13 +2,13 @@ package integration import ( "strings" - "tailscale.com/tailcfg" - "tailscale.com/types/key" "testing" "time" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "tailscale.com/tailcfg" + "tailscale.com/types/key" ) type ClientsSpec struct { @@ -71,9 +71,9 @@ func TestDERPServerWebsocketScenario(t *testing.T) { NodesPerUser: 1, Users: []string{"user1", "user2", "user3"}, Networks: map[string][]string{ - "usernet1": []string{"user1"}, - "usernet2": []string{"user2"}, - "usernet3": []string{"user3"}, + "usernet1": {"user1"}, + "usernet2": {"user2"}, + "usernet3": {"user3"}, }, } @@ -106,7 +106,6 @@ func derpServerScenario( furtherAssertions ...func(*Scenario), ) { IntegrationSkip(t) - // t.Parallel() scenario, err := NewScenario(spec) assertNoErr(t, err) diff --git a/integration/general_test.go b/integration/general_test.go index 292eb5ca..c60c2f46 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -26,7 +26,6 @@ import ( func TestPingAllByIP(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -68,7 +67,6 @@ func TestPingAllByIP(t *testing.T) { func TestPingAllByIPPublicDERP(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -118,7 +116,6 @@ func TestEphemeralInAlternateTimezone(t *testing.T) { func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -191,7 +188,6 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { // deleted by accident if they are still online and active. func TestEphemeral2006DeletedTooQuickly(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -260,18 +256,21 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { // Wait a bit and bring up the clients again before the expiry // time of the ephemeral nodes. // Nodes should be able to reconnect and work fine. - time.Sleep(30 * time.Second) - for _, client := range allClients { err := client.Up() if err != nil { t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) } } - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - success = pingAllHelper(t, allClients, allAddrs) + // Wait for clients to sync and be able to ping each other after reconnection + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = scenario.WaitForTailscaleSync() + assert.NoError(ct, err) + + success = pingAllHelper(t, allClients, allAddrs) + assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping") + }, 60*time.Second, 2*time.Second) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) // Take down all clients, this should start an expiry timer for each. @@ -284,7 +283,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { // This time wait for all of the nodes to expire and check that they are no longer // registered. - time.Sleep(3 * time.Minute) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, userName := range spec.Users { + nodes, err := headscale.ListNodes(userName) + assert.NoError(ct, err) + assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName) + } + }, 4*time.Minute, 10*time.Second) for _, userName := range spec.Users { nodes, err := headscale.ListNodes(userName) @@ -305,7 +310,6 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -341,20 +345,6 @@ func TestPingAllByHostname(t *testing.T) { // nolint:tparallel func TestTaildrop(t *testing.T) { IntegrationSkip(t) - t.Parallel() - - retry := func(times int, sleepInterval time.Duration, doWork func() error) error { - var err error - for range times { - err = doWork() - if err == nil { - return nil - } - time.Sleep(sleepInterval) - } - - return err - } spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -396,40 +386,27 @@ func TestTaildrop(t *testing.T) { "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets", } - err = retry(10, 1*time.Second, func() error { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { result, _, err := client.Execute(curlCommand) - if err != nil { - return err - } + assert.NoError(ct, err) + var fts []apitype.FileTarget err = json.Unmarshal([]byte(result), &fts) - if err != nil { - return err - } + assert.NoError(ct, err) if len(fts) != len(allClients)-1 { ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname()) for _, ft := range fts { ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) } - return fmt.Errorf( - "client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", - client.Hostname(), + assert.Failf(ct, "client %s does not have all its peers as FileTargets", + "got %d, want: %d\n%s", len(fts), len(allClients)-1, ftStr, ) } - - return err - }) - if err != nil { - t.Errorf( - "failed to query localapi for filetarget on %s, err: %s", - client.Hostname(), - err, - ) - } + }, 10*time.Second, 1*time.Second) } for _, client := range allClients { @@ -454,24 +431,15 @@ func TestTaildrop(t *testing.T) { fmt.Sprintf("%s:", peerFQDN), } - err := retry(10, 1*time.Second, func() error { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { t.Logf( "Sending file from %s to %s\n", client.Hostname(), peer.Hostname(), ) _, _, err := client.Execute(command) - - return err - }) - if err != nil { - t.Fatalf( - "failed to send taildrop file on %s with command %q, err: %s", - client.Hostname(), - strings.Join(command, " "), - err, - ) - } + assert.NoError(ct, err) + }, 10*time.Second, 1*time.Second) }) } } @@ -520,7 +488,6 @@ func TestTaildrop(t *testing.T) { func TestUpdateHostnameFromClient(t *testing.T) { IntegrationSkip(t) - t.Parallel() hostnames := map[string]string{ "1": "user1-host", @@ -603,9 +570,47 @@ func TestUpdateHostnameFromClient(t *testing.T) { assertNoErr(t, err) } - time.Sleep(5 * time.Second) + // Verify that the server-side rename is reflected in DNSName while HostName remains unchanged + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Build a map of expected DNSNames by node ID + expectedDNSNames := make(map[string]string) + for _, node := range nodes { + nodeID := strconv.FormatUint(node.GetId(), 10) + expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId()) + } + + // Verify from each client's perspective + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + + // Check self node + selfID := string(status.Self.ID) + expectedDNS := expectedDNSNames[selfID] + assert.Equal(ct, expectedDNS, status.Self.DNSName, + "Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID) + + // HostName should remain as the original client-reported hostname + originalHostname := hostnames[selfID] + assert.Equal(ct, originalHostname, status.Self.HostName, + "Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID) + + // Check peers + for _, peer := range status.Peer { + peerID := string(peer.ID) + if expectedDNS, ok := expectedDNSNames[peerID]; ok { + assert.Equal(ct, expectedDNS, peer.DNSName, + "Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname()) + + // HostName should remain as the original client-reported hostname + originalHostname := hostnames[peerID] + assert.Equal(ct, originalHostname, peer.HostName, + "Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname()) + } + } + } + }, 60*time.Second, 2*time.Second) - // Verify that the clients can see the new hostname, but no givenName for _, client := range allClients { status, err := client.Status() assertNoErr(t, err) @@ -647,7 +652,6 @@ func TestUpdateHostnameFromClient(t *testing.T) { func TestExpireNode(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -707,7 +711,23 @@ func TestExpireNode(t *testing.T) { t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) - time.Sleep(2 * time.Minute) + // Verify that the expired node has been marked in all peers list. + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + + if client.Hostname() != node.GetName() { + // Check if the expired node appears as expired in this client's peer list + for key, peer := range status.Peer { + if key == expiredNodeKey { + assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname()) + break + } + } + } + } + }, 3*time.Minute, 10*time.Second) now := time.Now() @@ -774,7 +794,6 @@ func TestExpireNode(t *testing.T) { func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -890,7 +909,6 @@ func TestNodeOnlineStatus(t *testing.T) { // five times ensuring they are able to restablish connectivity. func TestPingAllByIPManyUpDown(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -944,8 +962,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) { t.Fatalf("failed to take down all nodes: %s", err) } - time.Sleep(5 * time.Second) - for _, client := range allClients { c := client wg.Go(func() error { @@ -958,10 +974,14 @@ func TestPingAllByIPManyUpDown(t *testing.T) { t.Fatalf("failed to take down all nodes: %s", err) } - time.Sleep(5 * time.Second) + // Wait for sync and successful pings after nodes come back up + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = scenario.WaitForTailscaleSync() + assert.NoError(ct, err) - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + success := pingAllHelper(t, allClients, allAddrs) + assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up") + }, 30*time.Second, 2*time.Second) success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) @@ -970,7 +990,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) { func Test2118DeletingOnlineNodePanics(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1042,10 +1061,24 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { ) require.NoError(t, err) - time.Sleep(2 * time.Second) - // Ensure that the node has been deleted, this did not occur due to a panic. var nodeListAfter []v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeListAfter, + ) + assert.NoError(ct, err) + assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list") + }, 10*time.Second, 1*time.Second) + err = executeAndUnmarshal( headscale, []string{ diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 9c6816fa..c300a205 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -191,7 +191,7 @@ func WithPostgres() Option { } } -// WithPolicy sets the policy mode for headscale +// WithPolicy sets the policy mode for headscale. func WithPolicyMode(mode types.PolicyMode) Option { return func(hsic *HeadscaleInContainer) { hsic.policyMode = mode @@ -279,7 +279,7 @@ func New( return nil, err } - hostname := fmt.Sprintf("hs-%s", hash) + hostname := "hs-" + hash hsic := &HeadscaleInContainer{ hostname: hostname, @@ -308,14 +308,14 @@ func New( if hsic.postgres { hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres" - hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = fmt.Sprintf("postgres-%s", hash) + hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale" delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH") pgRunOptions := &dockertest.RunOptions{ - Name: fmt.Sprintf("postgres-%s", hash), + Name: "postgres-" + hash, Repository: "postgres", Tag: "latest", Networks: networks, @@ -328,7 +328,7 @@ func New( // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres") - + pg, err := pool.RunWithOptions(pgRunOptions) if err != nil { return nil, fmt.Errorf("starting postgres container: %w", err) @@ -373,7 +373,6 @@ func New( Env: env, } - if len(hsic.hostPortBindings) > 0 { runOptions.PortBindings = map[docker.Port][]docker.PortBinding{} for port, hostPorts := range hsic.hostPortBindings { @@ -396,7 +395,7 @@ func New( // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale") - + container, err := pool.BuildAndRunWithBuildOptions( headscaleBuildOptions, runOptions, @@ -566,7 +565,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { // extractTarToDirectory extracts a tar archive to a directory. func extractTarToDirectory(tarData []byte, targetDir string) error { - if err := os.MkdirAll(targetDir, 0755); err != nil { + if err := os.MkdirAll(targetDir, 0o755); err != nil { return fmt.Errorf("failed to create directory %s: %w", targetDir, err) } @@ -624,6 +623,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error { } targetDir := path.Join(savePath, t.hostname+"-pprof") + return extractTarToDirectory(tarFile, targetDir) } @@ -634,6 +634,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { } targetDir := path.Join(savePath, t.hostname+"-mapresponses") + return extractTarToDirectory(tarFile, targetDir) } @@ -672,17 +673,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { if err != nil { return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err) } - + if strings.TrimSpace(schemaCheck) == "" { - return fmt.Errorf("database file exists but has no schema (empty database)") + return errors.New("database file exists but has no schema (empty database)") } - + // Show a preview of the schema (first 500 chars) schemaPreview := schemaCheck if len(schemaPreview) > 500 { schemaPreview = schemaPreview[:500] + "..." } - log.Printf("Database schema preview:\n%s", schemaPreview) tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") if err != nil { @@ -727,7 +727,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } } - return fmt.Errorf("no regular file found in database tar archive") + return errors.New("no regular file found in database tar archive") } // Execute runs a command inside the Headscale container and returns the @@ -756,13 +756,13 @@ func (t *HeadscaleInContainer) Execute( // GetPort returns the docker container port as a string. func (t *HeadscaleInContainer) GetPort() string { - return fmt.Sprintf("%d", t.port) + return strconv.Itoa(t.port) } // GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer // instance. func (t *HeadscaleInContainer) GetHealthEndpoint() string { - return fmt.Sprintf("%s/health", t.GetEndpoint()) + return t.GetEndpoint() + "/health" } // GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer. @@ -772,10 +772,10 @@ func (t *HeadscaleInContainer) GetEndpoint() string { t.port) if t.hasTLS() { - return fmt.Sprintf("https://%s", hostEndpoint) + return "https://" + hostEndpoint } - return fmt.Sprintf("http://%s", hostEndpoint) + return "http://" + hostEndpoint } // GetCert returns the public certificate of the HeadscaleInContainer. @@ -910,6 +910,7 @@ func (t *HeadscaleInContainer) ListNodes( } ret = append(ret, nodes...) + return nil } @@ -932,6 +933,7 @@ func (t *HeadscaleInContainer) ListNodes( sort.Slice(ret, func(i, j int) bool { return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 }) + return ret, nil } @@ -943,10 +945,10 @@ func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) { var userMap map[string][]*v1.Node for _, node := range nodes { - if _, ok := userMap[node.User.Name]; !ok { - mak.Set(&userMap, node.User.Name, []*v1.Node{node}) + if _, ok := userMap[node.GetUser().GetName()]; !ok { + mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node}) } else { - userMap[node.User.Name] = append(userMap[node.User.Name], node) + userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node) } } @@ -999,7 +1001,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) { var userMap map[string]*v1.User for _, user := range users { - mak.Set(&userMap, user.Name, user) + mak.Set(&userMap, user.GetName(), user) } return userMap, nil @@ -1095,7 +1097,7 @@ func (h *HeadscaleInContainer) PID() (int, error) { case 1: return pids[0], nil default: - return 0, fmt.Errorf("multiple headscale processes running") + return 0, errors.New("multiple headscale processes running") } } @@ -1121,7 +1123,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( "headscale", "nodes", "approve-routes", "--output", "json", "--identifier", strconv.FormatUint(id, 10), - fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")), + "--routes=" + strings.Join(util.PrefixesToString(routes), ","), } result, _, err := dockertestutil.ExecuteCommand( diff --git a/integration/route_test.go b/integration/route_test.go index 053b4582..64677aec 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -4,13 +4,12 @@ import ( "encoding/json" "fmt" "net/netip" + "slices" "sort" "strings" "testing" "time" - "slices" - cmpdiff "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -37,7 +36,6 @@ var allPorts = filter.PortRange{First: 0, Last: 0xffff} // routes. func TestEnablingRoutes(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 3, @@ -182,11 +180,12 @@ func TestEnablingRoutes(t *testing.T) { for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] - if peerStatus.ID == "1" { + switch peerStatus.ID { + case "1": requirePeerSubnetRoutes(t, peerStatus, nil) - } else if peerStatus.ID == "2" { + case "2": requirePeerSubnetRoutes(t, peerStatus, nil) - } else { + default: requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")}) } } @@ -195,7 +194,6 @@ func TestEnablingRoutes(t *testing.T) { func TestHASubnetRouterFailover(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 3, @@ -779,7 +777,6 @@ func TestHASubnetRouterFailover(t *testing.T) { // https://github.com/juanfont/headscale/issues/1604 func TestSubnetRouteACL(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "user4" @@ -1003,7 +1000,6 @@ func TestSubnetRouteACL(t *testing.T) { // set during login instead of set. func TestEnablingExitRoutes(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "user2" @@ -1097,7 +1093,6 @@ func TestEnablingExitRoutes(t *testing.T) { // subnet router is working as expected. func TestSubnetRouterMultiNetwork(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1177,7 +1172,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Enable route _, err = headscale.ApproveRoutes( - nodes[0].Id, + nodes[0].GetId(), []netip.Prefix{*pref}, ) require.NoError(t, err) @@ -1224,7 +1219,6 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1300,7 +1294,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { } // Enable route - _, err = headscale.ApproveRoutes(nodes[0].Id, []netip.Prefix{tsaddr.AllIPv4()}) + _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) require.NoError(t, err) time.Sleep(5 * time.Second) @@ -1719,7 +1713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) assertNoErr(t, err) - err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key) + err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) assertNoErr(t, err) } // extra creation end. @@ -2065,7 +2059,6 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub // that are explicitly allowed in the ACL. func TestSubnetRouteACLFiltering(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Use router and node users for better clarity routerUser := "router" @@ -2090,7 +2083,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) // Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24) - aclPolicyStr := fmt.Sprintf(`{ + aclPolicyStr := `{ "hosts": { "router": "100.64.0.1/32", "node": "100.64.0.2/32" @@ -2115,7 +2108,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { ] } ] - }`) + }` route, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) diff --git a/integration/scenario.go b/integration/scenario.go index 358291ff..b235cf34 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -123,7 +123,7 @@ type ScenarioSpec struct { // NodesPerUser is how many nodes should be attached to each user. NodesPerUser int - // Networks, if set, is the seperate Docker networks that should be + // Networks, if set, is the separate Docker networks that should be // created and a list of the users that should be placed in those networks. // If not set, a single network will be created and all users+nodes will be // added there. @@ -1077,7 +1077,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) - hostname := fmt.Sprintf("hs-oidcmock-%s", hash) + hostname := "hs-oidcmock-" + hash usersJSON, err := json.Marshal(users) if err != nil { @@ -1093,16 +1093,15 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse }, Networks: s.Networks(), Env: []string{ - fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), + "MOCKOIDC_ADDR=" + hostname, fmt.Sprintf("MOCKOIDC_PORT=%d", port), "MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_SECRET=supersecret", - fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), - fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), + "MOCKOIDC_ACCESS_TTL=" + accessTTL.String(), + "MOCKOIDC_USERS=" + string(usersJSON), }, } - headscaleBuildOptions := &dockertest.BuildOptions{ Dockerfile: hsic.IntegrationTestDockerFileName, ContextDir: dockerContextPath, @@ -1117,7 +1116,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") - + if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( headscaleBuildOptions, mockOidcOptions, @@ -1184,7 +1183,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) - hostname := fmt.Sprintf("hs-webservice-%s", hash) + hostname := "hs-webservice-" + hash network, ok := s.networks[s.prefixedNetworkName(networkName)] if !ok { diff --git a/integration/scenario_test.go b/integration/scenario_test.go index ac0ff238..ead3f1fd 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -28,7 +28,6 @@ func IntegrationSkip(t *testing.T) { // nolint:tparallel func TestHeadscale(t *testing.T) { IntegrationSkip(t) - t.Parallel() var err error @@ -75,7 +74,6 @@ func TestHeadscale(t *testing.T) { // nolint:tparallel func TestTailscaleNodesJoiningHeadcale(t *testing.T) { IntegrationSkip(t) - t.Parallel() var err error diff --git a/integration/ssh_test.go b/integration/ssh_test.go index cf08613d..236aba20 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -22,35 +22,6 @@ func isSSHNoAccessStdError(stderr string) bool { strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node") } -var retry = func(times int, sleepInterval time.Duration, - doWork func() (string, string, error), -) (string, string, error) { - var result string - var stderr string - var err error - - for range times { - tempResult, tempStderr, err := doWork() - - result += tempResult - stderr += tempStderr - - if err == nil { - return result, stderr, nil - } - - // If we get a permission denied error, we can fail immediately - // since that is something we won-t recover from by retrying. - if err != nil && isSSHNoAccessStdError(stderr) { - return result, stderr, err - } - - time.Sleep(sleepInterval) - } - - return result, stderr, err -} - func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario { t.Helper() @@ -92,7 +63,6 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce func TestSSHOneUserToAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -160,7 +130,6 @@ func TestSSHOneUserToAll(t *testing.T) { func TestSSHMultipleUsersAllToAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -216,7 +185,6 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { func TestSSHNoSSHConfigured(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -261,7 +229,6 @@ func TestSSHNoSSHConfigured(t *testing.T) { func TestSSHIsBlockedInACL(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -313,7 +280,6 @@ func TestSSHIsBlockedInACL(t *testing.T) { func TestSSHUserOnlyIsolation(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -404,6 +370,14 @@ func TestSSHUserOnlyIsolation(t *testing.T) { } func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + return doSSHWithRetry(t, client, peer, true) +} + +func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + return doSSHWithRetry(t, client, peer, false) +} + +func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) { t.Helper() peerFQDN, _ := peer.FQDN() @@ -417,9 +391,29 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname()) log.Printf("Command: %s", strings.Join(command, " ")) - return retry(10, 1*time.Second, func() (string, string, error) { - return client.Execute(command) - }) + var result, stderr string + var err error + + if retry { + // Use assert.EventuallyWithT to retry SSH connections for success cases + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, stderr, err = client.Execute(command) + + // If we get a permission denied error, we can fail immediately + // since that is something we won't recover from by retrying. + if err != nil && isSSHNoAccessStdError(stderr) { + return // Don't retry permission denied errors + } + + // For all other errors, assert no error to trigger retry + assert.NoError(ct, err) + }, 10*time.Second, 1*time.Second) + } else { + // For failure cases, just execute once + result, stderr, err = client.Execute(command) + } + + return result, stderr, err } func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) { @@ -434,7 +428,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() - result, stderr, err := doSSH(t, client, peer) + result, stderr, err := doSSHWithoutRetry(t, client, peer) assert.Empty(t, result) @@ -444,7 +438,7 @@ func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer Tailsc func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() - result, stderr, _ := doSSH(t, client, peer) + result, stderr, _ := doSSHWithoutRetry(t, client, peer) assert.Empty(t, result) diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index d2738c55..3e4847eb 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -251,7 +251,6 @@ func New( Env: []string{}, } - if tsic.withWebsocketDERP { if version != VersionHead { return tsic, errInvalidClientConfig @@ -463,7 +462,7 @@ func (t *TailscaleInContainer) buildLoginCommand( if len(t.withTags) > 0 { command = append(command, - fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")), + "--advertise-tags="+strings.Join(t.withTags, ","), ) } @@ -685,7 +684,7 @@ func (t *TailscaleInContainer) MustID() types.NodeID { // Panics if version is lower then minimum. func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { - panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version)) + panic("tsic.Netmap() called with unsupported version: " + t.version) } command := []string{ @@ -1026,7 +1025,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err "tailscale", "ping", fmt.Sprintf("--timeout=%s", args.timeout), fmt.Sprintf("--c=%d", args.count), - fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)), + "--until-direct=" + strconv.FormatBool(args.direct), } command = append(command, hostnameOrIP) @@ -1131,11 +1130,11 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err command := []string{ "curl", "--silent", - "--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())), - "--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())), - "--retry", fmt.Sprintf("%d", args.retry), - "--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())), - "--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())), + "--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())), + "--max-time", strconv.Itoa(int(args.maxTime.Seconds())), + "--retry", strconv.Itoa(args.retry), + "--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())), + "--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())), url, } @@ -1230,7 +1229,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) { } if out.Len() == 0 { - return nil, fmt.Errorf("file is empty") + return nil, errors.New("file is empty") } return out.Bytes(), nil @@ -1259,5 +1258,6 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { if err = json.Unmarshal(currentProfile, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err) } + return &p.Persist.PrivateNodeKey, nil } diff --git a/integration/utils.go b/integration/utils.go index bcf488e2..c19f6459 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -3,7 +3,6 @@ package integration import ( "bufio" "bytes" - "context" "fmt" "io" "net/netip" @@ -267,7 +266,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) { // This isn't really relevant for Self as it won't be in its own socket/wireguard. // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) - // assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) for _, peer := range status.Peer { assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) @@ -311,7 +310,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) { func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { t.Helper() - _, err := backoff.Retry(context.Background(), func() (struct{}, error) { + _, err := backoff.Retry(t.Context(), func() (struct{}, error) { stdout, stderr, err := c.Execute(command) if err != nil { return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) @@ -492,6 +491,7 @@ func groupApprover(name string) policyv2.AutoApprover { func tagApprover(name string) policyv2.AutoApprover { return ptr.To(policyv2.Tag(name)) } + // // // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus // // if there is a peer with the given hostname. If no peer is found, nil is returned.