diff --git a/.dockerignore b/.dockerignore index e3acf996..9ea3e4a4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -17,3 +17,7 @@ LICENSE .vscode *.sock + +node_modules/ +package-lock.json +package.json diff --git a/.github/workflows/check-generated.yml b/.github/workflows/check-generated.yml new file mode 100644 index 00000000..17073a35 --- /dev/null +++ b/.github/workflows/check-generated.yml @@ -0,0 +1,55 @@ +name: Check Generated Files + +on: + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + check-generated: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 2 + - name: Get changed files + id: changed-files + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + with: + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - '**/*.proto' + - 'buf.gen.yaml' + - 'tools/**' + - uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + + - name: Run make generate + if: steps.changed-files.outputs.files == 'true' + run: nix develop --command -- make generate + + - name: Check for uncommitted changes + if: steps.changed-files.outputs.files == 'true' + run: | + if ! git diff --exit-code; then + echo "❌ Generated files are not up to date!" + echo "Please run 'make generate' and commit the changes." + exit 1 + else + echo "✅ All generated files are up to date." + fi diff --git a/.github/workflows/integration-test-template.yml b/.github/workflows/integration-test-template.yml index 939451d4..292985ad 100644 --- a/.github/workflows/integration-test-template.yml +++ b/.github/workflows/integration-test-template.yml @@ -77,7 +77,7 @@ jobs: attempt_delay: 300000 # 5 min attempt_limit: 2 command: | - nix develop --command -- hi run "^${{ inputs.test }}$" \ + nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=500 "^${{ inputs.test }}$" \ --timeout=120m \ ${{ inputs.postgres_flag }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 diff --git a/.gitignore b/.gitignore index 2ea56ad7..28d23c09 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ ignored/ tailscale/ .vscode/ +.claude/ + +*.prof # Binaries for programs and plugins *.exe @@ -46,3 +49,7 @@ integration_test/etc/config.dump.yaml /site __debug_bin + +node_modules/ +package-lock.json +package.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bac683b..f00e6934 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Next +**Minimum supported Tailscale client version: v1.64.0** + ### Database integrity improvements This release includes a significant database migration that addresses longstanding diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..8f2571ab --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,395 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +Headscale is an open-source implementation of the Tailscale control server written in Go. It provides self-hosted coordination for Tailscale networks (tailnets), managing node registration, IP allocation, policy enforcement, and DERP routing. + +## Development Commands + +### Quick Setup +```bash +# Recommended: Use Nix for dependency management +nix develop + +# Full development workflow +make dev # runs fmt + lint + test + build +``` + +### Essential Commands +```bash +# Build headscale binary +make build + +# Run tests +make test +go test ./... # All unit tests +go test -race ./... # With race detection + +# Run specific integration test +go run ./cmd/hi run "TestName" --postgres + +# Code formatting and linting +make fmt # Format all code (Go, docs, proto) +make lint # Lint all code (Go, proto) +make fmt-go # Format Go code only +make lint-go # Lint Go code only + +# Protocol buffer generation (after modifying proto/) +make generate + +# Clean build artifacts +make clean +``` + +### Integration Testing +```bash +# Use the hi (Headscale Integration) test runner +go run ./cmd/hi doctor # Check system requirements +go run ./cmd/hi run "TestPattern" # Run specific test +go run ./cmd/hi run "TestPattern" --postgres # With PostgreSQL backend + +# Test artifacts are saved to control_logs/ with logs and debug data +``` + +## Project Structure & Architecture + +### Top-Level Organization + +``` +headscale/ +├── cmd/ # Command-line applications +│ ├── headscale/ # Main headscale server binary +│ └── hi/ # Headscale Integration test runner +├── hscontrol/ # Core control plane logic +├── integration/ # End-to-end Docker-based tests +├── proto/ # Protocol buffer definitions +├── gen/ # Generated code (protobuf) +├── docs/ # Documentation +└── packaging/ # Distribution packaging +``` + +### Core Packages (`hscontrol/`) + +**Main Server (`hscontrol/`)** +- `app.go`: Application setup, dependency injection, server lifecycle +- `handlers.go`: HTTP/gRPC API endpoints for management operations +- `grpcv1.go`: gRPC service implementation for headscale API +- `poll.go`: **Critical** - Handles Tailscale MapRequest/MapResponse protocol +- `noise.go`: Noise protocol implementation for secure client communication +- `auth.go`: Authentication flows (web, OIDC, command-line) +- `oidc.go`: OpenID Connect integration for user authentication + +**State Management (`hscontrol/state/`)** +- `state.go`: Central coordinator for all subsystems (database, policy, IP allocation, DERP) +- `node_store.go`: **Performance-critical** - In-memory cache with copy-on-write semantics +- Thread-safe operations with deadlock detection +- Coordinates between database persistence and real-time operations + +**Database Layer (`hscontrol/db/`)** +- `db.go`: Database abstraction, GORM setup, migration management +- `node.go`: Node lifecycle, registration, expiration, IP assignment +- `users.go`: User management, namespace isolation +- `api_key.go`: API authentication tokens +- `preauth_keys.go`: Pre-authentication keys for automated node registration +- `ip.go`: IP address allocation and management +- `policy.go`: Policy storage and retrieval +- Schema migrations in `schema.sql` with extensive test data coverage + +**Policy Engine (`hscontrol/policy/`)** +- `policy.go`: Core ACL evaluation logic, HuJSON parsing +- `v2/`: Next-generation policy system with improved filtering +- `matcher/`: ACL rule matching and evaluation engine +- Determines peer visibility, route approval, and network access rules +- Supports both file-based and database-stored policies + +**Network Management (`hscontrol/`)** +- `derp/`: DERP (Designated Encrypted Relay for Packets) server implementation + - NAT traversal when direct connections fail + - Fallback relay for firewall-restricted environments +- `mapper/`: Converts internal Headscale state to Tailscale's wire protocol format + - `tail.go`: Tailscale-specific data structure generation +- `routes/`: Subnet route management and primary route selection +- `dns/`: DNS record management and MagicDNS implementation + +**Utilities & Support (`hscontrol/`)** +- `types/`: Core data structures, configuration, validation +- `util/`: Helper functions for networking, DNS, key management +- `templates/`: Client configuration templates (Apple, Windows, etc.) +- `notifier/`: Event notification system for real-time updates +- `metrics.go`: Prometheus metrics collection +- `capver/`: Tailscale capability version management + +### Key Subsystem Interactions + +**Node Registration Flow** +1. **Client Connection**: `noise.go` handles secure protocol handshake +2. **Authentication**: `auth.go` validates credentials (web/OIDC/preauth) +3. **State Creation**: `state.go` coordinates IP allocation via `db/ip.go` +4. **Storage**: `db/node.go` persists node, `NodeStore` caches in memory +5. **Network Setup**: `mapper/` generates initial Tailscale network map + +**Ongoing Operations** +1. **Poll Requests**: `poll.go` receives periodic client updates +2. **State Updates**: `NodeStore` maintains real-time node information +3. **Policy Application**: `policy/` evaluates ACL rules for peer relationships +4. **Map Distribution**: `mapper/` sends network topology to all affected clients + +**Route Management** +1. **Advertisement**: Clients announce routes via `poll.go` Hostinfo updates +2. **Storage**: `db/` persists routes, `NodeStore` caches for performance +3. **Approval**: `policy/` auto-approves routes based on ACL rules +4. **Distribution**: `routes/` selects primary routes, `mapper/` distributes to peers + +### Command-Line Tools (`cmd/`) + +**Main Server (`cmd/headscale/`)** +- `headscale.go`: CLI parsing, configuration loading, server startup +- Supports daemon mode, CLI operations (user/node management), database operations + +**Integration Test Runner (`cmd/hi/`)** +- `main.go`: Test execution framework with Docker orchestration +- `run.go`: Individual test execution with artifact collection +- `doctor.go`: System requirements validation +- `docker.go`: Container lifecycle management +- Essential for validating changes against real Tailscale clients + +### Generated & External Code + +**Protocol Buffers (`proto/` → `gen/`)** +- Defines gRPC API for headscale management operations +- Client libraries can generate from these definitions +- Run `make generate` after modifying `.proto` files + +**Integration Testing (`integration/`)** +- `scenario.go`: Docker test environment setup +- `tailscale.go`: Tailscale client container management +- Individual test files for specific functionality areas +- Real end-to-end validation with network isolation + +### Critical Performance Paths + +**High-Frequency Operations** +1. **MapRequest Processing** (`poll.go`): Every 15-60 seconds per client +2. **NodeStore Reads** (`node_store.go`): Every operation requiring node data +3. **Policy Evaluation** (`policy/`): On every peer relationship calculation +4. **Route Lookups** (`routes/`): During network map generation + +**Database Write Patterns** +- **Frequent**: Node heartbeats, endpoint updates, route changes +- **Moderate**: User operations, policy updates, API key management +- **Rare**: Schema migrations, bulk operations + +### Configuration & Deployment + +**Configuration** (`hscontrol/types/config.go`)** +- Database connection settings (SQLite/PostgreSQL) +- Network configuration (IP ranges, DNS settings) +- Policy mode (file vs database) +- DERP relay configuration +- OIDC provider settings + +**Key Dependencies** +- **GORM**: Database ORM with migration support +- **Tailscale Libraries**: Core networking and protocol code +- **Zerolog**: Structured logging throughout the application +- **Buf**: Protocol buffer toolchain for code generation + +### Development Workflow Integration + +The architecture supports incremental development: +- **Unit Tests**: Focus on individual packages (`*_test.go` files) +- **Integration Tests**: Validate cross-component interactions +- **Database Tests**: Extensive migration and data integrity validation +- **Policy Tests**: ACL rule evaluation and edge cases +- **Performance Tests**: NodeStore and high-frequency operation validation + +## Integration Test System + +### Overview +Integration tests use Docker containers running real Tailscale clients against a Headscale server. Tests validate end-to-end functionality including routing, ACLs, node lifecycle, and network coordination. + +### Running Integration Tests + +**System Requirements** +```bash +# Check if your system is ready +go run ./cmd/hi doctor +``` +This verifies Docker, Go, required images, and disk space. + +**Test Execution Patterns** +```bash +# Run a single test (recommended for development) +go run ./cmd/hi run "TestSubnetRouterMultiNetwork" + +# Run with PostgreSQL backend (for database-heavy tests) +go run ./cmd/hi run "TestExpireNode" --postgres + +# Run multiple tests with pattern matching +go run ./cmd/hi run "TestSubnet*" + +# Run all integration tests (CI/full validation) +go test ./integration -timeout 30m +``` + +**Test Categories & Timing** +- **Fast tests** (< 2 min): Basic functionality, CLI operations +- **Medium tests** (2-5 min): Route management, ACL validation +- **Slow tests** (5+ min): Node expiration, HA failover +- **Long-running tests** (10+ min): `TestNodeOnlineStatus` (12 min duration) + +### Test Infrastructure + +**Docker Setup** +- Headscale server container with configurable database backend +- Multiple Tailscale client containers with different versions +- Isolated networks per test scenario +- Automatic cleanup after test completion + +**Test Artifacts** +All test runs save artifacts to `control_logs/TIMESTAMP-ID/`: +``` +control_logs/20250713-213106-iajsux/ +├── hs-testname-abc123.stderr.log # Headscale server logs +├── hs-testname-abc123.stdout.log +├── hs-testname-abc123.db # Database snapshot +├── hs-testname-abc123_metrics.txt # Prometheus metrics +├── hs-testname-abc123-mapresponses/ # Protocol debug data +├── ts-client-xyz789.stderr.log # Tailscale client logs +├── ts-client-xyz789.stdout.log +└── ts-client-xyz789_status.json # Client status dump +``` + +### Test Development Guidelines + +**Timing Considerations** +Integration tests involve real network operations and Docker container lifecycle: + +```go +// ❌ Wrong: Immediate assertions after async operations +client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"}) +nodes, _ := headscale.ListNodes() +require.Len(t, nodes[0].GetAvailableRoutes(), 1) // May fail due to timing + +// ✅ Correct: Wait for async operations to complete +client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"}) +require.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes[0].GetAvailableRoutes(), 1) +}, 10*time.Second, 100*time.Millisecond, "route should be advertised") +``` + +**Common Test Patterns** +- **Route Advertisement**: Use `EventuallyWithT` for route propagation +- **Node State Changes**: Wait for NodeStore synchronization +- **ACL Policy Changes**: Allow time for policy recalculation +- **Network Connectivity**: Use ping tests with retries + +**Test Data Management** +```go +// Node identification: Don't assume array ordering +expectedRoutes := map[string]string{"1": "10.33.0.0/16"} +for _, node := range nodes { + nodeIDStr := fmt.Sprintf("%d", node.GetId()) + if route, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { + // Test the node that should have the route + } +} +``` + +### Troubleshooting Integration Tests + +**Common Failure Patterns** +1. **Timing Issues**: Test assertions run before async operations complete + - **Solution**: Use `EventuallyWithT` with appropriate timeouts + - **Timeout Guidelines**: 3-5s for route operations, 10s for complex scenarios + +2. **Infrastructure Problems**: Disk space, Docker issues, network conflicts + - **Check**: `go run ./cmd/hi doctor` for system health + - **Clean**: Remove old test containers and networks + +3. **NodeStore Synchronization**: Tests expecting immediate data availability + - **Key Points**: Route advertisements must propagate through poll requests + - **Fix**: Wait for NodeStore updates after Hostinfo changes + +4. **Database Backend Differences**: SQLite vs PostgreSQL behavior differences + - **Use**: `--postgres` flag for database-intensive tests + - **Note**: Some timing characteristics differ between backends + +**Debugging Failed Tests** +1. **Check test artifacts** in `control_logs/` for detailed logs +2. **Examine MapResponse JSON** files for protocol-level debugging +3. **Review Headscale stderr logs** for server-side error messages +4. **Check Tailscale client status** for network-level issues + +**Resource Management** +- Tests require significant disk space (each run ~100MB of logs) +- Docker containers are cleaned up automatically on success +- Failed tests may leave containers running - clean manually if needed +- Use `docker system prune` periodically to reclaim space + +### Best Practices for Test Modifications + +1. **Always test locally** before committing integration test changes +2. **Use appropriate timeouts** - too short causes flaky tests, too long slows CI +3. **Clean up properly** - ensure tests don't leave persistent state +4. **Handle both success and failure paths** in test scenarios +5. **Document timing requirements** for complex test scenarios + +## NodeStore Implementation Details + +**Key Insight from Recent Work**: The NodeStore is a critical performance optimization that caches node data in memory while ensuring consistency with the database. When working with route advertisements or node state changes: + +1. **Timing Considerations**: Route advertisements need time to propagate from clients to server. Use `require.EventuallyWithT()` patterns in tests instead of immediate assertions. + +2. **Synchronization Points**: NodeStore updates happen at specific points like `poll.go:420` after Hostinfo changes. Ensure these are maintained when modifying the polling logic. + +3. **Peer Visibility**: The NodeStore's `peersFunc` determines which nodes are visible to each other. Policy-based filtering is separate from monitoring visibility - expired nodes should remain visible for debugging but marked as expired. + +## Testing Guidelines + +### Integration Test Patterns +```go +// Use EventuallyWithT for async operations +require.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + // Check expected state +}, 10*time.Second, 100*time.Millisecond, "description") + +// Node route checking by actual node properties, not array position +var routeNode *v1.Node +for _, node := range nodes { + if nodeIDStr := fmt.Sprintf("%d", node.GetId()); expectedRoutes[nodeIDStr] != "" { + routeNode = node + break + } +} +``` + +### Running Problematic Tests +- Some tests require significant time (e.g., `TestNodeOnlineStatus` runs for 12 minutes) +- Infrastructure issues like disk space can cause test failures unrelated to code changes +- Use `--postgres` flag when testing database-heavy scenarios + +## Important Notes + +- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting) +- **Protocol Buffers**: Changes to `proto/` require `make generate` and should be committed separately +- **Code Style**: Enforced via golangci-lint with golines (width 88) and gofumpt formatting +- **Database**: Supports both SQLite (development) and PostgreSQL (production/testing) +- **Integration Tests**: Require Docker and can consume significant disk space +- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management + +## Debugging Integration Tests + +Test artifacts are preserved in `control_logs/TIMESTAMP-ID/` including: +- Headscale server logs (stderr/stdout) +- Tailscale client logs and status +- Database dumps and network captures +- MapResponse JSON files for protocol debugging + +When tests fail, check these artifacts first before assuming code issues. diff --git a/Makefile b/Makefile index 563109a6..d9b2c76b 100644 --- a/Makefile +++ b/Makefile @@ -87,10 +87,9 @@ lint-proto: check-deps $(PROTO_SOURCES) # Code generation .PHONY: generate -generate: check-deps $(PROTO_SOURCES) - @echo "Generating code from Protocol Buffers..." - rm -rf gen - buf generate proto +generate: check-deps + @echo "Generating code..." + go generate ./... # Clean targets .PHONY: clean diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index c482299c..8b32d935 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -212,13 +212,10 @@ var listUsersCmd = &cobra.Command{ switch { case id > 0: request.Id = uint64(id) - break case username != "": request.Name = username - break case email != "": request.Email = email - break } response, err := client.ListUsers(ctx, request) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 9abc6d4f..e7a50485 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -90,6 +90,32 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { log.Printf("Starting test: %s", config.TestPattern) + // Start stats collection for container resource monitoring (if enabled) + var statsCollector *StatsCollector + if config.Stats { + var err error + statsCollector, err = NewStatsCollector() + if err != nil { + if config.Verbose { + log.Printf("Warning: failed to create stats collector: %v", err) + } + statsCollector = nil + } + + if statsCollector != nil { + defer statsCollector.Close() + + // Start stats collection immediately - no need for complex retry logic + // The new implementation monitors Docker events and will catch containers as they start + if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil { + if config.Verbose { + log.Printf("Warning: failed to start stats collection: %v", err) + } + } + defer statsCollector.StopCollection() + } + } + exitCode, err := streamAndWait(ctx, cli, resp.ID) // Ensure all containers have finished and logs are flushed before extracting artifacts @@ -105,6 +131,20 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Always list control files regardless of test outcome listControlFiles(logsDir) + // Print stats summary and check memory limits if enabled + if config.Stats && statsCollector != nil { + violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit) + if len(violations) > 0 { + log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") + log.Printf("=================================") + for _, violation := range violations { + log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", + violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) + } + return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) + } + } + shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0) if shouldCleanup { if config.Verbose { @@ -379,10 +419,37 @@ func getDockerSocketPath() string { return "/var/run/docker.sock" } -// ensureImageAvailable pulls the specified Docker image to ensure it's available. +// checkImageAvailableLocally checks if the specified Docker image is available locally. +func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) { + _, _, err := cli.ImageInspectWithRaw(ctx, imageName) + if err != nil { + if client.IsErrNotFound(err) { + return false, nil + } + return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err) + } + + return true, nil +} + +// ensureImageAvailable checks if the image is available locally first, then pulls if needed. func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error { + // First check if image is available locally + available, err := checkImageAvailableLocally(ctx, cli, imageName) + if err != nil { + return fmt.Errorf("failed to check local image availability: %w", err) + } + + if available { + if verbose { + log.Printf("Image %s is available locally", imageName) + } + return nil + } + + // Image not available locally, try to pull it if verbose { - log.Printf("Pulling image %s...", imageName) + log.Printf("Image %s not found locally, pulling...", imageName) } reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{}) diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index a45bfa8f..8af6051f 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -190,7 +190,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult { } } -// checkGolangImage verifies we can access the golang Docker image. +// checkGolangImage verifies the golang Docker image is available locally or can be pulled. func checkGolangImage(ctx context.Context) DoctorResult { cli, err := createDockerClient() if err != nil { @@ -205,17 +205,40 @@ func checkGolangImage(ctx context.Context) DoctorResult { goVersion := detectGoVersion() imageName := "golang:" + goVersion - // Check if we can pull the image + // First check if image is available locally + available, err := checkImageAvailableLocally(ctx, cli, imageName) + if err != nil { + return DoctorResult{ + Name: "Golang Image", + Status: "FAIL", + Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err), + Suggestions: []string{ + "Check Docker daemon status", + "Try: docker images | grep golang", + }, + } + } + + if available { + return DoctorResult{ + Name: "Golang Image", + Status: "PASS", + Message: fmt.Sprintf("Golang image %s is available locally", imageName), + } + } + + // Image not available locally, try to pull it err = ensureImageAvailable(ctx, cli, imageName, false) if err != nil { return DoctorResult{ Name: "Golang Image", Status: "FAIL", - Message: fmt.Sprintf("Cannot pull golang image %s: %v", imageName, err), + Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err), Suggestions: []string{ "Check internet connectivity", "Verify Docker Hub access", "Try: docker pull " + imageName, + "Or run tests offline if image was pulled previously", }, } } @@ -223,7 +246,7 @@ func checkGolangImage(ctx context.Context) DoctorResult { return DoctorResult{ Name: "Golang Image", Status: "PASS", - Message: fmt.Sprintf("Golang image %s is available", imageName), + Message: fmt.Sprintf("Golang image %s is now available", imageName), } } diff --git a/cmd/hi/run.go b/cmd/hi/run.go index f40f563d..cd06b2d1 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -24,6 +24,9 @@ type RunConfig struct { KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"` LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"` Verbose bool `flag:"verbose,default=false,Verbose output"` + Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"` + HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"` + TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"` } // runIntegrationTest executes the integration test workflow. diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go new file mode 100644 index 00000000..ecb3f4fd --- /dev/null +++ b/cmd/hi/stats.go @@ -0,0 +1,468 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "sort" + "strings" + "sync" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/events" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" +) + +// ContainerStats represents statistics for a single container +type ContainerStats struct { + ContainerID string + ContainerName string + Stats []StatsSample + mutex sync.RWMutex +} + +// StatsSample represents a single stats measurement +type StatsSample struct { + Timestamp time.Time + CPUUsage float64 // CPU usage percentage + MemoryMB float64 // Memory usage in MB +} + +// StatsCollector manages collection of container statistics +type StatsCollector struct { + client *client.Client + containers map[string]*ContainerStats + stopChan chan struct{} + wg sync.WaitGroup + mutex sync.RWMutex + collectionStarted bool +} + +// NewStatsCollector creates a new stats collector instance +func NewStatsCollector() (*StatsCollector, error) { + cli, err := createDockerClient() + if err != nil { + return nil, fmt.Errorf("failed to create Docker client: %w", err) + } + + return &StatsCollector{ + client: cli, + containers: make(map[string]*ContainerStats), + stopChan: make(chan struct{}), + }, nil +} + +// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID +func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error { + sc.mutex.Lock() + defer sc.mutex.Unlock() + + if sc.collectionStarted { + return fmt.Errorf("stats collection already started") + } + + sc.collectionStarted = true + + // Start monitoring existing containers + sc.wg.Add(1) + go sc.monitorExistingContainers(ctx, runID, verbose) + + // Start Docker events monitoring for new containers + sc.wg.Add(1) + go sc.monitorDockerEvents(ctx, runID, verbose) + + if verbose { + log.Printf("Started container monitoring for run ID %s", runID) + } + + return nil +} + +// StopCollection stops all stats collection +func (sc *StatsCollector) StopCollection() { + // Check if already stopped without holding lock + sc.mutex.RLock() + if !sc.collectionStarted { + sc.mutex.RUnlock() + return + } + sc.mutex.RUnlock() + + // Signal stop to all goroutines + close(sc.stopChan) + + // Wait for all goroutines to finish + sc.wg.Wait() + + // Mark as stopped + sc.mutex.Lock() + sc.collectionStarted = false + sc.mutex.Unlock() +} + +// monitorExistingContainers checks for existing containers that match our criteria +func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) { + defer sc.wg.Done() + + containers, err := sc.client.ContainerList(ctx, container.ListOptions{}) + if err != nil { + if verbose { + log.Printf("Failed to list existing containers: %v", err) + } + return + } + + for _, cont := range containers { + if sc.shouldMonitorContainer(cont, runID) { + sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose) + } + } +} + +// monitorDockerEvents listens for container start events and begins monitoring relevant containers +func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) { + defer sc.wg.Done() + + filter := filters.NewArgs() + filter.Add("type", "container") + filter.Add("event", "start") + + eventOptions := events.ListOptions{ + Filters: filter, + } + + events, errs := sc.client.Events(ctx, eventOptions) + + for { + select { + case <-sc.stopChan: + return + case <-ctx.Done(): + return + case event := <-events: + if event.Type == "container" && event.Action == "start" { + // Get container details + containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) + if err != nil { + continue + } + + // Convert to types.Container format for consistency + cont := types.Container{ + ID: containerInfo.ID, + Names: []string{containerInfo.Name}, + Labels: containerInfo.Config.Labels, + } + + if sc.shouldMonitorContainer(cont, runID) { + sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose) + } + } + case err := <-errs: + if verbose { + log.Printf("Error in Docker events stream: %v", err) + } + return + } + } +} + +// shouldMonitorContainer determines if a container should be monitored +func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { + // Check if it has the correct run ID label + if cont.Labels == nil || cont.Labels["hi.run-id"] != runID { + return false + } + + // Check if it's an hs- or ts- container + for _, name := range cont.Names { + containerName := strings.TrimPrefix(name, "/") + if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") { + return true + } + } + + return false +} + +// startStatsForContainer begins stats collection for a specific container +func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) { + containerName = strings.TrimPrefix(containerName, "/") + + sc.mutex.Lock() + // Check if we're already monitoring this container + if _, exists := sc.containers[containerID]; exists { + sc.mutex.Unlock() + return + } + + sc.containers[containerID] = &ContainerStats{ + ContainerID: containerID, + ContainerName: containerName, + Stats: make([]StatsSample, 0), + } + sc.mutex.Unlock() + + if verbose { + log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12]) + } + + sc.wg.Add(1) + go sc.collectStatsForContainer(ctx, containerID, verbose) +} + +// collectStatsForContainer collects stats for a specific container using Docker API streaming +func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) { + defer sc.wg.Done() + + // Use Docker API streaming stats - much more efficient than CLI + statsResponse, err := sc.client.ContainerStats(ctx, containerID, true) + if err != nil { + if verbose { + log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) + } + return + } + defer statsResponse.Body.Close() + + decoder := json.NewDecoder(statsResponse.Body) + var prevStats *container.Stats + + for { + select { + case <-sc.stopChan: + return + case <-ctx.Done(): + return + default: + var stats container.Stats + if err := decoder.Decode(&stats); err != nil { + // EOF is expected when container stops or stream ends + if err.Error() != "EOF" && verbose { + log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) + } + return + } + + // Calculate CPU percentage (only if we have previous stats) + var cpuPercent float64 + if prevStats != nil { + cpuPercent = calculateCPUPercent(prevStats, &stats) + } + + // Calculate memory usage in MB + memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024) + + // Store the sample (skip first sample since CPU calculation needs previous stats) + if prevStats != nil { + // Get container stats reference without holding the main mutex + var containerStats *ContainerStats + var exists bool + + sc.mutex.RLock() + containerStats, exists = sc.containers[containerID] + sc.mutex.RUnlock() + + if exists && containerStats != nil { + containerStats.mutex.Lock() + containerStats.Stats = append(containerStats.Stats, StatsSample{ + Timestamp: time.Now(), + CPUUsage: cpuPercent, + MemoryMB: memoryMB, + }) + containerStats.mutex.Unlock() + } + } + + // Save current stats for next iteration + prevStats = &stats + } + } +} + +// calculateCPUPercent calculates CPU usage percentage from Docker stats +func calculateCPUPercent(prevStats, stats *container.Stats) float64 { + // CPU calculation based on Docker's implementation + cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage) + systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage) + + if systemDelta > 0 && cpuDelta >= 0 { + // Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100 + numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage)) + if numCPUs == 0 { + // Fallback: if PercpuUsage is not available, assume 1 CPU + numCPUs = 1.0 + } + return (cpuDelta / systemDelta) * numCPUs * 100.0 + } + return 0.0 +} + +// ContainerStatsSummary represents summary statistics for a container +type ContainerStatsSummary struct { + ContainerName string + SampleCount int + CPU StatsSummary + Memory StatsSummary +} + +// MemoryViolation represents a container that exceeded the memory limit +type MemoryViolation struct { + ContainerName string + MaxMemoryMB float64 + LimitMB float64 +} + +// StatsSummary represents min, max, and average for a metric +type StatsSummary struct { + Min float64 + Max float64 + Average float64 +} + +// GetSummary returns a summary of collected statistics +func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { + // Take snapshot of container references without holding main lock long + sc.mutex.RLock() + containerRefs := make([]*ContainerStats, 0, len(sc.containers)) + for _, containerStats := range sc.containers { + containerRefs = append(containerRefs, containerStats) + } + sc.mutex.RUnlock() + + summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) + + for _, containerStats := range containerRefs { + containerStats.mutex.RLock() + stats := make([]StatsSample, len(containerStats.Stats)) + copy(stats, containerStats.Stats) + containerName := containerStats.ContainerName + containerStats.mutex.RUnlock() + + if len(stats) == 0 { + continue + } + + summary := ContainerStatsSummary{ + ContainerName: containerName, + SampleCount: len(stats), + } + + // Calculate CPU stats + cpuValues := make([]float64, len(stats)) + memoryValues := make([]float64, len(stats)) + + for i, sample := range stats { + cpuValues[i] = sample.CPUUsage + memoryValues[i] = sample.MemoryMB + } + + summary.CPU = calculateStatsSummary(cpuValues) + summary.Memory = calculateStatsSummary(memoryValues) + + summaries = append(summaries, summary) + } + + // Sort by container name for consistent output + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].ContainerName < summaries[j].ContainerName + }) + + return summaries +} + +// calculateStatsSummary calculates min, max, and average for a slice of values +func calculateStatsSummary(values []float64) StatsSummary { + if len(values) == 0 { + return StatsSummary{} + } + + min := values[0] + max := values[0] + sum := 0.0 + + for _, value := range values { + if value < min { + min = value + } + if value > max { + max = value + } + sum += value + } + + return StatsSummary{ + Min: min, + Max: max, + Average: sum / float64(len(values)), + } +} + +// PrintSummary prints the statistics summary to the console +func (sc *StatsCollector) PrintSummary() { + summaries := sc.GetSummary() + + if len(summaries) == 0 { + log.Printf("No container statistics collected") + return + } + + log.Printf("Container Resource Usage Summary:") + log.Printf("================================") + + for _, summary := range summaries { + log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount) + log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%", + summary.CPU.Min, summary.CPU.Max, summary.CPU.Average) + log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB", + summary.Memory.Min, summary.Memory.Max, summary.Memory.Average) + log.Printf("") + } +} + +// CheckMemoryLimits checks if any containers exceeded their memory limits +func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation { + if hsLimitMB <= 0 && tsLimitMB <= 0 { + return nil + } + + summaries := sc.GetSummary() + var violations []MemoryViolation + + for _, summary := range summaries { + var limitMB float64 + if strings.HasPrefix(summary.ContainerName, "hs-") { + limitMB = hsLimitMB + } else if strings.HasPrefix(summary.ContainerName, "ts-") { + limitMB = tsLimitMB + } else { + continue // Skip containers that don't match our patterns + } + + if limitMB > 0 && summary.Memory.Max > limitMB { + violations = append(violations, MemoryViolation{ + ContainerName: summary.ContainerName, + MaxMemoryMB: summary.Memory.Max, + LimitMB: limitMB, + }) + } + } + + return violations +} + +// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any +func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation { + sc.PrintSummary() + return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB) +} + +// Close closes the stats collector and cleans up resources +func (sc *StatsCollector) Close() error { + sc.StopCollection() + return sc.client.Close() +} \ No newline at end of file diff --git a/flake.nix b/flake.nix index 17a99b56..70b51c7b 100644 --- a/flake.nix +++ b/flake.nix @@ -19,7 +19,7 @@ overlay = _: prev: let pkgs = nixpkgs.legacyPackages.${prev.system}; buildGo = pkgs.buildGo124Module; - vendorHash = "sha256-S2GnCg2dyfjIyi5gXhVEuRs5Bop2JAhZcnhg1fu4/Gg="; + vendorHash = "sha256-83L2NMyOwKCHWqcowStJ7Ze/U9CJYhzleDRLrJNhX2g="; in { headscale = buildGo { pname = "headscale"; diff --git a/go.mod b/go.mod index 399cc807..f719bc0b 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/gorilla/mux v1.8.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0 github.com/jagottsicher/termcolor v1.0.2 - github.com/klauspost/compress v1.18.0 github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 github.com/ory/dockertest/v3 v3.12.0 github.com/philip-bui/grpc-zerolog v1.0.1 @@ -43,11 +42,11 @@ require ( github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97 github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.39.0 + golang.org/x/crypto v0.40.0 golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 - golang.org/x/net v0.41.0 + golang.org/x/net v0.42.0 golang.org/x/oauth2 v0.30.0 - golang.org/x/sync v0.15.0 + golang.org/x/sync v0.16.0 google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 google.golang.org/grpc v1.73.0 google.golang.org/protobuf v1.36.6 @@ -55,7 +54,7 @@ require ( gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.30.0 - tailscale.com v1.84.2 + tailscale.com v1.84.3 zgo.at/zcache/v2 v2.2.0 zombiezen.com/go/postgrestest v1.0.1 ) @@ -81,7 +80,7 @@ require ( modernc.org/libc v1.62.1 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.10.0 // indirect - modernc.org/sqlite v1.37.0 // indirect + modernc.org/sqlite v1.37.0 ) require ( @@ -166,6 +165,7 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jsimonetti/rtnetlink v1.4.1 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/lib/pq v1.10.9 // indirect @@ -231,14 +231,19 @@ require ( go.opentelemetry.io/otel/trace v1.36.0 // indirect go.uber.org/multierr v1.11.0 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect - golang.org/x/mod v0.25.0 // indirect - golang.org/x/sys v0.33.0 // indirect - golang.org/x/term v0.32.0 // indirect - golang.org/x/text v0.26.0 // indirect + golang.org/x/mod v0.26.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/term v0.33.0 // indirect + golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.10.0 // indirect - golang.org/x/tools v0.33.0 // indirect + golang.org/x/tools v0.35.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect ) + +tool ( + golang.org/x/tools/cmd/stringer + tailscale.com/cmd/viewer +) diff --git a/go.sum b/go.sum index 3696736b..5571e67f 100644 --- a/go.sum +++ b/go.sum @@ -555,8 +555,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= @@ -567,8 +567,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -577,8 +577,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= -golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -587,8 +587,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -615,8 +615,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -624,8 +624,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= +golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= +golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -633,8 +633,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -643,8 +643,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -714,6 +714,8 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0= tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo= +tailscale.com v1.84.3 h1:Ur9LMedSgicwbqpy5xn7t49G8490/s6rqAJOk5Q5AYE= +tailscale.com v1.84.3/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo= zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo= zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= diff --git a/hscontrol/app.go b/hscontrol/app.go index bb98f82d..2bc42ea0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -28,14 +28,15 @@ import ( derpServer "github.com/juanfont/headscale/hscontrol/derp/server" "github.com/juanfont/headscale/hscontrol/dns" "github.com/juanfont/headscale/hscontrol/mapper" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/pkg/profile" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/sasha-s/go-deadlock" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" "golang.org/x/sync/errgroup" @@ -64,6 +65,19 @@ var ( ) ) +var ( + debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK") + debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT") +) + +func init() { + deadlock.Opts.Disable = !debugDeadlock + if debugDeadlock { + deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout() + deadlock.Opts.PrintAllCurrentGoroutines = true + } +} + const ( AuthPrefix = "Bearer " updateInterval = 5 * time.Second @@ -82,9 +96,8 @@ type Headscale struct { // Things that generate changes extraRecordMan *dns.ExtraRecordsMan - mapper *mapper.Mapper - nodeNotifier *notifier.Notifier authProvider AuthProvider + mapBatcher mapper.Batcher pollNetMapStreamWG sync.WaitGroup } @@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { cfg: cfg, noisePrivateKey: noisePrivateKey, pollNetMapStreamWG: sync.WaitGroup{}, - nodeNotifier: notifier.NewNotifier(cfg), state: s, } @@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname) - app.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - + app.Change(policyChanged) log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node") }) app.ephemeralGC = ephemeralGC @@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { defer cancel() oidcProvider, err := NewAuthProviderOIDC( ctx, + &app, cfg.ServerURL, &cfg.OIDC, - app.state, - app.nodeNotifier, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { return case <-expireTicker.C: - var update types.StateUpdate + var expiredNodeChanges []change.ChangeSet var changed bool - lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) + lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) if changed { - log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes") + log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes") - ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") - h.nodeNotifier.NotifyAll(ctx, update) + // Send the changes directly since they're already in the new format + for _, nodeChange := range expiredNodeChanges { + h.Change(nodeChange) + } } case <-derpTickerChan: @@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { derpMap.Regions[region.RegionID] = ®ion } - ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateDERPUpdated, - DERPMap: derpMap, - }) + h.Change(change.DERPSet) case records, ok := <-extraRecordsUpdate: if !ok { @@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { } h.cfg.TailcfgDNSConfig.ExtraRecords = records - ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all") - // TODO(kradalby): We can probably do better than sending a full update here, - // but for now this will ensure that all of the nodes get the new records. - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + h.Change(change.ExtraRecordsSet) } } } func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, - req interface{}, + req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, -) (interface{}, error) { +) (any, error) { // Check if the request is coming from the on-server client. // This is not secure, but it is to maintain maintainability // with the "legacy" database-based client @@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } -// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. -// // Maybe we should attempt a new in memory state and not go via the DB? -// // Maybe this should be implemented as an event bus? -// // A bool is returned indicating if a full update was sent to all nodes -// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { -// users, err := db.ListUsers() -// if err != nil { -// return err -// } - -// changed, err := polMan.SetUsers(users) -// if err != nil { -// return err -// } - -// if changed { -// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") -// notif.NotifyAll(ctx, types.UpdateFull()) -// } - -// return nil -// } - -// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. -// // Maybe we should attempt a new in memory state and not go via the DB? -// // Maybe this should be implemented as an event bus? -// // A bool is returned indicating if a full update was sent to all nodes -// func nodesChangedHook( -// db *db.HSDatabase, -// polMan policy.PolicyManager, -// notif *notifier.Notifier, -// ) (bool, error) { -// nodes, err := db.ListNodes() -// if err != nil { -// return false, err -// } - -// filterChanged, err := polMan.SetNodes(nodes) -// if err != nil { -// return false, err -// } - -// if filterChanged { -// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") -// notif.NotifyAll(ctx, types.UpdateFull()) - -// return true, nil -// } - -// return false, nil -// } - // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { capver.CanOldCodeBeCleanedUp() @@ -562,8 +511,9 @@ func (h *Headscale) Serve() error { Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)). Msg("Clients with a lower minimum version will be rejected") - // Fetch an initial DERP Map before we start serving - h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier) + h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() + defer h.mapBatcher.Close() // TODO(kradalby): fix state part. if h.cfg.DERP.ServerEnabled { @@ -838,8 +788,12 @@ func (h *Headscale) Serve() error { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") - ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + err = h.state.AutoApproveNodes() + if err != nil { + log.Error().Err(err).Msg("failed to approve routes after new policy") + } + + h.Change(change.PolicySet) } default: info := func(msg string) { log.Info().Msg(msg) } @@ -865,7 +819,6 @@ func (h *Headscale) Serve() error { } info("closing node notifier") - h.nodeNotifier.Close() info("waiting for netmap stream to close") h.pollNetMapStreamWG.Wait() @@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } + +// Change is used to send changes to nodes. +// All change should be enqueued here and empty will be automatically +// ignored. +func (h *Headscale) Change(c change.ChangeSet) { + h.mapBatcher.AddWork(c) +} diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 986bbabc..dcf248d4 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -10,6 +10,8 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -32,6 +34,21 @@ func (h *Headscale) handleRegister( } if node != nil { + // If an existing node is trying to register with an auth key, + // we need to validate the auth key even for existing nodes + if regReq.Auth != nil && regReq.Auth.AuthKey != "" { + resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) + if err != nil { + // Preserve HTTPError types so they can be handled properly by the HTTP layer + var httpErr HTTPError + if errors.As(err, &httpErr) { + return nil, httpErr + } + return nil, fmt.Errorf("handling register with auth key for existing node: %w", err) + } + return resp, nil + } + resp, err := h.handleExistingNode(node, regReq, machineKey) if err != nil { return nil, fmt.Errorf("handling existing node: %w", err) @@ -47,6 +64,11 @@ func (h *Headscale) handleRegister( if regReq.Auth != nil && regReq.Auth.AuthKey != "" { resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) if err != nil { + // Preserve HTTPError types so they can be handled properly by the HTTP layer + var httpErr HTTPError + if errors.As(err, &httpErr) { + return nil, httpErr + } return nil, fmt.Errorf("handling register with auth key: %w", err) } @@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { + if node.MachineKey != machineKey { return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } expired := node.IsExpired() + if !expired && !regReq.Expiry.IsZero() { requestExpiry := regReq.Expiry @@ -82,42 +106,26 @@ func (h *Headscale) handleExistingNode( // If the request expiry is in the past, we consider it a logout. if requestExpiry.Before(time.Now()) { if node.IsEphemeral() { - policyChanged, err := h.state.DeleteNode(node) + c, err := h.state.DeleteNode(node) if err != nil { return nil, fmt.Errorf("deleting ephemeral node: %w", err) } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na") - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } else { - ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID)) - } + h.Change(c) return nil, nil } } - n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry) + _, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) if err != nil { return nil, fmt.Errorf("setting node expiry: %w", err) } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na") - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } else { - ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") - h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID) + h.Change(c) } - return nodeToRegisterResponse(n), nil - } - - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(node), nil } func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { @@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, changed, err := h.state.HandleNodeFromPreAuthKey( + node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey( regReq, machineKey, ) @@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey( return nil, err } + // If node is nil, it means an ephemeral node was deleted during logout + if node == nil { + h.Change(changed) + return nil, nil + } + // This is a bit of a back and forth, but we have a bit of a chicken and egg // dependency here. // Because the way the policy manager works, we need to have the node @@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey( // ensure we send an update. // This works, but might be another good candidate for doing some sort of // eventbus. - routesChanged := h.state.AutoApproveRoutes(node) + // TODO(kradalby): This needs to be ran as part of the batcher maybe? + // now since we dont update the node/pol here anymore + routeChange := h.state.AutoApproveRoutes(node) if _, _, err := h.state.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } - if routesChanged { - ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname) - h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID)) - } else if changed { - ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname) - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } else { - // Existing node re-registering without route changes - // Still need to notify peers about the node being active again - // Use UpdateFull to ensure all peers get complete peer maps - ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname) - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + if routeChange && changed.Empty() { + changed = change.NodeAdded(node.ID) + } + h.Change(changed) + + // If policy changed due to node registration, send a separate policy change + if policyChanged { + policyChange := change.PolicyChange() + h.Change(policyChange) } return &tailcfg.RegisterResponse{ diff --git a/hscontrol/capver/capver.go b/hscontrol/capver/capver.go index 347ec981..b6bbca5b 100644 --- a/hscontrol/capver/capver.go +++ b/hscontrol/capver/capver.go @@ -1,5 +1,7 @@ package capver +//go:generate go run ../../tools/capver/main.go + import ( "slices" "sort" @@ -10,7 +12,7 @@ import ( "tailscale.com/util/set" ) -const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88 +const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90 // CanOldCodeBeCleanedUp is intended to be called on startup to see if // there are old code that can ble cleaned up, entries should contain diff --git a/hscontrol/capver/capver_generated.go b/hscontrol/capver/capver_generated.go index 687e3d51..79590000 100644 --- a/hscontrol/capver/capver_generated.go +++ b/hscontrol/capver/capver_generated.go @@ -1,14 +1,10 @@ package capver -// Generated DO NOT EDIT +//Generated DO NOT EDIT import "tailscale.com/tailcfg" var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ - "v1.60.0": 87, - "v1.60.1": 87, - "v1.62.0": 88, - "v1.62.1": 88, "v1.64.0": 90, "v1.64.1": 90, "v1.64.2": 90, @@ -36,18 +32,21 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ "v1.80.3": 113, "v1.82.0": 115, "v1.82.5": 115, + "v1.84.0": 116, + "v1.84.1": 116, + "v1.84.2": 116, } + var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ - 87: "v1.60.0", - 88: "v1.62.0", - 90: "v1.64.0", - 95: "v1.66.0", - 97: "v1.68.0", - 102: "v1.70.0", - 104: "v1.72.0", - 106: "v1.74.0", - 109: "v1.78.0", - 113: "v1.80.0", - 115: "v1.82.0", + 90: "v1.64.0", + 95: "v1.66.0", + 97: "v1.68.0", + 102: "v1.70.0", + 104: "v1.72.0", + 106: "v1.74.0", + 109: "v1.78.0", + 113: "v1.80.0", + 115: "v1.82.0", + 116: "v1.84.0", } diff --git a/hscontrol/capver/capver_test.go b/hscontrol/capver/capver_test.go index eb2d06ba..42f1df71 100644 --- a/hscontrol/capver/capver_test.go +++ b/hscontrol/capver/capver_test.go @@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) { stripV bool expected []string }{ - {3, false, []string{"v1.78", "v1.80", "v1.82"}}, - {2, true, []string{"1.80", "1.82"}}, + {3, false, []string{"v1.80", "v1.82", "v1.84"}}, + {2, true, []string{"1.82", "1.84"}}, // Lazy way to see all supported versions {10, true, []string{ - "1.64", "1.66", "1.68", "1.70", @@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) { "1.78", "1.80", "1.82", + "1.84", }}, {0, false, nil}, } @@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) { input tailcfg.CapabilityVersion expected string }{ - {88, "v1.62.0"}, {90, "v1.64.0"}, {95, "v1.66.0"}, {106, "v1.74.0"}, diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 86332a0d..47245c39 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -7,7 +7,6 @@ import ( "os/exec" "path/filepath" "slices" - "sort" "strings" "testing" "time" @@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { } if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool { - sort.Sort(sort.StringSlice(a)) - sort.Sort(sort.StringSlice(b)) + slices.Sort(a) + slices.Sort(b) return slices.Equal(a, b) }), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" { t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 2de29e69..83d62d3d 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -7,15 +7,19 @@ import ( "net/netip" "slices" "sort" + "strconv" "sync" + "testing" "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) const ( @@ -39,9 +43,7 @@ var ( // If no peer IDs are given, all peers are returned. // If at least one peer ID is given, only these peer nodes will be returned. func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListPeers(rx, nodeID, peerIDs...) - }) + return ListPeers(hsdb.DB, nodeID, peerIDs...) } // ListPeers returns peers of node, regardless of any Policy or if the node is expired. @@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types // ListNodes queries the database for either all nodes if no parameters are given // or for the given nodes if at least one node ID is given as parameter. func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListNodes(rx, nodeIDs...) - }) + return ListNodes(hsdb.DB, nodeIDs...) } // ListNodes queries the database for either all nodes if no parameters are given @@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) { } func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByID(rx, id) - }) + return GetNodeByID(hsdb.DB, id) } // GetNodeByID finds a Node by ID and returns the Node struct. @@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) { } func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByMachineKey(rx, machineKey) - }) + return GetNodeByMachineKey(hsdb.DB, machineKey) } // GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. @@ -163,9 +159,7 @@ func GetNodeByMachineKey( } func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByNodeKey(rx, nodeKey) - }) + return GetNodeByNodeKey(hsdb.DB, nodeKey) } // GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct. @@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( registrationMethod string, ipv4 *netip.Addr, ipv6 *netip.Addr, -) (*types.Node, bool, error) { - var newNode bool +) (*types.Node, change.ChangeSet, error) { + var nodeChange change.ChangeSet node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { if reg, ok := hsdb.regCache.Get(registrationID); ok { if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { @@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( } close(reg.Registered) - newNode = true + nodeChange = change.NodeAdded(node.ID) return node, err } else { @@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, err } + nodeChange = change.KeyExpiry(node.ID) + return node, nil } } @@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, ErrNodeNotFoundRegistrationCache }) - return node, newNode, err + return node, nodeChange, err } func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { @@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad if oldNode != nil && oldNode.UserID == node.UserID { node.ID = oldNode.ID node.GivenName = oldNode.GivenName + node.ApprovedRoutes = oldNode.ApprovedRoutes ipv4 = oldNode.IPv4 ipv6 = oldNode.IPv6 } @@ -594,17 +591,18 @@ func ensureUniqueGivenName( // containing the expired nodes, and a boolean indicating if any nodes were found. func ExpireExpiredNodes(tx *gorm.DB, lastCheck time.Time, -) (time.Time, types.StateUpdate, bool) { +) (time.Time, []change.ChangeSet, bool) { // use the time of the start of the function to ensure we // dont miss some nodes by returning it _after_ we have // checked everything. started := time.Now() expired := make([]*tailcfg.PeerChange, 0) + var updates []change.ChangeSet nodes, err := ListNodes(tx) if err != nil { - return time.Unix(0, 0), types.StateUpdate{}, false + return time.Unix(0, 0), nil, false } for _, node := range nodes { if node.IsExpired() && node.Expiry.After(lastCheck) { @@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB, NodeID: tailcfg.NodeID(node.ID), KeyExpiry: node.Expiry, }) + updates = append(updates, change.KeyExpiry(node.ID)) } } if len(expired) > 0 { - return started, types.UpdatePeerPatch(expired...), true + return started, updates, true } - return started, types.StateUpdate{}, false + return started, nil, false } // EphemeralGarbageCollector is a garbage collector that will delete nodes after @@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() { } } } + +func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node { + if !testing.Testing() { + panic("CreateNodeForTest can only be called during tests") + } + + if user == nil { + panic("CreateNodeForTest requires a valid user") + } + + nodeName := "testnode" + if len(hostname) > 0 && hostname[0] != "" { + nodeName = hostname[0] + } + + // Create a preauth key for the node + pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + if err != nil { + panic(fmt.Sprintf("failed to create preauth key for test node: %v", err)) + } + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + discoKey := key.NewDisco() + + node := &types.Node{ + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + DiscoKey: discoKey.Public(), + Hostname: nodeName, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + + err = hsdb.DB.Save(node).Error + if err != nil { + panic(fmt.Sprintf("failed to create test node: %v", err)) + } + + return node +} + +func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node { + if !testing.Testing() { + panic("CreateRegisteredNodeForTest can only be called during tests") + } + + node := hsdb.CreateNodeForTest(user, hostname...) + + err := hsdb.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNode(tx, *node, nil, nil) + return err + }) + if err != nil { + panic(fmt.Sprintf("failed to register test node: %v", err)) + } + + registeredNode, err := hsdb.GetNodeByID(node.ID) + if err != nil { + panic(fmt.Sprintf("failed to get registered test node: %v", err)) + } + + return registeredNode +} + +func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node { + if !testing.Testing() { + panic("CreateNodesForTest can only be called during tests") + } + + if user == nil { + panic("CreateNodesForTest requires a valid user") + } + + prefix := "testnode" + if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { + prefix = hostnamePrefix[0] + } + + nodes := make([]*types.Node, count) + for i := range count { + hostname := prefix + "-" + strconv.Itoa(i) + nodes[i] = hsdb.CreateNodeForTest(user, hostname) + } + + return nodes +} + +func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node { + if !testing.Testing() { + panic("CreateRegisteredNodesForTest can only be called during tests") + } + + if user == nil { + panic("CreateRegisteredNodesForTest requires a valid user") + } + + prefix := "testnode" + if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { + prefix = hostnamePrefix[0] + } + + nodes := make([]*types.Node, count) + for i := range count { + hostname := prefix + "-" + strconv.Itoa(i) + nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname) + } + + return nodes +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 9f10fc1c..8819fbcf 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -6,7 +6,6 @@ import ( "math/big" "net/netip" "regexp" - "strconv" "sync" "testing" "time" @@ -26,82 +25,36 @@ import ( ) func (s *Suite) TestGetNode(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode(types.UserID(user.ID), "testnode") + _, err := db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(node) - c.Assert(trx.Error, check.IsNil) + node := db.CreateNodeForTest(user, "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) + c.Assert(node.Hostname, check.Equals, "testnode") } func (s *Suite) TestGetNodeByID(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) + _, err := db.GetNodeByID(0) c.Assert(err, check.NotNil) - nodeKey := key.NewNode() - machineKey := key.NewMachine() + node := db.CreateNodeForTest(user, "testnode") - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.GetNodeByID(0) + retrievedNode, err := db.GetNodeByID(node.ID) c.Assert(err, check.IsNil) + c.Assert(retrievedNode.Hostname, check.Equals, "testnode") } func (s *Suite) TestHardDeleteNode(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") + node := db.CreateNodeForTest(user, "testnode3") - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode3", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - err = db.DeleteNode(&node) + err := db.DeleteNode(node) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode3") @@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { } func (s *Suite) TestListPeers(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) + _, err := db.GetNodeByID(0) c.Assert(err, check.NotNil) - for index := range 11 { - nodeKey := key.NewNode() - machineKey := key.NewMachine() + nodes := db.CreateNodesForTest(user, 11, "testnode") - node := types.Node{ - ID: types.NodeID(index), - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode" + strconv.Itoa(index), - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - } - - node0ByID, err := db.GetNodeByID(0) + firstNode := nodes[0] + peersOfFirstNode, err := db.ListPeers(firstNode.ID) c.Assert(err, check.IsNil) - peersOfNode0, err := db.ListPeers(node0ByID.ID) - c.Assert(err, check.IsNil) - - c.Assert(len(peersOfNode0), check.Equals, 9) - c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2") - c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7") - c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") + c.Assert(len(peersOfFirstNode), check.Equals, 10) + c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1") + c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6") + c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10") } func (s *Suite) TestExpireNode(c *check.C) { @@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) { // No parameter means no filter, should return all peers nodes, err = db.ListPeers(1) require.NoError(t, err) - assert.Len(t, nodes, 1) + assert.Equal(t, 1, len(nodes)) assert.Equal(t, "test2", nodes[0].Hostname) // Empty node list should return all peers nodes, err = db.ListPeers(1, types.NodeIDs{}...) require.NoError(t, err) - assert.Len(t, nodes, 1) + assert.Equal(t, 1, len(nodes)) assert.Equal(t, "test2", nodes[0].Hostname) // No match in IDs should return empty list and no error @@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) { // Partial match in IDs nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Len(t, nodes, 1) + assert.Equal(t, 1, len(nodes)) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs, but node ID is still filtered out nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Len(t, nodes, 1) + assert.Equal(t, 1, len(nodes)) assert.Equal(t, "test2", nodes[0].Hostname) } @@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) { // No parameter means no filter, should return all nodes nodes, err = db.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 2) + assert.Equal(t, 2, len(nodes)) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) // Empty node list should return all nodes nodes, err = db.ListNodes(types.NodeIDs{}...) require.NoError(t, err) - assert.Len(t, nodes, 2) + assert.Equal(t, 2, len(nodes)) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) @@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) { // Partial match in IDs nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Len(t, nodes, 1) + assert.Equal(t, 1, len(nodes)) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Len(t, nodes, 2) + assert.Equal(t, 2, len(nodes)) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index ee977ae3..2e60de2e 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e } func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { - return GetPreAuthKey(rx, key) - }) + return GetPreAuthKey(hsdb.DB, key) } // GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible @@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { // MarkExpirePreAuthKey marks a PreAuthKey as expired. func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { - if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { - return err - } - - return nil + now := time.Now() + return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error } func generateKey() (string, error) { diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 7945f090..605e7442 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -1,7 +1,7 @@ package db import ( - "sort" + "slices" "testing" "github.com/juanfont/headscale/hscontrol/types" @@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) { listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) gotTags := listedPaks[0].Proto().GetAclTags() - sort.Sort(sort.StringSlice(gotTags)) + slices.Sort(gotTags) c.Assert(gotTags, check.DeepEquals, tags) } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 76415a9d..1b333792 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -3,6 +3,8 @@ package db import ( "errors" "fmt" + "strconv" + "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { } func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByID(rx, uid) - }) + return GetUserByID(hsdb.DB, uid) } func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) { @@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) { } func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { - return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { - return ListUsers(rx, where...) - }) + return ListUsers(hsdb.DB, where...) } // ListUsers gets all the existing users. @@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error return nil } + +func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User { + if !testing.Testing() { + panic("CreateUserForTest can only be called during tests") + } + + userName := "testuser" + if len(name) > 0 && name[0] != "" { + userName = name[0] + } + + user, err := hsdb.CreateUser(types.User{Name: userName}) + if err != nil { + panic(fmt.Sprintf("failed to create test user: %v", err)) + } + + return user +} + +func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User { + if !testing.Testing() { + panic("CreateUsersForTest can only be called during tests") + } + + prefix := "testuser" + if len(namePrefix) > 0 && namePrefix[0] != "" { + prefix = namePrefix[0] + } + + users := make([]*types.User, count) + for i := range count { + name := prefix + "-" + strconv.Itoa(i) + users[i] = hsdb.CreateUserForTest(name) + } + + return users +} diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 13b75557..5b2f0c4b 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -11,8 +11,7 @@ import ( ) func (s *Suite) TestCreateAndDestroyUser(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") c.Assert(user.Name, check.Equals, "test") users, err := db.ListUsers() @@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { err := db.DestroyUser(9998) c.Assert(err, check.Equals, ErrUserNotFound) - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) @@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { } func (s *Suite) TestRenameUser(c *check.C) { - userTest, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + userTest := db.CreateUserForTest("test") c.Assert(userTest.Name, check.Equals, "test") users, err := db.ListUsers() @@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) { err = db.RenameUser(99988, "test") c.Assert(err, check.Equals, ErrUserNotFound) - userTest2, err := db.CreateUser(types.User{Name: "test2"}) - c.Assert(err, check.IsNil) + userTest2 := db.CreateUserForTest("test2") c.Assert(userTest2.Name, check.Equals, "test2") want := "UNIQUE constraint failed" @@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) { } func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser, err := db.CreateUser(types.User{Name: "old"}) - c.Assert(err, check.IsNil) - - newUser, err := db.CreateUser(types.User{Name: "new"}) - c.Assert(err, check.IsNil) + oldUser := db.CreateUserForTest("old") + newUser := db.CreateUserForTest("new") pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) c.Assert(err, check.IsNil) diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 038582c8..481ce589 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -17,10 +17,6 @@ import ( func (h *Headscale) debugHTTPServer() *http.Server { debugMux := http.NewServeMux() debug := tsweb.Debugger(debugMux) - debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte(h.nodeNotifier.String())) - })) debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config, err := json.MarshalIndent(h.cfg, "", " ") if err != nil { diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 9d358598..1ed619ec 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "io" + "maps" "net/http" "net/url" "os" @@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap { } for _, derpMap := range derpMaps { - for id, region := range derpMap.Regions { - result.Regions[id] = region - } + maps.Copy(result.Regions, derpMap.Regions) } return &result diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 7df4c92e..722f8421 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -1,3 +1,5 @@ +//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto + // nolint package hscontrol @@ -27,6 +29,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" ) @@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser( return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } - // Send policy update notifications if needed + + c := change.UserAdded(types.UserID(user.ID)) if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + c.Change = change.Policy } + api.h.Change(c) + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser( // Send policy update notifications if needed if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName()) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + api.h.Change(change.PolicyChange()) } newUser, err := api.h.state.GetUserByName(request.GetNewName()) @@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } + api.h.Change(change.UserRemoved(types.UserID(user.ID))) + return &v1.DeleteUserResponse{}, nil } @@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode( return nil, fmt.Errorf("looking up user: %w", err) } - node, _, err := api.h.state.HandleNodeFromAuthPath( + node, nodeChange, err := api.h.state.HandleNodeFromAuthPath( registrationId, types.UserID(user.ID), nil, @@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode( // ensure we send an update. // This works, but might be another good candidate for doing some sort of // eventbus. - routesChanged := api.h.state.AutoApproveRoutes(node) - _, policyChanged, err := api.h.state.SaveNode(node) + _ = api.h.state.AutoApproveRoutes(node) + _, _, err = api.h.state.SaveNode(node) if err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } - // Send policy update notifications if needed (from SaveNode or route changes) - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all") - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - if routesChanged { - ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID)) - } + api.h.Change(nodeChange) return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.ID) + resp.Online = api.h.mapBatcher.IsConnected(node.ID) return &v1.GetNodeResponse{Node: resp}, nil } @@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags( } } - node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags()) + node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags()) if err != nil { return &v1.SetTagsResponse{ Node: nil, }, status.Error(codes.InvalidArgument, err.Error()) } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) + api.h.Change(nodeChange) log.Trace(). Str("node", node.Hostname). @@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes( tsaddr.SortPrefixes(routes) routes = slices.Compact(routes) - node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes) + node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes) if err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } + routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) - if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) { - ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } else { - ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) + // Always propagate node changes from SetApprovedRoutes + api.h.Change(nodeChange) + + // If routes changed, propagate those changes too + if !routeChange.Empty() { + api.h.Change(routeChange) } proto := node.Proto() @@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode( return nil, err } - policyChanged, err := api.h.state.DeleteNode(node) + nodeChange, err := api.h.state.DeleteNode(node) if err != nil { return nil, err } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID)) + api.h.Change(nodeChange) return &v1.DeleteNodeResponse{}, nil } @@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode( ) (*v1.ExpireNodeResponse, error) { now := time.Now() - node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now) + node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now) if err != nil { return nil, err } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) - api.h.nodeNotifier.NotifyByNodeID( - ctx, - types.UpdateSelf(node.ID), - node.ID) - - ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID) + // TODO(kradalby): Ensure that both the selfupdate and peer updates are sent + api.h.Change(nodeChange) log.Trace(). Str("node", node.Hostname). @@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode( ctx context.Context, request *v1.RenameNodeRequest, ) (*v1.RenameNodeResponse, error) { - node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName()) + node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName()) if err != nil { return nil, err } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname) - api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID) - - ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) + // TODO(kradalby): investigate if we need selfupdate + api.h.Change(nodeChange) log.Trace(). Str("node", node.Hostname). @@ -498,7 +456,7 @@ func (api headscaleV1APIServer) ListNodes( // probably be done once. // TODO(kradalby): This should be done in one tx. - isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() + IsConnected := api.h.mapBatcher.ConnectedMap() if request.GetUser() != "" { user, err := api.h.state.GetUserByName(request.GetUser()) if err != nil { @@ -510,7 +468,7 @@ func (api headscaleV1APIServer) ListNodes( return nil, err } - response := nodesToProto(api.h.state, isLikelyConnected, nodes) + response := nodesToProto(api.h.state, IsConnected, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } @@ -523,18 +481,18 @@ func (api headscaleV1APIServer) ListNodes( return nodes[i].ID < nodes[j].ID }) - response := nodesToProto(api.h.state, isLikelyConnected, nodes) + response := nodesToProto(api.h.state, IsConnected, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } -func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node { +func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node { response := make([]*v1.Node, len(nodes)) for index, node := range nodes { resp := node.Proto() // Populate the online field based on // currently connected nodes. - if val, ok := isLikelyConnected.Load(node.ID); ok && val { + if val, ok := IsConnected.Load(node.ID); ok && val { resp.Online = true } @@ -556,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode( ctx context.Context, request *v1.MoveNodeRequest, ) (*v1.MoveNodeResponse, error) { - node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser())) + node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser())) if err != nil { return nil, err } - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname) - api.h.nodeNotifier.NotifyByNodeID( - ctx, - types.UpdateSelf(node.ID), - node.ID) - ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) + // TODO(kradalby): Ensure the policy is also sent + // TODO(kradalby): ensure that both the selfupdate and peer updates are sent + api.h.Change(nodeChange) return &v1.MoveNodeResponse{Node: node.Proto()}, nil } @@ -754,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + api.h.Change(change.PolicyChange()) } response := &v1.SetPolicyResponse{ diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go new file mode 100644 index 00000000..21b2209f --- /dev/null +++ b/hscontrol/mapper/batcher.go @@ -0,0 +1,155 @@ +package mapper + +import ( + "fmt" + "time" + + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/puzpuzpuz/xsync/v4" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +type batcherFunc func(cfg *types.Config, state *state.State) Batcher + +// Batcher defines the common interface for all batcher implementations. +type Batcher interface { + Start() + Close() + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error + RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) + IsConnected(id types.NodeID) bool + ConnectedMap() *xsync.Map[types.NodeID, bool] + AddWork(c change.ChangeSet) + MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) +} + +func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher { + return &LockFreeBatcher{ + mapper: mapper, + workers: workers, + tick: time.NewTicker(batchTime), + + // The size of this channel is arbitrary chosen, the sizing should be revisited. + workCh: make(chan work, workers*200), + nodes: xsync.NewMap[types.NodeID, *nodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), + pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](), + } +} + +// NewBatcherAndMapper creates a Batcher implementation. +func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher { + m := newMapper(cfg, state) + b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) + m.batcher = b + return b +} + +// nodeConnection interface for different connection implementations. +type nodeConnection interface { + nodeID() types.NodeID + version() tailcfg.CapabilityVersion + send(data *tailcfg.MapResponse) error +} + +// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet]. +func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) { + if c.Empty() { + return nil, nil + } + + // Validate inputs before processing + if nodeID == 0 { + return nil, fmt.Errorf("invalid nodeID: %d", nodeID) + } + + if mapper == nil { + return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) + } + + var mapResp *tailcfg.MapResponse + var err error + + switch c.Change { + case change.DERP: + mapResp, err = mapper.derpMapResponse(nodeID) + + case change.NodeCameOnline, change.NodeWentOffline: + if c.IsSubnetRouter { + // TODO(kradalby): This can potentially be a peer update of the old and new subnet router. + mapResp, err = mapper.fullMapResponse(nodeID, version) + } else { + mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ + { + NodeID: c.NodeID.NodeID(), + Online: ptr.To(c.Change == change.NodeCameOnline), + }, + }) + } + + case change.NodeNewOrUpdate: + mapResp, err = mapper.fullMapResponse(nodeID, version) + + case change.NodeRemove: + mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID) + + default: + // The following will always hit this: + // change.Full, change.Policy + mapResp, err = mapper.fullMapResponse(nodeID, version) + } + + if err != nil { + return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err) + } + + // TODO(kradalby): Is this necessary? + // Validate the generated map response - only check for nil response + // Note: mapResp.Node can be nil for peer updates, which is valid + if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove { + return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String()) + } + + return mapResp, nil +} + +// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet]. +func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error { + if nc == nil { + return fmt.Errorf("nodeConnection is nil") + } + + nodeID := nc.nodeID() + data, err := generateMapResponse(nodeID, nc.version(), mapper, c) + if err != nil { + return fmt.Errorf("generating map response for node %d: %w", nodeID, err) + } + + if data == nil { + // No data to send is valid for some change types + return nil + } + + // Send the map response + if err := nc.send(data); err != nil { + return fmt.Errorf("sending map response to node %d: %w", nodeID, err) + } + + return nil +} + +// workResult represents the result of processing a change. +type workResult struct { + mapResponse *tailcfg.MapResponse + err error +} + +// work represents a unit of work to be processed by workers. +type work struct { + c change.ChangeSet + nodeID types.NodeID + resultCh chan<- workResult // optional channel for synchronous operations +} diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go new file mode 100644 index 00000000..aeafa001 --- /dev/null +++ b/hscontrol/mapper/batcher_lockfree.go @@ -0,0 +1,491 @@ +package mapper + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. +type LockFreeBatcher struct { + tick *time.Ticker + mapper *mapper + workers int + + // Lock-free concurrent maps + nodes *xsync.Map[types.NodeID, *nodeConn] + connected *xsync.Map[types.NodeID, *time.Time] + + // Work queue channel + workCh chan work + ctx context.Context + cancel context.CancelFunc + + // Batching state + pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] + batchMutex sync.RWMutex + + // Metrics + totalNodes atomic.Int64 + totalUpdates atomic.Int64 + workQueuedCount atomic.Int64 + workProcessed atomic.Int64 + workErrors atomic.Int64 +} + +// AddNode registers a new node connection with the batcher and sends an initial map response. +// It creates or updates the node's connection data, validates the initial map generation, +// and notifies other nodes that this node has come online. +// TODO(kradalby): See if we can move the isRouter argument somewhere else. +func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error { + // First validate that we can generate initial map before doing anything else + fullSelfChange := change.FullSelf(id) + + // TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange. + // This currently means that the goroutine for the node connection will do the processing + // which means that we might have uncontrolled concurrency. + // When we use MapResponseFromChange, it will be processed by the same worker pool, causing + // it to be processed in a more controlled manner. + initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange) + if err != nil { + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) + } + + // Only after validation succeeds, create or update node connection + newConn := newNodeConn(id, c, version, b.mapper) + + var conn *nodeConn + if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded { + // Update existing connection + existing.updateConnection(c, version) + conn = existing + } else { + b.totalNodes.Add(1) + conn = newConn + } + + // Mark as connected only after validation succeeds + b.connected.Store(id, nil) // nil = connected + + log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher") + + // Send the validated initial map + if initialMap != nil { + if err := conn.send(initialMap); err != nil { + // Clean up the connection state on send failure + b.nodes.Delete(id) + b.connected.Delete(id) + return fmt.Errorf("failed to send initial map to node %d: %w", id, err) + } + + // Notify other nodes that this node came online + b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter}) + } + + return nil +} + +// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. +// It validates the connection channel matches the current one, closes the connection, +// and notifies other nodes that this node has gone offline. +func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) { + // Check if this is the current connection and mark it as closed + if existing, ok := b.nodes.Load(id); ok { + if !existing.matchesChannel(c) { + log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring") + return // Not the current connection, not an error + } + + // Mark the connection as closed to prevent further sends + if connData := existing.connData.Load(); connData != nil { + connData.closed.Store(true) + } + } + + log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline") + + // Remove node and mark disconnected atomically + b.nodes.Delete(id) + b.connected.Store(id, ptr.To(time.Now())) + b.totalNodes.Add(-1) + + // Notify other nodes that this node went offline + b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter}) +} + +// AddWork queues a change to be processed by the batcher. +// Critical changes are processed immediately, while others are batched for efficiency. +func (b *LockFreeBatcher) AddWork(c change.ChangeSet) { + b.addWork(c) +} + +func (b *LockFreeBatcher) Start() { + b.ctx, b.cancel = context.WithCancel(context.Background()) + go b.doWork() +} + +func (b *LockFreeBatcher) Close() { + if b.cancel != nil { + b.cancel() + } + close(b.workCh) +} + +func (b *LockFreeBatcher) doWork() { + log.Debug().Msg("batcher doWork loop started") + defer log.Debug().Msg("batcher doWork loop stopped") + + for i := range b.workers { + go b.worker(i + 1) + } + + for { + select { + case <-b.tick.C: + // Process batched changes + b.processBatchedChanges() + case <-b.ctx.Done(): + return + } + } +} + +func (b *LockFreeBatcher) worker(workerID int) { + log.Debug().Int("workerID", workerID).Msg("batcher worker started") + defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped") + + for { + select { + case w, ok := <-b.workCh: + if !ok { + return + } + + startTime := time.Now() + b.workProcessed.Add(1) + + // If the resultCh is set, it means that this is a work request + // where there is a blocking function waiting for the map that + // is being generated. + // This is used for synchronous map generation. + if w.resultCh != nil { + var result workResult + if nc, exists := b.nodes.Load(w.nodeID); exists { + result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c) + if result.err != nil { + b.workErrors.Add(1) + log.Error().Err(result.err). + Int("workerID", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("change", w.c.Change.String()). + Msg("failed to generate map response for synchronous work") + } + } else { + result.err = fmt.Errorf("node %d not found", w.nodeID) + b.workErrors.Add(1) + log.Error().Err(result.err). + Int("workerID", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Msg("node not found for synchronous work") + } + + // Send result + select { + case w.resultCh <- result: + case <-b.ctx.Done(): + return + } + + duration := time.Since(startTime) + if duration > 100*time.Millisecond { + log.Warn(). + Int("workerID", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("change", w.c.Change.String()). + Dur("duration", duration). + Msg("slow synchronous work processing") + } + continue + } + + // If resultCh is nil, this is an asynchronous work request + // that should be processed and sent to the node instead of + // returned to the caller. + if nc, exists := b.nodes.Load(w.nodeID); exists { + // Check if this connection is still active before processing + if connData := nc.connData.Load(); connData != nil && connData.closed.Load() { + log.Debug(). + Int("workerID", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("change", w.c.Change.String()). + Msg("skipping work for closed connection") + continue + } + + err := nc.change(w.c) + if err != nil { + b.workErrors.Add(1) + log.Error().Err(err). + Int("workerID", workerID). + Uint64("node.id", w.c.NodeID.Uint64()). + Str("change", w.c.Change.String()). + Msg("failed to apply change") + } + } else { + log.Debug(). + Int("workerID", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("change", w.c.Change.String()). + Msg("node not found for asynchronous work - node may have disconnected") + } + + duration := time.Since(startTime) + if duration > 100*time.Millisecond { + log.Warn(). + Int("workerID", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("change", w.c.Change.String()). + Dur("duration", duration). + Msg("slow asynchronous work processing") + } + + case <-b.ctx.Done(): + return + } + } +} + +func (b *LockFreeBatcher) addWork(c change.ChangeSet) { + // For critical changes that need immediate processing, send directly + if b.shouldProcessImmediately(c) { + if c.SelfUpdateOnly { + b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil}) + return + } + b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool { + if c.NodeID == nodeID && !c.AlsoSelf() { + return true + } + b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil}) + return true + }) + return + } + + // For non-critical changes, add to batch + b.addToBatch(c) +} + +// queueWork safely queues work +func (b *LockFreeBatcher) queueWork(w work) { + b.workQueuedCount.Add(1) + + select { + case b.workCh <- w: + // Successfully queued + case <-b.ctx.Done(): + // Batcher is shutting down + return + } +} + +// shouldProcessImmediately determines if a change should bypass batching +func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool { + // Process these changes immediately to avoid delaying critical functionality + switch c.Change { + case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy: + return true + default: + return false + } +} + +// addToBatch adds a change to the pending batch +func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { + b.batchMutex.Lock() + defer b.batchMutex.Unlock() + + if c.SelfUpdateOnly { + changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{}) + changes = append(changes, c) + b.pendingChanges.Store(c.NodeID, changes) + return + } + + b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool { + if c.NodeID == nodeID && !c.AlsoSelf() { + return true + } + + changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) + changes = append(changes, c) + b.pendingChanges.Store(nodeID, changes) + return true + }) +} + +// processBatchedChanges processes all pending batched changes +func (b *LockFreeBatcher) processBatchedChanges() { + b.batchMutex.Lock() + defer b.batchMutex.Unlock() + + if b.pendingChanges == nil { + return + } + + // Process all pending changes + b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool { + if len(changes) == 0 { + return true + } + + // Send all batched changes for this node + for _, c := range changes { + b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil}) + } + + // Clear the pending changes for this node + b.pendingChanges.Delete(nodeID) + return true + }) +} + +// IsConnected is lock-free read. +func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { + if val, ok := b.connected.Load(id); ok { + // nil means connected + return val == nil + } + return false +} + +// ConnectedMap returns a lock-free map of all connected nodes. +func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { + ret := xsync.NewMap[types.NodeID, bool]() + + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // nil means connected + ret.Store(id, val == nil) + return true + }) + + return ret +} + +// MapResponseFromChange queues work to generate a map response and waits for the result. +// This allows synchronous map generation using the same worker pool. +func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) { + resultCh := make(chan workResult, 1) + + // Queue the work with a result channel using the safe queueing method + b.queueWork(work{c: c, nodeID: id, resultCh: resultCh}) + + // Wait for the result + select { + case result := <-resultCh: + return result.mapResponse, result.err + case <-b.ctx.Done(): + return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) + } +} + +// connectionData holds the channel and connection parameters. +type connectionData struct { + c chan<- *tailcfg.MapResponse + version tailcfg.CapabilityVersion + closed atomic.Bool // Track if this connection has been closed +} + +// nodeConn described the node connection and its associated data. +type nodeConn struct { + id types.NodeID + mapper *mapper + + // Atomic pointer to connection data - allows lock-free updates + connData atomic.Pointer[connectionData] + + updateCount atomic.Int64 +} + +func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn { + nc := &nodeConn{ + id: id, + mapper: mapper, + } + + // Initialize connection data + data := &connectionData{ + c: c, + version: version, + } + nc.connData.Store(data) + + return nc +} + +// updateConnection atomically updates connection parameters. +func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) { + newData := &connectionData{ + c: c, + version: version, + } + nc.connData.Store(newData) +} + +// matchesChannel checks if the given channel matches current connection. +func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool { + data := nc.connData.Load() + if data == nil { + return false + } + // Compare channel pointers directly + return data.c == c +} + +// compressAndVersion atomically reads connection settings. +func (nc *nodeConn) version() tailcfg.CapabilityVersion { + data := nc.connData.Load() + if data == nil { + return 0 + } + + return data.version +} + +func (nc *nodeConn) nodeID() types.NodeID { + return nc.id +} + +func (nc *nodeConn) change(c change.ChangeSet) error { + return handleNodeChange(nc, nc.mapper, c) +} + +// send sends data to the node's channel. +// The node will pick it up and send it to the HTTP handler. +func (nc *nodeConn) send(data *tailcfg.MapResponse) error { + connData := nc.connData.Load() + if connData == nil { + return fmt.Errorf("node %d: no connection data", nc.id) + } + + // Check if connection has been closed + if connData.closed.Load() { + return fmt.Errorf("node %d: connection closed", nc.id) + } + + // TODO(kradalby): We might need some sort of timeout here if the client is not reading + // the channel. That might mean that we are sending to a node that has gone offline, but + // the channel is still open. + connData.c <- data + nc.updateCount.Add(1) + return nil +} diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go new file mode 100644 index 00000000..b2a632d4 --- /dev/null +++ b/hscontrol/mapper/batcher_test.go @@ -0,0 +1,1977 @@ +package mapper + +import ( + "fmt" + "net/netip" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "zgo.at/zcache/v2" +) + +// batcherTestCase defines a batcher function with a descriptive name for testing. +type batcherTestCase struct { + name string + fn batcherFunc +} + +// allBatcherFunctions contains all batcher implementations to test. +var allBatcherFunctions = []batcherTestCase{ + {"LockFree", NewBatcherAndMapper}, +} + +// emptyCache creates an empty registration cache for testing. +func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { + return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +} + +// Test configuration constants. +const ( + // Test data configuration. + TEST_USER_COUNT = 3 + TEST_NODES_PER_USER = 2 + + // Load testing configuration. + HIGH_LOAD_NODES = 25 // Increased from 9 + HIGH_LOAD_CYCLES = 100 // Increased from 20 + HIGH_LOAD_UPDATES = 50 // Increased from 20 + + // Extreme load testing configuration. + EXTREME_LOAD_NODES = 50 + EXTREME_LOAD_CYCLES = 200 + EXTREME_LOAD_UPDATES = 100 + + // Timing configuration. + TEST_TIMEOUT = 120 * time.Second // Increased for more intensive tests + UPDATE_TIMEOUT = 5 * time.Second + DEADLOCK_TIMEOUT = 30 * time.Second + + // Channel configuration. + NORMAL_BUFFER_SIZE = 50 + SMALL_BUFFER_SIZE = 3 + TINY_BUFFER_SIZE = 1 // For maximum contention + LARGE_BUFFER_SIZE = 200 + + reservedResponseHeaderSize = 4 +) + +// TestData contains all test entities created for a test scenario. +type TestData struct { + Database *db.HSDatabase + Users []*types.User + Nodes []node + State *state.State + Config *types.Config + Batcher Batcher +} + +type node struct { + n *types.Node + ch chan *tailcfg.MapResponse + + // Update tracking + updateCount int64 + patchCount int64 + fullCount int64 + maxPeersCount int + lastPeerCount int + stop chan struct{} + stopped chan struct{} +} + +// setupBatcherWithTestData creates a comprehensive test environment with real +// database test data including users and registered nodes. +// +// This helper creates a database, populates it with test data, then creates +// a state and batcher using the SAME database for testing. This provides real +// node data for testing full map responses and comprehensive update scenarios. +// +// Returns TestData struct containing all created entities and a cleanup function. +func setupBatcherWithTestData(t *testing.T, bf batcherFunc, userCount, nodesPerUser, bufferSize int) (*TestData, func()) { + t.Helper() + + // Create database and populate with test data first + tmpDir := t.TempDir() + dbPath := tmpDir + "/headscale_test.db" + + prefixV4 := netip.MustParsePrefix("100.64.0.0/10") + prefixV6 := netip.MustParsePrefix("fd7a:115c:a1e0::/48") + + cfg := &types.Config{ + Database: types.DatabaseConfig{ + Type: types.DatabaseSqlite, + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + PrefixV4: &prefixV4, + PrefixV6: &prefixV6, + IPAllocation: types.IPAllocationStrategySequential, + BaseDomain: "headscale.test", + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + DERP: types.DERPConfig{ + ServerEnabled: false, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 999: { + RegionID: 999, + }, + }, + }, + }, + Tuning: types.Tuning{ + BatchChangeDelay: 10 * time.Millisecond, + BatcherWorkers: types.DefaultBatcherWorkers(), // Use same logic as config.go + }, + } + + // Create database and populate it with test data + database, err := db.NewHeadscaleDatabase( + cfg.Database, + "", + emptyCache(), + ) + if err != nil { + t.Fatalf("setting up database: %s", err) + } + + // Create test users and nodes in the database + users := database.CreateUsersForTest(userCount, "testuser") + allNodes := make([]node, 0, userCount*nodesPerUser) + for _, user := range users { + dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node") + for i := range dbNodes { + allNodes = append(allNodes, node{ + n: dbNodes[i], + ch: make(chan *tailcfg.MapResponse, bufferSize), + }) + } + } + + // Now create state using the same database + state, err := state.NewState(cfg) + if err != nil { + t.Fatalf("Failed to create state: %v", err) + } + + // Set up a permissive policy that allows all communication for testing + allowAllPolicy := `{ + "acls": [ + { + "action": "accept", + "users": ["*"], + "ports": ["*:*"] + } + ] + }` + + _, err = state.SetPolicy([]byte(allowAllPolicy)) + if err != nil { + t.Fatalf("Failed to set allow-all policy: %v", err) + } + + // Create batcher with the state + batcher := bf(cfg, state) + batcher.Start() + + testData := &TestData{ + Database: database, + Users: users, + Nodes: allNodes, + State: state, + Config: cfg, + Batcher: batcher, + } + + cleanup := func() { + batcher.Close() + state.Close() + database.Close() + } + + return testData, cleanup +} + +type UpdateStats struct { + TotalUpdates int + UpdateSizes []int + LastUpdate time.Time +} + +// updateTracker provides thread-safe tracking of updates per node. +type updateTracker struct { + mu sync.RWMutex + stats map[types.NodeID]*UpdateStats +} + +// newUpdateTracker creates a new update tracker. +func newUpdateTracker() *updateTracker { + return &updateTracker{ + stats: make(map[types.NodeID]*UpdateStats), + } +} + +// recordUpdate records an update for a specific node. +func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) { + ut.mu.Lock() + defer ut.mu.Unlock() + + if ut.stats[nodeID] == nil { + ut.stats[nodeID] = &UpdateStats{} + } + + stats := ut.stats[nodeID] + stats.TotalUpdates++ + stats.UpdateSizes = append(stats.UpdateSizes, updateSize) + stats.LastUpdate = time.Now() +} + +// getStats returns a copy of the statistics for a node. +func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { + ut.mu.RLock() + defer ut.mu.RUnlock() + + if stats, exists := ut.stats[nodeID]; exists { + // Return a copy to avoid race conditions + return UpdateStats{ + TotalUpdates: stats.TotalUpdates, + UpdateSizes: append([]int{}, stats.UpdateSizes...), + LastUpdate: stats.LastUpdate, + } + } + + return UpdateStats{} +} + +// getAllStats returns a copy of all statistics. +func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats { + ut.mu.RLock() + defer ut.mu.RUnlock() + + result := make(map[types.NodeID]UpdateStats) + for nodeID, stats := range ut.stats { + result[nodeID] = UpdateStats{ + TotalUpdates: stats.TotalUpdates, + UpdateSizes: append([]int{}, stats.UpdateSizes...), + LastUpdate: stats.LastUpdate, + } + } + + return result +} + +func assertDERPMapResponse(t *testing.T, resp *tailcfg.MapResponse) { + t.Helper() + + assert.NotNil(t, resp.DERPMap, "DERPMap should not be nil in response") + assert.Len(t, resp.DERPMap.Regions, 1, "Expected exactly one DERP region in response") + assert.Equal(t, 999, resp.DERPMap.Regions[999].RegionID, "Expected DERP region ID to be 1337") +} + +func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected bool) { + t.Helper() + + // Check for peer changes patch (new online/offline notifications use patches) + if len(resp.PeersChangedPatch) > 0 { + require.Len(t, resp.PeersChangedPatch, 1) + assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online) + return + } + + // Fallback to old format for backwards compatibility + require.Len(t, resp.Peers, 1) + assert.Equal(t, expected, resp.Peers[0].Online) +} + +// UpdateInfo contains parsed information about an update. +type UpdateInfo struct { + IsFull bool + IsPatch bool + IsDERP bool + PeerCount int + PatchCount int +} + +// parseUpdateAndAnalyze parses an update and returns detailed information. +func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { + info := UpdateInfo{ + PeerCount: len(resp.Peers), + PatchCount: len(resp.PeersChangedPatch), + IsFull: len(resp.Peers) > 0, + IsPatch: len(resp.PeersChangedPatch) > 0, + IsDERP: resp.DERPMap != nil, + } + + return info, nil +} + +// start begins consuming updates from the node's channel and tracking stats. +func (n *node) start() { + // Prevent multiple starts on the same node + if n.stop != nil { + return // Already started + } + + n.stop = make(chan struct{}) + n.stopped = make(chan struct{}) + + go func() { + defer close(n.stopped) + + for { + select { + case data := <-n.ch: + atomic.AddInt64(&n.updateCount, 1) + + // Parse update and track detailed stats + if info, err := parseUpdateAndAnalyze(data); err == nil { + // Track update types + if info.IsFull { + atomic.AddInt64(&n.fullCount, 1) + n.lastPeerCount = info.PeerCount + // Update max peers seen + if info.PeerCount > n.maxPeersCount { + n.maxPeersCount = info.PeerCount + } + } + if info.IsPatch { + atomic.AddInt64(&n.patchCount, 1) + // For patches, we track how many patch items + if info.PatchCount > n.maxPeersCount { + n.maxPeersCount = info.PatchCount + } + } + } + + case <-n.stop: + return + } + } + }() +} + +// NodeStats contains final statistics for a node. +type NodeStats struct { + TotalUpdates int64 + PatchUpdates int64 + FullUpdates int64 + MaxPeersSeen int + LastPeerCount int +} + +// cleanup stops the update consumer and returns final stats. +func (n *node) cleanup() NodeStats { + if n.stop != nil { + close(n.stop) + <-n.stopped // Wait for goroutine to finish + } + + return NodeStats{ + TotalUpdates: atomic.LoadInt64(&n.updateCount), + PatchUpdates: atomic.LoadInt64(&n.patchCount), + FullUpdates: atomic.LoadInt64(&n.fullCount), + MaxPeersSeen: n.maxPeersCount, + LastPeerCount: n.lastPeerCount, + } +} + +// validateUpdateContent validates that the update data contains a proper MapResponse. +func validateUpdateContent(resp *tailcfg.MapResponse) (bool, string) { + if resp == nil { + return false, "nil MapResponse" + } + + // Simple validation - just check if it's a valid MapResponse + return true, "valid" +} + +// TestEnhancedNodeTracking verifies that the enhanced node tracking works correctly. +func TestEnhancedNodeTracking(t *testing.T) { + // Create a simple test node + testNode := node{ + n: &types.Node{ID: 1}, + ch: make(chan *tailcfg.MapResponse, 10), + } + + // Start the enhanced tracking + testNode.start() + + // Create a simple MapResponse that should be parsed correctly + resp := tailcfg.MapResponse{ + KeepAlive: false, + Peers: []*tailcfg.Node{ + {ID: 2}, + {ID: 3}, + }, + } + + // Send the data to the node's channel + testNode.ch <- &resp + + // Give it time to process + time.Sleep(100 * time.Millisecond) + + // Check stats + stats := testNode.cleanup() + t.Logf("Enhanced tracking stats: Total=%d, Full=%d, Patch=%d, MaxPeers=%d", + stats.TotalUpdates, stats.FullUpdates, stats.PatchUpdates, stats.MaxPeersSeen) + + require.Equal(t, int64(1), stats.TotalUpdates, "Expected 1 total update") + require.Equal(t, int64(1), stats.FullUpdates, "Expected 1 full update") + require.Equal(t, 2, stats.MaxPeersSeen, "Expected 2 max peers seen") +} + +// TestEnhancedTrackingWithBatcher verifies enhanced tracking works with a real batcher. +func TestEnhancedTrackingWithBatcher(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with 1 node + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, 10) + defer cleanup() + + batcher := testData.Batcher + testNode := &testData.Nodes[0] + + t.Logf("Testing enhanced tracking with node ID %d", testNode.n.ID) + + // Start enhanced tracking for the node + testNode.start() + + // Connect the node to the batcher + batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100)) + time.Sleep(100 * time.Millisecond) // Let connection settle + + // Generate some work + batcher.AddWork(change.FullSet) + time.Sleep(100 * time.Millisecond) // Let work be processed + + batcher.AddWork(change.PolicySet) + time.Sleep(100 * time.Millisecond) + + batcher.AddWork(change.DERPSet) + time.Sleep(100 * time.Millisecond) + + // Check stats + stats := testNode.cleanup() + t.Logf("Enhanced tracking with batcher: Total=%d, Full=%d, Patch=%d, MaxPeers=%d", + stats.TotalUpdates, stats.FullUpdates, stats.PatchUpdates, stats.MaxPeersSeen) + + if stats.TotalUpdates == 0 { + t.Error("Enhanced tracking with batcher received 0 updates - batcher may not be working") + } + }) + } +} + +// TestBatcherScalabilityAllToAll tests the batcher's ability to handle rapid node joins +// and ensure all nodes can see all other nodes. This is a critical test for mesh network +// functionality where every node must be able to communicate with every other node. +func TestBatcherScalabilityAllToAll(t *testing.T) { + // Reduce verbose application logging for cleaner test output + originalLevel := zerolog.GlobalLevel() + defer zerolog.SetGlobalLevel(originalLevel) + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + + // Test cases: different node counts to stress test the all-to-all connectivity + testCases := []struct { + name string + nodeCount int + }{ + {"10_nodes", 10}, + {"50_nodes", 50}, + {"100_nodes", 100}, + // Grinds to a halt because of Database bottleneck + // {"250_nodes", 250}, + // {"500_nodes", 500}, + // {"1000_nodes", 1000}, + // {"5000_nodes", 5000}, + } + + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Logf("ALL-TO-ALL TEST: %d nodes with %s batcher", tc.nodeCount, batcherFunc.name) + + // Create test environment - all nodes from same user so they can be peers + // We need enough users to support the node count (max 1000 nodes per user) + usersNeeded := max(1, (tc.nodeCount+999)/1000) + nodesPerUser := (tc.nodeCount + usersNeeded - 1) / usersNeeded + + // Use large buffer to avoid blocking during rapid joins + // Buffer needs to handle nodeCount * average_updates_per_node + // Estimate: each node receives ~2*nodeCount updates during all-to-all + bufferSize := max(1000, tc.nodeCount*2) + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, usersNeeded, nodesPerUser, bufferSize) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes[:tc.nodeCount] // Limit to requested count + + t.Logf("Created %d nodes across %d users, buffer size: %d", len(allNodes), usersNeeded, bufferSize) + + // Start enhanced tracking for all nodes + for i := range allNodes { + allNodes[i].start() + } + + // Give time for tracking goroutines to start + time.Sleep(100 * time.Millisecond) + + startTime := time.Now() + + // Join all nodes as fast as possible + t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) + for i := range allNodes { + node := &allNodes[i] + batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + + // Issue full update after each join to ensure connectivity + batcher.AddWork(change.FullSet) + + // Add tiny delay for large node counts to prevent overwhelming + if tc.nodeCount > 100 && i%50 == 49 { + time.Sleep(10 * time.Millisecond) + } + } + + joinTime := time.Since(startTime) + t.Logf("All nodes joined in %v, waiting for full connectivity...", joinTime) + + // Wait for all updates to propagate - no timeout, continue until all nodes achieve connectivity + checkInterval := 5 * time.Second + expectedPeers := tc.nodeCount - 1 // Each node should see all others except itself + + for { + time.Sleep(checkInterval) + + // Check if all nodes have seen the expected number of peers + connectedCount := 0 + + for i := range allNodes { + node := &allNodes[i] + // Check current stats without stopping the tracking + currentMaxPeers := node.maxPeersCount + if currentMaxPeers >= expectedPeers { + connectedCount++ + } + } + + progress := float64(connectedCount) / float64(len(allNodes)) * 100 + t.Logf("Progress: %d/%d nodes (%.1f%%) have seen %d+ peers", + connectedCount, len(allNodes), progress, expectedPeers) + + if connectedCount == len(allNodes) { + t.Logf("✅ All nodes achieved full connectivity!") + break + } + } + + totalTime := time.Since(startTime) + + // Disconnect all nodes + for i := range allNodes { + node := &allNodes[i] + batcher.RemoveNode(node.n.ID, node.ch, false) + } + + // Give time for final updates to process + time.Sleep(500 * time.Millisecond) + + // Collect final statistics + totalUpdates := int64(0) + totalFull := int64(0) + maxPeersGlobal := 0 + minPeersSeen := tc.nodeCount + successfulNodes := 0 + + nodeDetails := make([]string, 0, min(10, len(allNodes))) + + for i := range allNodes { + node := &allNodes[i] + stats := node.cleanup() + + totalUpdates += stats.TotalUpdates + totalFull += stats.FullUpdates + + if stats.MaxPeersSeen > maxPeersGlobal { + maxPeersGlobal = stats.MaxPeersSeen + } + if stats.MaxPeersSeen < minPeersSeen { + minPeersSeen = stats.MaxPeersSeen + } + + if stats.MaxPeersSeen >= expectedPeers { + successfulNodes++ + } + + // Collect details for first few nodes or failing nodes + if len(nodeDetails) < 10 || stats.MaxPeersSeen < expectedPeers { + nodeDetails = append(nodeDetails, + fmt.Sprintf("Node %d: %d updates (%d full), max %d peers", + node.n.ID, stats.TotalUpdates, stats.FullUpdates, stats.MaxPeersSeen)) + } + } + + // Final results + t.Logf("ALL-TO-ALL RESULTS: %d nodes, %d total updates (%d full)", + len(allNodes), totalUpdates, totalFull) + t.Logf(" Connectivity: %d/%d nodes successful (%.1f%%)", + successfulNodes, len(allNodes), float64(successfulNodes)/float64(len(allNodes))*100) + t.Logf(" Peers seen: min=%d, max=%d, expected=%d", + minPeersSeen, maxPeersGlobal, expectedPeers) + t.Logf(" Timing: join=%v, total=%v", joinTime, totalTime) + + // Show sample of node details + if len(nodeDetails) > 0 { + t.Logf(" Node sample:") + for _, detail := range nodeDetails[:min(5, len(nodeDetails))] { + t.Logf(" %s", detail) + } + if len(nodeDetails) > 5 { + t.Logf(" ... (%d more nodes)", len(nodeDetails)-5) + } + } + + // Final verification: Since we waited until all nodes achieved connectivity, + // this should always pass, but we verify the final state for completeness + if successfulNodes == len(allNodes) { + t.Logf("✅ PASS: All-to-all connectivity achieved for %d nodes", len(allNodes)) + } else { + // This should not happen since we loop until success, but handle it just in case + failedNodes := len(allNodes) - successfulNodes + t.Errorf("❌ UNEXPECTED: %d/%d nodes still failed after waiting for connectivity (expected %d, some saw %d-%d)", + failedNodes, len(allNodes), expectedPeers, minPeersSeen, maxPeersGlobal) + + // Show details of failed nodes for debugging + if len(nodeDetails) > 5 { + t.Logf("Failed nodes details:") + for _, detail := range nodeDetails[5:] { + if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) { + t.Logf(" %s", detail) + } + } + } + } + }) + } + }) + } +} + +// TestBatcherBasicOperations verifies core batcher functionality by testing +// the basic lifecycle of adding nodes, processing updates, and removing nodes. +// +// Enhanced with real database test data, this test creates a registered node +// and tests both DERP updates and full node updates. It validates the fundamental +// add/remove operations and basic work processing pipeline with actual update +// content validation instead of just byte count checks. +func TestBatcherBasicOperations(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 8) + defer cleanup() + + batcher := testData.Batcher + tn := testData.Nodes[0] + tn2 := testData.Nodes[1] + + // Test AddNode with real node ID + batcher.AddNode(tn.n.ID, tn.ch, false, 100) + if !batcher.IsConnected(tn.n.ID) { + t.Error("Node should be connected after AddNode") + } + + // Test work processing with DERP change + batcher.AddWork(change.DERPChange()) + + // Wait for update and validate content + select { + case data := <-tn.ch: + assertDERPMapResponse(t, data) + case <-time.After(200 * time.Millisecond): + t.Error("Did not receive expected DERP update") + } + + // Drain any initial messages from first node + drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) + + // Add the second node and verify update message + batcher.AddNode(tn2.n.ID, tn2.ch, false, 100) + assert.True(t, batcher.IsConnected(tn2.n.ID)) + + // First node should get an update that second node has connected. + select { + case data := <-tn.ch: + assertOnlineMapResponse(t, data, true) + case <-time.After(200 * time.Millisecond): + t.Error("Did not receive expected Online response update") + } + + // Second node should receive its initial full map + select { + case data := <-tn2.ch: + // Verify it's a full map response + assert.NotNil(t, data) + assert.True(t, len(data.Peers) >= 1 || data.Node != nil, "Should receive initial full map") + case <-time.After(200 * time.Millisecond): + t.Error("Second node should receive its initial full map") + } + + // Disconnect the second node + batcher.RemoveNode(tn2.n.ID, tn2.ch, false) + assert.False(t, batcher.IsConnected(tn2.n.ID)) + + // First node should get update that second has disconnected. + select { + case data := <-tn.ch: + assertOnlineMapResponse(t, data, false) + case <-time.After(200 * time.Millisecond): + t.Error("Did not receive expected Online response update") + } + + // // Test node-specific update with real node data + // batcher.AddWork(change.NodeKeyChanged(tn.n.ID)) + + // // Wait for node update (may be empty for certain node changes) + // select { + // case data := <-tn.ch: + // t.Logf("Received node update: %d bytes", len(data)) + // if len(data) == 0 { + // t.Logf("Empty node update (expected for some node changes in test environment)") + // } else { + // if valid, updateType := validateUpdateContent(data); !valid { + // t.Errorf("Invalid node update content: %s", updateType) + // } else { + // t.Logf("Valid node update type: %s", updateType) + // } + // } + // case <-time.After(200 * time.Millisecond): + // // Node changes might not always generate updates in test environment + // t.Logf("No node update received (may be expected in test environment)") + // } + + // Test RemoveNode + batcher.RemoveNode(tn.n.ID, tn.ch, false) + if batcher.IsConnected(tn.n.ID) { + t.Error("Node should be disconnected after RemoveNode") + } + }) + } +} + +func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { + count := 0 + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + select { + case data := <-ch: + count++ + // Optional: add debug output if needed + _ = data + case <-timer.C: + return + } + } +} + +// TestBatcherUpdateTypes tests different types of updates and verifies +// that the batcher correctly processes them based on their content. +// +// Enhanced with real database test data, this test creates registered nodes +// and tests various update types including DERP changes, node-specific changes, +// and full updates. This validates the change classification logic and ensures +// different update types are handled appropriately with actual node data. +// func TestBatcherUpdateTypes(t *testing.T) { +// for _, batcherFunc := range allBatcherFunctions { +// t.Run(batcherFunc.name, func(t *testing.T) { +// // Create test environment with real database and nodes +// testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 8) +// defer cleanup() + +// batcher := testData.Batcher +// testNodes := testData.Nodes + +// ch := make(chan *tailcfg.MapResponse, 10) +// // Use real node ID from test data +// batcher.AddNode(testNodes[0].n.ID, ch, false, "zstd", tailcfg.CapabilityVersion(100)) + +// tests := []struct { +// name string +// changeSet change.ChangeSet +// expectData bool // whether we expect to receive data +// description string +// }{ +// { +// name: "DERP change", +// changeSet: change.DERPSet, +// expectData: true, +// description: "DERP changes should generate map updates", +// }, +// { +// name: "Node key expiry", +// changeSet: change.KeyExpiry(testNodes[1].n.ID), +// expectData: true, +// description: "Node key expiry with real node data", +// }, +// { +// name: "Node new registration", +// changeSet: change.NodeAdded(testNodes[1].n.ID), +// expectData: true, +// description: "New node registration with real data", +// }, +// { +// name: "Full update", +// changeSet: change.FullSet, +// expectData: true, +// description: "Full updates with real node data", +// }, +// { +// name: "Policy change", +// changeSet: change.PolicySet, +// expectData: true, +// description: "Policy updates with real node data", +// }, +// } + +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// t.Logf("Testing: %s", tt.description) + +// // Clear any existing updates +// select { +// case <-ch: +// default: +// } + +// batcher.AddWork(tt.changeSet) + +// select { +// case data := <-ch: +// if !tt.expectData { +// t.Errorf("Unexpected update for %s: %d bytes", tt.name, len(data)) +// } else { +// t.Logf("%s: received %d bytes", tt.name, len(data)) + +// // Validate update content when we have data +// if len(data) > 0 { +// if valid, updateType := validateUpdateContent(data); !valid { +// t.Errorf("Invalid update content for %s: %s", tt.name, updateType) +// } else { +// t.Logf("%s: valid update type: %s", tt.name, updateType) +// } +// } else { +// t.Logf("%s: empty update (may be expected for some node changes)", tt.name) +// } +// } +// case <-time.After(100 * time.Millisecond): +// if tt.expectData { +// t.Errorf("Expected update for %s (%s) but none received", tt.name, tt.description) +// } else { +// t.Logf("%s: no update (expected)", tt.name) +// } +// } +// }) +// } +// }) +// } +// } + +// TestBatcherWorkQueueBatching tests that multiple changes get batched +// together and sent as a single update to reduce network overhead. +// +// Enhanced with real database test data, this test creates registered nodes +// and rapidly submits multiple types of changes including DERP updates and +// node changes. Due to the batching mechanism with BatchChangeDelay, these +// should be combined into fewer updates. This validates that the batching +// system works correctly with real node data and mixed change types. +func TestBatcherWorkQueueBatching(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 8) + defer cleanup() + + batcher := testData.Batcher + testNodes := testData.Nodes + + ch := make(chan *tailcfg.MapResponse, 10) + batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100)) + + // Track update content for validation + var receivedUpdates []*tailcfg.MapResponse + + // Add multiple changes rapidly to test batching + batcher.AddWork(change.DERPSet) + batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID)) + batcher.AddWork(change.DERPSet) + batcher.AddWork(change.NodeAdded(testNodes[1].n.ID)) + batcher.AddWork(change.DERPSet) + + // Collect updates with timeout + updateCount := 0 + timeout := time.After(200 * time.Millisecond) + for { + select { + case data := <-ch: + updateCount++ + receivedUpdates = append(receivedUpdates, data) + + // Validate update content + if data != nil { + if valid, reason := validateUpdateContent(data); valid { + t.Logf("Update %d: valid", updateCount) + } else { + t.Logf("Update %d: invalid: %s", updateCount, reason) + } + } else { + t.Logf("Update %d: nil update", updateCount) + } + case <-timeout: + // Expected: 5 changes should generate 6 updates (no batching in current implementation) + expectedUpdates := 6 + t.Logf("Received %d updates from %d changes (expected %d)", + updateCount, 5, expectedUpdates) + + if updateCount != expectedUpdates { + t.Errorf("Expected %d updates but received %d", expectedUpdates, updateCount) + } + + // Validate that all updates have valid content + validUpdates := 0 + for _, data := range receivedUpdates { + if data != nil { + if valid, _ := validateUpdateContent(data); valid { + validUpdates++ + } + } + } + + if validUpdates != updateCount { + t.Errorf("Expected all %d updates to be valid, but only %d were valid", + updateCount, validUpdates) + } + + return + } + } + }) + } +} + +// TestBatcherChannelClosingRace tests the fix for the async channel closing +// race condition that previously caused panics and data races. +// +// Enhanced with real database test data, this test simulates rapid node +// reconnections using real registered nodes while processing actual updates. +// The test verifies that channels are closed synchronously and deterministically +// even when real node updates are being processed, ensuring no race conditions +// occur during channel replacement with actual workload. +func XTestBatcherChannelClosingRace(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, 8) + defer cleanup() + + batcher := testData.Batcher + testNode := testData.Nodes[0] + var channelIssues int + var mutex sync.Mutex + + // Run rapid connect/disconnect cycles with real updates to test channel closing + for i := range 100 { + var wg sync.WaitGroup + + // First connection + ch1 := make(chan *tailcfg.MapResponse, 1) + wg.Add(1) + go func() { + defer wg.Done() + batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100)) + }() + + // Add real work during connection chaos + if i%10 == 0 { + batcher.AddWork(change.DERPSet) + } + + // Rapid second connection - should replace ch1 + ch2 := make(chan *tailcfg.MapResponse, 1) + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(1 * time.Microsecond) + batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100)) + }() + + // Remove second connection + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(2 * time.Microsecond) + batcher.RemoveNode(testNode.n.ID, ch2, false) + }() + + wg.Wait() + + // Verify ch1 behavior when replaced by ch2 + // The test is checking if ch1 gets closed/replaced properly + select { + case <-ch1: + // Channel received data or was closed, which is expected + case <-time.After(1 * time.Millisecond): + // If no data received, increment issues counter + mutex.Lock() + channelIssues++ + mutex.Unlock() + } + + // Clean up ch2 + select { + case <-ch2: + default: + } + } + + mutex.Lock() + defer mutex.Unlock() + + t.Logf("Channel closing issues: %d out of 100 iterations", channelIssues) + + // The main fix prevents panics and race conditions. Some timing variations + // are acceptable as long as there are no crashes or deadlocks. + if channelIssues > 50 { // Allow some timing variations + t.Errorf("Excessive channel closing issues: %d iterations", channelIssues) + } + }) + } +} + +// TestBatcherWorkerChannelSafety tests that worker goroutines handle closed +// channels safely without panicking when processing work items. +// +// Enhanced with real database test data, this test creates rapid connect/disconnect +// cycles using registered nodes while simultaneously queuing real work items. +// This creates a race where workers might try to send to channels that have been +// closed by node removal. The test validates that the safeSend() method properly +// handles closed channels with real update workloads. +func TestBatcherWorkerChannelSafety(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, 8) + defer cleanup() + + batcher := testData.Batcher + testNode := testData.Nodes[0] + var panics int + var channelErrors int + var invalidData int + var mutex sync.Mutex + + // Test rapid connect/disconnect with work generation + for i := range 50 { + func() { + defer func() { + if r := recover(); r != nil { + mutex.Lock() + panics++ + mutex.Unlock() + t.Logf("Panic caught: %v", r) + } + }() + + ch := make(chan *tailcfg.MapResponse, 5) + + // Add node and immediately queue real work + batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddWork(change.DERPSet) + + // Consumer goroutine to validate data and detect channel issues + go func() { + defer func() { + if r := recover(); r != nil { + mutex.Lock() + channelErrors++ + mutex.Unlock() + t.Logf("Channel consumer panic: %v", r) + } + }() + + for { + select { + case data, ok := <-ch: + if !ok { + // Channel was closed, which is expected + return + } + // Validate the data we received + if valid, reason := validateUpdateContent(data); !valid { + mutex.Lock() + invalidData++ + mutex.Unlock() + t.Logf("Invalid data received: %s", reason) + } + case <-time.After(10 * time.Millisecond): + // Timeout waiting for data + return + } + } + }() + + // Add node-specific work occasionally + if i%10 == 0 { + batcher.AddWork(change.KeyExpiry(testNode.n.ID)) + } + + // Rapid removal creates race between worker and removal + time.Sleep(time.Duration(i%3) * 100 * time.Microsecond) + batcher.RemoveNode(testNode.n.ID, ch, false) + + // Give workers time to process and close channels + time.Sleep(5 * time.Millisecond) + }() + } + + mutex.Lock() + defer mutex.Unlock() + + t.Logf("Worker safety test results: %d panics, %d channel errors, %d invalid data packets", + panics, channelErrors, invalidData) + + // Test failure conditions + if panics > 0 { + t.Errorf("Worker channel safety failed with %d panics", panics) + } + if channelErrors > 0 { + t.Errorf("Channel handling failed with %d channel errors", channelErrors) + } + if invalidData > 0 { + t.Errorf("Data validation failed with %d invalid data packets", invalidData) + } + }) + } +} + +// TestBatcherConcurrentClients tests that concurrent connection lifecycle changes +// don't affect other stable clients' ability to receive updates. +// +// The test sets up real test data with multiple users and registered nodes, +// then creates stable clients and churning clients that rapidly connect and +// disconnect. Work is generated continuously during these connection churn cycles using +// real node data. The test validates that stable clients continue to function +// normally and receive proper updates despite the connection churn from other clients, +// ensuring system stability under concurrent load. +func TestBatcherConcurrentClients(t *testing.T) { + if testing.Short() { + t.Skip("Skipping concurrent client test in short mode") + } + + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create comprehensive test environment with real data + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, TEST_USER_COUNT, TEST_NODES_PER_USER, 8) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + // Create update tracker for monitoring all updates + tracker := newUpdateTracker() + + // Set up stable clients using real node IDs + stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable + stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) + + for _, node := range stableNodes { + ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) + stableChannels[node.n.ID] = ch + batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + + // Monitor updates for each stable client + go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { + for { + select { + case data := <-channel: + if valid, reason := validateUpdateContent(data); valid { + tracker.recordUpdate(nodeID, 1) // Use 1 as update size since we have MapResponse + } else { + t.Errorf("Invalid update received for stable node %d: %s", nodeID, reason) + } + case <-time.After(TEST_TIMEOUT): + return + } + } + }(node.n.ID, ch) + } + + // Use remaining nodes for connection churn testing + churningNodes := allNodes[len(allNodes)/2:] + churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) + var churningChannelsMutex sync.Mutex // Protect concurrent map access + + var wg sync.WaitGroup + numCycles := 10 // Reduced for simpler test + panicCount := 0 + var panicMutex sync.Mutex + + // Track deadlock with timeout + done := make(chan struct{}) + go func() { + defer close(done) + + // Connection churn cycles - rapidly connect/disconnect to test concurrency safety + for i := range numCycles { + for _, node := range churningNodes { + wg.Add(2) + + // Connect churning node + go func(nodeID types.NodeID) { + defer func() { + if r := recover(); r != nil { + panicMutex.Lock() + panicCount++ + panicMutex.Unlock() + t.Logf("Panic in churning connect: %v", r) + } + wg.Done() + }() + + ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) + churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() + batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100)) + + // Consume updates to prevent blocking + go func() { + for { + select { + case data := <-ch: + if valid, _ := validateUpdateContent(data); valid { + tracker.recordUpdate(nodeID, 1) // Use 1 as update size since we have MapResponse + } + case <-time.After(20 * time.Millisecond): + return + } + } + }() + }(node.n.ID) + + // Disconnect churning node + go func(nodeID types.NodeID) { + defer func() { + if r := recover(); r != nil { + panicMutex.Lock() + panicCount++ + panicMutex.Unlock() + t.Logf("Panic in churning disconnect: %v", r) + } + wg.Done() + }() + + time.Sleep(time.Duration(i%5) * time.Millisecond) + churningChannelsMutex.Lock() + ch, exists := churningChannels[nodeID] + churningChannelsMutex.Unlock() + if exists { + batcher.RemoveNode(nodeID, ch, false) + } + }(node.n.ID) + } + + // Generate various types of work during racing + if i%3 == 0 { + // DERP changes + batcher.AddWork(change.DERPSet) + } + if i%5 == 0 { + // Full updates using real node data + batcher.AddWork(change.FullSet) + } + if i%7 == 0 && len(allNodes) > 0 { + // Node-specific changes using real nodes + node := allNodes[i%len(allNodes)] + batcher.AddWork(change.KeyExpiry(node.n.ID)) + } + + // Small delay to allow some batching + time.Sleep(2 * time.Millisecond) + } + + wg.Wait() + }() + + // Deadlock detection + select { + case <-done: + t.Logf("Connection churn cycles completed successfully") + case <-time.After(DEADLOCK_TIMEOUT): + t.Error("Test timed out - possible deadlock detected") + return + } + + // Allow final updates to be processed + time.Sleep(100 * time.Millisecond) + + // Validate results + panicMutex.Lock() + finalPanicCount := panicCount + panicMutex.Unlock() + + allStats := tracker.getAllStats() + + // Calculate expected vs actual updates + stableUpdateCount := 0 + churningUpdateCount := 0 + + // Count actual update sources to understand the pattern + // Let's track what we observe rather than trying to predict + expectedDerpUpdates := (numCycles + 2) / 3 + expectedFullUpdates := (numCycles + 4) / 5 + expectedKeyUpdates := (numCycles + 6) / 7 + totalGeneratedWork := expectedDerpUpdates + expectedFullUpdates + expectedKeyUpdates + + t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls", + expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork) + + for _, node := range stableNodes { + if stats, exists := allStats[node.n.ID]; exists { + stableUpdateCount += stats.TotalUpdates + t.Logf("Stable node %d: %d updates", + node.n.ID, stats.TotalUpdates) + } + + // Verify stable clients are still connected + if !batcher.IsConnected(node.n.ID) { + t.Errorf("Stable node %d should still be connected", node.n.ID) + } + } + + for _, node := range churningNodes { + if stats, exists := allStats[node.n.ID]; exists { + churningUpdateCount += stats.TotalUpdates + } + } + + t.Logf("Total updates - Stable clients: %d, Churning clients: %d", + stableUpdateCount, churningUpdateCount) + t.Logf("Average per stable client: %.1f updates", float64(stableUpdateCount)/float64(len(stableNodes))) + t.Logf("Panics during test: %d", finalPanicCount) + + // Validate test success criteria + if finalPanicCount > 0 { + t.Errorf("Test failed with %d panics", finalPanicCount) + } + + // Basic sanity check - stable clients should receive some updates + if stableUpdateCount == 0 { + t.Error("Stable clients received no updates - batcher may not be working") + } + + // Verify all stable clients are still functional + for _, node := range stableNodes { + if !batcher.IsConnected(node.n.ID) { + t.Errorf("Stable node %d lost connection during racing", node.n.ID) + } + } + }) + } +} + +// TestBatcherHighLoadStability tests batcher behavior under high concurrent load +// scenarios with multiple nodes rapidly connecting and disconnecting while +// continuous updates are generated. +// +// This test creates a high-stress environment with many nodes connecting and +// disconnecting rapidly while various types of updates are generated continuously. +// It validates that the system remains stable with no deadlocks, panics, or +// missed updates under sustained high load. The test uses real node data to +// generate authentic update scenarios and tracks comprehensive statistics. +func XTestBatcherScalability(t *testing.T) { + if testing.Short() { + t.Skip("Skipping scalability test in short mode") + } + + // Reduce verbose application logging for cleaner test output + originalLevel := zerolog.GlobalLevel() + defer zerolog.SetGlobalLevel(originalLevel) + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + + // Full test matrix for scalability testing + nodes := []int{25, 50, 100} // 250, 500, 1000, + + cycles := []int{10, 100} // 500 + bufferSizes := []int{1, 200, 1000} + chaosTypes := []string{"connection", "processing", "mixed"} + + type testCase struct { + name string + nodeCount int + cycles int + bufferSize int + chaosType string + expectBreak bool + description string + } + + var testCases []testCase + + // Generate all combinations of the test matrix + for _, nodeCount := range nodes { + for _, cycleCount := range cycles { + for _, bufferSize := range bufferSizes { + for _, chaosType := range chaosTypes { + expectBreak := false + // resourceIntensity := float64(nodeCount*cycleCount) / float64(bufferSize) + + // switch chaosType { + // case "processing": + // resourceIntensity *= 1.1 + // case "mixed": + // resourceIntensity *= 1.15 + // } + + // if resourceIntensity > 500000 { + // expectBreak = true + // } else if nodeCount >= 1000 && cycleCount >= 500 && bufferSize <= 1 { + // expectBreak = true + // } else if nodeCount >= 500 && cycleCount >= 500 && bufferSize <= 1 && chaosType == "mixed" { + // expectBreak = true + // } + + name := fmt.Sprintf("%s_%dn_%dc_%db", chaosType, nodeCount, cycleCount, bufferSize) + description := fmt.Sprintf("%s chaos: %d nodes, %d cycles, %d buffers", + chaosType, nodeCount, cycleCount, bufferSize) + + testCases = append(testCases, testCase{ + name: name, + nodeCount: nodeCount, + cycles: cycleCount, + bufferSize: bufferSize, + chaosType: chaosType, + expectBreak: expectBreak, + description: description, + }) + } + } + } + } + + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + for i, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create comprehensive test environment with real data using the specific buffer size for this test case + // Need 1000 nodes for largest test case, all from same user so they can be peers + usersNeeded := max(1, tc.nodeCount/1000) // 1 user per 1000 nodes, minimum 1 + nodesPerUser := tc.nodeCount / usersNeeded + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, usersNeeded, nodesPerUser, tc.bufferSize) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description) + t.Logf(" Cycles: %d, Buffer Size: %d, Chaos Type: %s", tc.cycles, tc.bufferSize, tc.chaosType) + + // Use provided nodes, limit to requested count + testNodes := allNodes[:min(len(allNodes), tc.nodeCount)] + + tracker := newUpdateTracker() + panicCount := int64(0) + deadlockDetected := false + + startTime := time.Now() + setupTime := time.Since(startTime) + t.Logf("Starting scalability test with %d nodes (setup took: %v)", len(testNodes), setupTime) + + // Comprehensive stress test + done := make(chan struct{}) + + // Start update consumers for all nodes + for i := range testNodes { + testNodes[i].start() + } + + // Give time for all tracking goroutines to start + time.Sleep(100 * time.Millisecond) + + // Connect all nodes first so they can see each other as peers + connectedNodes := make(map[types.NodeID]bool) + var connectedNodesMutex sync.RWMutex + for i := range testNodes { + node := &testNodes[i] + batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() + connectedNodes[node.n.ID] = true + connectedNodesMutex.Unlock() + } + + // Give more time for all connections to be established + time.Sleep(500 * time.Millisecond) + batcher.AddWork(change.FullSet) + time.Sleep(500 * time.Millisecond) // Allow initial update to propagate + + go func() { + defer close(done) + var wg sync.WaitGroup + + t.Logf("Starting load generation: %d cycles with %d nodes", tc.cycles, len(testNodes)) + + // Main load generation - varies by chaos type + for cycle := range tc.cycles { + if cycle%10 == 0 { + t.Logf("Cycle %d/%d completed", cycle, tc.cycles) + } + // Add delays for mixed chaos + if tc.chaosType == "mixed" && cycle%10 == 0 { + time.Sleep(time.Duration(cycle%2) * time.Microsecond) + } + + // For chaos testing, only disconnect/reconnect a subset of nodes + // This ensures some nodes stay connected to continue receiving updates + startIdx := cycle % len(testNodes) + endIdx := startIdx + len(testNodes)/4 + if endIdx > len(testNodes) { + endIdx = len(testNodes) + } + if startIdx >= endIdx { + startIdx = 0 + endIdx = min(len(testNodes)/4, len(testNodes)) + } + chaosNodes := testNodes[startIdx:endIdx] + if len(chaosNodes) == 0 { + chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos + } + + // Connection/disconnection cycles for subset of nodes + for i, node := range chaosNodes { + // Only add work if this is connection chaos or mixed + if tc.chaosType == "connection" || tc.chaosType == "mixed" { + wg.Add(2) + + // Disconnection first + go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + wg.Done() + }() + + connectedNodesMutex.RLock() + isConnected := connectedNodes[nodeID] + connectedNodesMutex.RUnlock() + + if isConnected { + batcher.RemoveNode(nodeID, channel, false) + connectedNodesMutex.Lock() + connectedNodes[nodeID] = false + connectedNodesMutex.Unlock() + } + }(node.n.ID, node.ch) + + // Then reconnection + go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse, index int) { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + wg.Done() + }() + + // Small delay before reconnecting + time.Sleep(time.Duration(index%3) * time.Millisecond) + batcher.AddNode(nodeID, channel, false, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() + connectedNodes[nodeID] = true + connectedNodesMutex.Unlock() + + // Add work to create load + if index%5 == 0 { + batcher.AddWork(change.FullSet) + } + }(node.n.ID, node.ch, i) + } + } + + // Concurrent work generation - scales with load + updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count + for i := range updateCount { + wg.Add(1) + go func(index int) { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + wg.Done() + }() + + // Generate different types of work to ensure updates are sent + switch index % 4 { + case 0: + batcher.AddWork(change.FullSet) + case 1: + batcher.AddWork(change.PolicySet) + case 2: + batcher.AddWork(change.DERPSet) + default: + // Pick a random node and generate a node change + if len(testNodes) > 0 { + nodeIdx := index % len(testNodes) + batcher.AddWork(change.NodeAdded(testNodes[nodeIdx].n.ID)) + } else { + batcher.AddWork(change.FullSet) + } + } + }(i) + } + } + + t.Logf("Waiting for all goroutines to complete") + wg.Wait() + t.Logf("All goroutines completed") + }() + + // Wait for completion with timeout and progress monitoring + progressTicker := time.NewTicker(10 * time.Second) + defer progressTicker.Stop() + + select { + case <-done: + t.Logf("Test completed successfully") + case <-time.After(TEST_TIMEOUT): + deadlockDetected = true + // Collect diagnostic information + allStats := tracker.getAllStats() + totalUpdates := 0 + for _, stats := range allStats { + totalUpdates += stats.TotalUpdates + } + interimPanics := atomic.LoadInt64(&panicCount) + t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT) + t.Logf(" Progress at timeout: %d total updates, %d panics", totalUpdates, interimPanics) + t.Logf(" Possible causes: deadlock, excessive load, or performance bottleneck") + + // Try to detect if workers are still active + if totalUpdates > 0 { + t.Logf(" System was processing updates - likely performance bottleneck") + } else { + t.Logf(" No updates processed - likely deadlock or startup issue") + } + } + + // Give time for batcher workers to process all the work and send updates + // BEFORE disconnecting nodes + time.Sleep(1 * time.Second) + + // Now disconnect all nodes from batcher to stop new updates + for i := range testNodes { + node := &testNodes[i] + batcher.RemoveNode(node.n.ID, node.ch, false) + } + + // Give time for enhanced tracking goroutines to process any remaining data in channels + time.Sleep(200 * time.Millisecond) + + // Cleanup nodes and get their final stats + totalUpdates := int64(0) + totalPatches := int64(0) + totalFull := int64(0) + maxPeersGlobal := 0 + nodeStatsReport := make([]string, 0, len(testNodes)) + + for i := range testNodes { + node := &testNodes[i] + stats := node.cleanup() + totalUpdates += stats.TotalUpdates + totalPatches += stats.PatchUpdates + totalFull += stats.FullUpdates + if stats.MaxPeersSeen > maxPeersGlobal { + maxPeersGlobal = stats.MaxPeersSeen + } + + if stats.TotalUpdates > 0 { + nodeStatsReport = append(nodeStatsReport, + fmt.Sprintf("Node %d: %d total (%d patch, %d full), max %d peers", + node.n.ID, stats.TotalUpdates, stats.PatchUpdates, stats.FullUpdates, stats.MaxPeersSeen)) + } + } + + // Comprehensive final summary + t.Logf("FINAL RESULTS: %d total updates (%d patch, %d full), max peers seen: %d", + totalUpdates, totalPatches, totalFull, maxPeersGlobal) + if len(nodeStatsReport) <= 10 { // Only log details for smaller tests + for _, report := range nodeStatsReport { + t.Logf(" %s", report) + } + } else { + t.Logf(" (%d nodes had activity, details suppressed for large test)", len(nodeStatsReport)) + } + + // Legacy tracker comparison (optional) + allStats := tracker.getAllStats() + legacyTotalUpdates := 0 + for _, stats := range allStats { + legacyTotalUpdates += stats.TotalUpdates + } + if legacyTotalUpdates != int(totalUpdates) { + t.Logf("Note: Legacy tracker mismatch - legacy: %d, new: %d", legacyTotalUpdates, totalUpdates) + } + + finalPanicCount := atomic.LoadInt64(&panicCount) + + // Validation based on expectation + testPassed := true + if tc.expectBreak { + // For tests expected to break, we're mainly checking that we don't crash + if finalPanicCount > 0 { + t.Errorf("System crashed with %d panics (even breaking point tests shouldn't crash)", finalPanicCount) + testPassed = false + } + // Timeout/deadlock is acceptable for breaking point tests + if deadlockDetected { + t.Logf("Expected breaking point reached: system overloaded at %d nodes", len(testNodes)) + } + } else { + // For tests expected to pass, validate proper operation + if finalPanicCount > 0 { + t.Errorf("Scalability test failed with %d panics", finalPanicCount) + testPassed = false + } + if deadlockDetected { + t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes)) + testPassed = false + } + if totalUpdates == 0 { + t.Error("No updates received - system may be completely stalled") + testPassed = false + } + } + + // Clear success/failure indication + if testPassed { + t.Logf("✅ PASS: %s | %d nodes, %d updates, 0 panics, no deadlock", + tc.name, len(testNodes), totalUpdates) + } else { + t.Logf("❌ FAIL: %s | %d nodes, %d updates, %d panics, deadlock: %v", + tc.name, len(testNodes), totalUpdates, finalPanicCount, deadlockDetected) + } + }) + } + }) + } +} + +// TestBatcherFullPeerUpdates verifies that when multiple nodes are connected +// and we send a FullSet update, nodes receive the complete peer list. +func TestBatcherFullPeerUpdates(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with 3 nodes from same user (so they can be peers) + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + t.Logf("Created %d nodes in database", len(allNodes)) + + // Connect nodes one at a time to avoid overwhelming the work queue + for i, node := range allNodes { + batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + t.Logf("Connected node %d (ID: %d)", i, node.n.ID) + // Small delay between connections to allow NodeCameOnline processing + time.Sleep(50 * time.Millisecond) + } + + // Give additional time for all NodeCameOnline events to be processed + t.Logf("Waiting for NodeCameOnline events to settle...") + time.Sleep(500 * time.Millisecond) + + // Check how many peers each node should see + for i, node := range allNodes { + peers, err := testData.State.ListPeers(node.n.ID) + if err != nil { + t.Errorf("Error listing peers for node %d: %v", i, err) + } else { + t.Logf("Node %d should see %d peers from state", i, len(peers)) + } + } + + // Send a full update - this should generate full peer lists + t.Logf("Sending FullSet update...") + batcher.AddWork(change.FullSet) + + // Give much more time for workers to process the FullSet work items + t.Logf("Waiting for FullSet to be processed...") + time.Sleep(1 * time.Second) + + // Check what each node receives - read multiple updates + totalUpdates := 0 + foundFullUpdate := false + + // Read all available updates for each node + for i := range len(allNodes) { + nodeUpdates := 0 + t.Logf("Reading updates for node %d:", i) + + // Read up to 10 updates per node or until timeout/no more data + for updateNum := range 10 { + select { + case data := <-allNodes[i].ch: + nodeUpdates++ + totalUpdates++ + + // Parse and examine the update - data is already a MapResponse + if data == nil { + t.Errorf("Node %d update %d: nil MapResponse", i, updateNum) + continue + } + + updateType := "unknown" + if len(data.Peers) > 0 { + updateType = "FULL" + foundFullUpdate = true + } else if len(data.PeersChangedPatch) > 0 { + updateType = "PATCH" + } else if data.DERPMap != nil { + updateType = "DERP" + } + + t.Logf(" Update %d: %s - Peers=%d, PeersChangedPatch=%d, DERPMap=%v", + updateNum, updateType, len(data.Peers), len(data.PeersChangedPatch), data.DERPMap != nil) + + if len(data.Peers) > 0 { + t.Logf(" Full peer list with %d peers", len(data.Peers)) + for j, peer := range data.Peers[:min(3, len(data.Peers))] { + t.Logf(" Peer %d: NodeID=%d, Online=%v", j, peer.ID, peer.Online) + } + } + if len(data.PeersChangedPatch) > 0 { + t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch)) + for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] { + t.Logf(" Patch %d: NodeID=%d, Online=%v", j, patch.NodeID, patch.Online) + } + } + + case <-time.After(500 * time.Millisecond): + } + } + t.Logf("Node %d received %d updates", i, nodeUpdates) + } + + t.Logf("Total updates received across all nodes: %d", totalUpdates) + + if !foundFullUpdate { + t.Errorf("CRITICAL: No FULL updates received despite sending change.FullSet!") + t.Errorf("This confirms the bug - FullSet updates are not generating full peer responses") + } + }) + } +} + +// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items. +func TestBatcherWorkQueueTracing(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10) + defer cleanup() + + batcher := testData.Batcher + nodes := testData.Nodes + + t.Logf("=== WORK QUEUE TRACING TEST ===") + + // Connect first node + batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100)) + t.Logf("Connected node %d", nodes[0].n.ID) + + // Wait for initial NodeCameOnline to be processed + time.Sleep(200 * time.Millisecond) + + // Drain any initial updates + drainedCount := 0 + for { + select { + case <-nodes[0].ch: + drainedCount++ + case <-time.After(100 * time.Millisecond): + goto drained + } + } + drained: + t.Logf("Drained %d initial updates", drainedCount) + + // Now send a single FullSet update and trace it closely + t.Logf("Sending change.FullSet work item...") + batcher.AddWork(change.FullSet) + + // Give short time for processing + time.Sleep(100 * time.Millisecond) + + // Check if any update was received + select { + case data := <-nodes[0].ch: + t.Logf("SUCCESS: Received update after FullSet!") + + if data != nil { + // Detailed analysis of the response - data is already a MapResponse + t.Logf("Response details:") + t.Logf(" Peers: %d", len(data.Peers)) + t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch)) + t.Logf(" PeersChanged: %d", len(data.PeersChanged)) + t.Logf(" PeersRemoved: %d", len(data.PeersRemoved)) + t.Logf(" DERPMap: %v", data.DERPMap != nil) + t.Logf(" KeepAlive: %v", data.KeepAlive) + t.Logf(" Node: %v", data.Node != nil) + + if len(data.Peers) > 0 { + t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers)) + } else if len(data.PeersChangedPatch) > 0 { + t.Errorf("ERROR: Received patch update instead of full update!") + } else if data.DERPMap != nil { + t.Logf("Received DERP map update") + } else if data.Node != nil { + t.Logf("Received self node update") + } else { + t.Errorf("ERROR: Received unknown update type!") + } + + // Check if there should be peers available + peers, err := testData.State.ListPeers(nodes[0].n.ID) + if err != nil { + t.Errorf("Error getting peers from state: %v", err) + } else { + t.Logf("State shows %d peers available for this node", len(peers)) + if len(peers) > 0 && len(data.Peers) == 0 { + t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers)) + } + } + } else { + t.Errorf("Response data is nil") + } + case <-time.After(2 * time.Second): + t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!") + t.Errorf("This indicates FullSet work items are not being processed at all") + } + }) + } +} diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go new file mode 100644 index 00000000..b6102c01 --- /dev/null +++ b/hscontrol/mapper/builder.go @@ -0,0 +1,259 @@ +package mapper + +import ( + "net/netip" + "sort" + "time" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" + "tailscale.com/types/views" + "tailscale.com/util/multierr" +) + +// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse +type MapResponseBuilder struct { + resp *tailcfg.MapResponse + mapper *mapper + nodeID types.NodeID + capVer tailcfg.CapabilityVersion + errs []error +} + +// NewMapResponseBuilder creates a new builder with basic fields set +func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { + now := time.Now() + return &MapResponseBuilder{ + resp: &tailcfg.MapResponse{ + KeepAlive: false, + ControlTime: &now, + }, + mapper: m, + nodeID: nodeID, + errs: nil, + } +} + +// addError adds an error to the builder's error list +func (b *MapResponseBuilder) addError(err error) { + if err != nil { + b.errs = append(b.errs, err) + } +} + +// hasErrors returns true if the builder has accumulated any errors +func (b *MapResponseBuilder) hasErrors() bool { + return len(b.errs) > 0 +} + +// WithCapabilityVersion sets the capability version for the response +func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder { + b.capVer = capVer + return b +} + +// WithSelfNode adds the requesting node to the response +func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + b.addError(err) + return b + } + + _, matchers := b.mapper.state.Filter() + tailnode, err := tailNode( + node.View(), b.capVer, b.mapper.state, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + }, + b.mapper.cfg) + if err != nil { + b.addError(err) + return b + } + + b.resp.Node = tailnode + return b +} + +// WithDERPMap adds the DERP map to the response +func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder { + b.resp.DERPMap = b.mapper.state.DERPMap() + return b +} + +// WithDomain adds the domain configuration +func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder { + b.resp.Domain = b.mapper.cfg.Domain() + return b +} + +// WithCollectServicesDisabled sets the collect services flag to false +func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder { + b.resp.CollectServices.Set(false) + return b +} + +// WithDebugConfig adds debug configuration +// It disables log tailing if the mapper's LogTail is not enabled +func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { + b.resp.Debug = &tailcfg.Debug{ + DisableLogTail: !b.mapper.cfg.LogTail.Enabled, + } + return b +} + +// WithSSHPolicy adds SSH policy configuration for the requesting node +func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + b.addError(err) + return b + } + + sshPolicy, err := b.mapper.state.SSHPolicy(node.View()) + if err != nil { + b.addError(err) + return b + } + + b.resp.SSHPolicy = sshPolicy + return b +} + +// WithDNSConfig adds DNS configuration for the requesting node +func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + b.addError(err) + return b + } + + b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node) + return b +} + +// WithUserProfiles adds user profiles for the requesting node and given peers +func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder { + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + b.addError(err) + return b + } + + b.resp.UserProfiles = generateUserProfiles(node, peers) + return b +} + +// WithPacketFilters adds packet filter rules based on policy +func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + b.addError(err) + return b + } + + filter, _ := b.mapper.state.Filter() + + // CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates) + // Currently, we do not send incremental package filters, however using the + // new PacketFilters field and "base" allows us to send a full update when we + // have to send an empty list, avoiding the hack in the else block. + b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ + "base": policy.ReduceFilterRules(node.View(), filter), + } + + return b +} + +// WithPeers adds full peer list with policy filtering (for full map response) +func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { + + tailPeers, err := b.buildTailPeers(peers) + if err != nil { + b.addError(err) + return b + } + + b.resp.Peers = tailPeers + return b +} + +// WithPeerChanges adds changed peers with policy filtering (for incremental updates) +func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder { + + tailPeers, err := b.buildTailPeers(peers) + if err != nil { + b.addError(err) + return b + } + + b.resp.PeersChanged = tailPeers + return b +} + +// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting +func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) { + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + return nil, err + } + + filter, matchers := b.mapper.state.Filter() + + // If there are filter rules present, see if there are any nodes that cannot + // access each-other at all and remove them from the peers. + var changedViews views.Slice[types.NodeView] + if len(filter) > 0 { + changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers) + } else { + changedViews = peers.ViewSlice() + } + + tailPeers, err := tailNodes( + changedViews, b.capVer, b.mapper.state, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + }, + b.mapper.cfg) + if err != nil { + return nil, err + } + + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + + return tailPeers, nil +} + +// WithPeerChangedPatch adds peer change patches +func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { + b.resp.PeersChangedPatch = changes + return b +} + +// WithPeersRemoved adds removed peer IDs +func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { + + var tailscaleIDs []tailcfg.NodeID + for _, id := range removedIDs { + tailscaleIDs = append(tailscaleIDs, id.NodeID()) + } + b.resp.PeersRemoved = tailscaleIDs + return b +} + +// Build finalizes the response and returns marshaled bytes +func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) { + if len(b.errs) > 0 { + return nil, multierr.New(b.errs...) + } + if debugDumpMapResponsePath != "" { + writeDebugMapResponse(b.resp, b.nodeID) + } + + return b.resp, nil +} diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go new file mode 100644 index 00000000..c8ff59ec --- /dev/null +++ b/hscontrol/mapper/builder_test.go @@ -0,0 +1,347 @@ +package mapper + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +func TestMapResponseBuilder_Basic(t *testing.T) { + cfg := &types.Config{ + BaseDomain: "example.com", + LogTail: types.LogTailConfig{ + Enabled: true, + }, + } + + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID) + + // Test basic builder creation + assert.NotNil(t, builder) + assert.Equal(t, nodeID, builder.nodeID) + assert.NotNil(t, builder.resp) + assert.False(t, builder.resp.KeepAlive) + assert.NotNil(t, builder.resp.ControlTime) + assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second) +} + +func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + capVer := tailcfg.CapabilityVersion(42) + + builder := m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer) + + assert.Equal(t, capVer, builder.capVer) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithDomain(t *testing.T) { + domain := "test.example.com" + cfg := &types.Config{ + ServerURL: "https://test.example.com", + BaseDomain: domain, + } + + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithDomain() + + assert.Equal(t, domain, builder.resp.Domain) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithCollectServicesDisabled() + + value, isSet := builder.resp.CollectServices.Get() + assert.True(t, isSet) + assert.False(t, value) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { + tests := []struct { + name string + logTailEnabled bool + expected bool + }{ + { + name: "LogTail enabled", + logTailEnabled: true, + expected: false, // DisableLogTail should be false when LogTail is enabled + }, + { + name: "LogTail disabled", + logTailEnabled: false, + expected: true, // DisableLogTail should be true when LogTail is disabled + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &types.Config{ + LogTail: types.LogTailConfig{ + Enabled: tt.logTailEnabled, + }, + } + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithDebugConfig() + + require.NotNil(t, builder.resp.Debug) + assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail) + assert.False(t, builder.hasErrors()) + }) + } +} + +func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + changes := []*tailcfg.PeerChange{ + { + NodeID: 123, + DERPRegion: 1, + }, + { + NodeID: 456, + DERPRegion: 2, + }, + } + + builder := m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch(changes) + + assert.Equal(t, changes, builder.resp.PeersChangedPatch) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + removedID1 := types.NodeID(123) + removedID2 := types.NodeID(456) + + builder := m.NewMapResponseBuilder(nodeID). + WithPeersRemoved(removedID1, removedID2) + + expected := []tailcfg.NodeID{ + removedID1.NodeID(), + removedID2.NodeID(), + } + assert.Equal(t, expected, builder.resp.PeersRemoved) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_ErrorHandling(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + // Simulate an error in the builder + builder := m.NewMapResponseBuilder(nodeID) + builder.addError(assert.AnError) + + // All subsequent calls should continue to work and accumulate errors + result := builder. + WithDomain(). + WithCollectServicesDisabled(). + WithDebugConfig() + + assert.True(t, result.hasErrors()) + assert.Len(t, result.errs, 1) + assert.Equal(t, assert.AnError, result.errs[0]) + + // Build should return the error + data, err := result.Build("none") + assert.Nil(t, data) + assert.Error(t, err) +} + +func TestMapResponseBuilder_ChainedCalls(t *testing.T) { + domain := "chained.example.com" + cfg := &types.Config{ + ServerURL: "https://chained.example.com", + BaseDomain: domain, + LogTail: types.LogTailConfig{ + Enabled: false, + }, + } + + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + capVer := tailcfg.CapabilityVersion(99) + + builder := m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer). + WithDomain(). + WithCollectServicesDisabled(). + WithDebugConfig() + + // Verify all fields are set correctly + assert.Equal(t, capVer, builder.capVer) + assert.Equal(t, domain, builder.resp.Domain) + value, isSet := builder.resp.CollectServices.Get() + assert.True(t, isSet) + assert.False(t, value) + assert.NotNil(t, builder.resp.Debug) + assert.True(t, builder.resp.Debug.DisableLogTail) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + removedID1 := types.NodeID(100) + removedID2 := types.NodeID(200) + + // Test calling WithPeersRemoved multiple times + builder := m.NewMapResponseBuilder(nodeID). + WithPeersRemoved(removedID1). + WithPeersRemoved(removedID2) + + // Second call should overwrite the first + expected := []tailcfg.NodeID{removedID2.NodeID()} + assert.Equal(t, expected, builder.resp.PeersRemoved) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch([]*tailcfg.PeerChange{}) + + assert.Empty(t, builder.resp.PeersChangedPatch) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch(nil) + + assert.Nil(t, builder.resp.PeersChangedPatch) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_MultipleErrors(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + // Create a builder and add multiple errors + builder := m.NewMapResponseBuilder(nodeID) + builder.addError(assert.AnError) + builder.addError(assert.AnError) + builder.addError(nil) // This should be ignored + + // All subsequent calls should continue to work + result := builder. + WithDomain(). + WithCollectServicesDisabled() + + assert.True(t, result.hasErrors()) + assert.Len(t, result.errs, 2) // nil error should be ignored + + // Build should return a multierr + data, err := result.Build("none") + assert.Nil(t, data) + assert.Error(t, err) + + // The error should contain information about multiple errors + assert.Contains(t, err.Error(), "multiple errors") +} \ No newline at end of file diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 553658f5..43764457 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -1,7 +1,6 @@ package mapper import ( - "encoding/binary" "encoding/json" "fmt" "io/fs" @@ -10,31 +9,21 @@ import ( "os" "path" "slices" - "sort" "strings" - "sync" - "sync/atomic" "time" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "tailscale.com/envknob" - "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" - "tailscale.com/types/views" ) const ( - nextDNSDoHPrefix = "https://dns.nextdns.io" - reservedResponseHeaderSize = 4 - mapperIDLength = 8 - debugMapResponsePerm = 0o755 + nextDNSDoHPrefix = "https://dns.nextdns.io" + mapperIDLength = 8 + debugMapResponsePerm = 0o755 ) var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH") @@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_ // - Create a "minifier" that removes info not needed for the node // - some sort of batching, wait for 5 or 60 seconds before sending -type Mapper struct { +type mapper struct { // Configuration - state *state.State - cfg *types.Config - notif *notifier.Notifier + state *state.State + cfg *types.Config + batcher Batcher - uid string created time.Time - seq uint64 } type patch struct { @@ -66,41 +53,31 @@ type patch struct { change *tailcfg.PeerChange } -func NewMapper( - state *state.State, +func newMapper( cfg *types.Config, - notif *notifier.Notifier, -) *Mapper { - uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) + state *state.State, +) *mapper { + // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) - return &Mapper{ + return &mapper{ state: state, cfg: cfg, - notif: notif, - uid: uid, created: time.Now(), - seq: 0, } } -func (m *Mapper) String() string { - return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created) -} - func generateUserProfiles( - node types.NodeView, - peers views.Slice[types.NodeView], + node *types.Node, + peers types.Nodes, ) []tailcfg.UserProfile { userMap := make(map[uint]*types.User) - ids := make([]uint, 0, peers.Len()+1) - user := node.User() - userMap[user.ID] = &user - ids = append(ids, user.ID) - for _, peer := range peers.All() { - peerUser := peer.User() - userMap[peerUser.ID] = &peerUser - ids = append(ids, peerUser.ID) + ids := make([]uint, 0, len(userMap)) + userMap[node.User.ID] = &node.User + ids = append(ids, node.User.ID) + for _, peer := range peers { + userMap[peer.User.ID] = &peer.User + ids = append(ids, peer.User.ID) } slices.Sort(ids) @@ -117,7 +94,7 @@ func generateUserProfiles( func generateDNSConfig( cfg *types.Config, - node types.NodeView, + node *types.Node, ) *tailcfg.DNSConfig { if cfg.TailcfgDNSConfig == nil { return nil @@ -137,17 +114,16 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ - "device_name": []string{node.Hostname()}, - "device_model": []string{node.Hostinfo().OS()}, + "device_name": []string{node.Hostname}, + "device_model": []string{node.Hostinfo.OS}, } - nodeIPs := node.IPs() - if len(nodeIPs) > 0 { - attrs.Add("device_ip", nodeIPs[0].String()) + if len(node.IPs()) > 0 { + attrs.Add("device_ip", node.IPs()[0].String()) } resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) @@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { } } -// fullMapResponse creates a complete MapResponse for a node. -// It is a separate function to make testing easier. -func (m *Mapper) fullMapResponse( - node types.NodeView, - peers views.Slice[types.NodeView], +// fullMapResponse returns a MapResponse for the given node. +func (m *mapper) fullMapResponse( + nodeID types.NodeID, capVer tailcfg.CapabilityVersion, + messages ...string, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, capVer) + peers, err := m.listPeers(nodeID) if err != nil { return nil, err } - err = appendPeerChanges( - resp, - true, // full change - m.state, - node, - capVer, - peers, - m.cfg, - ) - if err != nil { - return nil, err - } - - return resp, nil + return m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer). + WithSelfNode(). + WithDERPMap(). + WithDomain(). + WithCollectServicesDisabled(). + WithDebugConfig(). + WithSSHPolicy(). + WithDNSConfig(). + WithUserProfiles(peers). + WithPacketFilters(). + WithPeers(peers). + Build(messages...) } -// FullMapResponse returns a MapResponse for the given node. -func (m *Mapper) FullMapResponse( - mapRequest tailcfg.MapRequest, - node types.NodeView, - messages ...string, -) ([]byte, error) { - peers, err := m.ListPeers(node.ID()) - if err != nil { - return nil, err - } - - resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version) - if err != nil { - return nil, err - } - - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) -} - -// ReadOnlyMapResponse returns a MapResponse for the given node. -// Lite means that the peers has been omitted, this is intended -// to be used to answer MapRequests with OmitPeers set to true. -func (m *Mapper) ReadOnlyMapResponse( - mapRequest tailcfg.MapRequest, - node types.NodeView, - messages ...string, -) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) - if err != nil { - return nil, err - } - - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) -} - -func (m *Mapper) KeepAliveResponse( - mapRequest tailcfg.MapRequest, - node types.NodeView, -) ([]byte, error) { - resp := m.baseMapResponse() - resp.KeepAlive = true - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) -} - -func (m *Mapper) DERPMapResponse( - mapRequest tailcfg.MapRequest, - node types.NodeView, - derpMap *tailcfg.DERPMap, -) ([]byte, error) { - resp := m.baseMapResponse() - resp.DERPMap = derpMap - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) -} - -func (m *Mapper) PeerChangedResponse( - mapRequest tailcfg.MapRequest, - node types.NodeView, - changed map[types.NodeID]bool, - patches []*tailcfg.PeerChange, - messages ...string, -) ([]byte, error) { - var err error - resp := m.baseMapResponse() - - var removedIDs []tailcfg.NodeID - var changedIDs []types.NodeID - for nodeID, nodeChanged := range changed { - if nodeChanged { - if nodeID != node.ID() { - changedIDs = append(changedIDs, nodeID) - } - } else { - removedIDs = append(removedIDs, nodeID.NodeID()) - } - } - changedNodes := types.Nodes{} - if len(changedIDs) > 0 { - changedNodes, err = m.ListNodes(changedIDs...) - if err != nil { - return nil, err - } - } - - err = appendPeerChanges( - &resp, - false, // partial change - m.state, - node, - mapRequest.Version, - changedNodes.ViewSlice(), - m.cfg, - ) - if err != nil { - return nil, err - } - - resp.PeersRemoved = removedIDs - - // Sending patches as a part of a PeersChanged response - // is technically not suppose to be done, but they are - // applied after the PeersChanged. The patch list - // should _only_ contain Nodes that are not in the - // PeersChanged or PeersRemoved list and the caller - // should filter them out. - // - // From tailcfg docs: - // These are applied after Peers* above, but in practice the - // control server should only send these on their own, without - // the Peers* fields also set. - if patches != nil { - resp.PeersChangedPatch = patches - } - - _, matchers := m.state.Filter() - // Add the node itself, it might have changed, and particularly - // if there are no patches or changes, this is a self update. - tailnode, err := tailNode( - node, mapRequest.Version, m.state, - func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers) - }, - m.cfg) - if err != nil { - return nil, err - } - resp.Node = tailnode - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...) +func (m *mapper) derpMapResponse( + nodeID types.NodeID, +) (*tailcfg.MapResponse, error) { + return m.NewMapResponseBuilder(nodeID). + WithDERPMap(). + Build() } // PeerChangedPatchResponse creates a patch MapResponse with // incoming update from a state change. -func (m *Mapper) PeerChangedPatchResponse( - mapRequest tailcfg.MapRequest, - node types.NodeView, +func (m *mapper) peerChangedPatchResponse( + nodeID types.NodeID, changed []*tailcfg.PeerChange, -) ([]byte, error) { - resp := m.baseMapResponse() - resp.PeersChangedPatch = changed - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) -} - -func (m *Mapper) marshalMapResponse( - mapRequest tailcfg.MapRequest, - resp *tailcfg.MapResponse, - node types.NodeView, - compression string, - messages ...string, -) ([]byte, error) { - atomic.AddUint64(&m.seq, 1) - - jsonBody, err := json.Marshal(resp) - if err != nil { - return nil, fmt.Errorf("marshalling map response: %w", err) - } - - if debugDumpMapResponsePath != "" { - data := map[string]any{ - "Messages": messages, - "MapRequest": mapRequest, - "MapResponse": resp, - } - - responseType := "keepalive" - - switch { - case resp.Peers != nil && len(resp.Peers) > 0: - responseType = "full" - case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive: - responseType = "self" - case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: - responseType = "changed" - case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0: - responseType = "patch" - case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0: - responseType = "removed" - } - - body, err := json.MarshalIndent(data, "", " ") - if err != nil { - return nil, fmt.Errorf("marshalling map response: %w", err) - } - - perms := fs.FileMode(debugMapResponsePerm) - mPath := path.Join(debugDumpMapResponsePath, node.Hostname()) - err = os.MkdirAll(mPath, perms) - if err != nil { - panic(err) - } - - now := time.Now().Format("2006-01-02T15-04-05.999999999") - - mapResponsePath := path.Join( - mPath, - fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType), - ) - - log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) - err = os.WriteFile(mapResponsePath, body, perms) - if err != nil { - panic(err) - } - } - - var respBody []byte - if compression == util.ZstdCompression { - respBody = zstdEncode(jsonBody) - } else { - respBody = jsonBody - } - - data := make([]byte, reservedResponseHeaderSize) - binary.LittleEndian.PutUint32(data, uint32(len(respBody))) - data = append(data, respBody...) - - return data, nil -} - -func zstdEncode(in []byte) []byte { - encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) - if !ok { - panic("invalid type in sync pool") - } - out := encoder.EncodeAll(in, nil) - _ = encoder.Close() - zstdEncoderPool.Put(encoder) - - return out -} - -var zstdEncoderPool = &sync.Pool{ - New: func() any { - encoder, err := smallzstd.NewEncoder( - nil, - zstd.WithEncoderLevel(zstd.SpeedFastest)) - if err != nil { - panic(err) - } - - return encoder - }, -} - -// baseMapResponse returns a tailcfg.MapResponse with -// KeepAlive false and ControlTime set to now. -func (m *Mapper) baseMapResponse() tailcfg.MapResponse { - now := time.Now() - - resp := tailcfg.MapResponse{ - KeepAlive: false, - ControlTime: &now, - // TODO(kradalby): Implement PingRequest? - } - - return resp -} - -// baseWithConfigMapResponse returns a tailcfg.MapResponse struct -// with the basic configuration from headscale set. -// It is used in for bigger updates, such as full and lite, not -// incremental. -func (m *Mapper) baseWithConfigMapResponse( - node types.NodeView, - capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp := m.baseMapResponse() + return m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch(changed). + Build() +} - _, matchers := m.state.Filter() - tailnode, err := tailNode( - node, capVer, m.state, - func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers) - }, - m.cfg) +// peerChangeResponse returns a MapResponse with changed or added nodes. +func (m *mapper) peerChangeResponse( + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, + changedNodeID types.NodeID, +) (*tailcfg.MapResponse, error) { + peers, err := m.listPeers(nodeID, changedNodeID) if err != nil { return nil, err } - resp.Node = tailnode - resp.DERPMap = m.state.DERPMap() - - resp.Domain = m.cfg.Domain() - - // Do not instruct clients to collect services we do not - // support or do anything with them - resp.CollectServices = "false" - - resp.KeepAlive = false - - resp.Debug = &tailcfg.Debug{ - DisableLogTail: !m.cfg.LogTail.Enabled, - } - - return &resp, nil + return m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer). + WithSelfNode(). + WithUserProfiles(peers). + WithPeerChanges(peers). + Build() } -// ListPeers returns peers of node, regardless of any Policy or if the node is expired. +// peerRemovedResponse creates a MapResponse indicating that a peer has been removed. +func (m *mapper) peerRemovedResponse( + nodeID types.NodeID, + removedNodeID types.NodeID, +) (*tailcfg.MapResponse, error) { + return m.NewMapResponseBuilder(nodeID). + WithPeersRemoved(removedNodeID). + Build() +} + +func writeDebugMapResponse( + resp *tailcfg.MapResponse, + nodeID types.NodeID, + messages ...string, +) { + data := map[string]any{ + "Messages": messages, + "MapResponse": resp, + } + + responseType := "keepalive" + + switch { + case len(resp.Peers) > 0: + responseType = "full" + case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive: + responseType = "self" + case len(resp.PeersChanged) > 0: + responseType = "changed" + case len(resp.PeersChangedPatch) > 0: + responseType = "patch" + case len(resp.PeersRemoved) > 0: + responseType = "removed" + } + + body, err := json.MarshalIndent(data, "", " ") + if err != nil { + panic(err) + } + + perms := fs.FileMode(debugMapResponsePerm) + mPath := path.Join(debugDumpMapResponsePath, nodeID.String()) + err = os.MkdirAll(mPath, perms) + if err != nil { + panic(err) + } + + now := time.Now().Format("2006-01-02T15-04-05.999999999") + + mapResponsePath := path.Join( + mPath, + fmt.Sprintf("%s-%s.json", now, responseType), + ) + + log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) + if err != nil { + panic(err) + } +} + +// listPeers returns peers of node, regardless of any Policy or if the node is expired. // If no peer IDs are given, all peers are returned. // If at least one peer ID is given, only these peer nodes will be returned. -func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { +func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { peers, err := m.state.ListPeers(nodeID, peerIDs...) if err != nil { return nil, err } + // TODO(kradalby): Add back online via batcher. This was removed + // to avoid a circular dependency between the mapper and the notification. for _, peer := range peers { - online := m.notif.IsLikelyConnected(peer.ID) + online := m.batcher.IsConnected(peer.ID) peer.IsOnline = &online } return peers, nil } -// ListNodes queries the database for either all nodes if no parameters are given -// or for the given nodes if at least one node ID is given as parameter. -func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { - nodes, err := m.state.ListNodes(nodeIDs...) - if err != nil { - return nil, err - } - - for _, node := range nodes { - online := m.notif.IsLikelyConnected(node.ID) - node.IsOnline = &online - } - - return nodes, nil -} - // routeFilterFunc is a function that takes a node ID and returns a list of // netip.Prefixes that are allowed for that node. It is used to filter routes // from the primary route manager to the node. type routeFilterFunc func(id types.NodeID) []netip.Prefix - -// appendPeerChanges mutates a tailcfg.MapResponse with all the -// necessary changes when peers have changed. -func appendPeerChanges( - resp *tailcfg.MapResponse, - - fullChange bool, - state *state.State, - node types.NodeView, - capVer tailcfg.CapabilityVersion, - changed views.Slice[types.NodeView], - cfg *types.Config, -) error { - filter, matchers := state.Filter() - - sshPolicy, err := state.SSHPolicy(node) - if err != nil { - return err - } - - // If there are filter rules present, see if there are any nodes that cannot - // access each-other at all and remove them from the peers. - var reducedChanged views.Slice[types.NodeView] - if len(filter) > 0 { - reducedChanged = policy.ReduceNodes(node, changed, matchers) - } else { - reducedChanged = changed - } - - profiles := generateUserProfiles(node, reducedChanged) - - dnsConfig := generateDNSConfig(cfg, node) - - tailPeers, err := tailNodes( - reducedChanged, capVer, state, - func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers) - }, - cfg) - if err != nil { - return err - } - - // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID - }) - - if fullChange { - resp.Peers = tailPeers - } else { - resp.PeersChanged = tailPeers - } - resp.DNSConfig = dnsConfig - resp.UserProfiles = profiles - resp.SSHPolicy = sshPolicy - - // CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates) - // Currently, we do not send incremental package filters, however using the - // new PacketFilters field and "base" allows us to send a full update when we - // have to send an empty list, avoiding the hack in the else block. - resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, filter), - } - - return nil -} diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index b5747c2b..198ba6c4 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -3,6 +3,7 @@ package mapper import ( "fmt" "net/netip" + "slices" "testing" "github.com/google/go-cmp/cmp" @@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) { &types.Config{ TailcfgDNSConfig: &dnsConfigOrig, }, - nodeInShared1.View(), + nodeInShared1, ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { @@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ // Filter peers by the provided IDs var filtered types.Nodes for _, peer := range m.peers { - for _, id := range peerIDs { - if peer.ID == id { - filtered = append(filtered, peer) - break - } + if slices.Contains(peerIDs, peer.ID) { + filtered = append(filtered, peer) } } @@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { // Filter nodes by the provided IDs var filtered types.Nodes for _, node := range m.nodes { - for _, id := range nodeIDs { - if node.ID == id { - filtered = append(filtered, node) - break - } + if slices.Contains(nodeIDs, node.ID) { + filtered = append(filtered, node) } } diff --git a/hscontrol/mapper/utils.go b/hscontrol/mapper/utils.go new file mode 100644 index 00000000..c1dce1f7 --- /dev/null +++ b/hscontrol/mapper/utils.go @@ -0,0 +1,47 @@ +package mapper + +import "tailscale.com/tailcfg" + +// mergePatch takes the current patch and a newer patch +// and override any field that has changed. +func mergePatch(currPatch, newPatch *tailcfg.PeerChange) { + if newPatch.DERPRegion != 0 { + currPatch.DERPRegion = newPatch.DERPRegion + } + + if newPatch.Cap != 0 { + currPatch.Cap = newPatch.Cap + } + + if newPatch.CapMap != nil { + currPatch.CapMap = newPatch.CapMap + } + + if newPatch.Endpoints != nil { + currPatch.Endpoints = newPatch.Endpoints + } + + if newPatch.Key != nil { + currPatch.Key = newPatch.Key + } + + if newPatch.KeySignature != nil { + currPatch.KeySignature = newPatch.KeySignature + } + + if newPatch.DiscoKey != nil { + currPatch.DiscoKey = newPatch.DiscoKey + } + + if newPatch.Online != nil { + currPatch.Online = newPatch.Online + } + + if newPatch.LastSeen != nil { + currPatch.LastSeen = newPatch.LastSeen + } + + if newPatch.KeyExpiry != nil { + currPatch.KeyExpiry = newPatch.KeyExpiry + } +} diff --git a/hscontrol/noise.go b/hscontrol/noise.go index ec4e4e5b..db39992e 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( ns.nodeKey = nv.NodeKey() - sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv) + sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct()) sess.tracef("a node sending a MapRequest with Noise protocol") if !sess.isStreaming() { sess.serve() @@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } - respBody, err := json.Marshal(registerResponse) - if err != nil { - httpError(writer, err) + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { + log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") return } - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - writer.Write(respBody) + // Ensure response is flushed to client + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } } // getAndValidateNode retrieves the node from the database using the NodeKey // and validates that it matches the MachineKey from the Noise session. func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { - nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey) + node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) } - return types.NodeView{}, err + return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil) } + nv := node.View() + // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. if ns.machineKey != nv.MachineKey() { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) diff --git a/hscontrol/notifier/metrics.go b/hscontrol/notifier/metrics.go deleted file mode 100644 index 8a7a8839..00000000 --- a/hscontrol/notifier/metrics.go +++ /dev/null @@ -1,68 +0,0 @@ -package notifier - -import ( - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "tailscale.com/envknob" -) - -const prometheusNamespace = "headscale" - -var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS") - -var notifierUpdateSent *prometheus.CounterVec - -func init() { - if debugHighCardinalityMetrics { - notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: prometheusNamespace, - Name: "notifier_update_sent_total", - Help: "total count of update sent on nodes channel", - }, []string{"status", "type", "trigger", "id"}) - } else { - notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: prometheusNamespace, - Name: "notifier_update_sent_total", - Help: "total count of update sent on nodes channel", - }, []string{"status", "type", "trigger"}) - } -} - -var ( - notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: prometheusNamespace, - Name: "notifier_waiters_for_lock", - Help: "gauge of waiters for the notifier lock", - }, []string{"type", "action"}) - notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{ - Namespace: prometheusNamespace, - Name: "notifier_wait_for_lock_seconds", - Help: "histogram of time spent waiting for the notifier lock", - Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10}, - }, []string{"action"}) - notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: prometheusNamespace, - Name: "notifier_update_received_total", - Help: "total count of updates received by notifier", - }, []string{"type", "trigger"}) - notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{ - Namespace: prometheusNamespace, - Name: "notifier_open_channels_total", - Help: "total count open channels in notifier", - }) - notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: prometheusNamespace, - Name: "notifier_batcher_waiters_for_lock", - Help: "gauge of waiters for the notifier batcher lock", - }, []string{"type", "action"}) - notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: prometheusNamespace, - Name: "notifier_batcher_changes_pending", - Help: "gauge of full changes pending in the notifier batcher", - }, []string{}) - notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: prometheusNamespace, - Name: "notifier_batcher_patches_pending", - Help: "gauge of patches pending in the notifier batcher", - }, []string{}) -) diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go deleted file mode 100644 index 6bd990c7..00000000 --- a/hscontrol/notifier/notifier.go +++ /dev/null @@ -1,488 +0,0 @@ -package notifier - -import ( - "context" - "fmt" - "sort" - "strings" - "sync" - "time" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/puzpuzpuz/xsync/v4" - "github.com/rs/zerolog/log" - "github.com/sasha-s/go-deadlock" - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/util/set" -) - -var ( - debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK") - debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT") -) - -func init() { - deadlock.Opts.Disable = !debugDeadlock - if debugDeadlock { - deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout() - deadlock.Opts.PrintAllCurrentGoroutines = true - } -} - -type Notifier struct { - l deadlock.Mutex - nodes map[types.NodeID]chan<- types.StateUpdate - connected *xsync.MapOf[types.NodeID, bool] - b *batcher - cfg *types.Config - closed bool -} - -func NewNotifier(cfg *types.Config) *Notifier { - n := &Notifier{ - nodes: make(map[types.NodeID]chan<- types.StateUpdate), - connected: xsync.NewMapOf[types.NodeID, bool](), - cfg: cfg, - closed: false, - } - b := newBatcher(cfg.Tuning.BatchChangeDelay, n) - n.b = b - - go b.doWork() - - return n -} - -// Close stops the batcher and closes all channels. -func (n *Notifier) Close() { - notifierWaitersForLock.WithLabelValues("lock", "close").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "close").Dec() - - n.closed = true - n.b.close() - - // Close channels safely using the helper method - for nodeID, c := range n.nodes { - n.safeCloseChannel(nodeID, c) - } - - // Clear node map after closing channels - n.nodes = make(map[types.NodeID]chan<- types.StateUpdate) -} - -// safeCloseChannel closes a channel and panic recovers if already closed. -func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) { - defer func() { - if r := recover(); r != nil { - log.Error(). - Uint64("node.id", nodeID.Uint64()). - Any("recover", r). - Msg("recovered from panic when closing channel in Close()") - } - }() - close(c) -} - -func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) { - log.Trace(). - Uint64("node.id", nID.Uint64()). - Int("open_chans", len(n.nodes)).Msgf(msg, args...) -} - -func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { - start := time.Now() - notifierWaitersForLock.WithLabelValues("lock", "add").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "add").Dec() - notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds()) - - if n.closed { - return - } - - // If a channel exists, it means the node has opened a new - // connection. Close the old channel and replace it. - if curr, ok := n.nodes[nodeID]; ok { - n.tracef(nodeID, "channel present, closing and replacing") - // Use the safeCloseChannel helper in a goroutine to avoid deadlocks - // if/when someone is waiting to send on this channel - go func(ch chan<- types.StateUpdate) { - n.safeCloseChannel(nodeID, ch) - }(curr) - } - - n.nodes[nodeID] = c - n.connected.Store(nodeID, true) - - n.tracef(nodeID, "added new channel") - notifierNodeUpdateChans.Inc() -} - -// RemoveNode removes a node and a given channel from the notifier. -// It checks that the channel is the same as currently being updated -// and ignores the removal if it is not. -// RemoveNode reports if the node/chan was removed. -func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool { - start := time.Now() - notifierWaitersForLock.WithLabelValues("lock", "remove").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "remove").Dec() - notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds()) - - if n.closed { - return true - } - - if len(n.nodes) == 0 { - return true - } - - // If the channel exist, but it does not belong - // to the caller, ignore. - if curr, ok := n.nodes[nodeID]; ok { - if curr != c { - n.tracef(nodeID, "channel has been replaced, not removing") - return false - } - } - - delete(n.nodes, nodeID) - n.connected.Store(nodeID, false) - - n.tracef(nodeID, "removed channel") - notifierNodeUpdateChans.Dec() - - return true -} - -// IsConnected reports if a node is connected to headscale and has a -// poll session open. -func (n *Notifier) IsConnected(nodeID types.NodeID) bool { - notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec() - - if val, ok := n.connected.Load(nodeID); ok { - return val - } - - return false -} - -// IsLikelyConnected reports if a node is connected to headscale and has a -// poll session open, but doesn't lock, so might be wrong. -func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { - if val, ok := n.connected.Load(nodeID); ok { - return val - } - return false -} - -// LikelyConnectedMap returns a thread safe map of connected nodes. -func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { - return n.connected -} - -func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) { - n.NotifyWithIgnore(ctx, update) -} - -func (n *Notifier) NotifyWithIgnore( - ctx context.Context, - update types.StateUpdate, - ignoreNodeIDs ...types.NodeID, -) { - if n.closed { - return - } - - notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() - n.b.addOrPassthrough(update) -} - -func (n *Notifier) NotifyByNodeID( - ctx context.Context, - update types.StateUpdate, - nodeID types.NodeID, -) { - start := time.Now() - notifierWaitersForLock.WithLabelValues("lock", "notify").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "notify").Dec() - notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds()) - - if n.closed { - return - } - - if c, ok := n.nodes[nodeID]; ok { - select { - case <-ctx.Done(): - log.Error(). - Err(ctx.Err()). - Uint64("node.id", nodeID.Uint64()). - Any("origin", types.NotifyOriginKey.Value(ctx)). - Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)). - Msgf("update not sent, context cancelled") - if debugHighCardinalityMetrics { - notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc() - } else { - notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() - } - - return - case c <- update: - n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname")) - if debugHighCardinalityMetrics { - notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc() - } else { - notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() - } - } - } -} - -func (n *Notifier) sendAll(update types.StateUpdate) { - start := time.Now() - notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec() - notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds()) - - if n.closed { - return - } - - for id, c := range n.nodes { - // Whenever an update is sent to all nodes, there is a chance that the node - // has disconnected and the goroutine that was supposed to consume the update - // has shut down the channel and is waiting for the lock held here in RemoveNode. - // This means that there is potential for a deadlock which would stop all updates - // going out to clients. This timeout prevents that from happening by moving on to the - // next node if the context is cancelled. After sendAll releases the lock, the add/remove - // call will succeed and the update will go to the correct nodes on the next call. - ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout) - defer cancel() - select { - case <-ctx.Done(): - log.Error(). - Err(ctx.Err()). - Uint64("node.id", id.Uint64()). - Msgf("update not sent, context cancelled") - if debugHighCardinalityMetrics { - notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc() - } else { - notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc() - } - - return - case c <- update: - if debugHighCardinalityMetrics { - notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc() - } else { - notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc() - } - } - } -} - -func (n *Notifier) String() string { - notifierWaitersForLock.WithLabelValues("lock", "string").Inc() - n.l.Lock() - defer n.l.Unlock() - notifierWaitersForLock.WithLabelValues("lock", "string").Dec() - - var b strings.Builder - fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes)) - - var keys []types.NodeID - n.connected.Range(func(key types.NodeID, value bool) bool { - keys = append(keys, key) - return true - }) - sort.Slice(keys, func(i, j int) bool { - return keys[i] < keys[j] - }) - - for _, key := range keys { - fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key]) - } - - b.WriteString("\n") - fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes)) - - for _, key := range keys { - val, _ := n.connected.Load(key) - fmt.Fprintf(&b, "\t%d: %t\n", key, val) - } - - return b.String() -} - -type batcher struct { - tick *time.Ticker - - mu sync.Mutex - - cancelCh chan struct{} - - changedNodeIDs set.Slice[types.NodeID] - nodesChanged bool - patches map[types.NodeID]tailcfg.PeerChange - patchesChanged bool - - n *Notifier -} - -func newBatcher(batchTime time.Duration, n *Notifier) *batcher { - return &batcher{ - tick: time.NewTicker(batchTime), - cancelCh: make(chan struct{}), - patches: make(map[types.NodeID]tailcfg.PeerChange), - n: n, - } -} - -func (b *batcher) close() { - b.cancelCh <- struct{}{} -} - -// addOrPassthrough adds the update to the batcher, if it is not a -// type that is currently batched, it will be sent immediately. -func (b *batcher) addOrPassthrough(update types.StateUpdate) { - notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc() - b.mu.Lock() - defer b.mu.Unlock() - notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec() - - switch update.Type { - case types.StatePeerChanged: - b.changedNodeIDs.Add(update.ChangeNodes...) - b.nodesChanged = true - notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len())) - - case types.StatePeerChangedPatch: - for _, newPatch := range update.ChangePatches { - if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok { - overwritePatch(&curr, newPatch) - b.patches[types.NodeID(newPatch.NodeID)] = curr - } else { - b.patches[types.NodeID(newPatch.NodeID)] = *newPatch - } - } - b.patchesChanged = true - notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches))) - - default: - b.n.sendAll(update) - } -} - -// flush sends all the accumulated patches to all -// nodes in the notifier. -func (b *batcher) flush() { - notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc() - b.mu.Lock() - defer b.mu.Unlock() - notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec() - - if b.nodesChanged || b.patchesChanged { - var patches []*tailcfg.PeerChange - // If a node is getting a full update from a change - // node update, then the patch can be dropped. - for nodeID, patch := range b.patches { - if b.changedNodeIDs.Contains(nodeID) { - delete(b.patches, nodeID) - } else { - patches = append(patches, &patch) - } - } - - changedNodes := b.changedNodeIDs.Slice().AsSlice() - sort.Slice(changedNodes, func(i, j int) bool { - return changedNodes[i] < changedNodes[j] - }) - - if b.changedNodeIDs.Slice().Len() > 0 { - update := types.UpdatePeerChanged(changedNodes...) - - b.n.sendAll(update) - } - - if len(patches) > 0 { - patchUpdate := types.UpdatePeerPatch(patches...) - - b.n.sendAll(patchUpdate) - } - - b.changedNodeIDs = set.Slice[types.NodeID]{} - notifierBatcherChanges.WithLabelValues().Set(0) - b.nodesChanged = false - b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches)) - notifierBatcherPatches.WithLabelValues().Set(0) - b.patchesChanged = false - } -} - -func (b *batcher) doWork() { - for { - select { - case <-b.cancelCh: - return - case <-b.tick.C: - b.flush() - } - } -} - -// overwritePatch takes the current patch and a newer patch -// and override any field that has changed. -func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) { - if newPatch.DERPRegion != 0 { - currPatch.DERPRegion = newPatch.DERPRegion - } - - if newPatch.Cap != 0 { - currPatch.Cap = newPatch.Cap - } - - if newPatch.CapMap != nil { - currPatch.CapMap = newPatch.CapMap - } - - if newPatch.Endpoints != nil { - currPatch.Endpoints = newPatch.Endpoints - } - - if newPatch.Key != nil { - currPatch.Key = newPatch.Key - } - - if newPatch.KeySignature != nil { - currPatch.KeySignature = newPatch.KeySignature - } - - if newPatch.DiscoKey != nil { - currPatch.DiscoKey = newPatch.DiscoKey - } - - if newPatch.Online != nil { - currPatch.Online = newPatch.Online - } - - if newPatch.LastSeen != nil { - currPatch.LastSeen = newPatch.LastSeen - } - - if newPatch.KeyExpiry != nil { - currPatch.KeyExpiry = newPatch.KeyExpiry - } -} diff --git a/hscontrol/notifier/notifier_test.go b/hscontrol/notifier/notifier_test.go deleted file mode 100644 index c3e96a8d..00000000 --- a/hscontrol/notifier/notifier_test.go +++ /dev/null @@ -1,342 +0,0 @@ -package notifier - -import ( - "fmt" - "math/rand" - "net/netip" - "slices" - "sort" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "tailscale.com/tailcfg" -) - -func TestBatcher(t *testing.T) { - tests := []struct { - name string - updates []types.StateUpdate - want []types.StateUpdate - }{ - { - name: "full-passthrough", - updates: []types.StateUpdate{ - { - Type: types.StateFullUpdate, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StateFullUpdate, - }, - }, - }, - { - name: "derp-passthrough", - updates: []types.StateUpdate{ - { - Type: types.StateDERPUpdated, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StateDERPUpdated, - }, - }, - }, - { - name: "single-node-update", - updates: []types.StateUpdate{ - { - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{ - 2, - }, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{ - 2, - }, - }, - }, - }, - { - name: "merge-node-update", - updates: []types.StateUpdate{ - { - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{ - 2, 4, - }, - }, - { - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{ - 2, 3, - }, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{ - 2, 3, 4, - }, - }, - }, - }, - { - name: "single-patch-update", - updates: []types.StateUpdate{ - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 2, - DERPRegion: 5, - }, - }, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 2, - DERPRegion: 5, - }, - }, - }, - }, - }, - { - name: "merge-patch-to-same-node-update", - updates: []types.StateUpdate{ - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 2, - DERPRegion: 5, - }, - }, - }, - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 2, - DERPRegion: 6, - }, - }, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 2, - DERPRegion: 6, - }, - }, - }, - }, - }, - { - name: "merge-patch-to-multiple-node-update", - updates: []types.StateUpdate{ - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 3, - Endpoints: []netip.AddrPort{ - netip.MustParseAddrPort("1.1.1.1:9090"), - }, - }, - }, - }, - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 3, - Endpoints: []netip.AddrPort{ - netip.MustParseAddrPort("1.1.1.1:9090"), - netip.MustParseAddrPort("2.2.2.2:8080"), - }, - }, - }, - }, - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 4, - DERPRegion: 6, - }, - }, - }, - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 4, - Cap: tailcfg.CapabilityVersion(54), - }, - }, - }, - }, - want: []types.StateUpdate{ - { - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: 3, - Endpoints: []netip.AddrPort{ - netip.MustParseAddrPort("1.1.1.1:9090"), - netip.MustParseAddrPort("2.2.2.2:8080"), - }, - }, - { - NodeID: 4, - DERPRegion: 6, - Cap: tailcfg.CapabilityVersion(54), - }, - }, - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - n := NewNotifier(&types.Config{ - Tuning: types.Tuning{ - // We will call flush manually for the tests, - // so do not run the worker. - BatchChangeDelay: time.Hour, - - // Since we do not load the config, we won't get the - // default, so set it manually so we dont time out - // and have flakes. - NotifierSendTimeout: time.Second, - }, - }) - - ch := make(chan types.StateUpdate, 30) - defer close(ch) - n.AddNode(1, ch) - defer n.RemoveNode(1, ch) - - for _, u := range tt.updates { - n.NotifyAll(t.Context(), u) - } - - n.b.flush() - - var got []types.StateUpdate - for len(ch) > 0 { - out := <-ch - got = append(got, out) - } - - // Make the inner order stable for comparison. - for _, u := range got { - slices.Sort(u.ChangeNodes) - sort.Slice(u.ChangePatches, func(i, j int) bool { - return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID - }) - } - - if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { - t.Errorf("batcher() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected -// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to -// close a channel that was already closed, which can happen when a node changes -// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting. -func TestIsLikelyConnectedRaceCondition(t *testing.T) { - // mock config for the notifier - cfg := &types.Config{ - Tuning: types.Tuning{ - NotifierSendTimeout: 1 * time.Second, - BatchChangeDelay: 1 * time.Second, - NodeMapSessionBufferedChanSize: 30, - }, - } - - notifier := NewNotifier(cfg) - defer notifier.Close() - - nodeID := types.NodeID(1) - updateChan := make(chan types.StateUpdate, 10) - - var wg sync.WaitGroup - - // Number of goroutines to spawn for concurrent access - concurrentAccessors := 100 - iterations := 100 - - // Add node to notifier - notifier.AddNode(nodeID, updateChan) - - // Track errors - errChan := make(chan string, concurrentAccessors*iterations) - - // Start goroutines to cause a race - wg.Add(concurrentAccessors) - for i := range concurrentAccessors { - go func(routineID int) { - defer wg.Done() - - for range iterations { - // Simulate race by having some goroutines check IsLikelyConnected - // while others add/remove the node - switch routineID % 3 { - case 0: - // This goroutine checks connection status - isConnected := notifier.IsLikelyConnected(nodeID) - if isConnected != true && isConnected != false { - errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected) - } - case 1: - // This goroutine removes the node - notifier.RemoveNode(nodeID, updateChan) - default: - // This goroutine adds the node back - notifier.AddNode(nodeID, updateChan) - } - - // Small random delay to increase chance of races - time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) - } - }(i) - } - - wg.Wait() - close(errChan) - - // Collate errors - var errors []string - for err := range errChan { - errors = append(errors, err) - } - - if len(errors) > 0 { - t.Errorf("Detected %d race condition errors: %v", len(errors), errors) - } -} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 5f1935e5..b8607903 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -16,9 +16,8 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" @@ -56,11 +55,10 @@ type RegistrationInfo struct { } type AuthProviderOIDC struct { + h *Headscale serverURL string cfg *types.OIDCConfig - state *state.State registrationCache *zcache.Cache[string, RegistrationInfo] - notifier *notifier.Notifier oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -68,10 +66,9 @@ type AuthProviderOIDC struct { func NewAuthProviderOIDC( ctx context.Context, + h *Headscale, serverURL string, cfg *types.OIDCConfig, - state *state.State, - notif *notifier.Notifier, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -94,11 +91,10 @@ func NewAuthProviderOIDC( ) return &AuthProviderOIDC{ + h: h, serverURL: serverURL, cfg: cfg, - state: state, registrationCache: registrationCache, - notifier: notif, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Send policy update notifications if needed if policyChanged { - ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name) - a.notifier.NotifyAll(ctx, types.UpdateFull()) + a.h.Change(change.PolicyChange()) } // TODO(kradalby): Is this comment right? @@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Neither node nor machine key was found in the state cache meaning // that we could not reauth nor register the node. httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) - - return } func extractCodeAndStateParamFromRequest( @@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( var err error var newUser bool var policyChanged bool - user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier()) + user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, false, fmt.Errorf("creating or updating user: %w", err) } // if the user is still not found, create a new empty user. + // TODO(kradalby): This might cause us to not have an ID below which + // is a problem. if user == nil { newUser = true user = &types.User{} @@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( user.FromClaim(claims) if newUser { - user, policyChanged, err = a.state.CreateUser(*user) + user, policyChanged, err = a.h.state.CreateUser(*user) if err != nil { return nil, false, fmt.Errorf("creating user: %w", err) } } else { - _, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { + _, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { *u = *user return nil }) @@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration( registrationID types.RegistrationID, expiry time.Time, ) (bool, error) { - node, newNode, err := a.state.HandleNodeFromAuthPath( + node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( registrationID, types.UserID(user.ID), &expiry, @@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration( // ensure we send an update. // This works, but might be another good candidate for doing some sort of // eventbus. - routesChanged := a.state.AutoApproveRoutes(node) - _, policyChanged, err := a.state.SaveNode(node) + _ = a.h.state.AutoApproveRoutes(node) + _, policyChange, err := a.h.state.SaveNode(node) if err != nil { return false, fmt.Errorf("saving auto approved routes to node: %w", err) } - // Send policy update notifications if needed (from SaveNode or route changes) - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all") - a.notifier.NotifyAll(ctx, types.UpdateFull()) + // Policy updates are full and take precedence over node changes. + if !policyChange.Empty() { + a.h.Change(policyChange) + } else { + a.h.Change(nodeChange) } - if routesChanged { - ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) - a.notifier.NotifyByNodeID( - ctx, - types.UpdateSelf(node.ID), - node.ID, - ) - - ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) - a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) - } - - return newNode, nil + return !nodeChange.Empty(), nil } // TODO(kradalby): diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 5a9103e5..52457c9b 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf } } } + + // Also check approved subnet routes - nodes should have access + // to subnets they're approved to route traffic for. + subnetRoutes := node.SubnetRoutes() + + for _, subnetRoute := range subnetRoutes { + if expanded.OverlapsPrefix(subnetRoute) { + dests = append(dests, dest) + continue DEST_LOOP + } + } } if len(dests) > 0 { @@ -142,16 +153,23 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { newApproved = append(newApproved, route) } } - if newApproved != nil { - newApproved = append(newApproved, node.ApprovedRoutes...) - tsaddr.SortPrefixes(newApproved) - newApproved = slices.Compact(newApproved) - newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { + + // Only modify ApprovedRoutes if we have new routes to approve. + // This prevents clearing existing approved routes when nodes + // temporarily don't have announced routes during policy changes. + if len(newApproved) > 0 { + combined := append(newApproved, node.ApprovedRoutes...) + tsaddr.SortPrefixes(combined) + combined = slices.Compact(combined) + combined = lo.Filter(combined, func(route netip.Prefix, index int) bool { return route.IsValid() }) - node.ApprovedRoutes = newApproved - return true + // Only update if the routes actually changed + if !slices.Equal(node.ApprovedRoutes, combined) { + node.ApprovedRoutes = combined + return true + } } return false diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9d838e56..c546eb20 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules( } if ips == nil { + log.Debug().Msgf("destination resolved to nil ips: %v", dest) continue } - for _, pref := range ips.Prefixes() { + prefixes := ips.Prefixes() + + for _, pref := range prefixes { for _, port := range dest.Ports { pr := tailcfg.NetPortRange{ IP: pref.String(), @@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy( return nil, nil } + log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname()) + var rules []*tailcfg.SSHRule for index, rule := range pol.SSHs { @@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy( var principals []*tailcfg.SSHPrincipal srcIPs, err := rule.Sources.Resolve(pol, users, nodes) if err != nil { - log.Trace().Err(err).Msgf("resolving source ips") + log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) + continue // Skip this rule if we can't resolve sources } for addr := range util.IPSetAddrIter(srcIPs) { diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 2f4be34e..de839770 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { // TODO(kradalby): This could potentially be optimized by only clearing the // policies for nodes that have changed. Particularly if the only difference is // that nodes has been added or removed. - defer clear(pm.sshPolicyMap) + clear(pm.sshPolicyMap) filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes) if err != nil { diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index c38d1991..a2541da6 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) { // In addition to unmarshalling, it will also validate the policy. // This is the only entrypoint of reading a policy from a file or other source. func unmarshalPolicy(b []byte) (*Policy, error) { - if b == nil || len(b) == 0 { + if len(b) == 0 { return nil, nil } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index b048f62b..1833f060 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -2,20 +2,20 @@ package hscontrol import ( "context" + "encoding/binary" + "encoding/json" "fmt" "math/rand/v2" "net/http" - "net/netip" - "slices" "time" - "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" - xslices "golang.org/x/exp/slices" - "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/util/zstdframe" ) const ( @@ -31,18 +31,17 @@ type mapSession struct { req tailcfg.MapRequest ctx context.Context capVer tailcfg.CapabilityVersion - mapper *mapper.Mapper cancelChMu deadlock.Mutex - ch chan types.StateUpdate + ch chan *tailcfg.MapResponse cancelCh chan struct{} cancelChOpen bool keepAlive time.Duration keepAliveTicker *time.Ticker - node types.NodeView + node *types.Node w http.ResponseWriter warnf func(string, ...any) @@ -55,18 +54,9 @@ func (h *Headscale) newMapSession( ctx context.Context, req tailcfg.MapRequest, w http.ResponseWriter, - nv types.NodeView, + node *types.Node, ) *mapSession { - warnf, infof, tracef, errf := logPollFuncView(req, nv) - - var updateChan chan types.StateUpdate - if req.Stream { - // Use a buffered channel in case a node is not fully ready - // to receive a message to make sure we dont block the entire - // notifier. - updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) - updateChan <- types.UpdateFull() - } + warnf, infof, tracef, errf := logPollFunc(req, node) ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) @@ -75,11 +65,10 @@ func (h *Headscale) newMapSession( ctx: ctx, req: req, w: w, - node: nv, + node: node, capVer: req.Version, - mapper: h.mapper, - ch: updateChan, + ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize), cancelCh: make(chan struct{}), cancelChOpen: true, @@ -95,15 +84,11 @@ func (h *Headscale) newMapSession( } func (m *mapSession) isStreaming() bool { - return m.req.Stream && !m.req.ReadOnly + return m.req.Stream } func (m *mapSession) isEndpointUpdate() bool { - return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers -} - -func (m *mapSession) isReadOnlyUpdate() bool { - return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly + return !m.req.Stream && m.req.OmitPeers } func (m *mapSession) resetKeepAlive() { @@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() { func (m *mapSession) beforeServeLongPoll() { if m.node.IsEphemeral() { - m.h.ephemeralGC.Cancel(m.node.ID()) + m.h.ephemeralGC.Cancel(m.node.ID) } } func (m *mapSession) afterServeLongPoll() { if m.node.IsEphemeral() { - m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout) + m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout) } } // serve handles non-streaming requests. func (m *mapSession) serve() { - // TODO(kradalby): A set todos to harden: - // - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true - // This is the mechanism where the node gives us information about its // current configuration. // - // If OmitPeers is true, Stream is false, and ReadOnly is false, + // If OmitPeers is true and Stream is false // then the server will let clients update their endpoints without // breaking existing long-polling (Stream == true) connections. // In this case, the server can omit the entire response; the client @@ -138,26 +120,18 @@ func (m *mapSession) serve() { // // This is what Tailscale calls a Lite update, the client ignores // the response and just wants a 200. - // !req.stream && !req.ReadOnly && req.OmitPeers - // - // TODO(kradalby): remove ReadOnly when we only support capVer 68+ + // !req.stream && req.OmitPeers if m.isEndpointUpdate() { - m.handleEndpointUpdate() + c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req) + if err != nil { + httpError(m.w, err) + return + } - return - } + m.h.Change(c) - // ReadOnly is whether the client just wants to fetch the - // MapResponse, without updating their Endpoints. The - // Endpoints field will be ignored and LastSeen will not be - // updated and peers will not be notified of changes. - // - // The intended use is for clients to discover the DERP map at - // start-up before their first real endpoint update. - if m.isReadOnlyUpdate() { - m.handleReadOnlyRequest() - - return + m.w.WriteHeader(http.StatusOK) + mapResponseEndpointUpdates.WithLabelValues("ok").Inc() } } @@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() { close(m.cancelCh) m.cancelChMu.Unlock() - // only update node status if the node channel was removed. - // in principal, it will be removed, but the client rapidly - // reconnects, the channel might be of another connection. - // In that case, it is not closed and the node is still online. - if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) { - // TODO(kradalby): This can likely be made more effective, but likely most - // nodes has access to the same routes, so it might not be a big deal. - change, err := m.h.state.Disconnect(m.node.ID()) - if err != nil { - m.errf(err, "Failed to disconnect node %s", m.node.Hostname()) - } - - if change { - ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname()) - m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } + // TODO(kradalby): This can likely be made more effective, but likely most + // nodes has access to the same routes, so it might not be a big deal. + disconnectChange, err := m.h.state.Disconnect(m.node) + if err != nil { + m.errf(err, "Failed to disconnect node %s", m.node.Hostname) } + m.h.Change(disconnectChange) + + m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter()) m.afterServeLongPoll() m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) @@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() { m.h.pollNetMapStreamWG.Add(1) defer m.h.pollNetMapStreamWG.Done() - m.h.state.Connect(m.node.ID()) - - // Upgrade the writer to a ResponseController - rc := http.NewResponseController(m.w) - - // Longpolling will break if there is a write timeout, - // so it needs to be disabled. - rc.SetWriteDeadline(time.Time{}) - - ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname())) + ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) defer cancel() m.keepAliveTicker = time.NewTicker(m.keepAlive) - m.h.nodeNotifier.AddNode(m.node.ID(), m.ch) + // Add node to batcher BEFORE sending Connect change to prevent race condition + // where the change is sent before the node is in the batcher's node map + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil { + m.errf(err, "failed to add node to batcher") + // Send empty response to client to fail fast for invalid/non-existent nodes + select { + case m.ch <- &tailcfg.MapResponse{}: + default: + // Channel might be closed + } + return + } + + // Now send the Connect change - the batcher handles NodeCameOnline internally + // but we still need to update routes and other state-level changes + connectChange := m.h.state.Connect(m.node) + if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline { + m.h.Change(connectChange) + } m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) @@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() { // Consume updates sent to node case update, ok := <-m.ch: + m.tracef("received update from channel, ok: %t", ok) if !ok { m.tracef("update channel closed, streaming session is likely being replaced") return } - // If the node has been removed from headscale, close the stream - if slices.Contains(update.Removed, m.node.ID()) { - m.tracef("node removed, closing stream") + if err := m.writeMap(update); err != nil { + m.errf(err, "cannot write update to client") return } - m.tracef("received stream update: %s %s", update.Type.String(), update.Message) - mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc() - - var data []byte - var err error - var lastMessage string - - // Ensure the node view is updated, for example, there - // might have been a hostinfo update in a sidechannel - // which contains data needed to generate a map response. - m.node, err = m.h.state.GetNodeViewByID(m.node.ID()) - if err != nil { - m.errf(err, "Could not get machine from db") - - return - } - - updateType := "full" - switch update.Type { - case types.StateFullUpdate: - m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) - case types.StatePeerChanged: - changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) - - for _, nodeID := range update.ChangeNodes { - changed[nodeID] = true - } - - lastMessage = update.Message - m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) - updateType = "change" - - case types.StatePeerChangedPatch: - m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) - updateType = "patch" - case types.StatePeerRemoved: - changed := make(map[types.NodeID]bool, len(update.Removed)) - - for _, nodeID := range update.Removed { - changed[nodeID] = false - } - m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) - updateType = "remove" - case types.StateSelfUpdate: - lastMessage = update.Message - m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) - updateType = "remove" - case types.StateDERPUpdated: - m.tracef("Sending DERPUpdate MapResponse") - data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap()) - updateType = "derp" - } - - if err != nil { - m.errf(err, "Could not get the create map update") - - return - } - - // Only send update if there is change - if data != nil { - startWrite := time.Now() - _, err = m.w.Write(data) - if err != nil { - mapResponseSent.WithLabelValues("error", updateType).Inc() - m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m) - return - } - - err = rc.Flush() - if err != nil { - mapResponseSent.WithLabelValues("error", updateType).Inc() - m.errf(err, "flushing the map response to client, for mapSession: %p", m) - return - } - - log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node") - - if debugHighCardinalityMetrics { - mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix())) - } - mapResponseSent.WithLabelValues("ok", updateType).Inc() - m.tracef("update sent") - m.resetKeepAlive() - } + m.tracef("update sent") + m.resetKeepAlive() case <-m.keepAliveTicker.C: - data, err := m.mapper.KeepAliveResponse(m.req, m.node) - if err != nil { - m.errf(err, "Error generating the keep alive msg") - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - _, err = m.w.Write(data) - if err != nil { - m.errf(err, "Cannot write keep alive message") - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - err = rc.Flush() - if err != nil { - m.errf(err, "flushing keep alive to client, for mapSession: %p", m) - mapResponseSent.WithLabelValues("error", "keepalive").Inc() + if err := m.writeMap(&keepAlive); err != nil { + m.errf(err, "cannot write keep alive") return } if debugHighCardinalityMetrics { - mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix())) + mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) } mapResponseSent.WithLabelValues("ok", "keepalive").Inc() } } } -func (m *mapSession) handleEndpointUpdate() { - m.tracef("received endpoint update") - - // Get fresh node state from database for accurate route calculations - node, err := m.h.state.GetNodeByID(m.node.ID()) +// writeMap writes the map response to the client. +// It handles compression if requested and any headers that need to be set. +// It also handles flushing the response if the ResponseWriter +// implements http.Flusher. +func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { + jsonBody, err := json.Marshal(msg) if err != nil { - m.errf(err, "Failed to get fresh node from database for endpoint update") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - return + return fmt.Errorf("marshalling map response: %w", err) } - change := m.node.PeerChangeFromMapRequest(m.req) - - online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID()) - change.Online = &online - - node.ApplyPeerChange(&change) - - sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo) - - // The node might not set NetInfo if it has not changed and if - // the full HostInfo object is overwritten, the information is lost. - // If there is no NetInfo, keep the previous one. - // From 1.66 the client only sends it if changed: - // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 - // TODO(kradalby): evaluate if we need better comparing of hostinfo - // before we take the changes. - if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil { - m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo - } - node.Hostinfo = m.req.Hostinfo - - logTracePeerChange(node.Hostname, sendUpdate, &change) - - // If there is no changes and nothing to save, - // return early. - if peerChangeEmpty(change) && !sendUpdate { - mapResponseEndpointUpdates.WithLabelValues("noop").Inc() - return + if m.req.Compress == util.ZstdCompression { + jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression) } - // Auto approve any routes that have been defined in policy as - // auto approved. Check if this actually changed the node. - routesAutoApproved := m.h.state.AutoApproveRoutes(node) + data := make([]byte, reservedResponseHeaderSize) + binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) + data = append(data, jsonBody...) - // Always update routes for connected nodes to handle reconnection scenarios - // where routes need to be restored to the primary routes system - routesToSet := node.SubnetRoutes() + startWrite := time.Now() - if m.h.state.SetNodeRoutes(node.ID, routesToSet...) { - ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname) - m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } else if routesChanged { - // Only send peer changed notification if routes actually changed - ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname) - m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) - - // TODO(kradalby): I am not sure if we need this? - // Send an update to the node itself with to ensure it - // has an updated packetfilter allowing the new route - // if it is defined in the ACL. - ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname) - m.h.nodeNotifier.NotifyByNodeID( - ctx, - types.UpdateSelf(node.ID), - node.ID) + _, err = m.w.Write(data) + if err != nil { + return err } - // If routes were auto-approved, we need to save the node to persist the changes - if routesAutoApproved { - if _, _, err := m.h.state.SaveNode(node); err != nil { - m.errf(err, "Failed to save auto-approved routes to node") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - return + if m.isStreaming() { + if f, ok := m.w.(http.Flusher); ok { + f.Flush() + } else { + m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush") } } - // Check if there has been a change to Hostname and update them - // in the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the hostname change. - node.ApplyHostnameFromHostInfo(m.req.Hostinfo) + log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") - _, policyChanged, err := m.h.state.SaveNode(node) - if err != nil { - m.errf(err, "Failed to persist/update node in the database") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - - return - } - - // Send policy update notifications if needed - if policyChanged { - ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname) - m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) - } - - ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname) - m.h.nodeNotifier.NotifyWithIgnore( - ctx, - types.UpdatePeerChanged(node.ID), - node.ID, - ) - - m.w.WriteHeader(http.StatusOK) - mapResponseEndpointUpdates.WithLabelValues("ok").Inc() + return nil } -func (m *mapSession) handleReadOnlyRequest() { - m.tracef("Client asked for a lite update, responding without peers") - - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) - if err != nil { - m.errf(err, "Failed to create MapResponse") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseReadOnly.WithLabelValues("error").Inc() - return - } - - m.w.Header().Set("Content-Type", "application/json; charset=utf-8") - m.w.WriteHeader(http.StatusOK) - _, err = m.w.Write(mapResp) - if err != nil { - m.errf(err, "Failed to write response") - mapResponseReadOnly.WithLabelValues("error").Inc() - return - } - - m.w.WriteHeader(http.StatusOK) - mapResponseReadOnly.WithLabelValues("ok").Inc() +var keepAlive = tailcfg.MapResponse{ + KeepAlive: true, } -func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { - trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname) +func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) { + trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname) - if change.Key != nil { - trace = trace.Str("node_key", change.Key.ShortString()) + if peerChange.Key != nil { + trace = trace.Str("node_key", peerChange.Key.ShortString()) } - if change.DiscoKey != nil { - trace = trace.Str("disco_key", change.DiscoKey.ShortString()) + if peerChange.DiscoKey != nil { + trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString()) } - if change.Online != nil { - trace = trace.Bool("online", *change.Online) + if peerChange.Online != nil { + trace = trace.Bool("online", *peerChange.Online) } - if change.Endpoints != nil { - eps := make([]string, len(change.Endpoints)) - for idx, ep := range change.Endpoints { + if peerChange.Endpoints != nil { + eps := make([]string, len(peerChange.Endpoints)) + for idx, ep := range peerChange.Endpoints { eps[idx] = ep.String() } @@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe trace = trace.Bool("hostinfo_changed", hostinfoChange) } - if change.DERPRegion != 0 { - trace = trace.Int("derp_region", change.DERPRegion) + if peerChange.DERPRegion != 0 { + trace = trace.Int("derp_region", peerChange.DERPRegion) } - trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received") -} - -func peerChangeEmpty(chng tailcfg.PeerChange) bool { - return chng.Key == nil && - chng.DiscoKey == nil && - chng.Online == nil && - chng.Endpoints == nil && - chng.DERPRegion == 0 && - chng.LastSeen == nil && - chng.KeyExpiry == nil + trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received") } func logPollFunc( @@ -554,7 +323,6 @@ func logPollFunc( return func(msg string, a ...any) { log.Warn(). Caller(). - Bool("readOnly", mapRequest.ReadOnly). Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). @@ -564,7 +332,6 @@ func logPollFunc( func(msg string, a ...any) { log.Info(). Caller(). - Bool("readOnly", mapRequest.ReadOnly). Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). @@ -574,7 +341,6 @@ func logPollFunc( func(msg string, a ...any) { log.Trace(). Caller(). - Bool("readOnly", mapRequest.ReadOnly). Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). @@ -584,7 +350,6 @@ func logPollFunc( func(err error, msg string, a ...any) { log.Error(). Caller(). - Bool("readOnly", mapRequest.ReadOnly). Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). @@ -593,91 +358,3 @@ func logPollFunc( Msgf(msg, a...) } } - -func logPollFuncView( - mapRequest tailcfg.MapRequest, - nodeView types.NodeView, -) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) { - return func(msg string, a ...any) { - log.Warn(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", nodeView.ID().Uint64()). - Str("node", nodeView.Hostname()). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Info(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", nodeView.ID().Uint64()). - Str("node", nodeView.Hostname()). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Trace(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", nodeView.ID().Uint64()). - Str("node", nodeView.Hostname()). - Msgf(msg, a...) - }, - func(err error, msg string, a ...any) { - log.Error(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", nodeView.ID().Uint64()). - Str("node", nodeView.Hostname()). - Err(err). - Msgf(msg, a...) - } -} - -// hostInfoChanged reports if hostInfo has changed in two ways, -// - first bool reports if an update needs to be sent to nodes -// - second reports if there has been changes to routes -// the caller can then use this info to save and update nodes -// and routes as needed. -func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { - if old.Equal(new) { - return false, false - } - - if old == nil && new != nil { - return true, true - } - - // Routes - oldRoutes := make([]netip.Prefix, 0) - if old != nil { - oldRoutes = old.RoutableIPs - } - newRoutes := new.RoutableIPs - - tsaddr.SortPrefixes(oldRoutes) - tsaddr.SortPrefixes(newRoutes) - - if !xslices.Equal(oldRoutes, newRoutes) { - return true, true - } - - // Services is mostly useful for discovery and not critical, - // except for peerapi, which is how nodes talk to each other. - // If peerapi was not part of the initial mapresponse, we - // need to make sure its sent out later as it is needed for - // Taildrop. - // TODO(kradalby): Length comparison is a bit naive, replace. - if len(old.Services) != len(new.Services) { - return true, false - } - - return false, false -} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index b754e594..02d5d3cd 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -17,10 +17,13 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" + xslices "golang.org/x/exp/slices" "gorm.io/gorm" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" @@ -46,12 +49,6 @@ type State struct { // cfg holds the current Headscale configuration cfg *types.Config - // in-memory data, protected by mu - // nodes contains the current set of registered nodes - nodes types.Nodes - // users contains the current set of users/namespaces - users types.Users - // subsystem keeping state // db provides persistent storage and database operations db *hsdb.HSDatabase @@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) { return &State{ cfg: cfg, - nodes: nodes, - users: users, - db: db, ipAlloc: ipAlloc, // TODO(kradalby): Update DERPMap @@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) { s.mu.Lock() defer s.mu.Unlock() + if err := s.db.DB.Save(&user).Error; err != nil { return nil, false, fmt.Errorf("creating user: %w", err) } @@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) { return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err) } + // Even if the policy manager doesn't detect a filter change, SSH policies + // might now be resolvable when they weren't before. If there are existing + // nodes, we should send a policy change to ensure they get updated SSH policies. + if !policyChanged { + nodes, err := s.ListNodes() + if err == nil && len(nodes) > 0 { + policyChanged = true + } + } + + log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated") + // TODO(kradalby): implement the user in-memory cache return &user, policyChanged, nil @@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) { } // updateNodeTx performs a database transaction to update a node and refresh the policy manager. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) { +func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() @@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err return node, nil }) if err != nil { - return nil, false, err + return nil, change.EmptySet, err } // Check if policy manager needs updating policyChanged, err := s.updatePolicyManagerNodes() if err != nil { - return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err) + return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } // TODO(kradalby): implement the node in-memory cache - return node, policyChanged, nil + var c change.ChangeSet + if policyChanged { + c = change.PolicyChange() + } else { + // Basic node change without specific details since this is a generic update + c = change.NodeAdded(node.ID) + } + + return node, c, nil } // SaveNode persists an existing node to the database and updates the policy manager. -func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) { +func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() if err := s.db.DB.Save(node).Error; err != nil { - return nil, false, fmt.Errorf("saving node: %w", err) + return nil, change.EmptySet, fmt.Errorf("saving node: %w", err) } // Check if policy manager needs updating policyChanged, err := s.updatePolicyManagerNodes() if err != nil { - return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err) + return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) } // TODO(kradalby): implement the node in-memory cache - return node, policyChanged, nil + if policyChanged { + return node, change.PolicyChange(), nil + } + + return node, change.EmptySet, nil } // DeleteNode permanently removes a node and cleans up associated resources. // Returns whether policies changed and any error. This operation is irreversible. -func (s *State) DeleteNode(node *types.Node) (bool, error) { +func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) { err := s.db.DeleteNode(node) if err != nil { - return false, err + return change.EmptySet, err } + c := change.NodeRemoved(node.ID) + // Check if policy manager needs updating after node deletion policyChanged, err := s.updatePolicyManagerNodes() if err != nil { - return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err) + return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err) } - return policyChanged, nil + if policyChanged { + c = change.PolicyChange() + } + + return c, nil } -func (s *State) Connect(id types.NodeID) { +func (s *State) Connect(node *types.Node) change.ChangeSet { + c := change.NodeOnline(node.ID) + routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) + + if routeChange { + c = change.NodeAdded(node.ID) + } + + return c } -func (s *State) Disconnect(id types.NodeID) (bool, error) { - // TODO(kradalby): This node should update the in memory state - _, polChanged, err := s.SetLastSeen(id, time.Now()) +func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) { + c := change.NodeOffline(node.ID) + + _, _, err := s.SetLastSeen(node.ID, time.Now()) if err != nil { - return false, fmt.Errorf("disconnecting node: %w", err) + return c, fmt.Errorf("disconnecting node: %w", err) } - changed := s.primaryRoutes.SetRoutes(id) + if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange { + c = change.PolicyChange() + } - // TODO(kradalby): the returned change should be more nuanced allowing us to - // send more directed updates. - return changed || polChanged, nil + // TODO(kradalby): This node should update the in memory state + return c, nil } // GetNodeByID retrieves a node by ID. @@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) { } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) { + n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.NodeSetExpiry(tx, nodeID, expiry) }) + if err != nil { + return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) + } + + if !c.IsFull() { + c = change.KeyExpiry(nodeID) + } + + return n, c, nil } // SetNodeTags assigns tags to a node for use in access control policies. -func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) { + n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetTags(tx, nodeID, tags) }) + if err != nil { + return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err) + } + + if !c.IsFull() { + c = change.NodeAdded(nodeID) + } + + return n, c, nil } // SetApprovedRoutes sets the network routes that a node is approved to advertise. -func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) { + n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetApprovedRoutes(tx, nodeID, routes) }) + if err != nil { + return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) + } + + // Update primary routes after changing approved routes + routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...) + + if routeChange || !c.IsFull() { + c = change.PolicyChange() + } + + return n, c, nil } // RenameNode changes the display name of a node. -func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) { + n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.RenameNode(tx, nodeID, newName) }) + if err != nil { + return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err) + } + + if !c.IsFull() { + c = change.NodeAdded(nodeID) + } + + return n, c, nil } // SetLastSeen updates when a node was last seen, used for connectivity monitoring. -func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) { +func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) { return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetLastSeen(tx, nodeID, lastSeen) }) } // AssignNodeToUser transfers a node to a different user. -func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) { + n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.AssignNodeToUser(tx, nodeID, userID) }) + if err != nil { + return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err) + } + + if !c.IsFull() { + c = change.NodeAdded(nodeID) + } + + return n, c, nil } // BackfillNodeIPs assigns IP addresses to nodes that don't have them. @@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // ExpireExpiredNodes finds and processes expired nodes since the last check. // Returns next check time, state update with expired nodes, and whether any were found. -func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) { +func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) { return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck) } @@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) { } // SetNodeRoutes sets the primary routes for a node. -func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool { - return s.primaryRoutes.SetRoutes(nodeID, routes...) +func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet { + if s.primaryRoutes.SetRoutes(nodeID, routes...) { + // Route changes affect packet filters for all nodes, so trigger a policy change + // to ensure filters are regenerated across the entire network + return change.PolicyChange() + } + + return change.EmptySet } // GetNodePrimaryRoutes returns the primary routes for a node. @@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath( userID types.UserID, expiry *time.Time, registrationMethod string, -) (*types.Node, bool, error) { +) (*types.Node, change.ChangeSet, error) { ipv4, ipv6, err := s.ipAlloc.Next() if err != nil { - return nil, false, err + return nil, change.EmptySet, err } return s.db.HandleNodeFromAuthPath( @@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, -) (*types.Node, bool, error) { +) (*types.Node, change.ChangeSet, bool, error) { pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) + if err != nil { + return nil, change.EmptySet, false, err + } err = pak.Validate() if err != nil { - return nil, false, err + return nil, change.EmptySet, false, err } nodeToRegister := types.Node{ @@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey( AuthKeyID: &pak.ID, } - // For auth key registration, ensure we don't keep an expired node - // This is especially important for re-registration after logout - if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) { + if !regReq.Expiry.IsZero() { nodeToRegister.Expiry = ®Req.Expiry - } else if !regReq.Expiry.IsZero() { - // If client is sending an expired time (e.g., after logout), - // don't set expiry so the node won't be considered expired - log.Debug(). - Time("requested_expiry", regReq.Expiry). - Str("node", regReq.Hostinfo.Hostname). - Msg("Ignoring expired expiry time from auth key registration") } ipv4, ipv6, err := s.ipAlloc.Next() if err != nil { - return nil, false, fmt.Errorf("allocating IPs: %w", err) + return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err) } node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { @@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey( return node, nil }) if err != nil { - return nil, false, fmt.Errorf("writing node to database: %w", err) + return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err) + } + + // Check if this is a logout request for an ephemeral node + if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { + // This is a logout request for an ephemeral node, delete it immediately + c, err := s.DeleteNode(node) + if err != nil { + return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err) + } + return nil, c, false, nil } // Check if policy manager needs updating // This is necessary because we just created a new node. // We need to ensure that the policy manager is aware of this new node. - policyChanged, err := s.updatePolicyManagerNodes() + // Also update users to ensure all users are known when evaluating policies. + usersChanged, err := s.updatePolicyManagerUsers() if err != nil { - return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err) + return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err) } - return node, policyChanged, nil + nodesChanged, err := s.updatePolicyManagerNodes() + if err != nil { + return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err) + } + + policyChanged := usersChanged || nodesChanged + + c := change.NodeAdded(node.ID) + + return node, c, policyChanged, nil } // AllocateNextIPs allocates the next available IPv4 and IPv6 addresses. @@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) { return false, fmt.Errorf("listing users for policy update: %w", err) } + log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users") + changed, err := s.polMan.SetUsers(users) if err != nil { return false, fmt.Errorf("updating policy manager users: %w", err) } + log.Debug().Bool("changed", changed).Msg("Policy manager users updated") + return changed, nil } @@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error { return nil } + +// TODO(kradalby): This should just take the node ID? +func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) { + // TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, + // which means we could shortcut the whole change thing if there are no other important updates. + peerChange := node.PeerChangeFromMapRequest(req) + + node.ApplyPeerChange(&peerChange) + + sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo) + + // The node might not set NetInfo if it has not changed and if + // the full HostInfo object is overwritten, the information is lost. + // If there is no NetInfo, keep the previous one. + // From 1.66 the client only sends it if changed: + // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 + // TODO(kradalby): evaluate if we need better comparing of hostinfo + // before we take the changes. + if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil { + req.Hostinfo.NetInfo = node.Hostinfo.NetInfo + } + node.Hostinfo = req.Hostinfo + + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(peerChange) && !sendUpdate { + // mapResponseEndpointUpdates.WithLabelValues("noop").Inc() + return change.EmptySet, nil + } + + c := change.EmptySet + + // Check if the Hostinfo of the node has changed. + // If it has changed, check if there has been a change to + // the routable IPs of the host and update them in + // the database. Then send a Changed update + // (containing the whole node object) to peers to inform about + // the route change. + // If the hostinfo has changed, but not the routes, just update + // hostinfo and let the function continue. + if routesChanged { + // Auto approve any routes that have been defined in policy as + // auto approved. Check if this actually changed the node. + _ = s.AutoApproveRoutes(node) + + // Update the routes of the given node in the route manager to + // see if an update needs to be sent. + c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...) + } + + // Check if there has been a change to Hostname and update them + // in the database. Then send a Changed update + // (containing the whole node object) to peers to inform about + // the hostname change. + node.ApplyHostnameFromHostInfo(req.Hostinfo) + + _, policyChange, err := s.SaveNode(node) + if err != nil { + return change.EmptySet, err + } + + if policyChange.IsFull() { + c = policyChange + } + + if c.Empty() { + c = change.NodeAdded(node.ID) + } + + return c, nil +} + +// hostInfoChanged reports if hostInfo has changed in two ways, +// - first bool reports if an update needs to be sent to nodes +// - second reports if there has been changes to routes +// the caller can then use this info to save and update nodes +// and routes as needed. +func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { + if old.Equal(new) { + return false, false + } + + if old == nil && new != nil { + return true, true + } + + // Routes + oldRoutes := make([]netip.Prefix, 0) + if old != nil { + oldRoutes = old.RoutableIPs + } + newRoutes := new.RoutableIPs + + tsaddr.SortPrefixes(oldRoutes) + tsaddr.SortPrefixes(newRoutes) + + if !xslices.Equal(oldRoutes, newRoutes) { + return true, true + } + + // Services is mostly useful for discovery and not critical, + // except for peerapi, which is how nodes talk to each other. + // If peerapi was not part of the initial mapresponse, we + // need to make sure its sent out later as it is needed for + // Taildrop. + // TODO(kradalby): Length comparison is a bit naive, replace. + if len(old.Services) != len(new.Services) { + return true, false + } + + return false, false +} + +func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { + return peerChange.Key == nil && + peerChange.DiscoKey == nil && + peerChange.Online == nil && + peerChange.Endpoints == nil && + peerChange.DERPRegion == 0 && + peerChange.LastSeen == nil && + peerChange.KeyExpiry == nil +} diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go new file mode 100644 index 00000000..3301cb35 --- /dev/null +++ b/hscontrol/types/change/change.go @@ -0,0 +1,183 @@ +//go:generate go tool stringer -type=Change +package change + +import ( + "errors" + + "github.com/juanfont/headscale/hscontrol/types" +) + +type ( + NodeID = types.NodeID + UserID = types.UserID +) + +type Change int + +const ( + ChangeUnknown Change = 0 + + // Deprecated: Use specific change instead + // Full is a legacy change to ensure places where we + // have not yet determined the specific update, can send. + Full Change = 9 + + // Server changes. + Policy Change = 11 + DERP Change = 12 + ExtraRecords Change = 13 + + // Node changes. + NodeCameOnline Change = 21 + NodeWentOffline Change = 22 + NodeRemove Change = 23 + NodeKeyExpiry Change = 24 + NodeNewOrUpdate Change = 25 + + // User changes. + UserNewOrUpdate Change = 51 + UserRemove Change = 52 +) + +// AlsoSelf reports whether this change should also be sent to the node itself. +func (c Change) AlsoSelf() bool { + switch c { + case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate: + return true + } + return false +} + +type ChangeSet struct { + Change Change + + // SelfUpdateOnly indicates that this change should only be sent + // to the node itself, and not to other nodes. + // This is used for changes that are not relevant to other nodes. + // NodeID must be set if this is true. + SelfUpdateOnly bool + + // NodeID if set, is the ID of the node that is being changed. + // It must be set if this is a node change. + NodeID types.NodeID + + // UserID if set, is the ID of the user that is being changed. + // It must be set if this is a user change. + UserID types.UserID + + // IsSubnetRouter indicates whether the node is a subnet router. + IsSubnetRouter bool +} + +func (c *ChangeSet) Validate() error { + if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate { + if c.NodeID == 0 { + return errors.New("ChangeSet.NodeID must be set for node updates") + } + } + + if c.Change >= UserNewOrUpdate || c.Change <= UserRemove { + if c.UserID == 0 { + return errors.New("ChangeSet.UserID must be set for user updates") + } + } + + return nil +} + +// Empty reports whether the ChangeSet is empty, meaning it does not +// represent any change. +func (c ChangeSet) Empty() bool { + return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0 +} + +// IsFull reports whether the ChangeSet represents a full update. +func (c ChangeSet) IsFull() bool { + return c.Change == Full || c.Change == Policy +} + +func (c ChangeSet) AlsoSelf() bool { + // If NodeID is 0, it means this ChangeSet is not related to a specific node, + // so we consider it as a change that should be sent to all nodes. + if c.NodeID == 0 { + return true + } + return c.Change.AlsoSelf() || c.SelfUpdateOnly +} + +var ( + EmptySet = ChangeSet{Change: ChangeUnknown} + FullSet = ChangeSet{Change: Full} + DERPSet = ChangeSet{Change: DERP} + PolicySet = ChangeSet{Change: Policy} + ExtraRecordsSet = ChangeSet{Change: ExtraRecords} +) + +func FullSelf(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: Full, + SelfUpdateOnly: true, + NodeID: id, + } +} + +func NodeAdded(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeNewOrUpdate, + NodeID: id, + } +} + +func NodeRemoved(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeRemove, + NodeID: id, + } +} + +func NodeOnline(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeCameOnline, + NodeID: id, + } +} + +func NodeOffline(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeWentOffline, + NodeID: id, + } +} + +func KeyExpiry(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeKeyExpiry, + NodeID: id, + } +} + +func UserAdded(id types.UserID) ChangeSet { + return ChangeSet{ + Change: UserNewOrUpdate, + UserID: id, + } +} + +func UserRemoved(id types.UserID) ChangeSet { + return ChangeSet{ + Change: UserRemove, + UserID: id, + } +} + +func PolicyChange() ChangeSet { + return ChangeSet{ + Change: Policy, + } +} + +func DERPChange() ChangeSet { + return ChangeSet{ + Change: DERP, + } +} diff --git a/hscontrol/types/change/change_string.go b/hscontrol/types/change/change_string.go new file mode 100644 index 00000000..dbf9d17e --- /dev/null +++ b/hscontrol/types/change/change_string.go @@ -0,0 +1,57 @@ +// Code generated by "stringer -type=Change"; DO NOT EDIT. + +package change + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ChangeUnknown-0] + _ = x[Full-9] + _ = x[Policy-11] + _ = x[DERP-12] + _ = x[ExtraRecords-13] + _ = x[NodeCameOnline-21] + _ = x[NodeWentOffline-22] + _ = x[NodeRemove-23] + _ = x[NodeKeyExpiry-24] + _ = x[NodeNewOrUpdate-25] + _ = x[UserNewOrUpdate-51] + _ = x[UserRemove-52] +} + +const ( + _Change_name_0 = "ChangeUnknown" + _Change_name_1 = "Full" + _Change_name_2 = "PolicyDERPExtraRecords" + _Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate" + _Change_name_4 = "UserNewOrUpdateUserRemove" +) + +var ( + _Change_index_2 = [...]uint8{0, 6, 10, 22} + _Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67} + _Change_index_4 = [...]uint8{0, 15, 25} +) + +func (i Change) String() string { + switch { + case i == 0: + return _Change_name_0 + case i == 9: + return _Change_name_1 + case 11 <= i && i <= 13: + i -= 11 + return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]] + case 21 <= i && i <= 25: + i -= 21 + return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]] + case 51 <= i && i <= 52: + i -= 51 + return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]] + default: + return "Change(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 51e11757..a80f2ab4 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -1,16 +1,16 @@ -//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey - +//go:generate go tool viewer --type=User,Node,PreAuthKey package types +//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey + import ( - "context" "errors" "fmt" + "runtime" "time" "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" - "tailscale.com/util/ctxkey" ) const ( @@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } -var ( - NotifyOriginKey = ctxkey.New("notify.origin", "") - NotifyHostnameKey = ctxkey.New("notify.hostname", "") -) - -func NotifyCtx(ctx context.Context, origin, hostname string) context.Context { - ctx2, _ := context.WithTimeout(ctx, 3*time.Second) - ctx2 = NotifyOriginKey.WithValue(ctx2, origin) - ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname) - return ctx2 -} - const RegistrationIDLength = 24 type RegistrationID string @@ -199,3 +187,20 @@ type RegisterNode struct { Node Node Registered chan *Node } + +// DefaultBatcherWorkers returns the default number of batcher workers. +// Default to 3/4 of CPU cores, minimum 1, no maximum. +func DefaultBatcherWorkers() int { + return DefaultBatcherWorkersFor(runtime.NumCPU()) +} + +// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count. +// Default to 3/4 of CPU cores, minimum 1, no maximum. +func DefaultBatcherWorkersFor(cpuCount int) int { + defaultWorkers := (cpuCount * 3) / 4 + if defaultWorkers < 1 { + defaultWorkers = 1 + } + + return defaultWorkers +} diff --git a/hscontrol/types/common_test.go b/hscontrol/types/common_test.go new file mode 100644 index 00000000..a443918b --- /dev/null +++ b/hscontrol/types/common_test.go @@ -0,0 +1,36 @@ +package types + +import ( + "testing" +) + +func TestDefaultBatcherWorkersFor(t *testing.T) { + tests := []struct { + cpuCount int + expected int + }{ + {1, 1}, // (1*3)/4 = 0, should be minimum 1 + {2, 1}, // (2*3)/4 = 1 + {4, 3}, // (4*3)/4 = 3 + {8, 6}, // (8*3)/4 = 6 + {12, 9}, // (12*3)/4 = 9 + {16, 12}, // (16*3)/4 = 12 + {20, 15}, // (20*3)/4 = 15 + {24, 18}, // (24*3)/4 = 18 + } + + for _, test := range tests { + result := DefaultBatcherWorkersFor(test.cpuCount) + if result != test.expected { + t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected) + } + } +} + +func TestDefaultBatcherWorkers(t *testing.T) { + // Just verify it returns a valid value (>= 1) + result := DefaultBatcherWorkers() + if result < 1 { + t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result) + } +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 1e35303e..44773a55 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -234,6 +234,7 @@ type Tuning struct { NotifierSendTimeout time.Duration BatchChangeDelay time.Duration NodeMapSessionBufferedChanSize int + BatcherWorkers int } func validatePKCEMethod(method string) error { @@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) { NodeMapSessionBufferedChanSize: viper.GetInt( "tuning.node_mapsession_buffered_chan_size", ), + BatcherWorkers: func() int { + if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 { + return workers + } + return DefaultBatcherWorkers() + }(), }, }, nil } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 32f0274c..81a2a86a 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix { return routes } +// IsSubnetRouter reports if the node has any subnet routes. +func (node *Node) IsSubnetRouter() bool { + return len(node.SubnetRoutes()) > 0 +} + func (node *Node) String() string { return node.Hostname } @@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix { return v.ж.SubnetRoutes() } +func (v NodeView) IsSubnetRouter() bool { + if !v.Valid() { + return false + } + return v.ж.IsSubnetRouter() +} + func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) { if !v.Valid() { return diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index e47666ff..46329c12 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -1,17 +1,16 @@ package types import ( - "fmt" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" ) type PAKError string func (e PAKError) Error() string { return string(e) } -func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) } // PreAuthKey describes a pre-authorization key usable in a particular user. type PreAuthKey struct { @@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error { if pak == nil { return PAKError("invalid authkey") } + + log.Debug(). + Str("key", pak.Key). + Bool("hasExpiration", pak.Expiration != nil). + Time("expiration", func() time.Time { + if pak.Expiration != nil { + return *pak.Expiration + } + return time.Time{} + }()). + Time("now", time.Now()). + Bool("reusable", pak.Reusable). + Bool("used", pak.Used). + Msg("PreAuthKey.Validate: checking key") + if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { return PAKError("authkey expired") } diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index 30652e4b..140b70e2 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + "tailscale.com/util/dnsname" + "tailscale.com/util/must" ) func TestCheckForFQDNRules(t *testing.T) { @@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) { func TestMagicDNSRootDomains100(t *testing.T) { domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10")) - found := false - for _, domain := range domains { - if domain == "64.100.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) - - found = false - for _, domain := range domains { - if domain == "100.100.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) - - found = false - for _, domain := range domains { - if domain == "127.100.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa."))) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa."))) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa."))) } func TestMagicDNSRootDomains172(t *testing.T) { domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16")) - found := false - for _, domain := range domains { - if domain == "0.16.172.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) - - found = false - for _, domain := range domains { - if domain == "255.16.172.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa."))) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa."))) } // Happens when netmask is a multiple of 4 bits (sounds likely). diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index a44a6e97..d7bc7897 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // Parse latencies for j := 5; j <= 7; j++ { - if matches[j] != "" { + if j < len(matches) && matches[j] != "" { ms, err := strconv.ParseFloat(matches[j], 64) if err != nil { return Traceroute{}, fmt.Errorf("parsing latency: %w", err) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 1352a02b..8050f6e7 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -88,7 +88,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { var err error listNodes, err = headscale.ListNodes() assert.NoError(ct, err) - assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match before logout count") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count") }, 20*time.Second, 1*time.Second) for _, node := range listNodes { @@ -123,7 +123,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { var err error listNodes, err = headscale.ListNodes() assert.NoError(ct, err) - assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match after HTTPS reconnection") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection") }, 30*time.Second, 2*time.Second) for _, node := range listNodes { @@ -161,7 +161,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } listNodes, err = headscale.ListNodes() - require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + require.Len(t, listNodes, nodeCountBeforeLogout) for _, node := range listNodes { assertLastSeenSet(t, node) } @@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { "--user", strconv.FormatUint(userMap[userName].GetId(), 10), "expire", - key.Key, + key.GetKey(), }) assertNoErr(t, err) diff --git a/integration/cli_test.go b/integration/cli_test.go index 7f4f9936..42d191e0 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -604,7 +604,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err) - assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState, + assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState, "Expected node to be logged out, backend state: %s", status.BackendState) }, 30*time.Second, 2*time.Second) diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 86c1e046..799d70f3 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -147,3 +147,9 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) { config.CapAdd = append(config.CapAdd, "NET_ADMIN") config.Privileged = true } + +// DockerMemoryLimit sets memory limit and disables OOM kill for containers. +func DockerMemoryLimit(config *docker.HostConfig) { + config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes + config.OOMKillDisable = true +} diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 051b9261..e9ba69dd 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -145,9 +145,9 @@ func derpServerScenario( assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) for _, health := range status.Health { - assert.NotContains(ct, health, "could not connect to any relay server", + assert.NotContains(ct, health, "could not connect to any relay server", "Client %s should be connected to DERP relay", client.Hostname()) - assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", + assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", "Client %s should be connected to Headscale Embedded DERP", client.Hostname()) } }, 30*time.Second, 2*time.Second) @@ -166,9 +166,9 @@ func derpServerScenario( assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) for _, health := range status.Health { - assert.NotContains(ct, health, "could not connect to any relay server", + assert.NotContains(ct, health, "could not connect to any relay server", "Client %s should be connected to DERP relay after first run", client.Hostname()) - assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", + assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", "Client %s should be connected to Headscale Embedded DERP after first run", client.Hostname()) } }, 30*time.Second, 2*time.Second) @@ -191,9 +191,9 @@ func derpServerScenario( assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) for _, health := range status.Health { - assert.NotContains(ct, health, "could not connect to any relay server", + assert.NotContains(ct, health, "could not connect to any relay server", "Client %s should be connected to DERP relay after second run", client.Hostname()) - assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", + assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", "Client %s should be connected to Headscale Embedded DERP after second run", client.Hostname()) } }, 30*time.Second, 2*time.Second) diff --git a/integration/general_test.go b/integration/general_test.go index 0e1a8da5..4e250854 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -883,6 +883,10 @@ func TestNodeOnlineStatus(t *testing.T) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err) + if status == nil { + assert.Fail(ct, "status is nil") + return + } for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] @@ -984,16 +988,11 @@ func TestPingAllByIPManyUpDown(t *testing.T) { } // Wait for sync and successful pings after nodes come back up - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - err = scenario.WaitForTailscaleSync() - assert.NoError(ct, err) - - success := pingAllHelper(t, allClients, allAddrs) - assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up") - }, 30*time.Second, 2*time.Second) + err = scenario.WaitForTailscaleSync() + assert.NoError(t, err) success := pingAllHelper(t, allClients, allAddrs) - t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps)) } } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 5e7db275..e77d2fbe 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -260,7 +260,9 @@ func WithDERPConfig(derpMap tailcfg.DERPMap) Option { func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option { return func(hsic *HeadscaleInContainer) { hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String() - hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize) + hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa( + mapSessionChanSize, + ) } } @@ -279,10 +281,16 @@ func WithDebugPort(port int) Option { // buildEntrypoint builds the container entrypoint command based on configuration. func (hsic *HeadscaleInContainer) buildEntrypoint() []string { - debugCmd := fmt.Sprintf("/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve", hsic.debugPort) - - entrypoint := fmt.Sprintf("/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30", debugCmd) - + debugCmd := fmt.Sprintf( + "/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve", + hsic.debugPort, + ) + + entrypoint := fmt.Sprintf( + "/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30", + debugCmd, + ) + return []string{"/bin/bash", "-c", entrypoint} } @@ -447,8 +455,12 @@ func New( log.Printf("Created %s container\n", hsic.hostname) hsic.container = container - - log.Printf("Debug ports for %s: delve=%s, metrics/pprof=49090\n", hsic.hostname, hsic.GetHostDebugPort()) + + log.Printf( + "Debug ports for %s: delve=%s, metrics/pprof=49090\n", + hsic.hostname, + hsic.GetHostDebugPort(), + ) // Write the CA certificates to the container for i, cert := range hsic.caCerts { @@ -684,14 +696,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { return nil } - // First, let's see what files are actually in /tmp - tmpListing, err := t.Execute([]string{"ls", "-la", "/tmp/"}) - if err != nil { - log.Printf("Warning: could not list /tmp directory: %v", err) - } else { - log.Printf("Contents of /tmp in container %s:\n%s", t.hostname, tmpListing) - } - // Also check for any .sqlite files sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"}) if err != nil { @@ -718,12 +722,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { return errors.New("database file exists but has no schema (empty database)") } - // Show a preview of the schema (first 500 chars) - schemaPreview := schemaCheck - if len(schemaPreview) > 500 { - schemaPreview = schemaPreview[:500] + "..." - } - tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") if err != nil { return fmt.Errorf("failed to fetch database file: %w", err) @@ -740,7 +738,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { return fmt.Errorf("failed to read tar header: %w", err) } - log.Printf("Found file in tar: %s (type: %d, size: %d)", header.Name, header.Typeflag, header.Size) + log.Printf( + "Found file in tar: %s (type: %d, size: %d)", + header.Name, + header.Typeflag, + header.Size, + ) // Extract the first regular file we find if header.Typeflag == tar.TypeReg { @@ -756,11 +759,20 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { return fmt.Errorf("failed to copy database file: %w", err) } - log.Printf("Extracted database file: %s (%d bytes written, header claimed %d bytes)", dbPath, written, header.Size) + log.Printf( + "Extracted database file: %s (%d bytes written, header claimed %d bytes)", + dbPath, + written, + header.Size, + ) // Check if we actually wrote something if written == 0 { - return fmt.Errorf("database file is empty (size: %d, header size: %d)", written, header.Size) + return fmt.Errorf( + "database file is empty (size: %d, header size: %d)", + written, + header.Size, + ) } return nil @@ -871,7 +883,15 @@ func (t *HeadscaleInContainer) WaitForRunning() error { func (t *HeadscaleInContainer) CreateUser( user string, ) (*v1.User, error) { - command := []string{"headscale", "users", "create", user, fmt.Sprintf("--email=%s@test.no", user), "--output", "json"} + command := []string{ + "headscale", + "users", + "create", + user, + fmt.Sprintf("--email=%s@test.no", user), + "--output", + "json", + } result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -1182,13 +1202,18 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( []string{}, ) if err != nil { - return nil, fmt.Errorf("failed to execute list node command: %w", err) + return nil, fmt.Errorf( + "failed to execute approve routes command (node %d, routes %v): %w", + id, + routes, + err, + ) } var node *v1.Node err = json.Unmarshal([]byte(result), &node) if err != nil { - return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) + return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err) } return node, nil diff --git a/integration/route_test.go b/integration/route_test.go index aa6b9e2e..7243d3f2 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -310,7 +310,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Enable route on node 1 t.Logf("Enabling route on subnet router 1, no HA") _, err = headscale.ApproveRoutes( - 1, + MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{pref}, ) require.NoError(t, err) @@ -366,7 +366,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Enable route on node 2, now we will have a HA subnet router t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby") _, err = headscale.ApproveRoutes( - 2, + MustFindNode(subRouter2.Hostname(), nodes).GetId(), []netip.Prefix{pref}, ) require.NoError(t, err) @@ -422,7 +422,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // be enabled. t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby") _, err = headscale.ApproveRoutes( - 3, + MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{pref}, ) require.NoError(t, err) @@ -639,7 +639,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname()) t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname()) - _, err = headscale.ApproveRoutes(nodes[2].GetId(), []netip.Prefix{}) + _, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{}) time.Sleep(5 * time.Second) @@ -647,9 +647,9 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NoError(t, err) assert.Len(t, nodes, 6) - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -684,7 +684,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Disable the route of subnet router 1, making it failover to 2 t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname()) t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname()) - _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{}) + _, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{}) time.Sleep(5 * time.Second) @@ -692,9 +692,9 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NoError(t, err) assert.Len(t, nodes, 6) - requireNodeRouteCount(t, nodes[0], 1, 0, 0) - requireNodeRouteCount(t, nodes[1], 1, 1, 1) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -729,9 +729,10 @@ func TestHASubnetRouterFailover(t *testing.T) { // enable the route of subnet router 1, no change expected t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname()) t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname()) + r1Node := MustFindNode(subRouter1.Hostname(), nodes) _, err = headscale.ApproveRoutes( - nodes[0].GetId(), - util.MustStringsToPrefixes(nodes[0].GetAvailableRoutes()), + r1Node.GetId(), + util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()), ) time.Sleep(5 * time.Second) @@ -740,9 +741,9 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NoError(t, err) assert.Len(t, nodes, 6) - requireNodeRouteCount(t, nodes[0], 1, 1, 0) - requireNodeRouteCount(t, nodes[1], 1, 1, 1) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() diff --git a/integration/scenario.go b/integration/scenario.go index b235cf34..817d927b 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -223,7 +223,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { s.userToNetwork = userToNetwork - if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 { + if len(spec.OIDCUsers) != 0 { ttl := defaultAccessTTL if spec.OIDCAccessTTL != 0 { ttl = spec.OIDCAccessTTL diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 236aba20..3015503f 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -370,10 +370,12 @@ func TestSSHUserOnlyIsolation(t *testing.T) { } func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + t.Helper() return doSSHWithRetry(t, client, peer, true) } func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + t.Helper() return doSSHWithRetry(t, client, peer, false) } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 1818c16a..01603512 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -319,6 +319,7 @@ func New( dockertestutil.DockerRestartPolicy, dockertestutil.DockerAllowLocalIPv6, dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, ) case "unstable": tailscaleOptions.Repository = "tailscale/tailscale" @@ -329,6 +330,7 @@ func New( dockertestutil.DockerRestartPolicy, dockertestutil.DockerAllowLocalIPv6, dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, ) default: tailscaleOptions.Repository = "tailscale/tailscale" @@ -339,6 +341,7 @@ func New( dockertestutil.DockerRestartPolicy, dockertestutil.DockerAllowLocalIPv6, dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, ) } diff --git a/integration/utils.go b/integration/utils.go index a7ab048b..2e70b793 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -22,11 +22,11 @@ import ( const ( // derpPingTimeout defines the timeout for individual DERP ping operations - // Used in DERP connectivity tests to verify relay server communication + // Used in DERP connectivity tests to verify relay server communication. derpPingTimeout = 2 * time.Second - + // derpPingCount defines the number of ping attempts for DERP connectivity tests - // Higher count provides better reliability assessment of DERP connectivity + // Higher count provides better reliability assessment of DERP connectivity. derpPingCount = 10 ) @@ -317,11 +317,11 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) { // assertCommandOutputContains executes a command with exponential backoff retry until the output // contains the expected string or timeout is reached (10 seconds). -// This implements eventual consistency patterns and should be used instead of time.Sleep +// This implements eventual consistency patterns and should be used instead of time.Sleep // before executing commands that depend on network state propagation. // // Timeout: 10 seconds with exponential backoff -// Use cases: DNS resolution, route propagation, policy updates +// Use cases: DNS resolution, route propagation, policy updates. func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { t.Helper() @@ -361,10 +361,10 @@ func isSelfClient(client TailscaleClient, addr string) bool { } func dockertestMaxWait() time.Duration { - wait := 120 * time.Second //nolint + wait := 300 * time.Second //nolint if util.IsCI() { - wait = 300 * time.Second //nolint + wait = 600 * time.Second //nolint } return wait diff --git a/hscontrol/capver/gen/main.go b/tools/capver/main.go similarity index 91% rename from hscontrol/capver/gen/main.go rename to tools/capver/main.go index 3b31686d..37bab0bc 100644 --- a/hscontrol/capver/gen/main.go +++ b/tools/capver/main.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "os" "regexp" @@ -21,7 +20,7 @@ import ( const ( releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases" rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go" - outputFile = "../capver_generated.go" + outputFile = "../../hscontrol/capver/capver_generated.go" ) type Release struct { @@ -105,7 +104,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion sortedVersions := xmaps.Keys(versions) sort.Strings(sortedVersions) for _, version := range sortedVersions { - file.WriteString(fmt.Sprintf("\t\"%s\": %d,\n", version, versions[version])) + fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version]) } file.WriteString("}\n") @@ -115,16 +114,13 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string) for _, v := range sortedVersions { cap := versions[v] - log.Printf("cap for v: %d, %s", cap, v) // If it is already set, skip and continue, // we only want the first tailscale vsion per // capability vsion. if _, ok := capVarToTailscaleVer[cap]; ok { - log.Printf("Skipping %d, %s", cap, v) continue } - log.Printf("Storing %d, %s", cap, v) capVarToTailscaleVer[cap] = v } @@ -133,7 +129,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion return capsSorted[i] < capsSorted[j] }) for _, capVer := range capsSorted { - file.WriteString(fmt.Sprintf("\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])) + fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]) } file.WriteString("}\n")