integration: replace time.Sleep with assert.EventuallyWithT (#2680)

This commit is contained in:
Kristoffer Dalby 2025-07-10 23:38:55 +02:00 committed by GitHub
parent b904276f2b
commit c6d7b512bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 584 additions and 573 deletions

View File

@ -48,5 +48,4 @@ jobs:
- name: Deploy stable docs from tag - name: Deploy stable docs from tag
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
# This assumes that only newer tags are pushed # This assumes that only newer tags are pushed
run: run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest
mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest

View File

@ -75,7 +75,7 @@ jobs:
# Some of the jobs might still require manual restart as they are really # Some of the jobs might still require manual restart as they are really
# slow and this will cause them to eventually be killed by Github actions. # slow and this will cause them to eventually be killed by Github actions.
attempt_delay: 300000 # 5 min attempt_delay: 300000 # 5 min
attempt_limit: 3 attempt_limit: 2
command: | command: |
nix develop --command -- hi run "^${{ inputs.test }}$" \ nix develop --command -- hi run "^${{ inputs.test }}$" \
--timeout=120m \ --timeout=120m \

View File

@ -36,8 +36,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
if: steps.changed-files.outputs.files == 'true' if: steps.changed-files.outputs.files == 'true'
run: run: nix develop --command -- golangci-lint run
nix develop --command -- golangci-lint run
--new-from-rev=${{github.event.pull_request.base.sha}} --new-from-rev=${{github.event.pull_request.base.sha}}
--format=colored-line-number --format=colored-line-number
@ -75,8 +74,7 @@ jobs:
- name: Prettify code - name: Prettify code
if: steps.changed-files.outputs.files == 'true' if: steps.changed-files.outputs.files == 'true'
run: run: nix develop --command -- prettier --no-error-on-unmatched-pattern
nix develop --command -- prettier --no-error-on-unmatched-pattern
--ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html} --ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
proto-lint: proto-lint:

View File

@ -117,7 +117,7 @@ var createNodeCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()), "Cannot create node: "+status.Convert(err).Message(),
output, output,
) )
} }

View File

@ -2,6 +2,7 @@ package cli
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -68,7 +69,7 @@ func mockOIDC() error {
userStr := os.Getenv("MOCKOIDC_USERS") userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" { if userStr == "" {
return fmt.Errorf("MOCKOIDC_USERS not defined") return errors.New("MOCKOIDC_USERS not defined")
} }
var users []mockoidc.MockUser var users []mockoidc.MockUser

View File

@ -184,7 +184,7 @@ var listNodesCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), "Cannot get nodes: "+status.Convert(err).Message(),
output, output,
) )
} }
@ -398,10 +398,7 @@ var deleteNodeCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Error getting node node: "+status.Convert(err).Message(),
"Error getting node node: %s",
status.Convert(err).Message(),
),
output, output,
) )
@ -437,10 +434,7 @@ var deleteNodeCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Error deleting node: "+status.Convert(err).Message(),
"Error deleting node: %s",
status.Convert(err).Message(),
),
output, output,
) )
@ -498,10 +492,7 @@ var moveNodeCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Error getting node: "+status.Convert(err).Message(),
"Error getting node: %s",
status.Convert(err).Message(),
),
output, output,
) )
@ -517,10 +508,7 @@ var moveNodeCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Error moving node: "+status.Convert(err).Message(),
"Error moving node: %s",
status.Convert(err).Message(),
),
output, output,
) )
@ -567,10 +555,7 @@ be assigned to nodes.`,
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Error backfilling IPs: "+status.Convert(err).Message(),
"Error backfilling IPs: %s",
status.Convert(err).Message(),
),
output, output,
) )

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv"
survey "github.com/AlecAivazis/survey/v2" survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -27,10 +28,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
err := errors.New("--name or --identifier flag is required") err := errors.New("--name or --identifier flag is required")
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Cannot rename user: "+status.Convert(err).Message(),
"Cannot rename user: %s",
status.Convert(err).Message(),
),
"", "",
) )
} }
@ -114,10 +112,7 @@ var createUserCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Cannot create user: "+status.Convert(err).Message(),
"Cannot create user: %s",
status.Convert(err).Message(),
),
output, output,
) )
} }
@ -147,16 +142,16 @@ var destroyUserCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()), "Error: "+status.Convert(err).Message(),
output, output,
) )
} }
if len(users.GetUsers()) != 1 { if len(users.GetUsers()) != 1 {
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID") err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()), "Error: "+status.Convert(err).Message(),
output, output,
) )
} }
@ -185,10 +180,7 @@ var destroyUserCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Cannot destroy user: "+status.Convert(err).Message(),
"Cannot destroy user: %s",
status.Convert(err).Message(),
),
output, output,
) )
} }
@ -233,7 +225,7 @@ var listUsersCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()), "Cannot get users: "+status.Convert(err).Message(),
output, output,
) )
} }
@ -247,7 +239,7 @@ var listUsersCmd = &cobra.Command{
tableData = append( tableData = append(
tableData, tableData,
[]string{ []string{
fmt.Sprintf("%d", user.GetId()), strconv.FormatUint(user.GetId(), 10),
user.GetDisplayName(), user.GetDisplayName(),
user.GetName(), user.GetName(),
user.GetEmail(), user.GetEmail(),
@ -287,16 +279,16 @@ var renameUserCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()), "Error: "+status.Convert(err).Message(),
output, output,
) )
} }
if len(users.GetUsers()) != 1 { if len(users.GetUsers()) != 1 {
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID") err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()), "Error: "+status.Convert(err).Message(),
output, output,
) )
} }
@ -312,10 +304,7 @@ var renameUserCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( "Cannot rename user: "+status.Convert(err).Message(),
"Cannot rename user: %s",
status.Convert(err).Message(),
),
output, output,
) )
} }

View File

@ -66,7 +66,7 @@ func killTestContainers(ctx context.Context) error {
if cont.State == "running" { if cont.State == "running" {
_ = cli.ContainerKill(ctx, cont.ID, "KILL") _ = cli.ContainerKill(ctx, cont.ID, "KILL")
} }
// Then remove the container with retry logic // Then remove the container with retry logic
if removeContainerWithRetry(ctx, cli, cont.ID) { if removeContainerWithRetry(ctx, cli, cont.ID) {
removed++ removed++
@ -87,25 +87,25 @@ func killTestContainers(ctx context.Context) error {
func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool { func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool {
maxRetries := 3 maxRetries := 3
baseDelay := 100 * time.Millisecond baseDelay := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := range maxRetries {
err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{ err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
Force: true, Force: true,
}) })
if err == nil { if err == nil {
return true return true
} }
// If this is the last attempt, don't wait // If this is the last attempt, don't wait
if attempt == maxRetries-1 { if attempt == maxRetries-1 {
break break
} }
// Wait with exponential backoff // Wait with exponential backoff
delay := baseDelay * time.Duration(1<<attempt) delay := baseDelay * time.Duration(1<<attempt)
time.Sleep(delay) time.Sleep(delay)
} }
return false return false
} }

View File

@ -156,10 +156,10 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
projectRoot := findProjectRoot(pwd) projectRoot := findProjectRoot(pwd)
runID := dockertestutil.ExtractRunIDFromContainerName(containerName) runID := dockertestutil.ExtractRunIDFromContainerName(containerName)
env := []string{ env := []string{
fmt.Sprintf("HEADSCALE_INTEGRATION_POSTGRES=%d", boolToInt(config.UsePostgres)), fmt.Sprintf("HEADSCALE_INTEGRATION_POSTGRES=%d", boolToInt(config.UsePostgres)),
fmt.Sprintf("HEADSCALE_INTEGRATION_RUN_ID=%s", runID), "HEADSCALE_INTEGRATION_RUN_ID=" + runID,
} }
containerConfig := &container.Config{ containerConfig := &container.Config{
Image: "golang:" + config.GoVersion, Image: "golang:" + config.GoVersion,
@ -175,7 +175,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
// Get the correct Docker socket path from the current context // Get the correct Docker socket path from the current context
dockerSocketPath := getDockerSocketPath() dockerSocketPath := getDockerSocketPath()
if config.Verbose { if config.Verbose {
log.Printf("Using Docker socket: %s", dockerSocketPath) log.Printf("Using Docker socket: %s", dockerSocketPath)
} }
@ -184,7 +184,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
AutoRemove: false, // We'll remove manually for better control AutoRemove: false, // We'll remove manually for better control
Binds: []string{ Binds: []string{
fmt.Sprintf("%s:%s", projectRoot, projectRoot), fmt.Sprintf("%s:%s", projectRoot, projectRoot),
fmt.Sprintf("%s:/var/run/docker.sock", dockerSocketPath), dockerSocketPath + ":/var/run/docker.sock",
logsDir + ":/tmp/control", logsDir + ":/tmp/control",
}, },
Mounts: []mount.Mount{ Mounts: []mount.Mount{
@ -237,7 +237,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
} }
testContainers := getCurrentTestContainers(containers, testContainerID, verbose) testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
// Wait for all test containers to reach a final state // Wait for all test containers to reach a final state
maxWaitTime := 10 * time.Second maxWaitTime := 10 * time.Second
checkInterval := 500 * time.Millisecond checkInterval := 500 * time.Millisecond
@ -254,7 +254,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
return nil return nil
case <-ticker.C: case <-ticker.C:
allFinalized := true allFinalized := true
for _, testCont := range testContainers { for _, testCont := range testContainers {
inspect, err := cli.ContainerInspect(ctx, testCont.ID) inspect, err := cli.ContainerInspect(ctx, testCont.ID)
if err != nil { if err != nil {
@ -263,17 +263,18 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
} }
continue continue
} }
// Check if container is in a final state // Check if container is in a final state
if !isContainerFinalized(inspect.State) { if !isContainerFinalized(inspect.State) {
allFinalized = false allFinalized = false
if verbose { if verbose {
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
} }
break break
} }
} }
if allFinalized { if allFinalized {
if verbose { if verbose {
log.Printf("All test containers finalized, ready for artifact extraction") log.Printf("All test containers finalized, ready for artifact extraction")
@ -290,7 +291,6 @@ func isContainerFinalized(state *container.State) bool {
return !state.Running && state.FinishedAt != "" return !state.Running && state.FinishedAt != ""
} }
// findProjectRoot locates the project root by finding the directory containing go.mod. // findProjectRoot locates the project root by finding the directory containing go.mod.
func findProjectRoot(startPath string) string { func findProjectRoot(startPath string) string {
current := startPath current := startPath
@ -427,7 +427,7 @@ func listControlFiles(logsDir string) {
} }
if entry.IsDir() { if entry.IsDir() {
// Include directories (pprof, mapresponses) // Include directories (pprof, mapresponses)
if strings.Contains(name, "-pprof") || strings.Contains(name, "-mapresponses") { if strings.Contains(name, "-pprof") || strings.Contains(name, "-mapresponses") {
dataDirs = append(dataDirs, name) dataDirs = append(dataDirs, name)
} }
@ -510,7 +510,7 @@ type testContainer struct {
// getCurrentTestContainers filters containers to only include those from the current test run. // getCurrentTestContainers filters containers to only include those from the current test run.
func getCurrentTestContainers(containers []container.Summary, testContainerID string, verbose bool) []testContainer { func getCurrentTestContainers(containers []container.Summary, testContainerID string, verbose bool) []testContainer {
var testRunContainers []testContainer var testRunContainers []testContainer
// Find the test container to get its run ID label // Find the test container to get its run ID label
var runID string var runID string
for _, cont := range containers { for _, cont := range containers {
@ -521,16 +521,16 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
break break
} }
} }
if runID == "" { if runID == "" {
log.Printf("Error: test container %s missing required hi.run-id label", testContainerID[:12]) log.Printf("Error: test container %s missing required hi.run-id label", testContainerID[:12])
return testRunContainers return testRunContainers
} }
if verbose { if verbose {
log.Printf("Looking for containers with run ID: %s", runID) log.Printf("Looking for containers with run ID: %s", runID)
} }
// Find all containers with the same run ID // Find all containers with the same run ID
for _, cont := range containers { for _, cont := range containers {
for _, name := range cont.Names { for _, name := range cont.Names {
@ -546,18 +546,19 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
log.Printf("Including container %s (run ID: %s)", containerName, runID) log.Printf("Including container %s (run ID: %s)", containerName, runID)
} }
} }
break break
} }
} }
} }
return testRunContainers return testRunContainers
} }
// extractContainerArtifacts saves logs and tar files from a container. // extractContainerArtifacts saves logs and tar files from a container.
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Ensure the logs directory exists // Ensure the logs directory exists
if err := os.MkdirAll(logsDir, 0755); err != nil { if err := os.MkdirAll(logsDir, 0o755); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err) return fmt.Errorf("failed to create logs directory: %w", err)
} }
@ -608,12 +609,12 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
} }
// Write stdout logs // Write stdout logs
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0644); err != nil { if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stdout log: %w", err) return fmt.Errorf("failed to write stdout log: %w", err)
} }
// Write stderr logs // Write stderr logs
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0644); err != nil { if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stderr log: %w", err) return fmt.Errorf("failed to write stderr log: %w", err)
} }
@ -626,7 +627,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
// extractContainerFiles extracts database file and directories from headscale containers. // extractContainerFiles extracts database file and directories from headscale containers.
// Note: The actual file extraction is now handled by the integration tests themselves // Note: The actual file extraction is now handled by the integration tests themselves
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go // via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go.
func extractContainerFiles(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { func extractContainerFiles(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Files are now extracted directly by the integration tests // Files are now extracted directly by the integration tests
// This function is kept for potential future use or other file types // This function is kept for potential future use or other file types
@ -677,7 +678,7 @@ func extractDirectory(ctx context.Context, cli *client.Client, containerID, sour
// Create target directory // Create target directory
targetDir := filepath.Join(logsDir, dirName) targetDir := filepath.Join(logsDir, dirName)
if err := os.MkdirAll(targetDir, 0755); err != nil { if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err) return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
} }

View File

@ -10,10 +10,8 @@ import (
"strings" "strings"
) )
var ( // ErrFileNotFoundInTar indicates a file was not found in the tar archive.
// ErrFileNotFoundInTar indicates a file was not found in the tar archive. var ErrFileNotFoundInTar = errors.New("file not found in tar")
ErrFileNotFoundInTar = errors.New("file not found in tar")
)
// extractFileFromTar extracts a single file from a tar reader. // extractFileFromTar extracts a single file from a tar reader.
func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error { func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error {
@ -42,6 +40,7 @@ func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error
if _, err := io.Copy(outFile, tr); err != nil { if _, err := io.Copy(outFile, tr); err != nil {
return fmt.Errorf("failed to copy file contents: %w", err) return fmt.Errorf("failed to copy file contents: %w", err)
} }
return nil return nil
} }
} }
@ -98,4 +97,4 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
} }
return nil return nil
} }

View File

@ -143,6 +143,7 @@
yq-go yq-go
ripgrep ripgrep
postgresql postgresql
traceroute
# 'dot' is needed for pprof graphs # 'dot' is needed for pprof graphs
# go tool pprof -http=: <source> # go tool pprof -http=: <source>

View File

@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode(
return nil, nil return nil, nil
} }
} }
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry) n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey( node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq, regReq,
machineKey, machineKey,
@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
} }
if perr, ok := err.(types.PAKError); ok { var perr types.PAKError
if errors.As(err, &perr) {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
} }
return nil, err return nil, err
} }

View File

@ -1,11 +1,10 @@
package capver package capver
import ( import (
"slices"
"sort" "sort"
"strings" "strings"
"slices"
xmaps "golang.org/x/exp/maps" xmaps "golang.org/x/exp/maps"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/util/set" "tailscale.com/util/set"

View File

@ -1,6 +1,6 @@
package capver package capver
//Generated DO NOT EDIT // Generated DO NOT EDIT
import "tailscale.com/tailcfg" import "tailscale.com/tailcfg"
@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.82.5": 115, "v1.82.5": 115,
} }
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
87: "v1.60.0", 87: "v1.60.0",
88: "v1.62.0", 88: "v1.62.0",
90: "v1.64.0", 90: "v1.64.0",
95: "v1.66.0", 95: "v1.66.0",
97: "v1.68.0", 97: "v1.68.0",
102: "v1.70.0", 102: "v1.70.0",
104: "v1.72.0", 104: "v1.72.0",
106: "v1.74.0", 106: "v1.74.0",
109: "v1.78.0", 109: "v1.78.0",
113: "v1.80.0", 113: "v1.80.0",
115: "v1.82.0", 115: "v1.82.0",
} }

View File

@ -764,13 +764,13 @@ AND auth_key_id NOT IN (
// Drop all indexes first to avoid conflicts // Drop all indexes first to avoid conflicts
indexesToDrop := []string{ indexesToDrop := []string{
"idx_users_deleted_at", "idx_users_deleted_at",
"idx_provider_identifier", "idx_provider_identifier",
"idx_name_provider_identifier", "idx_name_provider_identifier",
"idx_name_no_provider_identifier", "idx_name_no_provider_identifier",
"idx_api_keys_prefix", "idx_api_keys_prefix",
"idx_policies_deleted_at", "idx_policies_deleted_at",
} }
for _, index := range indexesToDrop { for _, index := range indexesToDrop {
_ = tx.Exec("DROP INDEX IF EXISTS " + index).Error _ = tx.Exec("DROP INDEX IF EXISTS " + index).Error
} }
@ -927,6 +927,7 @@ AND auth_key_id NOT IN (
} }
log.Info().Msg("Schema recreation completed successfully") log.Info().Msg("Schema recreation completed successfully")
return nil return nil
}, },
Rollback: func(db *gorm.DB) error { return nil }, Rollback: func(db *gorm.DB) error { return nil },

View File

@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
Avoid: false, Avoid: false,
Nodes: []*tailcfg.DERPNode{ Nodes: []*tailcfg.DERPNode{
{ {
Name: fmt.Sprintf("%d", d.cfg.ServerRegionID), Name: strconv.Itoa(d.cfg.ServerRegionID),
RegionID: d.cfg.ServerRegionID, RegionID: d.cfg.ServerRegionID,
HostName: host, HostName: host,
DERPPort: port, DERPPort: port,

View File

@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() {
return struct{}{}, nil return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff())) }, backoff.WithBackOff(backoff.NewExponentialBackOff()))
if err != nil { if err != nil {
log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete") log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete")
continue continue

View File

@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode(
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
log.Trace(). log.Trace().

View File

@ -32,7 +32,7 @@ const (
reservedResponseHeaderSize = 4 reservedResponseHeaderSize = 4
) )
// httpError logs an error and sends an HTTP error response with the given // httpError logs an error and sends an HTTP error response with the given.
func httpError(w http.ResponseWriter, err error) { func httpError(w http.ResponseWriter, err error) {
var herr HTTPError var herr HTTPError
if errors.As(err, &herr) { if errors.As(err, &herr) {
@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest(
resp := &tailcfg.DERPAdmitClientResponse{ resp := &tailcfg.DERPAdmitClientResponse{
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
} }
return json.NewEncoder(writer).Encode(resp) return json.NewEncoder(writer).Encode(resp)
} }

View File

@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
} }
// ListNodes queries the database for either all nodes if no parameters are given // ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter // or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...) nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil { if err != nil {

View File

@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
} }
} }
// mockState is a mock implementation that provides the required methods // mockState is a mock implementation that provides the required methods.
type mockState struct { type mockState struct {
polMan policy.PolicyManager polMan policy.PolicyManager
derpMap *tailcfg.DERPMap derpMap *tailcfg.DERPMap
@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
} }
} }
} }
return filtered, nil return filtered, nil
} }
// Return all peers except the node itself // Return all peers except the node itself
@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
filtered = append(filtered, peer) filtered = append(filtered, peer)
} }
} }
return filtered, nil return filtered, nil
} }
@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
} }
} }
} }
return filtered, nil return filtered, nil
} }
return m.nodes, nil return m.nodes, nil
} }

View File

@ -11,7 +11,7 @@ import (
"tailscale.com/types/views" "tailscale.com/types/views"
) )
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag // NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
type NodeCanHaveTagChecker interface { type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node types.NodeView, tag string) bool NodeCanHaveTag(node types.NodeView, tag string) bool
} }

View File

@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
} }
n, err := r.ResponseWriter.Write(b) n, err := r.ResponseWriter.Write(b)
r.written += int64(n) r.written += int64(n)
return n, err return n, err
} }

View File

@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
n.b = b n.b = b
go b.doWork() go b.doWork()
return n return n
} }
@ -72,7 +73,7 @@ func (n *Notifier) Close() {
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate) n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
} }
// safeCloseChannel closes a channel and panic recovers if already closed // safeCloseChannel closes a channel and panic recovers if already closed.
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) { func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok { if val, ok := n.connected.Load(nodeID); ok {
return val return val
} }
return false return false
} }
@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return false return false
} }
// LikelyConnectedMap returns a thread safe map of connected nodes // LikelyConnectedMap returns a thread safe map of connected nodes.
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected return n.connected
} }

View File

@ -1,17 +1,15 @@
package notifier package notifier
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"net/netip" "net/netip"
"slices"
"sort" "sort"
"sync" "sync"
"testing" "testing"
"time" "time"
"slices"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) {
defer n.RemoveNode(1, ch) defer n.RemoveNode(1, ch)
for _, u := range tt.updates { for _, u := range tt.updates {
n.NotifyAll(context.Background(), u) n.NotifyAll(t.Context(), u)
} }
n.b.flush() n.b.flush()
@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) {
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected // TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to // Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes // close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting // network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
func TestIsLikelyConnectedRaceCondition(t *testing.T) { func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier // mock config for the notifier
cfg := &types.Config{ cfg := &types.Config{
@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) {
for range iterations { for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected // Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node // while others add/remove the node
if routineID%3 == 0 { switch routineID % 3 {
case 0:
// This goroutine checks connection status // This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID) isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false { if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected) errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
} }
} else if routineID%3 == 1 { case 1:
// This goroutine removes the node // This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan) notifier.RemoveNode(nodeID, updateChan)
} else { default:
// This goroutine adds the node back // This goroutine adds the node back
notifier.AddNode(nodeID, updateChan) notifier.AddNode(nodeID, updateChan)
} }

View File

@ -84,11 +84,8 @@ func NewAuthProviderOIDC(
ClientID: cfg.ClientID, ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret, ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(), Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf( RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback",
"%s/oidc/callback", Scopes: cfg.Scope,
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
} }
registrationCache := zcache.New[string, RegistrationInfo]( registrationCache := zcache.New[string, RegistrationInfo](
@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
req *http.Request, req *http.Request,
) { ) {
vars := mux.Vars(req) vars := mux.Vars(req)
registrationIdStr, _ := vars["registration_id"] registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that // We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render // is passed as a key is not parsable/validated as a NodePublic key, then fail to render
@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
} }
oauth2Token, err := a.getOauth2Token(req.Context(), code, state) oauth2Token, err := a.getOauth2Token(req.Context(), code, state)
if err != nil { if err != nil {
httpError(writer, err) httpError(writer, err)
return return
@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning // Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node. // that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return return
} }
@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
} }
return oauth2Token, err return oauth2Token, err
} }

View File

@ -2,9 +2,8 @@ package matcher
import ( import (
"net/netip" "net/netip"
"strings"
"slices" "slices"
"strings"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
@ -28,6 +27,7 @@ func (m Match) DebugString() string {
for _, prefix := range m.dests.Prefixes() { for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n") sb.WriteString(" " + prefix.String() + "\n")
} }
return sb.String() return sb.String()
} }
@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
for _, rule := range rules { for _, rule := range rules {
matches = append(matches, MatchFromFilterRule(rule)) matches = append(matches, MatchFromFilterRule(rule))
} }
return matches return matches
} }

View File

@ -4,7 +4,6 @@ import (
"net/netip" "net/netip"
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"

View File

@ -5,7 +5,6 @@ import (
"slices" "slices"
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo" "github.com/samber/lo"
@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
// AutoApproveRoutes approves any route that can be autoapproved from // AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy. // the nodes perspective according to the given policy.
// It reports true if any routes were approved. // It reports true if any routes were approved.
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes // Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
if pm == nil { if pm == nil {
return false return false

View File

@ -7,9 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) {
} }
} }
} }
func TestReduceRoutes(t *testing.T) { func TestReduceRoutes(t *testing.T) {
type args struct { type args struct {
node *types.Node node *types.Node

View File

@ -13,9 +13,7 @@ import (
"tailscale.com/types/views" "tailscale.com/types/views"
) )
var ( var ErrInvalidAction = errors.New("invalid action")
ErrInvalidAction = errors.New("invalid action")
)
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a // compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
var destPorts []tailcfg.NetPortRange var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations { for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes) ips, err := dest.Resolve(pol, users, nodes)
if err != nil { if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips") log.Trace().Err(err).Msgf("resolving destination ips")
} }
@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
for _, pref := range ips.Prefixes() { for _, pref := range ips.Prefixes() {
out = append(out, pref.String()) out = append(out, pref.String())
} }
return out return out
} }

View File

@ -4,19 +4,17 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"strings" "strings"
"sync" "sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"slices"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/util/deephash"
"tailscale.com/types/views" "tailscale.com/types/views"
"tailscale.com/util/deephash"
) )
type PolicyManager struct { type PolicyManager struct {
@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
return pm.filter, pm.matchers return pm.filter, pm.matchers
} }
@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
pm.users = users pm.users = users
return pm.updateLocked() return pm.updateLocked()
} }
@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
pm.nodes = nodes pm.nodes = nodes
return pm.updateLocked() return pm.updateLocked()
} }
@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// cannot just lookup in the prefix map and have to check // cannot just lookup in the prefix map and have to check
// if there is a "parent" prefix available. // if there is a "parent" prefix available.
for prefix, approveAddrs := range pm.autoApproveMap { for prefix, approveAddrs := range pm.autoApproveMap {
// Check if prefix is larger (so containing) and then overlaps // Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover // the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {

View File

@ -1,10 +1,10 @@
package v2 package v2
import ( import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"

View File

@ -6,9 +6,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"slices" "slices"
"strconv"
"strings"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
// Check if it's the wildcard port range // Check if it's the wildcard port range
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 { if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
return json.Marshal(fmt.Sprintf("%s:*", alias)) return json.Marshal(alias + ":*")
} }
// Otherwise, format as "alias:ports" // Otherwise, format as "alias:ports"
var ports []string var ports []string
for _, port := range a.Ports { for _, port := range a.Ports {
if port.First == port.Last { if port.First == port.Last {
ports = append(ports, fmt.Sprintf("%d", port.First)) ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
} else { } else {
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last)) ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
} }
@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
if err := u.Validate(); err != nil { if err := u.Validate(); err != nil {
return err return err
} }
return nil return nil
} }
@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
return buildIPSetMultiErr(&ips, errs) return buildIPSetMultiErr(&ips, errs)
} }
// Group is a special string which is always prefixed with `group:` // Group is a special string which is always prefixed with `group:`.
type Group string type Group string
func (g Group) Validate() error { func (g Group) Validate() error {
@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
if err := g.Validate(); err != nil { if err := g.Validate(); err != nil {
return err return err
} }
return nil return nil
} }
@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
return buildIPSetMultiErr(&ips, errs) return buildIPSetMultiErr(&ips, errs)
} }
// Tag is a special string which is always prefixed with `tag:` // Tag is a special string which is always prefixed with `tag:`.
type Tag string type Tag string
func (t Tag) Validate() error { func (t Tag) Validate() error {
@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
if err := t.Validate(); err != nil { if err := t.Validate(); err != nil {
return err return err
} }
return nil return nil
} }
@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
if err := h.Validate(); err != nil { if err := h.Validate(); err != nil {
return err return err
} }
return nil return nil
} }
@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
} }
*p = Prefix(addrPref) *p = Prefix(addrPref)
return nil return nil
} }
@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
return err return err
} }
*p = Prefix(pref) *p = Prefix(pref)
return nil return nil
} }
@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
if err := p.Validate(); err != nil { if err := p.Validate(); err != nil {
return err return err
} }
return nil return nil
} }
@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
} }
} }
// AutoGroup is a special string which is always prefixed with `autogroup:` // AutoGroup is a special string which is always prefixed with `autogroup:`.
type AutoGroup string type AutoGroup string
const ( const (
@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
if err := ag.Validate(); err != nil { if err := ag.Validate(); err != nil {
return err return err
} }
return nil return nil
} }
@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil { if err != nil {
return err return err
} }
if err := ve.Alias.Validate(); err != nil { if err := ve.Validate(); err != nil {
return err return err
} }
default: default:
return fmt.Errorf("type %T not supported", vs) return fmt.Errorf("type %T not supported", vs)
} }
return nil return nil
} }
@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
return err return err
} }
ve.Alias = ptr ve.Alias = ptr
return nil return nil
} }
@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
for i, alias := range aliases { for i, alias := range aliases {
(*a)[i] = alias.Alias (*a)[i] = alias.Alias
} }
return nil return nil
} }
@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
return ips, multierr.New(append(errs, err)...) return ips, multierr.New(append(errs, err)...)
} }
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer // Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
func unmarshalPointer[T any]( func unmarshalPointer[T any](
b []byte, b []byte,
parseFunc func(string) (T, error), parseFunc func(string) (T, error),
@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
for i, autoApprover := range autoApprovers { for i, autoApprover := range autoApprovers {
(*aa)[i] = autoApprover.AutoApprover (*aa)[i] = autoApprover.AutoApprover
} }
return nil return nil
} }
@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
return err return err
} }
ve.AutoApprover = ptr ve.AutoApprover = ptr
return nil return nil
} }
@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
return err return err
} }
ve.Owner = ptr ve.Owner = ptr
return nil return nil
} }
@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
for i, owner := range owners { for i, owner := range owners {
(*o)[i] = owner.Owner (*o)[i] = owner.Owner
} }
return nil return nil
} }
@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
case isGroup(s): case isGroup(s):
return ptr.To(Group(s)), nil return ptr.To(Group(s)), nil
} }
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
- user (containing an "@") - user (containing an "@")
- group (starting with "group:") - group (starting with "group:")
@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
(*g)[group] = usernames (*g)[group] = usernames
} }
return nil return nil
} }
@ -1252,7 +1269,7 @@ type Policy struct {
// We use the default JSON marshalling behavior provided by the Go runtime. // We use the default JSON marshalling behavior provided by the Go runtime.
var ( var (
// TODO(kradalby): Add these checks for tagOwners and autoApprovers // TODO(kradalby): Add these checks for tagOwners and autoApprovers.
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged} autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
} }
if src.Is(AutoGroupInternet) { if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`) return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
} }
if !slices.Contains(autogroupForSrc, *src) { if !slices.Contains(autogroupForSrc, *src) {
@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
} }
if src.Is(AutoGroupInternet) { if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
} }
if !slices.Contains(autogroupForSSHSrc, *src) { if !slices.Contains(autogroupForSSHSrc, *src) {
@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
} }
if dst.Is(AutoGroupInternet) { if dst.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
} }
if !slices.Contains(autogroupForSSHDst, *dst) { if !slices.Contains(autogroupForSSHDst, *dst) {
@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
for _, acl := range p.ACLs { for _, acl := range p.ACLs {
for _, src := range acl.Sources { for _, src := range acl.Sources {
switch src.(type) { switch src := src.(type) {
case *Host: case *Host:
h := src.(*Host) h := src
if !p.Hosts.exist(*h) { if !p.Hosts.exist(*h) {
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
} }
case *AutoGroup: case *AutoGroup:
ag := src.(*AutoGroup) ag := src
if err := validateAutogroupSupported(ag); err != nil { if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err) errs = append(errs, err)
@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
continue continue
} }
case *Group: case *Group:
g := src.(*Group) g := src
if err := p.Groups.Contains(g); err != nil { if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
case *Tag: case *Tag:
tagOwner := src.(*Tag) tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil { if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
} }
for _, src := range ssh.Sources { for _, src := range ssh.Sources {
switch src.(type) { switch src := src.(type) {
case *AutoGroup: case *AutoGroup:
ag := src.(*AutoGroup) ag := src
if err := validateAutogroupSupported(ag); err != nil { if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err) errs = append(errs, err)
@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
continue continue
} }
case *Group: case *Group:
g := src.(*Group) g := src
if err := p.Groups.Contains(g); err != nil { if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
case *Tag: case *Tag:
tagOwner := src.(*Tag) tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil { if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
} }
} }
for _, dst := range ssh.Destinations { for _, dst := range ssh.Destinations {
switch dst.(type) { switch dst := dst.(type) {
case *AutoGroup: case *AutoGroup:
ag := dst.(*AutoGroup) ag := dst
if err := validateAutogroupSupported(ag); err != nil { if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err) errs = append(errs, err)
continue continue
@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
continue continue
} }
case *Tag: case *Tag:
tagOwner := dst.(*Tag) tagOwner := dst
if err := p.TagOwners.Contains(tagOwner); err != nil { if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
for _, tagOwners := range p.TagOwners { for _, tagOwners := range p.TagOwners {
for _, tagOwner := range tagOwners { for _, tagOwner := range tagOwners {
switch tagOwner.(type) { switch tagOwner := tagOwner.(type) {
case *Group: case *Group:
g := tagOwner.(*Group) g := tagOwner
if err := p.Groups.Contains(g); err != nil { if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
for _, approvers := range p.AutoApprovers.Routes { for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers { for _, approver := range approvers {
switch approver.(type) { switch approver := approver.(type) {
case *Group: case *Group:
g := approver.(*Group) g := approver
if err := p.Groups.Contains(g); err != nil { if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
case *Tag: case *Tag:
tagOwner := approver.(*Tag) tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil { if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
} }
for _, approver := range p.AutoApprovers.ExitNode { for _, approver := range p.AutoApprovers.ExitNode {
switch approver.(type) { switch approver := approver.(type) {
case *Group: case *Group:
g := approver.(*Group) g := approver
if err := p.Groups.Contains(g); err != nil { if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
case *Tag: case *Tag:
tagOwner := approver.(*Tag) tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil { if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
} }
p.validated = true p.validated = true
return nil return nil
} }
@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
) )
} }
} }
return nil return nil
} }
@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
) )
} }
} }
return nil return nil
} }

View File

@ -5,13 +5,13 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model" "github.com/prometheus/common/model"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go4.org/netipx" "go4.org/netipx"
@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
// Marshal the policy to JSON // Marshal the policy to JSON
marshalled, err := json.MarshalIndent(policy, "", " ") marshalled, err := json.MarshalIndent(policy, "", " ")
require.NoError(t, err) require.NoError(t, err)
// Make sure all expected fields are present in the JSON // Make sure all expected fields are present in the JSON
jsonString := string(marshalled) jsonString := string(marshalled)
assert.Contains(t, jsonString, "group:example") assert.Contains(t, jsonString, "group:example")
@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
assert.Contains(t, jsonString, "accept") assert.Contains(t, jsonString, "accept")
assert.Contains(t, jsonString, "tcp") assert.Contains(t, jsonString, "tcp")
assert.Contains(t, jsonString, "80") assert.Contains(t, jsonString, "80")
// Unmarshal back to verify round trip // Unmarshal back to verify round trip
var roundTripped Policy var roundTripped Policy
err = json.Unmarshal(marshalled, &roundTripped) err = json.Unmarshal(marshalled, &roundTripped)
require.NoError(t, err) require.NoError(t, err)
// Compare the original and round-tripped policies // Compare the original and round-tripped policies
cmps := append(util.Comparers, cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool { cmp.Comparer(func(x, y Prefix) bool {
return x == y return x == y
}), }),
cmpopts.IgnoreUnexported(Policy{}), cmpopts.IgnoreUnexported(Policy{}),
cmpopts.EquateEmpty(), cmpopts.EquateEmpty(),
) )
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" { if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
} }
@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
}, },
} }
cmps := append(util.Comparers, cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool { cmp.Comparer(func(x, y Prefix) bool {
return x == y return x == y
}), }),
cmpopts.IgnoreUnexported(Policy{}), cmpopts.IgnoreUnexported(Policy{}),
) )
// For round-trip testing, we'll normalize the policies before comparing // For round-trip testing, we'll normalize the policies before comparing
for _, tt := range tests { for _, tt := range tests {
@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
} else if !strings.Contains(err.Error(), tt.wantErr) { } else if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr) t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
} }
return // Skip the rest of the test if we expected an error return // Skip the rest of the test if we expected an error
} }
@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("round-trip unmarshalling: %v", err) t.Fatalf("round-trip unmarshalling: %v", err)
} }
// Add EquateEmpty to handle nil vs empty maps/slices // Add EquateEmpty to handle nil vs empty maps/slices
roundTripCmps := append(cmps, roundTripCmps := append(cmps,
cmpopts.EquateEmpty(), cmpopts.EquateEmpty(),
cmpopts.IgnoreUnexported(Policy{}), cmpopts.IgnoreUnexported(Policy{}),
) )
@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
builder.AddPrefix(mp(p)) builder.AddPrefix(mp(p))
} }
ipSet, _ := builder.IPSet() ipSet, _ := builder.IPSet()
return ipSet return ipSet
} }

View File

@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
expected []tailcfg.PortRange expected []tailcfg.PortRange
err string err string
}{ }{
{"80", []tailcfg.PortRange{{80, 80}}, ""}, {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
{"80-90", []tailcfg.PortRange{{80, 90}}, ""}, {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""}, {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""}, {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"}, {"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"}, {"-90", nil, "invalid port range format"},

View File

@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
} }
tsaddr.SortPrefixes(routes) tsaddr.SortPrefixes(routes)
return routes return routes
} }

View File

@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
if err != nil { if err != nil {
return types.NodeView{}, err return types.NodeView{}, err
} }
return node.View(), nil return node.View(), nil
} }
@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er
if err != nil { if err != nil {
return types.NodeView{}, err return types.NodeView{}, err
} }
return node.View(), nil return node.View(), nil
} }
@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey(
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) { if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
nodeToRegister.Expiry = &regReq.Expiry nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() { } else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout), // If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired // don't set expiry so the node won't be considered expired
log.Debug(). log.Debug().
Time("requested_expiry", regReq.Expiry). Time("requested_expiry", regReq.Expiry).

View File

@ -2,6 +2,7 @@ package hscontrol
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
certDomains := tsNode.CertDomains() certDomains := tsNode.CertDomains()
if len(certDomains) == 0 { if len(certDomains) == 0 {
return fmt.Errorf("no cert domains available for HTTPS") return errors.New("no cert domains available for HTTPS")
} }
base := "https://" + certDomains[0] base := "https://" + certDomains[0]
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
logf("TailSQL started") logf("TailSQL started")
<-ctx.Done() <-ctx.Done()
logf("TailSQL shutting down...") logf("TailSQL shutting down...")
return tsNode.Close() return tsNode.Close()
} }

View File

@ -62,7 +62,7 @@ func Apple(url string) *elem.Element {
), ),
elem.Pre(nil, elem.Pre(nil,
elem.Code(nil, elem.Code(nil,
elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)), elem.Text("tailscale login --login-server "+url),
), ),
), ),
headerTwo("GUI"), headerTwo("GUI"),
@ -143,10 +143,7 @@ func Apple(url string) *elem.Element {
elem.Code( elem.Code(
nil, nil,
elem.Text( elem.Text(
fmt.Sprintf( "defaults write io.tailscale.ipn.macos ControlURL "+url,
`defaults write io.tailscale.ipn.macos ControlURL %s`,
url,
),
), ),
), ),
), ),
@ -155,10 +152,7 @@ func Apple(url string) *elem.Element {
elem.Code( elem.Code(
nil, nil,
elem.Text( elem.Text(
fmt.Sprintf( "defaults write io.tailscale.ipn.macsys ControlURL "+url,
`defaults write io.tailscale.ipn.macsys ControlURL %s`,
url,
),
), ),
), ),
), ),

View File

@ -1,8 +1,6 @@
package templates package templates
import ( import (
"fmt"
"github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/attrs"
) )
@ -31,7 +29,7 @@ func Windows(url string) *elem.Element {
), ),
elem.Pre(nil, elem.Pre(nil,
elem.Code(nil, elem.Code(nil,
elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)), elem.Text("tailscale login --login-server "+url),
), ),
), ),
), ),

View File

@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return rid return rid
} }

View File

@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error {
log.Warn().Msg("No config file found, using defaults") log.Warn().Msg("No config file found, using defaults")
return nil return nil
} }
return fmt.Errorf("fatal error reading config file: %w", err) return fmt.Errorf("fatal error reading config file: %w", err)
} }
@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) {
} }
if prefix4 == nil && prefix6 == nil { if prefix4 == nil && prefix6 == nil {
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required") return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
} }
allocStr := viper.GetString("prefixes.allocation") allocStr := viper.GetString("prefixes.allocation")
@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
s := len(serverDomainParts) s := len(serverDomainParts)
b := len(baseDomainParts) b := len(baseDomainParts)
for i := range len(baseDomainParts) { for i := range baseDomainParts {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
return nil return nil
} }

View File

@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) {
assert.Equal(t, "trace", viper.GetString("log.level")) assert.Equal(t, "trace", viper.GetString("log.level"))
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4")) assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log")) assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
return nil, nil return nil, nil
}, },
want: nil, want: nil,

View File

@ -28,8 +28,10 @@ var (
ErrNodeUserHasNoName = errors.New("node user has no name") ErrNodeUserHasNoName = errors.New("node user has no name")
) )
type NodeID uint64 type (
type NodeIDs []NodeID NodeID uint64
NodeIDs []NodeID
)
func (n NodeIDs) Len() int { return len(n) } func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
return true return true
} }
} }
return false return false
} }
@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
// and therefore should not be treated as a // and therefore should not be treated as a
// user owned device. // user owned device.
// Currently, this function only handles tags set // Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys) // via CLI ("forced tags" and preauthkeys).
func (node *Node) IsTagged() bool { func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 { if len(node.ForcedTags) > 0 {
return true return true
@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool {
// HasTag reports if a node has a given tag. // HasTag reports if a node has a given tag.
// Currently, this function only handles tags set // Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys) // via CLI ("forced tags" and preauthkeys).
func (node *Node) HasTag(tag string) bool { func (node *Node) HasTag(tag string) bool {
return slices.Contains(node.Tags(), tag) return slices.Contains(node.Tags(), tag)
} }
@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string {
sb.WriteString(node.DebugString()) sb.WriteString(node.DebugString())
sb.WriteString("\n") sb.WriteString("\n")
} }
return sb.String() return sb.String()
} }
@ -590,6 +594,7 @@ func (node Node) DebugString() string {
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes()) fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
sb.WriteString("\n") sb.WriteString("\n")
return sb.String() return sb.String()
} }
@ -689,7 +694,7 @@ func (v NodeView) Tags() []string {
// and therefore should not be treated as a // and therefore should not be treated as a
// user owned device. // user owned device.
// Currently, this function only handles tags set // Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys) // via CLI ("forced tags" and preauthkeys).
func (v NodeView) IsTagged() bool { func (v NodeView) IsTagged() bool {
if !v.Valid() { if !v.Valid() {
return false return false
@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
// GetFQDN returns the fully qualified domain name for the node. // GetFQDN returns the fully qualified domain name for the node.
func (v NodeView) GetFQDN(baseDomain string) (string, error) { func (v NodeView) GetFQDN(baseDomain string) (string, error) {
if !v.Valid() { if !v.Valid() {
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid") return "", errors.New("failed to create valid FQDN: node view is invalid")
} }
return v.ж.GetFQDN(baseDomain) return v.ж.GetFQDN(baseDomain)
} }
@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string {
} }
return v.ж.IPsAsString() return v.ж.IPsAsString()
} }

View File

@ -2,7 +2,6 @@ package types
import ( import (
"fmt" "fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
@ -10,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"

View File

@ -11,7 +11,7 @@ import (
type PAKError string type PAKError string
func (e PAKError) Error() string { return string(e) } func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) } func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user. // PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct { type PreAuthKey struct {

View File

@ -1,6 +1,7 @@
package types package types
import ( import (
"errors"
"testing" "testing"
"time" "time"
@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("expected error but got none") t.Errorf("expected error but got none")
} else { } else {
httpErr, ok := err.(PAKError) var httpErr PAKError
ok := errors.As(err, &httpErr)
if !ok { if !ok {
t.Errorf("expected HTTPError but got %T", err) t.Errorf("expected HTTPError but got %T", err)
} else { } else {

View File

@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string {
// - Remove empty path segments // - Remove empty path segments
// - For non-URL identifiers, it joins non-empty segments with a single slash // - For non-URL identifiers, it joins non-empty segments with a single slash
// - Returns empty string for identifiers with only slashes // - Returns empty string for identifiers with only slashes
// - Normalize URL schemes to lowercase // - Normalize URL schemes to lowercase.
func CleanIdentifier(identifier string) string { func CleanIdentifier(identifier string) string {
if identifier == "" { if identifier == "" {
return identifier return identifier
@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string {
cleanParts = append(cleanParts, part) cleanParts = append(cleanParts, part)
} }
} }
if len(cleanParts) == 0 { if len(cleanParts) == 0 {
u.Path = "" u.Path = ""
} else { } else {
@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string {
} }
// Ensure scheme is lowercase // Ensure scheme is lowercase
u.Scheme = strings.ToLower(u.Scheme) u.Scheme = strings.ToLower(u.Scheme)
return u.String() return u.String()
} }
@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string {
if len(cleanParts) == 0 { if len(cleanParts) == 0 {
return "" return ""
} }
return strings.Join(cleanParts, "/") return strings.Join(cleanParts, "/")
} }

View File

@ -1,4 +1,6 @@
package types package types
var Version = "dev" var (
var GitCommitHash = "dev" Version = "dev"
GitCommitHash = "dev"
)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"regexp" "regexp"
"strconv"
"strings" "strings"
"unicode" "unicode"
@ -21,8 +22,10 @@ const (
LabelHostnameLength = 63 LabelHostnameLength = 63
) )
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var (
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
var ErrInvalidUserName = errors.New("invalid user name") var ErrInvalidUserName = errors.New("invalid user name")
@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.) // here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{} rdnsSlice := []string{}
for i := lastOctet - 1; i >= 0; i-- { for i := lastOctet - 1; i >= 0; i-- {
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i])) rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
} }
rdnsSlice = append(rdnsSlice, "in-addr.arpa.") rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
rdnsBase := strings.Join(rdnsSlice, ".") rdnsBase := strings.Join(rdnsSlice, ".")
@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) { makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".") prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix)) return dnsname.ToFQDN(prefix + ".ip6.arpa")
} }
var fqdns []dnsname.FQDN var fqdns []dnsname.FQDN

View File

@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
"rowsAffected": rowsAffected, "rowsAffected": rowsAffected,
} }
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) { if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) {
l.Logger.Error().Err(err).Fields(fields).Msgf("") l.Logger.Error().Err(err).Fields(fields).Msgf("")
return return
} }

View File

@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet {
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16")) internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet() theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet return theInternetSet
}) })

View File

@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
} }
type TraceroutePath struct { type TraceroutePath struct {
// Hop is the current jump in the total traceroute. // Hop is the current jump in the total traceroute.
Hop int Hop int
// Hostname is the resolved hostname or IP address identifying the jump // Hostname is the resolved hostname or IP address identifying the jump
Hostname string Hostname string
// IP is the IP address of the jump // IP is the IP address of the jump
IP netip.Addr IP netip.Addr
// Latencies is a list of the latencies for this jump // Latencies is a list of the latencies for this jump
Latencies []time.Duration Latencies []time.Duration
} }
type Traceroute struct { type Traceroute struct {
// Hostname is the resolved hostname or IP address identifying the target // Hostname is the resolved hostname or IP address identifying the target
Hostname string Hostname string
// IP is the IP address of the target // IP is the IP address of the target
IP netip.Addr IP netip.Addr
// Route is the path taken to reach the target if successful. The list is ordered by the path taken. // Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath Route []TraceroutePath
// Success indicates if the traceroute was successful. // Success indicates if the traceroute was successful.
Success bool Success bool
// Err contains an error if the traceroute was not successful. // Err contains an error if the traceroute was not successful.
Err error Err error
} }
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct // ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
func ParseTraceroute(output string) (Traceroute, error) { func ParseTraceroute(output string) (Traceroute, error) {
lines := strings.Split(strings.TrimSpace(output), "\n") lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) < 1 { if len(lines) < 1 {
@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
} }
// Parse each hop line // Parse each hop line
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
for i := 1; i < len(lines); i++ { for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i]) matches := hopRegex.FindStringSubmatch(lines[i])

View File

@ -1077,7 +1077,6 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,
@ -1213,7 +1212,6 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
func TestACLAutogroupMember(t *testing.T) { func TestACLAutogroupMember(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := aclScenario(t, scenario := aclScenario(t,
&policyv2.Policy{ &policyv2.Policy{
@ -1271,7 +1269,6 @@ func TestACLAutogroupMember(t *testing.T) {
func TestACLAutogroupTagged(t *testing.T) { func TestACLAutogroupTagged(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := aclScenario(t, scenario := aclScenario(t,
&policyv2.Policy{ &policyv2.Policy{

View File

@ -3,12 +3,11 @@ package integration
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"slices"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
@ -19,7 +18,6 @@ import (
func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
for _, https := range []bool{true, false} { for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
@ -66,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
assertNoErrGetHeadscale(t, err) assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients)) assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes) nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout) t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -161,12 +159,11 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
} }
}) })
} }
} }
func assertLastSeenSet(t *testing.T, node *v1.Node) { func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node) assert.NotNil(t, node)
assert.NotNil(t, node.LastSeen) assert.NotNil(t, node.GetLastSeen())
} }
// This test will first log in two sets of nodes to two sets of users, then // This test will first log in two sets of nodes to two sets of users, then
@ -175,7 +172,6 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
// still has nodes, but they are not connected. // still has nodes, but they are not connected.
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -204,7 +200,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
assertNoErrGetHeadscale(t, err) assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients)) assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes) nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout) t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -259,7 +255,6 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
for _, https := range []bool{true, false} { for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
@ -303,7 +298,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
assertNoErrGetHeadscale(t, err) assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients)) assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes) nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout) t.Logf("node count before logout: %d", nodeCountBeforeLogout)

View File

@ -1,14 +1,12 @@
package integration package integration
import ( import (
"fmt" "maps"
"net/netip" "net/netip"
"sort" "sort"
"testing" "testing"
"time" "time"
"maps"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -21,7 +19,6 @@ import (
func TestOIDCAuthenticationPingAll(t *testing.T) { func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
// Logins to MockOIDC is served by a queue with a strict order, // Logins to MockOIDC is served by a queue with a strict order,
// if we use more than one node per user, the order of the logins // if we use more than one node per user, the order of the logins
@ -119,7 +116,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
// This test is really flaky. // This test is really flaky.
func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
shortAccessTTL := 5 * time.Minute shortAccessTTL := 5 * time.Minute
@ -174,9 +170,13 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
// of safety reasons) before checking if the clients have logged out. // of safety reasons) before checking if the clients have logged out.
// The Wait function can't do it itself as it has an upper bound of 1 // The Wait function can't do it itself as it has an upper bound of 1
// min. // min.
time.Sleep(shortAccessTTL + 10*time.Second) assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
assertTailscaleNodesLogout(t, allClients) status, err := client.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}
}, shortAccessTTL+10*time.Second, 5*time.Second)
} }
func TestOIDC024UserCreation(t *testing.T) { func TestOIDC024UserCreation(t *testing.T) {
@ -295,9 +295,7 @@ func TestOIDC024UserCreation(t *testing.T) {
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,
} }
for _, user := range tt.cliUsers { spec.Users = append(spec.Users, tt.cliUsers...)
spec.Users = append(spec.Users, user)
}
for _, user := range tt.oidcUsers { for _, user := range tt.oidcUsers {
spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified)) spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified))
@ -350,7 +348,6 @@ func TestOIDC024UserCreation(t *testing.T) {
func TestOIDCAuthenticationWithPKCE(t *testing.T) { func TestOIDCAuthenticationWithPKCE(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
// Single user with one node for testing PKCE flow // Single user with one node for testing PKCE flow
spec := ScenarioSpec{ spec := ScenarioSpec{
@ -402,7 +399,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
func TestOIDCReloginSameNodeNewUser(t *testing.T) { func TestOIDCReloginSameNodeNewUser(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
// Create no nodes and no users // Create no nodes and no users
scenario, err := NewScenario(ScenarioSpec{ scenario, err := NewScenario(ScenarioSpec{
@ -440,7 +436,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
listUsers, err := headscale.ListUsers() listUsers, err := headscale.ListUsers()
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, listUsers, 0) assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err) assertNoErr(t, err)
@ -482,7 +478,13 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout() err = ts.Logout()
assertNoErr(t, err) assertNoErr(t, err)
time.Sleep(5 * time.Second) // Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and // TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it // logs in immediately after the first logout and I cannot reproduce it
@ -530,16 +532,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Machine key is the same as the "machine" has not changed, // Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node // but Node key is not as it is a new node
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey) assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
// Log out user2, and log into user1, no new node should be created, // Log out user2, and log into user1, no new node should be created,
// the node should now "become" node1 again // the node should now "become" node1 again
err = ts.Logout() err = ts.Logout()
assertNoErr(t, err) assertNoErr(t, err)
time.Sleep(5 * time.Second) // Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and // TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it // logs in immediately after the first logout and I cannot reproduce it
@ -588,24 +596,24 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Validate that the machine we had when we logged in the first time, has the same // Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same // machine key, but a different ID than the newly logged in version of the same
// machine. // machine.
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey) assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id) assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id) assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id) assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
// Even tho we are logging in again with the same user, the previous key has been expired // Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same // and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches. // as the user + machinekey still matches.
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey) assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey) assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id) assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
// The "logged back in" machine should have the same machinekey but a different nodekey // The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user. // than the version logged in with a different user.
assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey) assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
} }
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
@ -623,7 +631,7 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
return mockoidc.MockUser{ return mockoidc.MockUser{
Subject: username, Subject: username,
PreferredUsername: username, PreferredUsername: username,
Email: fmt.Sprintf("%s@headscale.net", username), Email: username + "@headscale.net",
EmailVerified: emailVerified, EmailVerified: emailVerified,
} }
} }

View File

@ -2,9 +2,8 @@ package integration
import ( import (
"net/netip" "net/netip"
"testing"
"slices" "slices"
"testing"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/samber/lo" "github.com/samber/lo"
@ -55,7 +54,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -95,7 +93,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
assertNoErrGetHeadscale(t, err) assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients)) assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes) nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout) t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -140,7 +138,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
listNodes, err = headscale.ListNodes() listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes)) require.Len(t, listNodes, nodeCountBeforeLogout)
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
for _, client := range allClients { for _, client := range allClients {

View File

@ -18,8 +18,8 @@ import (
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"tailscale.com/tailcfg"
) )
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
@ -30,7 +30,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul
err = json.Unmarshal([]byte(str), result) err = json.Unmarshal([]byte(str), result)
if err != nil { if err != nil {
return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str) return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str)
} }
return nil return nil
@ -48,7 +48,6 @@ func sortWithID[T GRPCSortable](a, b T) int {
func TestUserCommand(t *testing.T) { func TestUserCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"user1", "user2"}, Users: []string{"user1", "user2"},
@ -184,7 +183,7 @@ func TestUserCommand(t *testing.T) {
"--identifier=1", "--identifier=1",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Contains(t, deleteResult, "User destroyed") assert.Contains(t, deleteResult, "User destroyed")
var listAfterIDDelete []*v1.User var listAfterIDDelete []*v1.User
@ -222,7 +221,7 @@ func TestUserCommand(t *testing.T) {
"--name=newname", "--name=newname",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Contains(t, deleteResult, "User destroyed") assert.Contains(t, deleteResult, "User destroyed")
var listAfterNameDelete []v1.User var listAfterNameDelete []v1.User
@ -238,12 +237,11 @@ func TestUserCommand(t *testing.T) {
) )
assertNoErr(t, err) assertNoErr(t, err)
require.Len(t, listAfterNameDelete, 0) require.Empty(t, listAfterNameDelete)
} }
func TestPreAuthKeyCommand(t *testing.T) { func TestPreAuthKeyCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
user := "preauthkeyspace" user := "preauthkeyspace"
count := 3 count := 3
@ -347,7 +345,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
continue continue
} }
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
} }
// Test key expiry // Test key expiry
@ -386,7 +384,6 @@ func TestPreAuthKeyCommand(t *testing.T) {
func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
user := "pre-auth-key-without-exp-user" user := "pre-auth-key-without-exp-user"
spec := ScenarioSpec{ spec := ScenarioSpec{
@ -448,7 +445,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
user := "pre-auth-key-reus-ephm-user" user := "pre-auth-key-reus-ephm-user"
spec := ScenarioSpec{ spec := ScenarioSpec{
@ -524,7 +520,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
user1 := "user1" user1 := "user1"
user2 := "user2" user2 := "user2"
@ -575,7 +570,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
require.Nil(t, err) require.NoError(t, err)
require.Len(t, listNodes, 1) require.Len(t, listNodes, 1)
assert.Equal(t, user1, listNodes[0].GetUser().GetName()) assert.Equal(t, user1, listNodes[0].GetUser().GetName())
@ -613,7 +608,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
} }
listNodes, err = headscale.ListNodes() listNodes, err = headscale.ListNodes()
require.Nil(t, err) require.NoError(t, err)
require.Len(t, listNodes, 2) require.Len(t, listNodes, 2)
assert.Equal(t, user1, listNodes[0].GetUser().GetName()) assert.Equal(t, user1, listNodes[0].GetUser().GetName())
assert.Equal(t, user2, listNodes[1].GetUser().GetName()) assert.Equal(t, user2, listNodes[1].GetUser().GetName())
@ -621,7 +616,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
func TestApiKeyCommand(t *testing.T) { func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
count := 5 count := 5
@ -653,7 +647,7 @@ func TestApiKeyCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.NotEmpty(t, apiResult) assert.NotEmpty(t, apiResult)
keys[idx] = apiResult keys[idx] = apiResult
@ -672,7 +666,7 @@ func TestApiKeyCommand(t *testing.T) {
}, },
&listedAPIKeys, &listedAPIKeys,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listedAPIKeys, 5) assert.Len(t, listedAPIKeys, 5)
@ -728,7 +722,7 @@ func TestApiKeyCommand(t *testing.T) {
listedAPIKeys[idx].GetPrefix(), listedAPIKeys[idx].GetPrefix(),
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
} }
@ -744,7 +738,7 @@ func TestApiKeyCommand(t *testing.T) {
}, },
&listedAfterExpireAPIKeys, &listedAfterExpireAPIKeys,
) )
assert.Nil(t, err) assert.NoError(t, err)
for index := range listedAfterExpireAPIKeys { for index := range listedAfterExpireAPIKeys {
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
@ -770,7 +764,7 @@ func TestApiKeyCommand(t *testing.T) {
"--prefix", "--prefix",
listedAPIKeys[0].GetPrefix(), listedAPIKeys[0].GetPrefix(),
}) })
assert.Nil(t, err) assert.NoError(t, err)
var listedAPIKeysAfterDelete []v1.ApiKey var listedAPIKeysAfterDelete []v1.ApiKey
err = executeAndUnmarshal(headscale, err = executeAndUnmarshal(headscale,
@ -783,14 +777,13 @@ func TestApiKeyCommand(t *testing.T) {
}, },
&listedAPIKeysAfterDelete, &listedAPIKeysAfterDelete,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listedAPIKeysAfterDelete, 4) assert.Len(t, listedAPIKeysAfterDelete, 4)
} }
func TestNodeTagCommand(t *testing.T) { func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"user1"}, Users: []string{"user1"},
@ -811,7 +804,7 @@ func TestNodeTagCommand(t *testing.T) {
types.MustRegistrationID().String(), types.MustRegistrationID().String(),
} }
nodes := make([]*v1.Node, len(regIDs)) nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err) assert.NoError(t, err)
for index, regID := range regIDs { for index, regID := range regIDs {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -829,7 +822,7 @@ func TestNodeTagCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -847,7 +840,7 @@ func TestNodeTagCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -866,7 +859,7 @@ func TestNodeTagCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
@ -894,7 +887,7 @@ func TestNodeTagCommand(t *testing.T) {
}, },
&resultMachines, &resultMachines,
) )
assert.Nil(t, err) assert.NoError(t, err)
found := false found := false
for _, node := range resultMachines { for _, node := range resultMachines {
if node.GetForcedTags() != nil { if node.GetForcedTags() != nil {
@ -905,19 +898,15 @@ func TestNodeTagCommand(t *testing.T) {
} }
} }
} }
assert.Equal( assert.True(
t, t,
true,
found, found,
"should find a node with the tag 'tag:test' in the list of nodes", "should find a node with the tag 'tag:test' in the list of nodes",
) )
} }
func TestNodeAdvertiseTagCommand(t *testing.T) { func TestNodeAdvertiseTagCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
tests := []struct { tests := []struct {
name string name string
@ -1024,7 +1013,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
}, },
&resultMachines, &resultMachines,
) )
assert.Nil(t, err) assert.NoError(t, err)
found := false found := false
for _, node := range resultMachines { for _, node := range resultMachines {
if tags := node.GetValidTags(); tags != nil { if tags := node.GetValidTags(); tags != nil {
@ -1043,7 +1032,6 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
func TestNodeCommand(t *testing.T) { func TestNodeCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"node-user", "other-user"}, Users: []string{"node-user", "other-user"},
@ -1067,7 +1055,7 @@ func TestNodeCommand(t *testing.T) {
types.MustRegistrationID().String(), types.MustRegistrationID().String(),
} }
nodes := make([]*v1.Node, len(regIDs)) nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err) assert.NoError(t, err)
for index, regID := range regIDs { for index, regID := range regIDs {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -1085,7 +1073,7 @@ func TestNodeCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1103,7 +1091,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -1123,7 +1111,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listAll, &listAll,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listAll, 5) assert.Len(t, listAll, 5)
@ -1144,7 +1132,7 @@ func TestNodeCommand(t *testing.T) {
types.MustRegistrationID().String(), types.MustRegistrationID().String(),
} }
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
assert.Nil(t, err) assert.NoError(t, err)
for index, regID := range otherUserRegIDs { for index, regID := range otherUserRegIDs {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -1162,7 +1150,7 @@ func TestNodeCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1180,7 +1168,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
otherUserMachines[index] = &node otherUserMachines[index] = &node
} }
@ -1200,7 +1188,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listAllWithotherUser, &listAllWithotherUser,
) )
assert.Nil(t, err) assert.NoError(t, err)
// All nodes, nodes + otherUser // All nodes, nodes + otherUser
assert.Len(t, listAllWithotherUser, 7) assert.Len(t, listAllWithotherUser, 7)
@ -1226,7 +1214,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listOnlyotherUserMachineUser, &listOnlyotherUserMachineUser,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listOnlyotherUserMachineUser, 2) assert.Len(t, listOnlyotherUserMachineUser, 2)
@ -1258,7 +1246,7 @@ func TestNodeCommand(t *testing.T) {
"--force", "--force",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
// Test: list main user after node is deleted // Test: list main user after node is deleted
var listOnlyMachineUserAfterDelete []v1.Node var listOnlyMachineUserAfterDelete []v1.Node
@ -1275,14 +1263,13 @@ func TestNodeCommand(t *testing.T) {
}, },
&listOnlyMachineUserAfterDelete, &listOnlyMachineUserAfterDelete,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listOnlyMachineUserAfterDelete, 4) assert.Len(t, listOnlyMachineUserAfterDelete, 4)
} }
func TestNodeExpireCommand(t *testing.T) { func TestNodeExpireCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"node-expire-user"}, Users: []string{"node-expire-user"},
@ -1323,7 +1310,7 @@ func TestNodeExpireCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1341,7 +1328,7 @@ func TestNodeExpireCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -1360,7 +1347,7 @@ func TestNodeExpireCommand(t *testing.T) {
}, },
&listAll, &listAll,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listAll, 5) assert.Len(t, listAll, 5)
@ -1377,10 +1364,10 @@ func TestNodeExpireCommand(t *testing.T) {
"nodes", "nodes",
"expire", "expire",
"--identifier", "--identifier",
fmt.Sprintf("%d", listAll[idx].GetId()), strconv.FormatUint(listAll[idx].GetId(), 10),
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
} }
var listAllAfterExpiry []v1.Node var listAllAfterExpiry []v1.Node
@ -1395,7 +1382,7 @@ func TestNodeExpireCommand(t *testing.T) {
}, },
&listAllAfterExpiry, &listAllAfterExpiry,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listAllAfterExpiry, 5) assert.Len(t, listAllAfterExpiry, 5)
@ -1408,7 +1395,6 @@ func TestNodeExpireCommand(t *testing.T) {
func TestNodeRenameCommand(t *testing.T) { func TestNodeRenameCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"node-rename-command"}, Users: []string{"node-rename-command"},
@ -1432,7 +1418,7 @@ func TestNodeRenameCommand(t *testing.T) {
types.MustRegistrationID().String(), types.MustRegistrationID().String(),
} }
nodes := make([]*v1.Node, len(regIDs)) nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err) assert.NoError(t, err)
for index, regID := range regIDs { for index, regID := range regIDs {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -1487,7 +1473,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&listAll, &listAll,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listAll, 5) assert.Len(t, listAll, 5)
@ -1504,11 +1490,11 @@ func TestNodeRenameCommand(t *testing.T) {
"nodes", "nodes",
"rename", "rename",
"--identifier", "--identifier",
fmt.Sprintf("%d", listAll[idx].GetId()), strconv.FormatUint(listAll[idx].GetId(), 10),
fmt.Sprintf("newnode-%d", idx+1), fmt.Sprintf("newnode-%d", idx+1),
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Contains(t, res, "Node renamed") assert.Contains(t, res, "Node renamed")
} }
@ -1525,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&listAllAfterRename, &listAllAfterRename,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listAllAfterRename, 5) assert.Len(t, listAllAfterRename, 5)
@ -1542,7 +1528,7 @@ func TestNodeRenameCommand(t *testing.T) {
"nodes", "nodes",
"rename", "rename",
"--identifier", "--identifier",
fmt.Sprintf("%d", listAll[4].GetId()), strconv.FormatUint(listAll[4].GetId(), 10),
strings.Repeat("t", 64), strings.Repeat("t", 64),
}, },
) )
@ -1560,7 +1546,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&listAllAfterRenameAttempt, &listAllAfterRenameAttempt,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, listAllAfterRenameAttempt, 5) assert.Len(t, listAllAfterRenameAttempt, 5)
@ -1573,7 +1559,6 @@ func TestNodeRenameCommand(t *testing.T) {
func TestNodeMoveCommand(t *testing.T) { func TestNodeMoveCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"old-user", "new-user"}, Users: []string{"old-user", "new-user"},
@ -1610,7 +1595,7 @@ func TestNodeMoveCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1628,13 +1613,13 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, uint64(1), node.GetId()) assert.Equal(t, uint64(1), node.GetId())
assert.Equal(t, "nomad-node", node.GetName()) assert.Equal(t, "nomad-node", node.GetName())
assert.Equal(t, node.GetUser().GetName(), "old-user") assert.Equal(t, "old-user", node.GetUser().GetName())
nodeID := fmt.Sprintf("%d", node.GetId()) nodeID := strconv.FormatUint(node.GetId(), 10)
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
@ -1651,9 +1636,9 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "new-user") assert.Equal(t, "new-user", node.GetUser().GetName())
var allNodes []v1.Node var allNodes []v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1667,13 +1652,13 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&allNodes, &allNodes,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Len(t, allNodes, 1) assert.Len(t, allNodes, 1)
assert.Equal(t, allNodes[0].GetId(), node.GetId()) assert.Equal(t, allNodes[0].GetId(), node.GetId())
assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") assert.Equal(t, "new-user", allNodes[0].GetUser().GetName())
_, err = headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
@ -1693,7 +1678,7 @@ func TestNodeMoveCommand(t *testing.T) {
err, err,
"user not found", "user not found",
) )
assert.Equal(t, node.GetUser().GetName(), "new-user") assert.Equal(t, "new-user", node.GetUser().GetName())
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
@ -1710,9 +1695,9 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user") assert.Equal(t, "old-user", node.GetUser().GetName())
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
@ -1729,14 +1714,13 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user") assert.Equal(t, "old-user", node.GetUser().GetName())
} }
func TestPolicyCommand(t *testing.T) { func TestPolicyCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
Users: []string{"user1"}, Users: []string{"user1"},
@ -1817,7 +1801,6 @@ func TestPolicyCommand(t *testing.T) {
func TestPolicyBrokenConfigCommand(t *testing.T) { func TestPolicyBrokenConfigCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,

View File

@ -1,7 +1,6 @@
package integration package integration
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -104,7 +103,7 @@ func DERPVerify(
defer c.Close() defer c.Close()
var result error var result error
if err := c.Connect(context.Background()); err != nil { if err := c.Connect(t.Context()); err != nil {
result = fmt.Errorf("client Connect: %w", err) result = fmt.Errorf("client Connect: %w", err)
} }
if m, err := c.Recv(); err != nil { if m, err := c.Recv(); err != nil {

View File

@ -15,7 +15,6 @@ import (
func TestResolveMagicDNS(t *testing.T) { func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -49,7 +48,7 @@ func TestResolveMagicDNS(t *testing.T) {
// It is safe to ignore this error as we handled it when caching it // It is safe to ignore this error as we handled it when caching it
peerFQDN, _ := peer.FQDN() peerFQDN, _ := peer.FQDN()
assert.Equal(t, fmt.Sprintf("%s.headscale.net.", peer.Hostname()), peerFQDN) assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN)
command := []string{ command := []string{
"tailscale", "tailscale",
@ -85,7 +84,6 @@ func TestResolveMagicDNS(t *testing.T) {
func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,
@ -222,12 +220,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
_, err = hs.Execute([]string{"rm", erPath}) _, err = hs.Execute([]string{"rm", erPath})
assertNoErr(t, err) assertNoErr(t, err)
time.Sleep(2 * time.Second)
// The same paths should still be available as it is not cleared on delete. // The same paths should still be available as it is not cleared on delete.
for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9") for _, client := range allClients {
} result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"})
assert.NoError(ct, err)
assert.Contains(ct, result, "9.9.9.9")
}
}, 10*time.Second, 1*time.Second)
// Write a new file, the backoff mechanism should make the filewatcher pick it up // Write a new file, the backoff mechanism should make the filewatcher pick it up
// again. // again.

View File

@ -33,26 +33,27 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
} }
// GenerateRunID creates a unique run identifier with timestamp and random hash. // GenerateRunID creates a unique run identifier with timestamp and random hash.
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3) // Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3).
func GenerateRunID() string { func GenerateRunID() string {
now := time.Now() now := time.Now()
timestamp := now.Format("20060102-150405") timestamp := now.Format("20060102-150405")
// Add a short random hash to ensure uniqueness // Add a short random hash to ensure uniqueness
randomHash := util.MustGenerateRandomStringDNSSafe(6) randomHash := util.MustGenerateRandomStringDNSSafe(6)
return fmt.Sprintf("%s-%s", timestamp, randomHash) return fmt.Sprintf("%s-%s", timestamp, randomHash)
} }
// ExtractRunIDFromContainerName extracts the run ID from container name. // ExtractRunIDFromContainerName extracts the run ID from container name.
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH" // Expects format: "prefix-YYYYMMDD-HHMMSS-HASH".
func ExtractRunIDFromContainerName(containerName string) string { func ExtractRunIDFromContainerName(containerName string) string {
parts := strings.Split(containerName, "-") parts := strings.Split(containerName, "-")
if len(parts) >= 3 { if len(parts) >= 3 {
// Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH) // Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH)
return strings.Join(parts[len(parts)-3:], "-") return strings.Join(parts[len(parts)-3:], "-")
} }
panic(fmt.Sprintf("unexpected container name format: %s", containerName)) panic("unexpected container name format: " + containerName)
} }
// IsRunningInContainer checks if the current process is running inside a Docker container. // IsRunningInContainer checks if the current process is running inside a Docker container.
@ -62,4 +63,4 @@ func IsRunningInContainer() bool {
// This could be improved with more robust detection if needed // This could be improved with more robust detection if needed
_, err := os.Stat("/.dockerenv") _, err := os.Stat("/.dockerenv")
return err == nil return err == nil
} }

View File

@ -30,7 +30,7 @@ func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption {
}) })
} }
// buffer is a goroutine safe bytes.buffer // buffer is a goroutine safe bytes.buffer.
type buffer struct { type buffer struct {
store bytes.Buffer store bytes.Buffer
mutex sync.Mutex mutex sync.Mutex
@ -58,8 +58,8 @@ func ExecuteCommand(
env []string, env []string,
options ...ExecuteCommandOption, options ...ExecuteCommandOption,
) (string, string, error) { ) (string, string, error) {
var stdout = buffer{} stdout := buffer{}
var stderr = buffer{} stderr := buffer{}
execConfig := ExecuteCommandConfig{ execConfig := ExecuteCommandConfig{
timeout: dockerExecuteTimeout, timeout: dockerExecuteTimeout,

View File

@ -159,7 +159,6 @@ func New(
}, },
} }
if dsic.workdir != "" { if dsic.workdir != "" {
runOptions.WorkingDir = dsic.workdir runOptions.WorkingDir = dsic.workdir
} }
@ -192,7 +191,7 @@ func New(
} }
// Add integration test labels if running under hi tool // Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "derp") dockertestutil.DockerAddIntegrationLabels(runOptions, "derp")
container, err = pool.BuildAndRunWithBuildOptions( container, err = pool.BuildAndRunWithBuildOptions(
buildOptions, buildOptions,
runOptions, runOptions,

View File

@ -2,13 +2,13 @@ package integration
import ( import (
"strings" "strings"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"testing" "testing"
"time" "time"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
type ClientsSpec struct { type ClientsSpec struct {
@ -71,9 +71,9 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
NodesPerUser: 1, NodesPerUser: 1,
Users: []string{"user1", "user2", "user3"}, Users: []string{"user1", "user2", "user3"},
Networks: map[string][]string{ Networks: map[string][]string{
"usernet1": []string{"user1"}, "usernet1": {"user1"},
"usernet2": []string{"user2"}, "usernet2": {"user2"},
"usernet3": []string{"user3"}, "usernet3": {"user3"},
}, },
} }
@ -106,7 +106,6 @@ func derpServerScenario(
furtherAssertions ...func(*Scenario), furtherAssertions ...func(*Scenario),
) { ) {
IntegrationSkip(t) IntegrationSkip(t)
// t.Parallel()
scenario, err := NewScenario(spec) scenario, err := NewScenario(spec)
assertNoErr(t, err) assertNoErr(t, err)

View File

@ -26,7 +26,6 @@ import (
func TestPingAllByIP(t *testing.T) { func TestPingAllByIP(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -68,7 +67,6 @@ func TestPingAllByIP(t *testing.T) {
func TestPingAllByIPPublicDERP(t *testing.T) { func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -118,7 +116,6 @@ func TestEphemeralInAlternateTimezone(t *testing.T) {
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -191,7 +188,6 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
// deleted by accident if they are still online and active. // deleted by accident if they are still online and active.
func TestEphemeral2006DeletedTooQuickly(t *testing.T) { func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -260,18 +256,21 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// Wait a bit and bring up the clients again before the expiry // Wait a bit and bring up the clients again before the expiry
// time of the ephemeral nodes. // time of the ephemeral nodes.
// Nodes should be able to reconnect and work fine. // Nodes should be able to reconnect and work fine.
time.Sleep(30 * time.Second)
for _, client := range allClients { for _, client := range allClients {
err := client.Up() err := client.Up()
if err != nil { if err != nil {
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
} }
} }
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success = pingAllHelper(t, allClients, allAddrs) // Wait for clients to sync and be able to ping each other after reconnection
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
success = pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping")
}, 60*time.Second, 2*time.Second)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
// Take down all clients, this should start an expiry timer for each. // Take down all clients, this should start an expiry timer for each.
@ -284,7 +283,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// This time wait for all of the nodes to expire and check that they are no longer // This time wait for all of the nodes to expire and check that they are no longer
// registered. // registered.
time.Sleep(3 * time.Minute) assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName)
assert.NoError(ct, err)
assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName)
}
}, 4*time.Minute, 10*time.Second)
for _, userName := range spec.Users { for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName) nodes, err := headscale.ListNodes(userName)
@ -305,7 +310,6 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
func TestPingAllByHostname(t *testing.T) { func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -341,20 +345,6 @@ func TestPingAllByHostname(t *testing.T) {
// nolint:tparallel // nolint:tparallel
func TestTaildrop(t *testing.T) { func TestTaildrop(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
retry := func(times int, sleepInterval time.Duration, doWork func() error) error {
var err error
for range times {
err = doWork()
if err == nil {
return nil
}
time.Sleep(sleepInterval)
}
return err
}
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -396,40 +386,27 @@ func TestTaildrop(t *testing.T) {
"/var/run/tailscale/tailscaled.sock", "/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets", "http://local-tailscaled.sock/localapi/v0/file-targets",
} }
err = retry(10, 1*time.Second, func() error { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, _, err := client.Execute(curlCommand) result, _, err := client.Execute(curlCommand)
if err != nil { assert.NoError(ct, err)
return err
}
var fts []apitype.FileTarget var fts []apitype.FileTarget
err = json.Unmarshal([]byte(result), &fts) err = json.Unmarshal([]byte(result), &fts)
if err != nil { assert.NoError(ct, err)
return err
}
if len(fts) != len(allClients)-1 { if len(fts) != len(allClients)-1 {
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname()) ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
for _, ft := range fts { for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
} }
return fmt.Errorf( assert.Failf(ct, "client %s does not have all its peers as FileTargets",
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", "got %d, want: %d\n%s",
client.Hostname(),
len(fts), len(fts),
len(allClients)-1, len(allClients)-1,
ftStr, ftStr,
) )
} }
}, 10*time.Second, 1*time.Second)
return err
})
if err != nil {
t.Errorf(
"failed to query localapi for filetarget on %s, err: %s",
client.Hostname(),
err,
)
}
} }
for _, client := range allClients { for _, client := range allClients {
@ -454,24 +431,15 @@ func TestTaildrop(t *testing.T) {
fmt.Sprintf("%s:", peerFQDN), fmt.Sprintf("%s:", peerFQDN),
} }
err := retry(10, 1*time.Second, func() error { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
t.Logf( t.Logf(
"Sending file from %s to %s\n", "Sending file from %s to %s\n",
client.Hostname(), client.Hostname(),
peer.Hostname(), peer.Hostname(),
) )
_, _, err := client.Execute(command) _, _, err := client.Execute(command)
assert.NoError(ct, err)
return err }, 10*time.Second, 1*time.Second)
})
if err != nil {
t.Fatalf(
"failed to send taildrop file on %s with command %q, err: %s",
client.Hostname(),
strings.Join(command, " "),
err,
)
}
}) })
} }
} }
@ -520,7 +488,6 @@ func TestTaildrop(t *testing.T) {
func TestUpdateHostnameFromClient(t *testing.T) { func TestUpdateHostnameFromClient(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
hostnames := map[string]string{ hostnames := map[string]string{
"1": "user1-host", "1": "user1-host",
@ -603,9 +570,47 @@ func TestUpdateHostnameFromClient(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
} }
time.Sleep(5 * time.Second) // Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Build a map of expected DNSNames by node ID
expectedDNSNames := make(map[string]string)
for _, node := range nodes {
nodeID := strconv.FormatUint(node.GetId(), 10)
expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId())
}
// Verify from each client's perspective
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
// Check self node
selfID := string(status.Self.ID)
expectedDNS := expectedDNSNames[selfID]
assert.Equal(ct, expectedDNS, status.Self.DNSName,
"Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID)
// HostName should remain as the original client-reported hostname
originalHostname := hostnames[selfID]
assert.Equal(ct, originalHostname, status.Self.HostName,
"Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID)
// Check peers
for _, peer := range status.Peer {
peerID := string(peer.ID)
if expectedDNS, ok := expectedDNSNames[peerID]; ok {
assert.Equal(ct, expectedDNS, peer.DNSName,
"Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname())
// HostName should remain as the original client-reported hostname
originalHostname := hostnames[peerID]
assert.Equal(ct, originalHostname, peer.HostName,
"Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname())
}
}
}
}, 60*time.Second, 2*time.Second)
// Verify that the clients can see the new hostname, but no givenName
for _, client := range allClients { for _, client := range allClients {
status, err := client.Status() status, err := client.Status()
assertNoErr(t, err) assertNoErr(t, err)
@ -647,7 +652,6 @@ func TestUpdateHostnameFromClient(t *testing.T) {
func TestExpireNode(t *testing.T) { func TestExpireNode(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -707,7 +711,23 @@ func TestExpireNode(t *testing.T) {
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
time.Sleep(2 * time.Minute) // Verify that the expired node has been marked in all peers list.
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
if client.Hostname() != node.GetName() {
// Check if the expired node appears as expired in this client's peer list
for key, peer := range status.Peer {
if key == expiredNodeKey {
assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname())
break
}
}
}
}
}, 3*time.Minute, 10*time.Second)
now := time.Now() now := time.Now()
@ -774,7 +794,6 @@ func TestExpireNode(t *testing.T) {
func TestNodeOnlineStatus(t *testing.T) { func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -890,7 +909,6 @@ func TestNodeOnlineStatus(t *testing.T) {
// five times ensuring they are able to restablish connectivity. // five times ensuring they are able to restablish connectivity.
func TestPingAllByIPManyUpDown(t *testing.T) { func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions), NodesPerUser: len(MustTestVersions),
@ -944,8 +962,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Fatalf("failed to take down all nodes: %s", err) t.Fatalf("failed to take down all nodes: %s", err)
} }
time.Sleep(5 * time.Second)
for _, client := range allClients { for _, client := range allClients {
c := client c := client
wg.Go(func() error { wg.Go(func() error {
@ -958,10 +974,14 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Fatalf("failed to take down all nodes: %s", err) t.Fatalf("failed to take down all nodes: %s", err)
} }
time.Sleep(5 * time.Second) // Wait for sync and successful pings after nodes come back up
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
err = scenario.WaitForTailscaleSync() success := pingAllHelper(t, allClients, allAddrs)
assertNoErrSync(t, err) assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
}, 30*time.Second, 2*time.Second)
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
@ -970,7 +990,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
func Test2118DeletingOnlineNodePanics(t *testing.T) { func Test2118DeletingOnlineNodePanics(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,
@ -1042,10 +1061,24 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(2 * time.Second)
// Ensure that the node has been deleted, this did not occur due to a panic. // Ensure that the node has been deleted, this did not occur due to a panic.
var nodeListAfter []v1.Node var nodeListAfter []v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeListAfter,
)
assert.NoError(ct, err)
assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list")
}, 10*time.Second, 1*time.Second)
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
[]string{ []string{

View File

@ -191,7 +191,7 @@ func WithPostgres() Option {
} }
} }
// WithPolicy sets the policy mode for headscale // WithPolicy sets the policy mode for headscale.
func WithPolicyMode(mode types.PolicyMode) Option { func WithPolicyMode(mode types.PolicyMode) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
hsic.policyMode = mode hsic.policyMode = mode
@ -279,7 +279,7 @@ func New(
return nil, err return nil, err
} }
hostname := fmt.Sprintf("hs-%s", hash) hostname := "hs-" + hash
hsic := &HeadscaleInContainer{ hsic := &HeadscaleInContainer{
hostname: hostname, hostname: hostname,
@ -308,14 +308,14 @@ func New(
if hsic.postgres { if hsic.postgres {
hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres" hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres"
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = fmt.Sprintf("postgres-%s", hash) hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash
hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale"
hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale"
hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale"
delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH") delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH")
pgRunOptions := &dockertest.RunOptions{ pgRunOptions := &dockertest.RunOptions{
Name: fmt.Sprintf("postgres-%s", hash), Name: "postgres-" + hash,
Repository: "postgres", Repository: "postgres",
Tag: "latest", Tag: "latest",
Networks: networks, Networks: networks,
@ -328,7 +328,7 @@ func New(
// Add integration test labels if running under hi tool // Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres") dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres")
pg, err := pool.RunWithOptions(pgRunOptions) pg, err := pool.RunWithOptions(pgRunOptions)
if err != nil { if err != nil {
return nil, fmt.Errorf("starting postgres container: %w", err) return nil, fmt.Errorf("starting postgres container: %w", err)
@ -373,7 +373,6 @@ func New(
Env: env, Env: env,
} }
if len(hsic.hostPortBindings) > 0 { if len(hsic.hostPortBindings) > 0 {
runOptions.PortBindings = map[docker.Port][]docker.PortBinding{} runOptions.PortBindings = map[docker.Port][]docker.PortBinding{}
for port, hostPorts := range hsic.hostPortBindings { for port, hostPorts := range hsic.hostPortBindings {
@ -396,7 +395,7 @@ func New(
// Add integration test labels if running under hi tool // Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale") dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale")
container, err := pool.BuildAndRunWithBuildOptions( container, err := pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions, headscaleBuildOptions,
runOptions, runOptions,
@ -566,7 +565,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
// extractTarToDirectory extracts a tar archive to a directory. // extractTarToDirectory extracts a tar archive to a directory.
func extractTarToDirectory(tarData []byte, targetDir string) error { func extractTarToDirectory(tarData []byte, targetDir string) error {
if err := os.MkdirAll(targetDir, 0755); err != nil { if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err) return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
} }
@ -624,6 +623,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
} }
targetDir := path.Join(savePath, t.hostname+"-pprof") targetDir := path.Join(savePath, t.hostname+"-pprof")
return extractTarToDirectory(tarFile, targetDir) return extractTarToDirectory(tarFile, targetDir)
} }
@ -634,6 +634,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
} }
targetDir := path.Join(savePath, t.hostname+"-mapresponses") targetDir := path.Join(savePath, t.hostname+"-mapresponses")
return extractTarToDirectory(tarFile, targetDir) return extractTarToDirectory(tarFile, targetDir)
} }
@ -672,17 +673,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err) return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err)
} }
if strings.TrimSpace(schemaCheck) == "" { if strings.TrimSpace(schemaCheck) == "" {
return fmt.Errorf("database file exists but has no schema (empty database)") return errors.New("database file exists but has no schema (empty database)")
} }
// Show a preview of the schema (first 500 chars) // Show a preview of the schema (first 500 chars)
schemaPreview := schemaCheck schemaPreview := schemaCheck
if len(schemaPreview) > 500 { if len(schemaPreview) > 500 {
schemaPreview = schemaPreview[:500] + "..." schemaPreview = schemaPreview[:500] + "..."
} }
log.Printf("Database schema preview:\n%s", schemaPreview)
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil { if err != nil {
@ -727,7 +727,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
} }
} }
return fmt.Errorf("no regular file found in database tar archive") return errors.New("no regular file found in database tar archive")
} }
// Execute runs a command inside the Headscale container and returns the // Execute runs a command inside the Headscale container and returns the
@ -756,13 +756,13 @@ func (t *HeadscaleInContainer) Execute(
// GetPort returns the docker container port as a string. // GetPort returns the docker container port as a string.
func (t *HeadscaleInContainer) GetPort() string { func (t *HeadscaleInContainer) GetPort() string {
return fmt.Sprintf("%d", t.port) return strconv.Itoa(t.port)
} }
// GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer // GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer
// instance. // instance.
func (t *HeadscaleInContainer) GetHealthEndpoint() string { func (t *HeadscaleInContainer) GetHealthEndpoint() string {
return fmt.Sprintf("%s/health", t.GetEndpoint()) return t.GetEndpoint() + "/health"
} }
// GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer. // GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer.
@ -772,10 +772,10 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
t.port) t.port)
if t.hasTLS() { if t.hasTLS() {
return fmt.Sprintf("https://%s", hostEndpoint) return "https://" + hostEndpoint
} }
return fmt.Sprintf("http://%s", hostEndpoint) return "http://" + hostEndpoint
} }
// GetCert returns the public certificate of the HeadscaleInContainer. // GetCert returns the public certificate of the HeadscaleInContainer.
@ -910,6 +910,7 @@ func (t *HeadscaleInContainer) ListNodes(
} }
ret = append(ret, nodes...) ret = append(ret, nodes...)
return nil return nil
} }
@ -932,6 +933,7 @@ func (t *HeadscaleInContainer) ListNodes(
sort.Slice(ret, func(i, j int) bool { sort.Slice(ret, func(i, j int) bool {
return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1
}) })
return ret, nil return ret, nil
} }
@ -943,10 +945,10 @@ func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
var userMap map[string][]*v1.Node var userMap map[string][]*v1.Node
for _, node := range nodes { for _, node := range nodes {
if _, ok := userMap[node.User.Name]; !ok { if _, ok := userMap[node.GetUser().GetName()]; !ok {
mak.Set(&userMap, node.User.Name, []*v1.Node{node}) mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node})
} else { } else {
userMap[node.User.Name] = append(userMap[node.User.Name], node) userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node)
} }
} }
@ -999,7 +1001,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) {
var userMap map[string]*v1.User var userMap map[string]*v1.User
for _, user := range users { for _, user := range users {
mak.Set(&userMap, user.Name, user) mak.Set(&userMap, user.GetName(), user)
} }
return userMap, nil return userMap, nil
@ -1095,7 +1097,7 @@ func (h *HeadscaleInContainer) PID() (int, error) {
case 1: case 1:
return pids[0], nil return pids[0], nil
default: default:
return 0, fmt.Errorf("multiple headscale processes running") return 0, errors.New("multiple headscale processes running")
} }
} }
@ -1121,7 +1123,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
"headscale", "nodes", "approve-routes", "headscale", "nodes", "approve-routes",
"--output", "json", "--output", "json",
"--identifier", strconv.FormatUint(id, 10), "--identifier", strconv.FormatUint(id, 10),
fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")), "--routes=" + strings.Join(util.PrefixesToString(routes), ","),
} }
result, _, err := dockertestutil.ExecuteCommand( result, _, err := dockertestutil.ExecuteCommand(

View File

@ -4,13 +4,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"sort" "sort"
"strings" "strings"
"testing" "testing"
"time" "time"
"slices"
cmpdiff "github.com/google/go-cmp/cmp" cmpdiff "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -37,7 +36,6 @@ var allPorts = filter.PortRange{First: 0, Last: 0xffff}
// routes. // routes.
func TestEnablingRoutes(t *testing.T) { func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 3, NodesPerUser: 3,
@ -182,11 +180,12 @@ func TestEnablingRoutes(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" { switch peerStatus.ID {
case "1":
requirePeerSubnetRoutes(t, peerStatus, nil) requirePeerSubnetRoutes(t, peerStatus, nil)
} else if peerStatus.ID == "2" { case "2":
requirePeerSubnetRoutes(t, peerStatus, nil) requirePeerSubnetRoutes(t, peerStatus, nil)
} else { default:
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")}) requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")})
} }
} }
@ -195,7 +194,6 @@ func TestEnablingRoutes(t *testing.T) {
func TestHASubnetRouterFailover(t *testing.T) { func TestHASubnetRouterFailover(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 3, NodesPerUser: 3,
@ -779,7 +777,6 @@ func TestHASubnetRouterFailover(t *testing.T) {
// https://github.com/juanfont/headscale/issues/1604 // https://github.com/juanfont/headscale/issues/1604
func TestSubnetRouteACL(t *testing.T) { func TestSubnetRouteACL(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
user := "user4" user := "user4"
@ -1003,7 +1000,6 @@ func TestSubnetRouteACL(t *testing.T) {
// set during login instead of set. // set during login instead of set.
func TestEnablingExitRoutes(t *testing.T) { func TestEnablingExitRoutes(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
user := "user2" user := "user2"
@ -1097,7 +1093,6 @@ func TestEnablingExitRoutes(t *testing.T) {
// subnet router is working as expected. // subnet router is working as expected.
func TestSubnetRouterMultiNetwork(t *testing.T) { func TestSubnetRouterMultiNetwork(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,
@ -1177,7 +1172,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
// Enable route // Enable route
_, err = headscale.ApproveRoutes( _, err = headscale.ApproveRoutes(
nodes[0].Id, nodes[0].GetId(),
[]netip.Prefix{*pref}, []netip.Prefix{*pref},
) )
require.NoError(t, err) require.NoError(t, err)
@ -1224,7 +1219,6 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{ spec := ScenarioSpec{
NodesPerUser: 1, NodesPerUser: 1,
@ -1300,7 +1294,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
} }
// Enable route // Enable route
_, err = headscale.ApproveRoutes(nodes[0].Id, []netip.Prefix{tsaddr.AllIPv4()}) _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -1719,7 +1713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
assertNoErr(t, err) assertNoErr(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
assertNoErr(t, err) assertNoErr(t, err)
} }
// extra creation end. // extra creation end.
@ -2065,7 +2059,6 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
// that are explicitly allowed in the ACL. // that are explicitly allowed in the ACL.
func TestSubnetRouteACLFiltering(t *testing.T) { func TestSubnetRouteACLFiltering(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
// Use router and node users for better clarity // Use router and node users for better clarity
routerUser := "router" routerUser := "router"
@ -2090,7 +2083,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
// Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24) // Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24)
aclPolicyStr := fmt.Sprintf(`{ aclPolicyStr := `{
"hosts": { "hosts": {
"router": "100.64.0.1/32", "router": "100.64.0.1/32",
"node": "100.64.0.2/32" "node": "100.64.0.2/32"
@ -2115,7 +2108,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
] ]
} }
] ]
}`) }`
route, err := scenario.SubnetOfNetwork("usernet1") route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err) require.NoError(t, err)

View File

@ -123,7 +123,7 @@ type ScenarioSpec struct {
// NodesPerUser is how many nodes should be attached to each user. // NodesPerUser is how many nodes should be attached to each user.
NodesPerUser int NodesPerUser int
// Networks, if set, is the seperate Docker networks that should be // Networks, if set, is the separate Docker networks that should be
// created and a list of the users that should be placed in those networks. // created and a list of the users that should be placed in those networks.
// If not set, a single network will be created and all users+nodes will be // If not set, a single network will be created and all users+nodes will be
// added there. // added there.
@ -1077,7 +1077,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash) hostname := "hs-oidcmock-" + hash
usersJSON, err := json.Marshal(users) usersJSON, err := json.Marshal(users)
if err != nil { if err != nil {
@ -1093,16 +1093,15 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
}, },
Networks: s.Networks(), Networks: s.Networks(),
Env: []string{ Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), "MOCKOIDC_ADDR=" + hostname,
fmt.Sprintf("MOCKOIDC_PORT=%d", port), fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret", "MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), "MOCKOIDC_ACCESS_TTL=" + accessTTL.String(),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), "MOCKOIDC_USERS=" + string(usersJSON),
}, },
} }
headscaleBuildOptions := &dockertest.BuildOptions{ headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName, Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath, ContextDir: dockerContextPath,
@ -1117,7 +1116,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
// Add integration test labels if running under hi tool // Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc")
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions, headscaleBuildOptions,
mockOidcOptions, mockOidcOptions,
@ -1184,7 +1183,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-webservice-%s", hash) hostname := "hs-webservice-" + hash
network, ok := s.networks[s.prefixedNetworkName(networkName)] network, ok := s.networks[s.prefixedNetworkName(networkName)]
if !ok { if !ok {

View File

@ -28,7 +28,6 @@ func IntegrationSkip(t *testing.T) {
// nolint:tparallel // nolint:tparallel
func TestHeadscale(t *testing.T) { func TestHeadscale(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
var err error var err error
@ -75,7 +74,6 @@ func TestHeadscale(t *testing.T) {
// nolint:tparallel // nolint:tparallel
func TestTailscaleNodesJoiningHeadcale(t *testing.T) { func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
var err error var err error

View File

@ -22,35 +22,6 @@ func isSSHNoAccessStdError(stderr string) bool {
strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node") strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node")
} }
var retry = func(times int, sleepInterval time.Duration,
doWork func() (string, string, error),
) (string, string, error) {
var result string
var stderr string
var err error
for range times {
tempResult, tempStderr, err := doWork()
result += tempResult
stderr += tempStderr
if err == nil {
return result, stderr, nil
}
// If we get a permission denied error, we can fail immediately
// since that is something we won-t recover from by retrying.
if err != nil && isSSHNoAccessStdError(stderr) {
return result, stderr, err
}
time.Sleep(sleepInterval)
}
return result, stderr, err
}
func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario { func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario {
t.Helper() t.Helper()
@ -92,7 +63,6 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
func TestSSHOneUserToAll(t *testing.T) { func TestSSHOneUserToAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t, scenario := sshScenario(t,
&policyv2.Policy{ &policyv2.Policy{
@ -160,7 +130,6 @@ func TestSSHOneUserToAll(t *testing.T) {
func TestSSHMultipleUsersAllToAll(t *testing.T) { func TestSSHMultipleUsersAllToAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t, scenario := sshScenario(t,
&policyv2.Policy{ &policyv2.Policy{
@ -216,7 +185,6 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
func TestSSHNoSSHConfigured(t *testing.T) { func TestSSHNoSSHConfigured(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t, scenario := sshScenario(t,
&policyv2.Policy{ &policyv2.Policy{
@ -261,7 +229,6 @@ func TestSSHNoSSHConfigured(t *testing.T) {
func TestSSHIsBlockedInACL(t *testing.T) { func TestSSHIsBlockedInACL(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t, scenario := sshScenario(t,
&policyv2.Policy{ &policyv2.Policy{
@ -313,7 +280,6 @@ func TestSSHIsBlockedInACL(t *testing.T) {
func TestSSHUserOnlyIsolation(t *testing.T) { func TestSSHUserOnlyIsolation(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t, scenario := sshScenario(t,
&policyv2.Policy{ &policyv2.Policy{
@ -404,6 +370,14 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
} }
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
return doSSHWithRetry(t, client, peer, true)
}
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
return doSSHWithRetry(t, client, peer, false)
}
func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) {
t.Helper() t.Helper()
peerFQDN, _ := peer.FQDN() peerFQDN, _ := peer.FQDN()
@ -417,9 +391,29 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string,
log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname()) log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname())
log.Printf("Command: %s", strings.Join(command, " ")) log.Printf("Command: %s", strings.Join(command, " "))
return retry(10, 1*time.Second, func() (string, string, error) { var result, stderr string
return client.Execute(command) var err error
})
if retry {
// Use assert.EventuallyWithT to retry SSH connections for success cases
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, stderr, err = client.Execute(command)
// If we get a permission denied error, we can fail immediately
// since that is something we won't recover from by retrying.
if err != nil && isSSHNoAccessStdError(stderr) {
return // Don't retry permission denied errors
}
// For all other errors, assert no error to trigger retry
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
} else {
// For failure cases, just execute once
result, stderr, err = client.Execute(command)
}
return result, stderr, err
} }
func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) { func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) {
@ -434,7 +428,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {
t.Helper() t.Helper()
result, stderr, err := doSSH(t, client, peer) result, stderr, err := doSSHWithoutRetry(t, client, peer)
assert.Empty(t, result) assert.Empty(t, result)
@ -444,7 +438,7 @@ func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer Tailsc
func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) { func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) {
t.Helper() t.Helper()
result, stderr, _ := doSSH(t, client, peer) result, stderr, _ := doSSHWithoutRetry(t, client, peer)
assert.Empty(t, result) assert.Empty(t, result)

View File

@ -251,7 +251,6 @@ func New(
Env: []string{}, Env: []string{},
} }
if tsic.withWebsocketDERP { if tsic.withWebsocketDERP {
if version != VersionHead { if version != VersionHead {
return tsic, errInvalidClientConfig return tsic, errInvalidClientConfig
@ -463,7 +462,7 @@ func (t *TailscaleInContainer) buildLoginCommand(
if len(t.withTags) > 0 { if len(t.withTags) > 0 {
command = append(command, command = append(command,
fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")), "--advertise-tags="+strings.Join(t.withTags, ","),
) )
} }
@ -685,7 +684,7 @@ func (t *TailscaleInContainer) MustID() types.NodeID {
// Panics if version is lower then minimum. // Panics if version is lower then minimum.
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version)) panic("tsic.Netmap() called with unsupported version: " + t.version)
} }
command := []string{ command := []string{
@ -1026,7 +1025,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err
"tailscale", "ping", "tailscale", "ping",
fmt.Sprintf("--timeout=%s", args.timeout), fmt.Sprintf("--timeout=%s", args.timeout),
fmt.Sprintf("--c=%d", args.count), fmt.Sprintf("--c=%d", args.count),
fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)), "--until-direct=" + strconv.FormatBool(args.direct),
} }
command = append(command, hostnameOrIP) command = append(command, hostnameOrIP)
@ -1131,11 +1130,11 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
command := []string{ command := []string{
"curl", "curl",
"--silent", "--silent",
"--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())), "--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())),
"--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())), "--max-time", strconv.Itoa(int(args.maxTime.Seconds())),
"--retry", fmt.Sprintf("%d", args.retry), "--retry", strconv.Itoa(args.retry),
"--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())), "--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())),
"--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())), "--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())),
url, url,
} }
@ -1230,7 +1229,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) {
} }
if out.Len() == 0 { if out.Len() == 0 {
return nil, fmt.Errorf("file is empty") return nil, errors.New("file is empty")
} }
return out.Bytes(), nil return out.Bytes(), nil
@ -1259,5 +1258,6 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
if err = json.Unmarshal(currentProfile, &p); err != nil { if err = json.Unmarshal(currentProfile, &p); err != nil {
return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err) return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err)
} }
return &p.Persist.PrivateNodeKey, nil return &p.Persist.PrivateNodeKey, nil
} }

View File

@ -3,7 +3,6 @@ package integration
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
@ -267,7 +266,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) {
// This isn't really relevant for Self as it won't be in its own socket/wireguard. // This isn't really relevant for Self as it won't be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname()) // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
for _, peer := range status.Peer { for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
@ -311,7 +310,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper() t.Helper()
_, err := backoff.Retry(context.Background(), func() (struct{}, error) { _, err := backoff.Retry(t.Context(), func() (struct{}, error) {
stdout, stderr, err := c.Execute(command) stdout, stderr, err := c.Execute(command)
if err != nil { if err != nil {
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
@ -492,6 +491,7 @@ func groupApprover(name string) policyv2.AutoApprover {
func tagApprover(name string) policyv2.AutoApprover { func tagApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Tag(name)) return ptr.To(policyv2.Tag(name))
} }
// //
// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus // // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
// // if there is a peer with the given hostname. If no peer is found, nil is returned. // // if there is a peer with the given hostname. If no peer is found, nil is returned.