From da9018a0ebc23d67741f3b4caf61ed00a026ead9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 12 Nov 2025 09:36:36 -0600 Subject: [PATCH] types: make pre auth key use bcrypt (#2853) --- CHANGELOG.md | 14 +- cmd/headscale/cli/preauthkeys.go | 2 +- hscontrol/auth_test.go | 36 ++- hscontrol/db/api_key.go | 216 ++++++++++++-- hscontrol/db/api_key_test.go | 145 ++++++++++ hscontrol/db/db.go | 32 +++ hscontrol/db/ip_test.go | 5 - hscontrol/db/preauth_keys.go | 182 ++++++++++-- hscontrol/db/preauth_keys_test.go | 450 +++++++++++++++++++++++++++--- hscontrol/db/schema.sql | 3 + hscontrol/db/suite_test.go | 2 + hscontrol/db/users_test.go | 338 +++++++++++++++------- hscontrol/grpcv1.go | 18 ++ hscontrol/state/state.go | 11 +- hscontrol/types/api_key.go | 19 +- hscontrol/types/preauth_key.go | 55 +++- hscontrol/types/types_clone.go | 3 + hscontrol/types/types_view.go | 20 +- hscontrol/util/util.go | 17 ++ hscontrol/util/util_test.go | 96 +++++++ integration/cli_test.go | 11 +- 21 files changed, 1450 insertions(+), 225 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04a16156..e75a8a50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,13 +8,25 @@ The OIDC callback and device registration web pages have been updated to use the Material for MkDocs design system from the official documentation. The templates now use consistent typography, spacing, and colours across all registration flows. External links are properly secured with noreferrer/noopener attributes. +### Pre-authentication key security improvements + +Pre-authentication keys now use bcrypt hashing for improved security +[#2853](https://github.com/juanfont/headscale/pull/2853). Keys are stored as a +prefix and bcrypt hash instead of plaintext. The full key is only displayed once +at creation time. When listing keys, only the prefix is shown (e.g., +`hskey-auth-{prefix}-***`). All new keys use the format +`hskey-auth-{prefix}-{secret}`. Legacy plaintext keys continue to work for +backwards compatibility. ### Changes - Add NixOS module in repository for faster iteration [#2857](https://github.com/juanfont/headscale/pull/2857) - Add favicon to webpages [#2858](https://github.com/juanfont/headscale/pull/2858) -- Reclaim IPs from the IP allocator when nodes are deleted [#2831](https://github.com/juanfont/headscale/pull/2831) - Redesign OIDC callback and registration web templates [#2832](https://github.com/juanfont/headscale/pull/2832) +- Reclaim IPs from the IP allocator when nodes are deleted [#2831](https://github.com/juanfont/headscale/pull/2831) +- Add bcrypt hashing for pre-authentication keys [#2853](https://github.com/juanfont/headscale/pull/2853) +- Add structured prefix format for API keys (`hskey-api-{prefix}-{secret}`) [#2853](https://github.com/juanfont/headscale/pull/2853) +- Add registration keys for web authentication tracking (`hskey-reg-{random}`) [#2853](https://github.com/juanfont/headscale/pull/2853) ## 0.27.1 (2025-11-11) diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index c0c08831..e42fa1e3 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -88,7 +88,7 @@ var listPreAuthKeys = &cobra.Command{ tableData := pterm.TableData{ { "ID", - "Key", + "Key/Prefix", "Reusable", "Ephemeral", "Used", diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index bf6da356..9a5566c6 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -3026,7 +3026,11 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // Create user and single-use pre-auth key user := app.state.CreateUserForTest("test-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false + pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false + require.NoError(t, err) + + // Fetch the full pre-auth key to check Reusable field + pak, err := app.state.GetPreAuthKey(pakNew.Key) require.NoError(t, err) require.False(t, pak.Reusable, "key should be single-use for this test") @@ -3036,7 +3040,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // STEP 1: Initial registration with pre-auth key (simulates fresh node joining) initialReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: pak.Key, + AuthKey: pakNew.Key, }, NodeKey: nodeKey.Public(), Hostinfo: &tailcfg.Hostinfo{ @@ -3060,7 +3064,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { assert.Equal(t, machineKey.Public(), node.MachineKey()) // Verify pre-auth key is now marked as used - usedPak, err := app.state.GetPreAuthKey(pak.Key) + usedPak, err := app.state.GetPreAuthKey(pakNew.Key) require.NoError(t, err) assert.True(t, usedPak.Used, "pre-auth key should be marked as used after initial registration") @@ -3073,7 +3077,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") restartReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: pak.Key, // Same key, now marked as Used=true + AuthKey: pakNew.Key, // Same key, now marked as Used=true }, NodeKey: nodeKey.Public(), // Same node key Hostinfo: &tailcfg.Hostinfo{ @@ -3113,7 +3117,11 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) { app := createTestApp(t) user := app.state.CreateUserForTest("test-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true + pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true + require.NoError(t, err) + + // Fetch the full pre-auth key to check Reusable field + pak, err := app.state.GetPreAuthKey(pakNew.Key) require.NoError(t, err) require.True(t, pak.Reusable) @@ -3123,7 +3131,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) { // Initial registration initialReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: pak.Key, + AuthKey: pakNew.Key, }, NodeKey: nodeKey.Public(), Hostinfo: &tailcfg.Hostinfo{ @@ -3140,7 +3148,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) { // Node restart - re-registration with reusable key restartReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: pak.Key, // Reusable key + AuthKey: pakNew.Key, // Reusable key }, NodeKey: nodeKey.Public(), Hostinfo: &tailcfg.Hostinfo{ @@ -3209,7 +3217,11 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. // Create a SINGLE-USE pre-auth key (reusable=false) // This is the type of key that triggers the bug in issue #2830 - preAuthKey, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + preAuthKeyNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + require.NoError(t, err) + + // Fetch the full pre-auth key to check Reusable and Used fields + preAuthKey, err := app.state.GetPreAuthKey(preAuthKeyNew.Key) require.NoError(t, err) require.False(t, preAuthKey.Reusable, "Pre-auth key must be single-use to test issue #2830") require.False(t, preAuthKey.Used, "Pre-auth key should not be used yet") @@ -3222,7 +3234,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. // This simulates the first time the container starts and runs 'tailscale up --authkey=...' initialReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: preAuthKey.Key, + AuthKey: preAuthKeyNew.Key, // Use the full key from creation }, NodeKey: nodeKey.Public(), Hostinfo: &tailcfg.Hostinfo{ @@ -3238,7 +3250,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. require.Equal(t, "testuser", initialResp.User.DisplayName, "User should match the pre-auth key's user") // Verify the pre-auth key is now marked as Used - updatedKey, err := app.state.GetPreAuthKey(preAuthKey.Key) + updatedKey, err := app.state.GetPreAuthKey(preAuthKeyNew.Key) require.NoError(t, err) require.True(t, updatedKey.Used, "Pre-auth key should be marked as Used after initial registration") @@ -3253,7 +3265,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. // This is exactly what happens when a container restarts reregisterReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: preAuthKey.Key, // Same key, now marked as Used=true + AuthKey: preAuthKeyNew.Key, // Same key, now marked as Used=true }, NodeKey: nodeKey.Public(), // Same NodeKey Hostinfo: &tailcfg.Hostinfo{ @@ -3280,7 +3292,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. attackReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ - AuthKey: preAuthKey.Key, // Try to use the same key + AuthKey: preAuthKeyNew.Key, // Try to use the same key }, NodeKey: differentNodeKey.Public(), Hostinfo: &tailcfg.Hostinfo{ diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 51083145..7457670c 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -9,33 +9,64 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) const ( - apiPrefixLength = 7 - apiKeyLength = 32 + apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential + apiKeyPrefixLength = 12 + apiKeyHashLength = 64 + + // Legacy format constants. + legacyAPIPrefixLength = 7 + legacyAPIKeyLength = 32 ) -var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") +var ( + ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") + ErrAPIKeyGenerationFailed = errors.New("failed to generate API key") + ErrAPIKeyInvalidGeneration = errors.New("generated API key failed validation") +) // CreateAPIKey creates a new ApiKey in a user, and returns it. func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { - prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) + // Generate public prefix (12 chars) + prefix, err := util.GenerateRandomStringURLSafe(apiKeyPrefixLength) if err != nil { return "", nil, err } - toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength) + // Validate prefix + if len(prefix) != apiKeyPrefixLength { + return "", nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyPrefixLength, len(prefix)) + } + + if !isValidBase64URLSafe(prefix) { + return "", nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrAPIKeyInvalidGeneration) + } + + // Generate secret (64 chars) + secret, err := util.GenerateRandomStringURLSafe(apiKeyHashLength) if err != nil { return "", nil, err } - // Key to return to user, this will only be visible _once_ - keyStr := prefix + "." + toBeHashed + // Validate secret + if len(secret) != apiKeyHashLength { + return "", nil, fmt.Errorf("%w: generated secret has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyHashLength, len(secret)) + } - hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) + if !isValidBase64URLSafe(secret) { + return "", nil, fmt.Errorf("%w: generated secret contains invalid characters", ErrAPIKeyInvalidGeneration) + } + + // Full key string (shown ONCE to user) + keyStr := apiKeyPrefix + prefix + "-" + secret + + // bcrypt hash of secret + hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) if err != nil { return "", nil, err } @@ -103,23 +134,164 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { - prefix, hash, found := strings.Cut(keyStr, ".") - if !found { - return false, ErrAPIKeyFailedToParse - } - - key, err := hsdb.GetAPIKey(prefix) + key, err := validateAPIKey(hsdb.DB, keyStr) if err != nil { - return false, fmt.Errorf("failed to validate api key: %w", err) - } - - if key.Expiration.Before(time.Now()) { - return false, nil - } - - if err := bcrypt.CompareHashAndPassword(key.Hash, []byte(hash)); err != nil { return false, err } + if key.Expiration != nil && key.Expiration.Before(time.Now()) { + return false, nil + } + return true, nil } + +// ParseAPIKeyPrefix extracts the database prefix from a display prefix. +// Handles formats: "hskey-api-{12chars}-***", "hskey-api-{12chars}", or just "{12chars}". +// Returns the 12-character prefix suitable for database lookup. +func ParseAPIKeyPrefix(displayPrefix string) (string, error) { + // If it's already just the 12-character prefix, return it + if len(displayPrefix) == apiKeyPrefixLength && isValidBase64URLSafe(displayPrefix) { + return displayPrefix, nil + } + + // If it starts with the API key prefix, parse it + if strings.HasPrefix(displayPrefix, apiKeyPrefix) { + // Remove the "hskey-api-" prefix + _, remainder, found := strings.Cut(displayPrefix, apiKeyPrefix) + if !found { + return "", fmt.Errorf("%w: invalid display prefix format", ErrAPIKeyFailedToParse) + } + + // Extract just the first 12 characters (the actual prefix) + if len(remainder) < apiKeyPrefixLength { + return "", fmt.Errorf("%w: prefix too short", ErrAPIKeyFailedToParse) + } + + prefix := remainder[:apiKeyPrefixLength] + + // Validate it's base64 URL-safe + if !isValidBase64URLSafe(prefix) { + return "", fmt.Errorf("%w: prefix contains invalid characters", ErrAPIKeyFailedToParse) + } + + return prefix, nil + } + + // For legacy 7-character prefixes or other formats, return as-is + return displayPrefix, nil +} + +// validateAPIKey validates an API key and returns the key if valid. +// Handles both new (hskey-api-{prefix}-{secret}) and legacy (prefix.secret) formats. +func validateAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) { + // Validate input is not empty + if keyStr == "" { + return nil, ErrAPIKeyFailedToParse + } + + // Check for new format: hskey-api-{prefix}-{secret} + _, prefixAndSecret, found := strings.Cut(keyStr, apiKeyPrefix) + + if !found { + // Legacy format: prefix.secret + return validateLegacyAPIKey(db, keyStr) + } + + // New format: parse and verify + const expectedMinLength = apiKeyPrefixLength + 1 + apiKeyHashLength + if len(prefixAndSecret) < expectedMinLength { + return nil, fmt.Errorf( + "%w: key too short, expected at least %d chars after prefix, got %d", + ErrAPIKeyFailedToParse, + expectedMinLength, + len(prefixAndSecret), + ) + } + + // Use fixed-length parsing + prefix := prefixAndSecret[:apiKeyPrefixLength] + + // Validate separator at expected position + if prefixAndSecret[apiKeyPrefixLength] != '-' { + return nil, fmt.Errorf( + "%w: expected separator '-' at position %d, got '%c'", + ErrAPIKeyFailedToParse, + apiKeyPrefixLength, + prefixAndSecret[apiKeyPrefixLength], + ) + } + + secret := prefixAndSecret[apiKeyPrefixLength+1:] + + // Validate secret length + if len(secret) != apiKeyHashLength { + return nil, fmt.Errorf( + "%w: secret length mismatch, expected %d chars, got %d", + ErrAPIKeyFailedToParse, + apiKeyHashLength, + len(secret), + ) + } + + // Validate prefix contains only base64 URL-safe characters + if !isValidBase64URLSafe(prefix) { + return nil, fmt.Errorf( + "%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrAPIKeyFailedToParse, + ) + } + + // Validate secret contains only base64 URL-safe characters + if !isValidBase64URLSafe(secret) { + return nil, fmt.Errorf( + "%w: secret contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrAPIKeyFailedToParse, + ) + } + + // Look up by prefix (indexed) + var key types.APIKey + + err := db.First(&key, "prefix = ?", prefix).Error + if err != nil { + return nil, fmt.Errorf("API key not found: %w", err) + } + + // Verify bcrypt hash + err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret)) + if err != nil { + return nil, fmt.Errorf("invalid API key: %w", err) + } + + return &key, nil +} + +// validateLegacyAPIKey validates a legacy format API key (prefix.secret). +func validateLegacyAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) { + // Legacy format uses "." as separator + prefix, secret, found := strings.Cut(keyStr, ".") + if !found { + return nil, ErrAPIKeyFailedToParse + } + + // Legacy prefix is 7 chars + if len(prefix) != legacyAPIPrefixLength { + return nil, fmt.Errorf("%w: legacy prefix length mismatch", ErrAPIKeyFailedToParse) + } + + var key types.APIKey + + err := db.First(&key, "prefix = ?", prefix).Error + if err != nil { + return nil, fmt.Errorf("API key not found: %w", err) + } + + // Verify bcrypt (key.Hash stores bcrypt of full secret) + err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret)) + if err != nil { + return nil, fmt.Errorf("invalid API key: %w", err) + } + + return &key, nil +} diff --git a/hscontrol/db/api_key_test.go b/hscontrol/db/api_key_test.go index c0b4e988..6899da6c 100644 --- a/hscontrol/db/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -1,8 +1,14 @@ package db import ( + "strings" + "testing" "time" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" "gopkg.in/check.v1" ) @@ -87,3 +93,142 @@ func (*Suite) TestExpireAPIKey(c *check.C) { c.Assert(err, check.IsNil) c.Assert(notValid, check.Equals, false) } + +func TestAPIKeyWithPrefix(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "new_key_with_prefix", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + keyStr, apiKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + + // Verify format: hskey-api-{12-char-prefix}-{64-char-secret} + assert.True(t, strings.HasPrefix(keyStr, "hskey-api-")) + + _, prefixAndSecret, found := strings.Cut(keyStr, "hskey-api-") + assert.True(t, found) + assert.GreaterOrEqual(t, len(prefixAndSecret), 12+1+64) + + prefix := prefixAndSecret[:12] + assert.Len(t, prefix, 12) + assert.Equal(t, byte('-'), prefixAndSecret[12]) + secret := prefixAndSecret[13:] + assert.Len(t, secret, 64) + + // Verify stored fields + assert.Len(t, apiKey.Prefix, types.NewAPIKeyPrefixLength) + assert.NotNil(t, apiKey.Hash) + }, + }, + { + name: "new_key_can_be_retrieved", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + keyStr, createdKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + + // Validate the created key + valid, err := db.ValidateAPIKey(keyStr) + require.NoError(t, err) + assert.True(t, valid) + + // Verify prefix is correct length + assert.Len(t, createdKey.Prefix, types.NewAPIKeyPrefixLength) + }, + }, + { + name: "invalid_key_format_rejected", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + invalidKeys := []string{ + "", + "hskey-api-short", + "hskey-api-ABCDEFGHIJKL-tooshort", + "hskey-api-ABC$EFGHIJKL-" + strings.Repeat("a", 64), + "hskey-api-ABCDEFGHIJKL" + strings.Repeat("a", 64), // missing separator + } + + for _, invalidKey := range invalidKeys { + valid, err := db.ValidateAPIKey(invalidKey) + require.Error(t, err, "key should be rejected: %s", invalidKey) + assert.False(t, valid) + } + }, + }, + { + name: "legacy_key_still_works", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + // Insert legacy API key directly (7-char prefix + 32-char secret) + legacyPrefix := "abcdefg" + legacySecret := strings.Repeat("x", 32) + legacyKey := legacyPrefix + "." + legacySecret + hash, err := bcrypt.GenerateFromPassword([]byte(legacySecret), bcrypt.DefaultCost) + require.NoError(t, err) + + now := time.Now() + err = db.DB.Exec(` + INSERT INTO api_keys (prefix, hash, created_at) + VALUES (?, ?, ?) + `, legacyPrefix, hash, now).Error + require.NoError(t, err) + + // Validate legacy key + valid, err := db.ValidateAPIKey(legacyKey) + require.NoError(t, err) + assert.True(t, valid) + }, + }, + { + name: "wrong_secret_rejected", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + keyStr, _, err := db.CreateAPIKey(nil) + require.NoError(t, err) + + // Tamper with the secret + _, prefixAndSecret, _ := strings.Cut(keyStr, "hskey-api-") + prefix := prefixAndSecret[:12] + tamperedKey := "hskey-api-" + prefix + "-" + strings.Repeat("x", 64) + + valid, err := db.ValidateAPIKey(tamperedKey) + require.Error(t, err) + assert.False(t, valid) + }, + }, + { + name: "expired_key_rejected", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + // Create expired key + expired := time.Now().Add(-1 * time.Hour) + keyStr, _, err := db.CreateAPIKey(&expired) + require.NoError(t, err) + + // Should fail validation + valid, err := db.ValidateAPIKey(keyStr) + require.NoError(t, err) + assert.False(t, valid) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) + } +} diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 4eefee91..5539f5c5 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -991,6 +991,38 @@ AND auth_key_id NOT IN ( // - NEVER use gorm.AutoMigrate, write the exact migration steps needed // - AutoMigrate depends on the struct staying exactly the same, which it won't over time. // - Never write migrations that requires foreign keys to be disabled. + { + // Add columns for prefix and hash for pre auth keys, implementing + // them with the same security model as api keys. + ID: "202511011637-preauthkey-bcrypt", + Migrate: func(tx *gorm.DB) error { + // Check and add prefix column if it doesn't exist + if !tx.Migrator().HasColumn(&types.PreAuthKey{}, "prefix") { + err := tx.Migrator().AddColumn(&types.PreAuthKey{}, "prefix") + if err != nil { + return fmt.Errorf("adding prefix column: %w", err) + } + } + + // Check and add hash column if it doesn't exist + if !tx.Migrator().HasColumn(&types.PreAuthKey{}, "hash") { + err := tx.Migrator().AddColumn(&types.PreAuthKey{}, "hash") + if err != nil { + return fmt.Errorf("adding hash column: %w", err) + } + } + + // Create partial unique index to allow multiple legacy keys (NULL/empty prefix) + // while enforcing uniqueness for new bcrypt-based keys + err := tx.Exec("CREATE UNIQUE INDEX IF NOT EXISTS idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''").Error + if err != nil { + return fmt.Errorf("creating prefix index: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index f558cdf7..3ec81c9f 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - "github.com/davecgh/go-spew/spew" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" @@ -159,8 +158,6 @@ func TestIPAllocatorSequential(t *testing.T) { types.IPAllocationStrategySequential, ) - spew.Dump(alloc) - var got4s []netip.Addr var got6s []netip.Addr @@ -263,8 +260,6 @@ func TestIPAllocatorRandom(t *testing.T) { alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom) - spew.Dump(alloc) - for range tt.getCount { got4, got6, err := alloc.Next() if err != nil { diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 94575269..00260966 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -1,8 +1,6 @@ package db import ( - "crypto/rand" - "encoding/hex" "errors" "fmt" "slices" @@ -10,6 +8,8 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" "tailscale.com/util/set" ) @@ -28,12 +28,18 @@ func (hsdb *HSDatabase) CreatePreAuthKey( ephemeral bool, expiration *time.Time, aclTags []string, -) (*types.PreAuthKey, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { +) (*types.PreAuthKeyNew, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) { return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) }) } +const ( + authKeyPrefix = "hskey-auth-" + authKeyPrefixLength = 12 + authKeyLength = 64 +) + // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func CreatePreAuthKey( tx *gorm.DB, @@ -42,7 +48,7 @@ func CreatePreAuthKey( ephemeral bool, expiration *time.Time, aclTags []string, -) (*types.PreAuthKey, error) { +) (*types.PreAuthKeyNew, error) { user, err := GetUserByID(tx, uid) if err != nil { return nil, err @@ -65,14 +71,43 @@ func CreatePreAuthKey( } now := time.Now().UTC() - // TODO(kradalby): unify the key generations spread all over the code. - kstr, err := generateKey() + + prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength) + if err != nil { + return nil, err + } + + // Validate generated prefix (should always be valid, but be defensive) + if len(prefix) != authKeyPrefixLength { + return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix)) + } + + if !isValidBase64URLSafe(prefix) { + return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse) + } + + toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength) + if err != nil { + return nil, err + } + + // Validate generated hash (should always be valid, but be defensive) + if len(toBeHashed) != authKeyLength { + return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed)) + } + + if !isValidBase64URLSafe(toBeHashed) { + return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse) + } + + keyStr := authKeyPrefix + prefix + "-" + toBeHashed + + hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) if err != nil { return nil, err } key := types.PreAuthKey{ - Key: kstr, UserID: user.ID, User: *user, Reusable: reusable, @@ -80,13 +115,24 @@ func CreatePreAuthKey( CreatedAt: &now, Expiration: expiration, Tags: aclTags, + Prefix: prefix, // Store prefix + Hash: hash, // Store hash } if err := tx.Save(&key).Error; err != nil { return nil, fmt.Errorf("failed to create key in the database: %w", err) } - return &key, nil + return &types.PreAuthKeyNew{ + ID: key.ID, + Key: keyStr, + Reusable: key.Reusable, + Ephemeral: key.Ephemeral, + Tags: key.Tags, + Expiration: key.Expiration, + CreatedAt: key.CreatedAt, + User: key.User, + }, nil } func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { @@ -110,6 +156,107 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e return keys, nil } +var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey") + +func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { + var pak types.PreAuthKey + + // Validate input is not empty + if keyStr == "" { + return nil, ErrPreAuthKeyFailedToParse + } + + _, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix) + + if !found { + // Legacy format (plaintext) - backwards compatibility + err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error + if err != nil { + return nil, ErrPreAuthKeyNotFound + } + + return &pak, nil + } + + // New format: hskey-auth-{12-char-prefix}-{64-char-hash} + // Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77 + const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength + if len(prefixAndHash) < expectedMinLength { + return nil, fmt.Errorf( + "%w: key too short, expected at least %d chars after prefix, got %d", + ErrPreAuthKeyFailedToParse, + expectedMinLength, + len(prefixAndHash), + ) + } + + // Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe + prefix := prefixAndHash[:authKeyPrefixLength] + + // Validate separator at expected position + if prefixAndHash[authKeyPrefixLength] != '-' { + return nil, fmt.Errorf( + "%w: expected separator '-' at position %d, got '%c'", + ErrPreAuthKeyFailedToParse, + authKeyPrefixLength, + prefixAndHash[authKeyPrefixLength], + ) + } + + hash := prefixAndHash[authKeyPrefixLength+1:] + + // Validate hash length + if len(hash) != authKeyLength { + return nil, fmt.Errorf( + "%w: hash length mismatch, expected %d chars, got %d", + ErrPreAuthKeyFailedToParse, + authKeyLength, + len(hash), + ) + } + + // Validate prefix contains only base64 URL-safe characters + if !isValidBase64URLSafe(prefix) { + return nil, fmt.Errorf( + "%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrPreAuthKeyFailedToParse, + ) + } + + // Validate hash contains only base64 URL-safe characters + if !isValidBase64URLSafe(hash) { + return nil, fmt.Errorf( + "%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrPreAuthKeyFailedToParse, + ) + } + + // Look up key by prefix + err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error + if err != nil { + return nil, ErrPreAuthKeyNotFound + } + + // Verify hash matches + err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash)) + if err != nil { + return nil, fmt.Errorf("invalid auth key: %w", err) + } + + return &pak, nil +} + +// isValidBase64URLSafe checks if a string contains only base64 URL-safe characters. +func isValidBase64URLSafe(s string) bool { + for _, c := range s { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' { + return false + } + } + + return true +} + func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { return GetPreAuthKey(hsdb.DB, key) } @@ -117,12 +264,7 @@ func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { // GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible // for checking if the key is usable (expired or used). func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { - pak := types.PreAuthKey{} - if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil { - return nil, ErrPreAuthKeyNotFound - } - - return &pak, nil + return findAuthKey(tx, key) } // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey @@ -159,13 +301,3 @@ func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { now := time.Now() return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error } - -func generateKey() (string, error) { - size := 24 - bytes := make([]byte, size) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - return hex.EncodeToString(bytes), nil -} diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 605e7442..9ad8ae42 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -1,74 +1,135 @@ package db import ( + "fmt" "slices" + "strings" "testing" + "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/check.v1" "tailscale.com/types/ptr" ) -func (*Suite) TestCreatePreAuthKey(c *check.C) { - // ID does not exist - _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) - c.Assert(err, check.NotNil) +func TestCreatePreAuthKey(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "error_invalid_user_id", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) + assert.Error(t, err) + }, + }, + { + name: "success_create_and_list", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) - // Did we get a valid key? - c.Assert(key.Key, check.NotNil) - c.Assert(len(key.Key), check.Equals, 48) + key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, key.Key) - // Make sure the User association is populated - c.Assert(key.User.ID, check.Equals, user.ID) + // List keys for the user + keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) + require.NoError(t, err) + assert.Len(t, keys, 1) - // ID does not exist - _, err = db.ListPreAuthKeys(1000000) - c.Assert(err, check.NotNil) + // Verify User association is populated + assert.Equal(t, user.ID, keys[0].User.ID) + }, + }, + { + name: "error_list_invalid_user_id", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) - c.Assert(err, check.IsNil) - c.Assert(len(keys), check.Equals, 1) + _, err := db.ListPreAuthKeys(1000000) + assert.Error(t, err) + }, + }, + } - // Make sure the User association is populated - c.Assert((keys)[0].User.ID, check.Equals, user.ID) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) + } } -func (*Suite) TestPreAuthKeyACLTags(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test8"}) - c.Assert(err, check.IsNil) +func TestPreAuthKeyACLTags(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "reject_malformed_tags", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) - c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected + user, err := db.CreateUser(types.User{Name: "test-tags-1"}) + require.NoError(t, err) - tags := []string{"tag:test1", "tag:test2"} - tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) - c.Assert(err, check.IsNil) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) + assert.Error(t, err) + }, + }, + { + name: "deduplicate_and_sort_tags", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) - c.Assert(err, check.IsNil) - gotTags := listedPaks[0].Proto().GetAclTags() - slices.Sort(gotTags) - c.Assert(gotTags, check.DeepEquals, tags) + user, err := db.CreateUser(types.User{Name: "test-tags-2"}) + require.NoError(t, err) + + expectedTags := []string{"tag:test1", "tag:test2"} + tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} + + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) + require.NoError(t, err) + + listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) + require.NoError(t, err) + require.Len(t, listedPaks, 1) + + gotTags := listedPaks[0].Proto().GetAclTags() + slices.Sort(gotTags) + assert.Equal(t, expectedTags, gotTags) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) + } } func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) user, err := db.CreateUser(types.User{Name: "test8"}) - assert.NoError(t, err) + require.NoError(t, err) key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"}) - assert.NoError(t, err) + require.NoError(t, err) node := types.Node{ ID: 0, @@ -79,6 +140,317 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { } db.DB.Save(&node) - err = db.DB.Delete(key).Error + err = db.DB.Delete(&types.PreAuthKey{ID: key.ID}).Error require.ErrorContains(t, err, "constraint failed: FOREIGN KEY constraint failed") } + +func TestPreAuthKeyAuthentication(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user := db.CreateUserForTest("test-user") + + tests := []struct { + name string + setupKey func() string // Returns key string to test + wantFindErr bool // Error when finding the key + wantValidateErr bool // Error when validating the key + validateResult func(*testing.T, *types.PreAuthKey) + }{ + { + name: "legacy_key_plaintext", + setupKey: func() string { + // Insert legacy key directly using GORM (simulate existing production key) + // Note: We use raw SQL to bypass GORM's handling and set prefix to empty string + // which simulates how legacy keys exist in production databases + legacyKey := "abc123def456ghi789jkl012mno345pqr678stu901vwx234yz" + now := time.Now() + + // Use raw SQL to insert with empty prefix to avoid UNIQUE constraint + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at) + VALUES (?, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, true, false, false, now).Error + require.NoError(t, err) + + return legacyKey + }, + wantFindErr: false, + wantValidateErr: false, + validateResult: func(t *testing.T, pak *types.PreAuthKey) { + t.Helper() + + assert.Equal(t, user.ID, pak.UserID) + assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated + assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix + assert.Nil(t, pak.Hash) // Legacy keys have nil Hash + }, + }, + { + name: "new_key_bcrypt", + setupKey: func() string { + // Create new key via API + keyStr, err := db.CreatePreAuthKey( + types.UserID(user.ID), + true, false, nil, []string{"tag:test"}, + ) + require.NoError(t, err) + + return keyStr.Key + }, + wantFindErr: false, + wantValidateErr: false, + validateResult: func(t *testing.T, pak *types.PreAuthKey) { + t.Helper() + + assert.Equal(t, user.ID, pak.UserID) + assert.Empty(t, pak.Key) // New keys have empty Key + assert.NotEmpty(t, pak.Prefix) // New keys have Prefix + assert.NotNil(t, pak.Hash) // New keys have Hash + assert.Len(t, pak.Prefix, 12) // Prefix is 12 chars + }, + }, + { + name: "new_key_format_validation", + setupKey: func() string { + keyStr, err := db.CreatePreAuthKey( + types.UserID(user.ID), + true, false, nil, nil, + ) + require.NoError(t, err) + + // Verify format: hskey-auth-{12-char-prefix}-{64-char-hash} + // Use fixed-length parsing since prefix/hash can contain dashes (base64 URL-safe) + assert.True(t, strings.HasPrefix(keyStr.Key, "hskey-auth-")) + + // Extract prefix and hash using fixed-length parsing like the real code does + _, prefixAndHash, found := strings.Cut(keyStr.Key, "hskey-auth-") + assert.True(t, found) + assert.GreaterOrEqual(t, len(prefixAndHash), 12+1+64) // prefix + '-' + hash minimum + + prefix := prefixAndHash[:12] + assert.Len(t, prefix, 12) // Prefix is 12 chars + assert.Equal(t, byte('-'), prefixAndHash[12]) // Separator + hash := prefixAndHash[13:] + assert.Len(t, hash, 64) // Hash is 64 chars + + return keyStr.Key + }, + wantFindErr: false, + wantValidateErr: false, + }, + { + name: "invalid_bcrypt_hash", + setupKey: func() string { + // Create valid key + key, err := db.CreatePreAuthKey( + types.UserID(user.ID), + true, false, nil, nil, + ) + require.NoError(t, err) + + keyStr := key.Key + + // Return key with tampered hash using fixed-length parsing + _, prefixAndHash, _ := strings.Cut(keyStr, "hskey-auth-") + prefix := prefixAndHash[:12] + + return "hskey-auth-" + prefix + "-" + "wrong_hash_here_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "empty_key", + setupKey: func() string { + return "" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "key_too_short", + setupKey: func() string { + return "hskey-auth-short" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "missing_separator", + setupKey: func() string { + return "hskey-auth-ABCDEFGHIJKLabcdefghijklmnopqrstuvwxyz1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "hash_too_short", + setupKey: func() string { + return "hskey-auth-ABCDEFGHIJKL-short" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "prefix_with_invalid_chars", + setupKey: func() string { + return "hskey-auth-ABC$EF@HIJKL-" + strings.Repeat("a", 64) + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "hash_with_invalid_chars", + setupKey: func() string { + return "hskey-auth-ABCDEFGHIJKL-" + "invalid$chars" + strings.Repeat("a", 54) + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "prefix_not_found_in_db", + setupKey: func() string { + // Create a validly formatted key but with a prefix that doesn't exist + return "hskey-auth-NotInDB12345-" + strings.Repeat("a", 64) + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "expired_legacy_key", + setupKey: func() string { + legacyKey := "expired_legacy_key_123456789012345678901234" + now := time.Now() + expiration := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago + + // Use raw SQL to avoid UNIQUE constraint on empty prefix + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at, expiration) + VALUES (?, ?, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, true, false, false, now, expiration).Error + require.NoError(t, err) + + return legacyKey + }, + wantFindErr: false, + wantValidateErr: true, + }, + { + name: "used_single_use_legacy_key", + setupKey: func() string { + legacyKey := "used_legacy_key_123456789012345678901234567" + now := time.Now() + + // Use raw SQL to avoid UNIQUE constraint on empty prefix + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at) + VALUES (?, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, false, false, true, now).Error + require.NoError(t, err) + + return legacyKey + }, + wantFindErr: false, + wantValidateErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyStr := tt.setupKey() + + pak, err := db.GetPreAuthKey(keyStr) + + if tt.wantFindErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, pak) + + // Check validation if needed + if tt.wantValidateErr { + err := pak.Validate() + assert.Error(t, err) + + return + } + + if tt.validateResult != nil { + tt.validateResult(t, pak) + } + }) + } +} + +func TestMultipleLegacyKeysAllowed(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user, err := db.CreateUser(types.User{Name: "test-legacy"}) + require.NoError(t, err) + + // Create multiple legacy keys by directly inserting with empty prefix + // This simulates the migration scenario where existing databases have multiple + // plaintext keys without prefix/hash fields + now := time.Now() + + for i := range 5 { + legacyKey := fmt.Sprintf("legacy_key_%d_%s", i, strings.Repeat("x", 40)) + + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at) + VALUES (?, '', NULL, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, true, false, false, now).Error + require.NoError(t, err, "should allow multiple legacy keys with empty prefix") + } + + // Verify all legacy keys can be retrieved + var legacyKeys []types.PreAuthKey + + err = db.DB.Where("prefix = '' OR prefix IS NULL").Find(&legacyKeys).Error + require.NoError(t, err) + assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys") + + // Now create new bcrypt-based keys - these should have unique prefixes + key1, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, key1.Key) + + key2, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, key2.Key) + + // Verify the new keys have different prefixes + pak1, err := db.GetPreAuthKey(key1.Key) + require.NoError(t, err) + assert.NotEmpty(t, pak1.Prefix) + + pak2, err := db.GetPreAuthKey(key2.Key) + require.NoError(t, err) + assert.NotEmpty(t, pak2.Prefix) + + assert.NotEqual(t, pak1.Prefix, pak2.Prefix, "new keys should have unique prefixes") + + // Verify we cannot manually insert duplicate non-empty prefixes + duplicatePrefix := "test_prefix1" + hash1 := []byte("hash1") + hash2 := []byte("hash2") + + // First insert should succeed + err = db.DB.Exec(` + INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at) + VALUES ('', ?, ?, ?, ?, ?, ?, ?) + `, duplicatePrefix, hash1, user.ID, true, false, false, now).Error + require.NoError(t, err, "first key with prefix should succeed") + + // Second insert with same prefix should fail + err = db.DB.Exec(` + INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at) + VALUES ('', ?, ?, ?, ?, ?, ?, ?) + `, duplicatePrefix, hash2, user.ID, true, false, false, now).Error + require.Error(t, err, "duplicate non-empty prefix should be rejected") + assert.Contains(t, err.Error(), "UNIQUE constraint failed", "should fail with UNIQUE constraint error") +} diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index 175e2aff..075c9d4d 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -48,6 +48,8 @@ CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users( CREATE TABLE pre_auth_keys( id integer PRIMARY KEY AUTOINCREMENT, key text, + prefix text, + hash blob, user_id integer, reusable numeric, ephemeral numeric DEFAULT false, @@ -59,6 +61,7 @@ CREATE TABLE pre_auth_keys( CONSTRAINT fk_pre_auth_keys_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL ); +CREATE UNIQUE INDEX idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''; CREATE TABLE api_keys( id integer PRIMARY KEY AUTOINCREMENT, diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 0589ff81..e28d4076 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog" "gopkg.in/check.v1" "zombiezen.com/go/postgrestest" ) @@ -56,6 +57,7 @@ func newSQLiteTestDB() (*HSDatabase, error) { } log.Printf("database path: %s", tmpDir+"/headscale_test.db") + zerolog.SetGlobalLevel(zerolog.Disabled) db, err = NewHeadscaleDatabase( types.DatabaseConfig{ diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 5b2f0c4b..53a10e80 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -1,138 +1,276 @@ package db import ( - "strings" + "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "gopkg.in/check.v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/types/ptr" ) -func (s *Suite) TestCreateAndDestroyUser(c *check.C) { +func TestCreateAndDestroyUser(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + user := db.CreateUserForTest("test") - c.Assert(user.Name, check.Equals, "test") + assert.Equal(t, "test", user.Name) users, err := db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) + require.NoError(t, err) + assert.Len(t, users, 1) err = db.DestroyUser(types.UserID(user.ID)) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = db.GetUserByID(types.UserID(user.ID)) - c.Assert(err, check.NotNil) + assert.Error(t, err) } -func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := db.DestroyUser(9998) - c.Assert(err, check.Equals, ErrUserNotFound) +func TestDestroyUserErrors(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "error_user_not_found", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - user := db.CreateUserForTest("test") + err := db.DestroyUser(9998) + assert.ErrorIs(t, err, ErrUserNotFound) + }, + }, + { + name: "success_deletes_preauthkeys", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") - err = db.DestroyUser(types.UserID(user.ID)) - c.Assert(err, check.IsNil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + require.NoError(t, err) - result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) - // destroying a user also deletes all associated preauthkeys - c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) + err = db.DestroyUser(types.UserID(user.ID)) + require.NoError(t, err) - user, err = db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) + // Verify preauth key was deleted (need to search by prefix for new keys) + var foundPak types.PreAuthKey - pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) + result := db.DB.First(&foundPak, "id = ?", pak.ID) + assert.ErrorIs(t, result.Error, gorm.ErrRecordNotFound) + }, + }, + { + name: "error_user_has_nodes", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - node := types.Node{ - ID: 0, - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + require.NoError(t, err) + + node := types.Node{ + ID: 0, + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + trx := db.DB.Save(&node) + require.NoError(t, trx.Error) + + err = db.DestroyUser(types.UserID(user.ID)) + assert.ErrorIs(t, err, ErrUserStillHasNodes) + }, + }, } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - err = db.DestroyUser(types.UserID(user.ID)) - c.Assert(err, check.Equals, ErrUserStillHasNodes) -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) -func (s *Suite) TestRenameUser(c *check.C) { - userTest := db.CreateUserForTest("test") - c.Assert(userTest.Name, check.Equals, "test") - - users, err := db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") - c.Assert(err, check.IsNil) - - users, err = db.ListUsers(&types.User{Name: "test"}) - c.Assert(err, check.Equals, nil) - c.Assert(len(users), check.Equals, 0) - - users, err = db.ListUsers(&types.User{Name: "test-renamed"}) - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = db.RenameUser(99988, "test") - c.Assert(err, check.Equals, ErrUserNotFound) - - userTest2 := db.CreateUserForTest("test2") - c.Assert(userTest2.Name, check.Equals, "test2") - - want := "UNIQUE constraint failed" - err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed") - if err == nil || !strings.Contains(err.Error(), want) { - c.Fatalf("expected failure with unique constraint, want: %q got: %q", want, err) + tt.test(t, db) + }) } } -func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser := db.CreateUserForTest("old") - newUser := db.CreateUserForTest("new") +func TestRenameUser(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "success_rename", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) + userTest := db.CreateUserForTest("test") + assert.Equal(t, "test", userTest.Name) - node := types.Node{ - ID: 12, - Hostname: "testnode", - UserID: oldUser.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + users, err := db.ListUsers() + require.NoError(t, err) + assert.Len(t, users, 1) + + err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") + require.NoError(t, err) + + users, err = db.ListUsers(&types.User{Name: "test"}) + require.NoError(t, err) + assert.Empty(t, users) + + users, err = db.ListUsers(&types.User{Name: "test-renamed"}) + require.NoError(t, err) + assert.Len(t, users, 1) + }, + }, + { + name: "error_user_not_found", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + err := db.RenameUser(99988, "test") + assert.ErrorIs(t, err, ErrUserNotFound) + }, + }, + { + name: "error_duplicate_name", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + userTest := db.CreateUserForTest("test") + userTest2 := db.CreateUserForTest("test2") + + assert.Equal(t, "test", userTest.Name) + assert.Equal(t, "test2", userTest2.Name) + + err := db.RenameUser(types.UserID(userTest2.ID), "test") + require.Error(t, err) + assert.Contains(t, err.Error(), "UNIQUE constraint failed") + }, + }, } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - c.Assert(node.UserID, check.Equals, oldUser.ID) - err = db.Write(func(tx *gorm.DB) error { - return AssignNodeToUser(tx, 12, types.UserID(newUser.ID)) - }) - c.Assert(err, check.IsNil) - // Reload node from database to see updated values - updatedNode, err := db.GetNodeByID(12) - c.Assert(err, check.IsNil) - c.Assert(updatedNode.UserID, check.Equals, newUser.ID) - c.Assert(updatedNode.User.Name, check.Equals, newUser.Name) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) - err = db.Write(func(tx *gorm.DB) error { - return AssignNodeToUser(tx, 12, 9584849) - }) - c.Assert(err, check.Equals, ErrUserNotFound) - - err = db.Write(func(tx *gorm.DB) error { - return AssignNodeToUser(tx, 12, types.UserID(newUser.ID)) - }) - c.Assert(err, check.IsNil) - // Reload node from database again to see updated values - finalNode, err := db.GetNodeByID(12) - c.Assert(err, check.IsNil) - c.Assert(finalNode.UserID, check.Equals, newUser.ID) - c.Assert(finalNode.User.Name, check.Equals, newUser.Name) + tt.test(t, db) + }) + } +} + +func TestAssignNodeToUser(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "success_reassign_node", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + oldUser := db.CreateUserForTest("old") + newUser := db.CreateUserForTest("new") + + pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) + require.NoError(t, err) + + node := types.Node{ + ID: 12, + Hostname: "testnode", + UserID: oldUser.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + trx := db.DB.Save(&node) + require.NoError(t, trx.Error) + assert.Equal(t, oldUser.ID, node.UserID) + + err = db.Write(func(tx *gorm.DB) error { + return AssignNodeToUser(tx, 12, types.UserID(newUser.ID)) + }) + require.NoError(t, err) + + // Reload node from database to see updated values + updatedNode, err := db.GetNodeByID(12) + require.NoError(t, err) + assert.Equal(t, newUser.ID, updatedNode.UserID) + assert.Equal(t, newUser.Name, updatedNode.User.Name) + }, + }, + { + name: "error_user_not_found", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + oldUser := db.CreateUserForTest("old") + + pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) + require.NoError(t, err) + + node := types.Node{ + ID: 12, + Hostname: "testnode", + UserID: oldUser.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + trx := db.DB.Save(&node) + require.NoError(t, trx.Error) + + err = db.Write(func(tx *gorm.DB) error { + return AssignNodeToUser(tx, 12, 9584849) + }) + assert.ErrorIs(t, err, ErrUserNotFound) + }, + }, + { + name: "success_reassign_to_same_user", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + user := db.CreateUserForTest("user") + + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + require.NoError(t, err) + + node := types.Node{ + ID: 12, + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + trx := db.DB.Save(&node) + require.NoError(t, trx.Error) + + err = db.Write(func(tx *gorm.DB) error { + return AssignNodeToUser(tx, 12, types.UserID(user.ID)) + }) + require.NoError(t, err) + + // Reload node from database again to see updated values + finalNode, err := db.GetNodeByID(12) + require.NoError(t, err) + assert.Equal(t, user.ID, finalNode.UserID) + assert.Equal(t, user.Name, finalNode.User.Name) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) + } } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 6d5189b8..3b8e9d47 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -236,10 +236,18 @@ func (api headscaleV1APIServer) RegisterNode( ctx context.Context, request *v1.RegisterNodeRequest, ) (*v1.RegisterNodeResponse, error) { + // Generate ephemeral registration key for tracking this registration flow in logs + registrationKey, err := util.GenerateRegistrationKey() + if err != nil { + log.Warn().Err(err).Msg("Failed to generate registration key") + registrationKey = "" // Continue without key if generation fails + } + log.Trace(). Caller(). Str("user", request.GetUser()). Str("registration_id", request.GetKey()). + Str("registration_key", registrationKey). Msg("Registering node") registrationId, err := types.RegistrationIDFromString(request.GetKey()) @@ -259,9 +267,19 @@ func (api headscaleV1APIServer) RegisterNode( util.RegisterMethodCLI, ) if err != nil { + log.Error(). + Str("registration_key", registrationKey). + Err(err). + Msg("Failed to register node") return nil, err } + log.Info(). + Str("registration_key", registrationKey). + Str("node_id", fmt.Sprintf("%d", node.ID())). + Str("hostname", node.Hostname()). + Msg("Node registered successfully") + // This is a bit of a back and forth, but we have a bit of a chicken and egg // dependency here. // Because the way the policy manager works, we need to have the node diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 6ef11f54..248784be 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -901,7 +901,14 @@ func (s *State) CreateAPIKey(expiration *time.Time) (string, *types.APIKey, erro } // GetAPIKey retrieves an API key by its prefix. -func (s *State) GetAPIKey(prefix string) (*types.APIKey, error) { +// Accepts both display format (hskey-api-{12chars}-***) and database format ({12chars}). +func (s *State) GetAPIKey(displayPrefix string) (*types.APIKey, error) { + // Parse the display prefix to extract the database prefix + prefix, err := hsdb.ParseAPIKeyPrefix(displayPrefix) + if err != nil { + return nil, err + } + return s.db.GetAPIKey(prefix) } @@ -921,7 +928,7 @@ func (s *State) DestroyAPIKey(key types.APIKey) error { } // CreatePreAuthKey generates a new pre-authentication key for a user. -func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKey, error) { +func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) { return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags) } diff --git a/hscontrol/types/api_key.go b/hscontrol/types/api_key.go index 8ca00044..b6a12b65 100644 --- a/hscontrol/types/api_key.go +++ b/hscontrol/types/api_key.go @@ -7,6 +7,13 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +const ( + // NewAPIKeyPrefixLength is the length of the prefix for new API keys. + NewAPIKeyPrefixLength = 12 + // LegacyAPIKeyPrefixLength is the length of the prefix for legacy API keys. + LegacyAPIKeyPrefixLength = 7 +) + // APIKey describes the datamodel for API keys used to remotely authenticate with // headscale. type APIKey struct { @@ -21,8 +28,16 @@ type APIKey struct { func (key *APIKey) Proto() *v1.ApiKey { protoKey := v1.ApiKey{ - Id: key.ID, - Prefix: key.Prefix, + Id: key.ID, + } + + // Show prefix format: distinguish between new (12-char) and legacy (7-char) keys + if len(key.Prefix) == NewAPIKeyPrefixLength { + // New format key (12-char prefix) + protoKey.Prefix = "hskey-api-" + key.Prefix + "-***" + } else { + // Legacy format key (7-char prefix) or fallback + protoKey.Prefix = key.Prefix + "***" } if key.Expiration != nil { diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 659e0a76..1081f451 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -14,8 +14,15 @@ func (e PAKError) Error() string { return string(e) } // PreAuthKey describes a pre-authorization key usable in a particular user. type PreAuthKey struct { - ID uint64 `gorm:"primary_key"` - Key string + ID uint64 `gorm:"primary_key"` + + // Legacy plaintext key (for backwards compatibility) + Key string + + // New bcrypt-based authentication + Prefix string + Hash []byte // bcrypt + UserID uint User User `gorm:"constraint:OnDelete:SET NULL;"` Reusable bool @@ -32,17 +39,59 @@ type PreAuthKey struct { Expiration *time.Time } +// PreAuthKeyNew is returned once when the key is created. +type PreAuthKeyNew struct { + ID uint64 `gorm:"primary_key"` + Key string + Reusable bool + Ephemeral bool + Tags []string + Expiration *time.Time + CreatedAt *time.Time + User User +} + +func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey { + protoKey := v1.PreAuthKey{ + Id: key.ID, + Key: key.Key, + User: key.User.Proto(), + Reusable: key.Reusable, + Ephemeral: key.Ephemeral, + AclTags: key.Tags, + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) + } + + return &protoKey +} + func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ User: key.User.Proto(), Id: key.ID, - Key: key.Key, Ephemeral: key.Ephemeral, Reusable: key.Reusable, Used: key.Used, AclTags: key.Tags, } + // For new keys (with prefix/hash), show the prefix so users can identify the key + // For legacy keys (with plaintext key), show the full key for backwards compatibility + if key.Prefix != "" { + protoKey.Key = "hskey-auth-" + key.Prefix + "-***" + } else if key.Key != "" { + // Legacy key - show full key for backwards compatibility + // TODO: Consider hiding this in a future major version + protoKey.Key = key.Key + } + if key.Expiration != nil { protoKey.Expiration = timestamppb.New(*key.Expiration) } diff --git a/hscontrol/types/types_clone.go b/hscontrol/types/types_clone.go index 3f530dc9..7699fb8f 100644 --- a/hscontrol/types/types_clone.go +++ b/hscontrol/types/types_clone.go @@ -110,6 +110,7 @@ func (src *PreAuthKey) Clone() *PreAuthKey { } dst := new(PreAuthKey) *dst = *src + dst.Hash = append(src.Hash[:0:0], src.Hash...) dst.Tags = append(src.Tags[:0:0], src.Tags...) if dst.CreatedAt != nil { dst.CreatedAt = ptr.To(*src.CreatedAt) @@ -124,6 +125,8 @@ func (src *PreAuthKey) Clone() *PreAuthKey { var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct { ID uint64 Key string + Prefix string + Hash []byte UserID uint User User Reusable bool diff --git a/hscontrol/types/types_view.go b/hscontrol/types/types_view.go index 5c31eac8..076f5dbb 100644 --- a/hscontrol/types/types_view.go +++ b/hscontrol/types/types_view.go @@ -239,14 +239,16 @@ func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error { return nil } -func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } -func (v PreAuthKeyView) Key() string { return v.ж.Key } -func (v PreAuthKeyView) UserID() uint { return v.ж.UserID } -func (v PreAuthKeyView) User() User { return v.ж.User } -func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } -func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } -func (v PreAuthKeyView) Used() bool { return v.ж.Used } -func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } +func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } +func (v PreAuthKeyView) Key() string { return v.ж.Key } +func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix } +func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) } +func (v PreAuthKeyView) UserID() uint { return v.ж.UserID } +func (v PreAuthKeyView) User() User { return v.ж.User } +func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } +func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } +func (v PreAuthKeyView) Used() bool { return v.ж.Used } +func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] { return views.ValuePointerOf(v.ж.CreatedAt) } @@ -259,6 +261,8 @@ func (v PreAuthKeyView) Expiration() views.ValuePointer[time.Time] { var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct { ID uint64 Key string + Prefix string + Hash []byte UserID uint User User Reusable bool diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index a9dc748e..4d828d02 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -294,3 +294,20 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri return InvalidString() } + +// GenerateRegistrationKey generates a vanity key for tracking web authentication +// registration flows in logs. This key is NOT stored in the database and does NOT use bcrypt - +// it's purely for observability and correlating log entries during the registration process. +func GenerateRegistrationKey() (string, error) { + const ( + registerKeyPrefix = "hskey-reg-" //nolint:gosec // This is a vanity key for logging, not a credential + registerKeyLength = 64 + ) + + randomPart, err := GenerateRandomStringURLSafe(registerKeyLength) + if err != nil { + return "", fmt.Errorf("generating registration key: %w", err) + } + + return registerKeyPrefix + randomPart, nil +} diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 22418e34..33f27b7a 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1288,3 +1288,99 @@ func TestEnsureHostname_Idempotent(t *testing.T) { t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) } } + +func TestGenerateRegistrationKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + test func(*testing.T) + }{ + { + name: "generates_key_with_correct_prefix", + test: func(t *testing.T) { + t.Helper() + + key, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + if !strings.HasPrefix(key, "hskey-reg-") { + t.Errorf("key does not have expected prefix: %s", key) + } + }, + }, + { + name: "generates_key_with_correct_length", + test: func(t *testing.T) { + t.Helper() + + key, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + // Expected format: hskey-reg-{64-char-random} + // Total length: 10 (prefix) + 64 (random) = 74 + if len(key) != 74 { + t.Errorf("key length = %d, want 74", len(key)) + } + }, + }, + { + name: "generates_unique_keys", + test: func(t *testing.T) { + t.Helper() + + key1, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + key2, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + if key1 == key2 { + t.Error("generated keys should be unique") + } + }, + }, + { + name: "key_contains_only_valid_chars", + test: func(t *testing.T) { + t.Helper() + + key, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + // Remove prefix + _, randomPart, found := strings.Cut(key, "hskey-reg-") + if !found { + t.Error("key does not contain expected prefix") + } + + // Verify base64 URL-safe characters (A-Za-z0-9_-) + for _, ch := range randomPart { + if (ch < 'A' || ch > 'Z') && + (ch < 'a' || ch > 'z') && + (ch < '0' || ch > '9') && + ch != '_' && ch != '-' { + t.Errorf("key contains invalid character: %c", ch) + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.test(t) + }) + } +} diff --git a/integration/cli_test.go b/integration/cli_test.go index 37e3c33d..dca37570 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -337,9 +337,10 @@ func TestPreAuthKeyCommand(t *testing.T) { }, ) - assert.NotEmpty(t, listedPreAuthKeys[1].GetKey()) - assert.NotEmpty(t, listedPreAuthKeys[2].GetKey()) - assert.NotEmpty(t, listedPreAuthKeys[3].GetKey()) + // New keys show prefix after listing, so check the created keys instead + assert.NotEmpty(t, keys[0].GetKey()) + assert.NotEmpty(t, keys[1].GetKey()) + assert.NotEmpty(t, keys[2].GetKey()) assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeys[2].GetExpiration().AsTime().After(time.Now())) @@ -370,7 +371,7 @@ func TestPreAuthKeyCommand(t *testing.T) { ) } - // Test key expiry + // Test key expiry - use the full key from creation, not the masked one from listing _, err = headscale.Execute( []string{ "headscale", @@ -378,7 +379,7 @@ func TestPreAuthKeyCommand(t *testing.T) { "--user", "1", "expire", - listedPreAuthKeys[1].GetKey(), + keys[0].GetKey(), }, ) require.NoError(t, err)