types: make pre auth key use bcrypt (#2853)

This commit is contained in:
Kristoffer Dalby
2025-11-12 09:36:36 -06:00
committed by GitHub
parent e3ced80278
commit da9018a0eb
21 changed files with 1450 additions and 225 deletions

View File

@@ -8,13 +8,25 @@ The OIDC callback and device registration web pages have been updated to use the
Material for MkDocs design system from the official documentation. The templates Material for MkDocs design system from the official documentation. The templates
now use consistent typography, spacing, and colours across all registration now use consistent typography, spacing, and colours across all registration
flows. External links are properly secured with noreferrer/noopener attributes. flows. External links are properly secured with noreferrer/noopener attributes.
### Pre-authentication key security improvements
Pre-authentication keys now use bcrypt hashing for improved security
[#2853](https://github.com/juanfont/headscale/pull/2853). Keys are stored as a
prefix and bcrypt hash instead of plaintext. The full key is only displayed once
at creation time. When listing keys, only the prefix is shown (e.g.,
`hskey-auth-{prefix}-***`). All new keys use the format
`hskey-auth-{prefix}-{secret}`. Legacy plaintext keys continue to work for
backwards compatibility.
### Changes ### Changes
- Add NixOS module in repository for faster iteration [#2857](https://github.com/juanfont/headscale/pull/2857) - Add NixOS module in repository for faster iteration [#2857](https://github.com/juanfont/headscale/pull/2857)
- Add favicon to webpages [#2858](https://github.com/juanfont/headscale/pull/2858) - Add favicon to webpages [#2858](https://github.com/juanfont/headscale/pull/2858)
- Reclaim IPs from the IP allocator when nodes are deleted [#2831](https://github.com/juanfont/headscale/pull/2831)
- Redesign OIDC callback and registration web templates [#2832](https://github.com/juanfont/headscale/pull/2832) - Redesign OIDC callback and registration web templates [#2832](https://github.com/juanfont/headscale/pull/2832)
- Reclaim IPs from the IP allocator when nodes are deleted [#2831](https://github.com/juanfont/headscale/pull/2831)
- Add bcrypt hashing for pre-authentication keys [#2853](https://github.com/juanfont/headscale/pull/2853)
- Add structured prefix format for API keys (`hskey-api-{prefix}-{secret}`) [#2853](https://github.com/juanfont/headscale/pull/2853)
- Add registration keys for web authentication tracking (`hskey-reg-{random}`) [#2853](https://github.com/juanfont/headscale/pull/2853)
## 0.27.1 (2025-11-11) ## 0.27.1 (2025-11-11)

View File

@@ -88,7 +88,7 @@ var listPreAuthKeys = &cobra.Command{
tableData := pterm.TableData{ tableData := pterm.TableData{
{ {
"ID", "ID",
"Key", "Key/Prefix",
"Reusable", "Reusable",
"Ephemeral", "Ephemeral",
"Used", "Used",

View File

@@ -3026,7 +3026,11 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// Create user and single-use pre-auth key // Create user and single-use pre-auth key
user := app.state.CreateUserForTest("test-user") user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable field
pak, err := app.state.GetPreAuthKey(pakNew.Key)
require.NoError(t, err) require.NoError(t, err)
require.False(t, pak.Reusable, "key should be single-use for this test") require.False(t, pak.Reusable, "key should be single-use for this test")
@@ -3036,7 +3040,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// STEP 1: Initial registration with pre-auth key (simulates fresh node joining) // STEP 1: Initial registration with pre-auth key (simulates fresh node joining)
initialReq := tailcfg.RegisterRequest{ initialReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, AuthKey: pakNew.Key,
}, },
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
@@ -3060,7 +3064,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
assert.Equal(t, machineKey.Public(), node.MachineKey()) assert.Equal(t, machineKey.Public(), node.MachineKey())
// Verify pre-auth key is now marked as used // Verify pre-auth key is now marked as used
usedPak, err := app.state.GetPreAuthKey(pak.Key) usedPak, err := app.state.GetPreAuthKey(pakNew.Key)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, usedPak.Used, "pre-auth key should be marked as used after initial registration") assert.True(t, usedPak.Used, "pre-auth key should be marked as used after initial registration")
@@ -3073,7 +3077,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key")
restartReq := tailcfg.RegisterRequest{ restartReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same key, now marked as Used=true AuthKey: pakNew.Key, // Same key, now marked as Used=true
}, },
NodeKey: nodeKey.Public(), // Same node key NodeKey: nodeKey.Public(), // Same node key
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
@@ -3113,7 +3117,11 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) {
app := createTestApp(t) app := createTestApp(t)
user := app.state.CreateUserForTest("test-user") user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable field
pak, err := app.state.GetPreAuthKey(pakNew.Key)
require.NoError(t, err) require.NoError(t, err)
require.True(t, pak.Reusable) require.True(t, pak.Reusable)
@@ -3123,7 +3131,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) {
// Initial registration // Initial registration
initialReq := tailcfg.RegisterRequest{ initialReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, AuthKey: pakNew.Key,
}, },
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
@@ -3140,7 +3148,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) {
// Node restart - re-registration with reusable key // Node restart - re-registration with reusable key
restartReq := tailcfg.RegisterRequest{ restartReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Reusable key AuthKey: pakNew.Key, // Reusable key
}, },
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
@@ -3209,7 +3217,11 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
// Create a SINGLE-USE pre-auth key (reusable=false) // Create a SINGLE-USE pre-auth key (reusable=false)
// This is the type of key that triggers the bug in issue #2830 // This is the type of key that triggers the bug in issue #2830
preAuthKey, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) preAuthKeyNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable and Used fields
preAuthKey, err := app.state.GetPreAuthKey(preAuthKeyNew.Key)
require.NoError(t, err) require.NoError(t, err)
require.False(t, preAuthKey.Reusable, "Pre-auth key must be single-use to test issue #2830") require.False(t, preAuthKey.Reusable, "Pre-auth key must be single-use to test issue #2830")
require.False(t, preAuthKey.Used, "Pre-auth key should not be used yet") require.False(t, preAuthKey.Used, "Pre-auth key should not be used yet")
@@ -3222,7 +3234,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
// This simulates the first time the container starts and runs 'tailscale up --authkey=...' // This simulates the first time the container starts and runs 'tailscale up --authkey=...'
initialReq := tailcfg.RegisterRequest{ initialReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: preAuthKey.Key, AuthKey: preAuthKeyNew.Key, // Use the full key from creation
}, },
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
@@ -3238,7 +3250,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
require.Equal(t, "testuser", initialResp.User.DisplayName, "User should match the pre-auth key's user") require.Equal(t, "testuser", initialResp.User.DisplayName, "User should match the pre-auth key's user")
// Verify the pre-auth key is now marked as Used // Verify the pre-auth key is now marked as Used
updatedKey, err := app.state.GetPreAuthKey(preAuthKey.Key) updatedKey, err := app.state.GetPreAuthKey(preAuthKeyNew.Key)
require.NoError(t, err) require.NoError(t, err)
require.True(t, updatedKey.Used, "Pre-auth key should be marked as Used after initial registration") require.True(t, updatedKey.Used, "Pre-auth key should be marked as Used after initial registration")
@@ -3253,7 +3265,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
// This is exactly what happens when a container restarts // This is exactly what happens when a container restarts
reregisterReq := tailcfg.RegisterRequest{ reregisterReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: preAuthKey.Key, // Same key, now marked as Used=true AuthKey: preAuthKeyNew.Key, // Same key, now marked as Used=true
}, },
NodeKey: nodeKey.Public(), // Same NodeKey NodeKey: nodeKey.Public(), // Same NodeKey
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
@@ -3280,7 +3292,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
attackReq := tailcfg.RegisterRequest{ attackReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{ Auth: &tailcfg.RegisterResponseAuth{
AuthKey: preAuthKey.Key, // Try to use the same key AuthKey: preAuthKeyNew.Key, // Try to use the same key
}, },
NodeKey: differentNodeKey.Public(), NodeKey: differentNodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{

View File

@@ -9,33 +9,64 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
const ( const (
apiPrefixLength = 7 apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential
apiKeyLength = 32 apiKeyPrefixLength = 12
apiKeyHashLength = 64
// Legacy format constants.
legacyAPIPrefixLength = 7
legacyAPIKeyLength = 32
) )
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") var (
ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
ErrAPIKeyGenerationFailed = errors.New("failed to generate API key")
ErrAPIKeyInvalidGeneration = errors.New("generated API key failed validation")
)
// CreateAPIKey creates a new ApiKey in a user, and returns it. // CreateAPIKey creates a new ApiKey in a user, and returns it.
func (hsdb *HSDatabase) CreateAPIKey( func (hsdb *HSDatabase) CreateAPIKey(
expiration *time.Time, expiration *time.Time,
) (string, *types.APIKey, error) { ) (string, *types.APIKey, error) {
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) // Generate public prefix (12 chars)
prefix, err := util.GenerateRandomStringURLSafe(apiKeyPrefixLength)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength) // Validate prefix
if len(prefix) != apiKeyPrefixLength {
return "", nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyPrefixLength, len(prefix))
}
if !isValidBase64URLSafe(prefix) {
return "", nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrAPIKeyInvalidGeneration)
}
// Generate secret (64 chars)
secret, err := util.GenerateRandomStringURLSafe(apiKeyHashLength)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
// Key to return to user, this will only be visible _once_ // Validate secret
keyStr := prefix + "." + toBeHashed if len(secret) != apiKeyHashLength {
return "", nil, fmt.Errorf("%w: generated secret has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyHashLength, len(secret))
}
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) if !isValidBase64URLSafe(secret) {
return "", nil, fmt.Errorf("%w: generated secret contains invalid characters", ErrAPIKeyInvalidGeneration)
}
// Full key string (shown ONCE to user)
keyStr := apiKeyPrefix + prefix + "-" + secret
// bcrypt hash of secret
hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@@ -103,23 +134,164 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
} }
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
prefix, hash, found := strings.Cut(keyStr, ".") key, err := validateAPIKey(hsdb.DB, keyStr)
if !found {
return false, ErrAPIKeyFailedToParse
}
key, err := hsdb.GetAPIKey(prefix)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to validate api key: %w", err)
}
if key.Expiration.Before(time.Now()) {
return false, nil
}
if err := bcrypt.CompareHashAndPassword(key.Hash, []byte(hash)); err != nil {
return false, err return false, err
} }
if key.Expiration != nil && key.Expiration.Before(time.Now()) {
return false, nil
}
return true, nil return true, nil
} }
// ParseAPIKeyPrefix extracts the database prefix from a display prefix.
// Handles formats: "hskey-api-{12chars}-***", "hskey-api-{12chars}", or just "{12chars}".
// Returns the 12-character prefix suitable for database lookup.
func ParseAPIKeyPrefix(displayPrefix string) (string, error) {
// If it's already just the 12-character prefix, return it
if len(displayPrefix) == apiKeyPrefixLength && isValidBase64URLSafe(displayPrefix) {
return displayPrefix, nil
}
// If it starts with the API key prefix, parse it
if strings.HasPrefix(displayPrefix, apiKeyPrefix) {
// Remove the "hskey-api-" prefix
_, remainder, found := strings.Cut(displayPrefix, apiKeyPrefix)
if !found {
return "", fmt.Errorf("%w: invalid display prefix format", ErrAPIKeyFailedToParse)
}
// Extract just the first 12 characters (the actual prefix)
if len(remainder) < apiKeyPrefixLength {
return "", fmt.Errorf("%w: prefix too short", ErrAPIKeyFailedToParse)
}
prefix := remainder[:apiKeyPrefixLength]
// Validate it's base64 URL-safe
if !isValidBase64URLSafe(prefix) {
return "", fmt.Errorf("%w: prefix contains invalid characters", ErrAPIKeyFailedToParse)
}
return prefix, nil
}
// For legacy 7-character prefixes or other formats, return as-is
return displayPrefix, nil
}
// validateAPIKey validates an API key and returns the key if valid.
// Handles both new (hskey-api-{prefix}-{secret}) and legacy (prefix.secret) formats.
func validateAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) {
// Validate input is not empty
if keyStr == "" {
return nil, ErrAPIKeyFailedToParse
}
// Check for new format: hskey-api-{prefix}-{secret}
_, prefixAndSecret, found := strings.Cut(keyStr, apiKeyPrefix)
if !found {
// Legacy format: prefix.secret
return validateLegacyAPIKey(db, keyStr)
}
// New format: parse and verify
const expectedMinLength = apiKeyPrefixLength + 1 + apiKeyHashLength
if len(prefixAndSecret) < expectedMinLength {
return nil, fmt.Errorf(
"%w: key too short, expected at least %d chars after prefix, got %d",
ErrAPIKeyFailedToParse,
expectedMinLength,
len(prefixAndSecret),
)
}
// Use fixed-length parsing
prefix := prefixAndSecret[:apiKeyPrefixLength]
// Validate separator at expected position
if prefixAndSecret[apiKeyPrefixLength] != '-' {
return nil, fmt.Errorf(
"%w: expected separator '-' at position %d, got '%c'",
ErrAPIKeyFailedToParse,
apiKeyPrefixLength,
prefixAndSecret[apiKeyPrefixLength],
)
}
secret := prefixAndSecret[apiKeyPrefixLength+1:]
// Validate secret length
if len(secret) != apiKeyHashLength {
return nil, fmt.Errorf(
"%w: secret length mismatch, expected %d chars, got %d",
ErrAPIKeyFailedToParse,
apiKeyHashLength,
len(secret),
)
}
// Validate prefix contains only base64 URL-safe characters
if !isValidBase64URLSafe(prefix) {
return nil, fmt.Errorf(
"%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrAPIKeyFailedToParse,
)
}
// Validate secret contains only base64 URL-safe characters
if !isValidBase64URLSafe(secret) {
return nil, fmt.Errorf(
"%w: secret contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrAPIKeyFailedToParse,
)
}
// Look up by prefix (indexed)
var key types.APIKey
err := db.First(&key, "prefix = ?", prefix).Error
if err != nil {
return nil, fmt.Errorf("API key not found: %w", err)
}
// Verify bcrypt hash
err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret))
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return &key, nil
}
// validateLegacyAPIKey validates a legacy format API key (prefix.secret).
func validateLegacyAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) {
// Legacy format uses "." as separator
prefix, secret, found := strings.Cut(keyStr, ".")
if !found {
return nil, ErrAPIKeyFailedToParse
}
// Legacy prefix is 7 chars
if len(prefix) != legacyAPIPrefixLength {
return nil, fmt.Errorf("%w: legacy prefix length mismatch", ErrAPIKeyFailedToParse)
}
var key types.APIKey
err := db.First(&key, "prefix = ?", prefix).Error
if err != nil {
return nil, fmt.Errorf("API key not found: %w", err)
}
// Verify bcrypt (key.Hash stores bcrypt of full secret)
err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret))
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return &key, nil
}

View File

@@ -1,8 +1,14 @@
package db package db
import ( import (
"strings"
"testing"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@@ -87,3 +93,142 @@ func (*Suite) TestExpireAPIKey(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(notValid, check.Equals, false) c.Assert(notValid, check.Equals, false)
} }
func TestAPIKeyWithPrefix(t *testing.T) {
tests := []struct {
name string
test func(*testing.T, *HSDatabase)
}{
{
name: "new_key_with_prefix",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
keyStr, apiKey, err := db.CreateAPIKey(nil)
require.NoError(t, err)
// Verify format: hskey-api-{12-char-prefix}-{64-char-secret}
assert.True(t, strings.HasPrefix(keyStr, "hskey-api-"))
_, prefixAndSecret, found := strings.Cut(keyStr, "hskey-api-")
assert.True(t, found)
assert.GreaterOrEqual(t, len(prefixAndSecret), 12+1+64)
prefix := prefixAndSecret[:12]
assert.Len(t, prefix, 12)
assert.Equal(t, byte('-'), prefixAndSecret[12])
secret := prefixAndSecret[13:]
assert.Len(t, secret, 64)
// Verify stored fields
assert.Len(t, apiKey.Prefix, types.NewAPIKeyPrefixLength)
assert.NotNil(t, apiKey.Hash)
},
},
{
name: "new_key_can_be_retrieved",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
keyStr, createdKey, err := db.CreateAPIKey(nil)
require.NoError(t, err)
// Validate the created key
valid, err := db.ValidateAPIKey(keyStr)
require.NoError(t, err)
assert.True(t, valid)
// Verify prefix is correct length
assert.Len(t, createdKey.Prefix, types.NewAPIKeyPrefixLength)
},
},
{
name: "invalid_key_format_rejected",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
invalidKeys := []string{
"",
"hskey-api-short",
"hskey-api-ABCDEFGHIJKL-tooshort",
"hskey-api-ABC$EFGHIJKL-" + strings.Repeat("a", 64),
"hskey-api-ABCDEFGHIJKL" + strings.Repeat("a", 64), // missing separator
}
for _, invalidKey := range invalidKeys {
valid, err := db.ValidateAPIKey(invalidKey)
require.Error(t, err, "key should be rejected: %s", invalidKey)
assert.False(t, valid)
}
},
},
{
name: "legacy_key_still_works",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
// Insert legacy API key directly (7-char prefix + 32-char secret)
legacyPrefix := "abcdefg"
legacySecret := strings.Repeat("x", 32)
legacyKey := legacyPrefix + "." + legacySecret
hash, err := bcrypt.GenerateFromPassword([]byte(legacySecret), bcrypt.DefaultCost)
require.NoError(t, err)
now := time.Now()
err = db.DB.Exec(`
INSERT INTO api_keys (prefix, hash, created_at)
VALUES (?, ?, ?)
`, legacyPrefix, hash, now).Error
require.NoError(t, err)
// Validate legacy key
valid, err := db.ValidateAPIKey(legacyKey)
require.NoError(t, err)
assert.True(t, valid)
},
},
{
name: "wrong_secret_rejected",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
keyStr, _, err := db.CreateAPIKey(nil)
require.NoError(t, err)
// Tamper with the secret
_, prefixAndSecret, _ := strings.Cut(keyStr, "hskey-api-")
prefix := prefixAndSecret[:12]
tamperedKey := "hskey-api-" + prefix + "-" + strings.Repeat("x", 64)
valid, err := db.ValidateAPIKey(tamperedKey)
require.Error(t, err)
assert.False(t, valid)
},
},
{
name: "expired_key_rejected",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
// Create expired key
expired := time.Now().Add(-1 * time.Hour)
keyStr, _, err := db.CreateAPIKey(&expired)
require.NoError(t, err)
// Should fail validation
valid, err := db.ValidateAPIKey(keyStr)
require.NoError(t, err)
assert.False(t, valid)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
tt.test(t, db)
})
}
}

View File

@@ -991,6 +991,38 @@ AND auth_key_id NOT IN (
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed // - NEVER use gorm.AutoMigrate, write the exact migration steps needed
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time. // - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
// - Never write migrations that requires foreign keys to be disabled. // - Never write migrations that requires foreign keys to be disabled.
{
// Add columns for prefix and hash for pre auth keys, implementing
// them with the same security model as api keys.
ID: "202511011637-preauthkey-bcrypt",
Migrate: func(tx *gorm.DB) error {
// Check and add prefix column if it doesn't exist
if !tx.Migrator().HasColumn(&types.PreAuthKey{}, "prefix") {
err := tx.Migrator().AddColumn(&types.PreAuthKey{}, "prefix")
if err != nil {
return fmt.Errorf("adding prefix column: %w", err)
}
}
// Check and add hash column if it doesn't exist
if !tx.Migrator().HasColumn(&types.PreAuthKey{}, "hash") {
err := tx.Migrator().AddColumn(&types.PreAuthKey{}, "hash")
if err != nil {
return fmt.Errorf("adding hash column: %w", err)
}
}
// Create partial unique index to allow multiple legacy keys (NULL/empty prefix)
// while enforcing uniqueness for new bcrypt-based keys
err := tx.Exec("CREATE UNIQUE INDEX IF NOT EXISTS idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''").Error
if err != nil {
return fmt.Errorf("creating prefix index: %w", err)
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
}, },
) )

View File

@@ -6,7 +6,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@@ -159,8 +158,6 @@ func TestIPAllocatorSequential(t *testing.T) {
types.IPAllocationStrategySequential, types.IPAllocationStrategySequential,
) )
spew.Dump(alloc)
var got4s []netip.Addr var got4s []netip.Addr
var got6s []netip.Addr var got6s []netip.Addr
@@ -263,8 +260,6 @@ func TestIPAllocatorRandom(t *testing.T) {
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom) alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom)
spew.Dump(alloc)
for range tt.getCount { for range tt.getCount {
got4, got6, err := alloc.Next() got4, got6, err := alloc.Next()
if err != nil { if err != nil {

View File

@@ -1,8 +1,6 @@
package db package db
import ( import (
"crypto/rand"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"slices" "slices"
@@ -10,6 +8,8 @@ import (
"time" "time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
@@ -28,12 +28,18 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKeyNew, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) {
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
}) })
} }
const (
authKeyPrefix = "hskey-auth-"
authKeyPrefixLength = 12
authKeyLength = 64
)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey( func CreatePreAuthKey(
tx *gorm.DB, tx *gorm.DB,
@@ -42,7 +48,7 @@ func CreatePreAuthKey(
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKeyNew, error) {
user, err := GetUserByID(tx, uid) user, err := GetUserByID(tx, uid)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -65,14 +71,43 @@ func CreatePreAuthKey(
} }
now := time.Now().UTC() now := time.Now().UTC()
// TODO(kradalby): unify the key generations spread all over the code.
kstr, err := generateKey() prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength)
if err != nil {
return nil, err
}
// Validate generated prefix (should always be valid, but be defensive)
if len(prefix) != authKeyPrefixLength {
return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix))
}
if !isValidBase64URLSafe(prefix) {
return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse)
}
toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength)
if err != nil {
return nil, err
}
// Validate generated hash (should always be valid, but be defensive)
if len(toBeHashed) != authKeyLength {
return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed))
}
if !isValidBase64URLSafe(toBeHashed) {
return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse)
}
keyStr := authKeyPrefix + prefix + "-" + toBeHashed
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := types.PreAuthKey{ key := types.PreAuthKey{
Key: kstr,
UserID: user.ID, UserID: user.ID,
User: *user, User: *user,
Reusable: reusable, Reusable: reusable,
@@ -80,13 +115,24 @@ func CreatePreAuthKey(
CreatedAt: &now, CreatedAt: &now,
Expiration: expiration, Expiration: expiration,
Tags: aclTags, Tags: aclTags,
Prefix: prefix, // Store prefix
Hash: hash, // Store hash
} }
if err := tx.Save(&key).Error; err != nil { if err := tx.Save(&key).Error; err != nil {
return nil, fmt.Errorf("failed to create key in the database: %w", err) return nil, fmt.Errorf("failed to create key in the database: %w", err)
} }
return &key, nil return &types.PreAuthKeyNew{
ID: key.ID,
Key: keyStr,
Reusable: key.Reusable,
Ephemeral: key.Ephemeral,
Tags: key.Tags,
Expiration: key.Expiration,
CreatedAt: key.CreatedAt,
User: key.User,
}, nil
} }
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
@@ -110,6 +156,107 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
return keys, nil return keys, nil
} }
var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey")
func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) {
var pak types.PreAuthKey
// Validate input is not empty
if keyStr == "" {
return nil, ErrPreAuthKeyFailedToParse
}
_, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix)
if !found {
// Legacy format (plaintext) - backwards compatibility
err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error
if err != nil {
return nil, ErrPreAuthKeyNotFound
}
return &pak, nil
}
// New format: hskey-auth-{12-char-prefix}-{64-char-hash}
// Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77
const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength
if len(prefixAndHash) < expectedMinLength {
return nil, fmt.Errorf(
"%w: key too short, expected at least %d chars after prefix, got %d",
ErrPreAuthKeyFailedToParse,
expectedMinLength,
len(prefixAndHash),
)
}
// Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe
prefix := prefixAndHash[:authKeyPrefixLength]
// Validate separator at expected position
if prefixAndHash[authKeyPrefixLength] != '-' {
return nil, fmt.Errorf(
"%w: expected separator '-' at position %d, got '%c'",
ErrPreAuthKeyFailedToParse,
authKeyPrefixLength,
prefixAndHash[authKeyPrefixLength],
)
}
hash := prefixAndHash[authKeyPrefixLength+1:]
// Validate hash length
if len(hash) != authKeyLength {
return nil, fmt.Errorf(
"%w: hash length mismatch, expected %d chars, got %d",
ErrPreAuthKeyFailedToParse,
authKeyLength,
len(hash),
)
}
// Validate prefix contains only base64 URL-safe characters
if !isValidBase64URLSafe(prefix) {
return nil, fmt.Errorf(
"%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrPreAuthKeyFailedToParse,
)
}
// Validate hash contains only base64 URL-safe characters
if !isValidBase64URLSafe(hash) {
return nil, fmt.Errorf(
"%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrPreAuthKeyFailedToParse,
)
}
// Look up key by prefix
err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error
if err != nil {
return nil, ErrPreAuthKeyNotFound
}
// Verify hash matches
err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash))
if err != nil {
return nil, fmt.Errorf("invalid auth key: %w", err)
}
return &pak, nil
}
// isValidBase64URLSafe checks if a string contains only base64 URL-safe characters.
func isValidBase64URLSafe(s string) bool {
for _, c := range s {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' {
return false
}
}
return true
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return GetPreAuthKey(hsdb.DB, key) return GetPreAuthKey(hsdb.DB, key)
} }
@@ -117,12 +264,7 @@ func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible // GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
// for checking if the key is usable (expired or used). // for checking if the key is usable (expired or used).
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{} return findAuthKey(tx, key)
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
return nil, ErrPreAuthKeyNotFound
}
return &pak, nil
} }
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
@@ -159,13 +301,3 @@ func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
now := time.Now() now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
} }
func generateKey() (string, error) {
size := 24
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -1,74 +1,135 @@
package db package db
import ( import (
"fmt"
"slices" "slices"
"strings"
"testing" "testing"
"time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/check.v1"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func TestCreatePreAuthKey(t *testing.T) {
// ID does not exist tests := []struct {
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil) name string
c.Assert(err, check.NotNil) test func(*testing.T, *HSDatabase)
}{
{
name: "error_invalid_user_id",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
user, err := db.CreateUser(types.User{Name: "test"}) _, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
c.Assert(err, check.IsNil) assert.Error(t, err)
},
},
{
name: "success_create_and_list",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil) require.NoError(t, err)
// Did we get a valid key? key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(key.Key, check.NotNil) require.NoError(t, err)
c.Assert(len(key.Key), check.Equals, 48) assert.NotEmpty(t, key.Key)
// Make sure the User association is populated // List keys for the user
c.Assert(key.User.ID, check.Equals, user.ID) keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
require.NoError(t, err)
assert.Len(t, keys, 1)
// ID does not exist // Verify User association is populated
_, err = db.ListPreAuthKeys(1000000) assert.Equal(t, user.ID, keys[0].User.ID)
c.Assert(err, check.NotNil) },
},
{
name: "error_list_invalid_user_id",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) _, err := db.ListPreAuthKeys(1000000)
c.Assert(err, check.IsNil) assert.Error(t, err)
c.Assert(len(keys), check.Equals, 1) },
},
}
// Make sure the User association is populated for _, tt := range tests {
c.Assert((keys)[0].User.ID, check.Equals, user.ID) t.Run(tt.name, func(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
tt.test(t, db)
})
}
} }
func (*Suite) TestPreAuthKeyACLTags(c *check.C) { func TestPreAuthKeyACLTags(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test8"}) tests := []struct {
c.Assert(err, check.IsNil) name string
test func(*testing.T, *HSDatabase)
}{
{
name: "reject_malformed_tags",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) user, err := db.CreateUser(types.User{Name: "test-tags-1"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected require.NoError(t, err)
tags := []string{"tag:test1", "tag:test2"} _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} assert.Error(t, err)
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) },
c.Assert(err, check.IsNil) },
{
name: "deduplicate_and_sort_tags",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) user, err := db.CreateUser(types.User{Name: "test-tags-2"})
c.Assert(err, check.IsNil) require.NoError(t, err)
gotTags := listedPaks[0].Proto().GetAclTags()
slices.Sort(gotTags) expectedTags := []string{"tag:test1", "tag:test2"}
c.Assert(gotTags, check.DeepEquals, tags) tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
require.NoError(t, err)
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
require.NoError(t, err)
require.Len(t, listedPaks, 1)
gotTags := listedPaks[0].Proto().GetAclTags()
slices.Sort(gotTags)
assert.Equal(t, expectedTags, gotTags)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
tt.test(t, db)
})
}
} }
func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
db, err := newSQLiteTestDB() db, err := newSQLiteTestDB()
require.NoError(t, err) require.NoError(t, err)
user, err := db.CreateUser(types.User{Name: "test8"}) user, err := db.CreateUser(types.User{Name: "test8"})
assert.NoError(t, err) require.NoError(t, err)
key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"}) key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"})
assert.NoError(t, err) require.NoError(t, err)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
@@ -79,6 +140,317 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
} }
db.DB.Save(&node) db.DB.Save(&node)
err = db.DB.Delete(key).Error err = db.DB.Delete(&types.PreAuthKey{ID: key.ID}).Error
require.ErrorContains(t, err, "constraint failed: FOREIGN KEY constraint failed") require.ErrorContains(t, err, "constraint failed: FOREIGN KEY constraint failed")
} }
func TestPreAuthKeyAuthentication(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
user := db.CreateUserForTest("test-user")
tests := []struct {
name string
setupKey func() string // Returns key string to test
wantFindErr bool // Error when finding the key
wantValidateErr bool // Error when validating the key
validateResult func(*testing.T, *types.PreAuthKey)
}{
{
name: "legacy_key_plaintext",
setupKey: func() string {
// Insert legacy key directly using GORM (simulate existing production key)
// Note: We use raw SQL to bypass GORM's handling and set prefix to empty string
// which simulates how legacy keys exist in production databases
legacyKey := "abc123def456ghi789jkl012mno345pqr678stu901vwx234yz"
now := time.Now()
// Use raw SQL to insert with empty prefix to avoid UNIQUE constraint
err := db.DB.Exec(`
INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at)
VALUES (?, ?, ?, ?, ?, ?)
`, legacyKey, user.ID, true, false, false, now).Error
require.NoError(t, err)
return legacyKey
},
wantFindErr: false,
wantValidateErr: false,
validateResult: func(t *testing.T, pak *types.PreAuthKey) {
t.Helper()
assert.Equal(t, user.ID, pak.UserID)
assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated
assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix
assert.Nil(t, pak.Hash) // Legacy keys have nil Hash
},
},
{
name: "new_key_bcrypt",
setupKey: func() string {
// Create new key via API
keyStr, err := db.CreatePreAuthKey(
types.UserID(user.ID),
true, false, nil, []string{"tag:test"},
)
require.NoError(t, err)
return keyStr.Key
},
wantFindErr: false,
wantValidateErr: false,
validateResult: func(t *testing.T, pak *types.PreAuthKey) {
t.Helper()
assert.Equal(t, user.ID, pak.UserID)
assert.Empty(t, pak.Key) // New keys have empty Key
assert.NotEmpty(t, pak.Prefix) // New keys have Prefix
assert.NotNil(t, pak.Hash) // New keys have Hash
assert.Len(t, pak.Prefix, 12) // Prefix is 12 chars
},
},
{
name: "new_key_format_validation",
setupKey: func() string {
keyStr, err := db.CreatePreAuthKey(
types.UserID(user.ID),
true, false, nil, nil,
)
require.NoError(t, err)
// Verify format: hskey-auth-{12-char-prefix}-{64-char-hash}
// Use fixed-length parsing since prefix/hash can contain dashes (base64 URL-safe)
assert.True(t, strings.HasPrefix(keyStr.Key, "hskey-auth-"))
// Extract prefix and hash using fixed-length parsing like the real code does
_, prefixAndHash, found := strings.Cut(keyStr.Key, "hskey-auth-")
assert.True(t, found)
assert.GreaterOrEqual(t, len(prefixAndHash), 12+1+64) // prefix + '-' + hash minimum
prefix := prefixAndHash[:12]
assert.Len(t, prefix, 12) // Prefix is 12 chars
assert.Equal(t, byte('-'), prefixAndHash[12]) // Separator
hash := prefixAndHash[13:]
assert.Len(t, hash, 64) // Hash is 64 chars
return keyStr.Key
},
wantFindErr: false,
wantValidateErr: false,
},
{
name: "invalid_bcrypt_hash",
setupKey: func() string {
// Create valid key
key, err := db.CreatePreAuthKey(
types.UserID(user.ID),
true, false, nil, nil,
)
require.NoError(t, err)
keyStr := key.Key
// Return key with tampered hash using fixed-length parsing
_, prefixAndHash, _ := strings.Cut(keyStr, "hskey-auth-")
prefix := prefixAndHash[:12]
return "hskey-auth-" + prefix + "-" + "wrong_hash_here_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "empty_key",
setupKey: func() string {
return ""
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "key_too_short",
setupKey: func() string {
return "hskey-auth-short"
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "missing_separator",
setupKey: func() string {
return "hskey-auth-ABCDEFGHIJKLabcdefghijklmnopqrstuvwxyz1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ"
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "hash_too_short",
setupKey: func() string {
return "hskey-auth-ABCDEFGHIJKL-short"
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "prefix_with_invalid_chars",
setupKey: func() string {
return "hskey-auth-ABC$EF@HIJKL-" + strings.Repeat("a", 64)
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "hash_with_invalid_chars",
setupKey: func() string {
return "hskey-auth-ABCDEFGHIJKL-" + "invalid$chars" + strings.Repeat("a", 54)
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "prefix_not_found_in_db",
setupKey: func() string {
// Create a validly formatted key but with a prefix that doesn't exist
return "hskey-auth-NotInDB12345-" + strings.Repeat("a", 64)
},
wantFindErr: true,
wantValidateErr: false,
},
{
name: "expired_legacy_key",
setupKey: func() string {
legacyKey := "expired_legacy_key_123456789012345678901234"
now := time.Now()
expiration := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
// Use raw SQL to avoid UNIQUE constraint on empty prefix
err := db.DB.Exec(`
INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at, expiration)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, legacyKey, user.ID, true, false, false, now, expiration).Error
require.NoError(t, err)
return legacyKey
},
wantFindErr: false,
wantValidateErr: true,
},
{
name: "used_single_use_legacy_key",
setupKey: func() string {
legacyKey := "used_legacy_key_123456789012345678901234567"
now := time.Now()
// Use raw SQL to avoid UNIQUE constraint on empty prefix
err := db.DB.Exec(`
INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at)
VALUES (?, ?, ?, ?, ?, ?)
`, legacyKey, user.ID, false, false, true, now).Error
require.NoError(t, err)
return legacyKey
},
wantFindErr: false,
wantValidateErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyStr := tt.setupKey()
pak, err := db.GetPreAuthKey(keyStr)
if tt.wantFindErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, pak)
// Check validation if needed
if tt.wantValidateErr {
err := pak.Validate()
assert.Error(t, err)
return
}
if tt.validateResult != nil {
tt.validateResult(t, pak)
}
})
}
}
func TestMultipleLegacyKeysAllowed(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
user, err := db.CreateUser(types.User{Name: "test-legacy"})
require.NoError(t, err)
// Create multiple legacy keys by directly inserting with empty prefix
// This simulates the migration scenario where existing databases have multiple
// plaintext keys without prefix/hash fields
now := time.Now()
for i := range 5 {
legacyKey := fmt.Sprintf("legacy_key_%d_%s", i, strings.Repeat("x", 40))
err := db.DB.Exec(`
INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at)
VALUES (?, '', NULL, ?, ?, ?, ?, ?)
`, legacyKey, user.ID, true, false, false, now).Error
require.NoError(t, err, "should allow multiple legacy keys with empty prefix")
}
// Verify all legacy keys can be retrieved
var legacyKeys []types.PreAuthKey
err = db.DB.Where("prefix = '' OR prefix IS NULL").Find(&legacyKeys).Error
require.NoError(t, err)
assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys")
// Now create new bcrypt-based keys - these should have unique prefixes
key1, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key1.Key)
key2, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key2.Key)
// Verify the new keys have different prefixes
pak1, err := db.GetPreAuthKey(key1.Key)
require.NoError(t, err)
assert.NotEmpty(t, pak1.Prefix)
pak2, err := db.GetPreAuthKey(key2.Key)
require.NoError(t, err)
assert.NotEmpty(t, pak2.Prefix)
assert.NotEqual(t, pak1.Prefix, pak2.Prefix, "new keys should have unique prefixes")
// Verify we cannot manually insert duplicate non-empty prefixes
duplicatePrefix := "test_prefix1"
hash1 := []byte("hash1")
hash2 := []byte("hash2")
// First insert should succeed
err = db.DB.Exec(`
INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at)
VALUES ('', ?, ?, ?, ?, ?, ?, ?)
`, duplicatePrefix, hash1, user.ID, true, false, false, now).Error
require.NoError(t, err, "first key with prefix should succeed")
// Second insert with same prefix should fail
err = db.DB.Exec(`
INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at)
VALUES ('', ?, ?, ?, ?, ?, ?, ?)
`, duplicatePrefix, hash2, user.ID, true, false, false, now).Error
require.Error(t, err, "duplicate non-empty prefix should be rejected")
assert.Contains(t, err.Error(), "UNIQUE constraint failed", "should fail with UNIQUE constraint error")
}

View File

@@ -48,6 +48,8 @@ CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users(
CREATE TABLE pre_auth_keys( CREATE TABLE pre_auth_keys(
id integer PRIMARY KEY AUTOINCREMENT, id integer PRIMARY KEY AUTOINCREMENT,
key text, key text,
prefix text,
hash blob,
user_id integer, user_id integer,
reusable numeric, reusable numeric,
ephemeral numeric DEFAULT false, ephemeral numeric DEFAULT false,
@@ -59,6 +61,7 @@ CREATE TABLE pre_auth_keys(
CONSTRAINT fk_pre_auth_keys_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL CONSTRAINT fk_pre_auth_keys_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL
); );
CREATE UNIQUE INDEX idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != '';
CREATE TABLE api_keys( CREATE TABLE api_keys(
id integer PRIMARY KEY AUTOINCREMENT, id integer PRIMARY KEY AUTOINCREMENT,

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"zombiezen.com/go/postgrestest" "zombiezen.com/go/postgrestest"
) )
@@ -56,6 +57,7 @@ func newSQLiteTestDB() (*HSDatabase, error) {
} }
log.Printf("database path: %s", tmpDir+"/headscale_test.db") log.Printf("database path: %s", tmpDir+"/headscale_test.db")
zerolog.SetGlobalLevel(zerolog.Disabled)
db, err = NewHeadscaleDatabase( db, err = NewHeadscaleDatabase(
types.DatabaseConfig{ types.DatabaseConfig{

View File

@@ -1,138 +1,276 @@
package db package db
import ( import (
"strings" "testing"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
) )
func (s *Suite) TestCreateAndDestroyUser(c *check.C) { func TestCreateAndDestroyUser(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
user := db.CreateUserForTest("test") user := db.CreateUserForTest("test")
c.Assert(user.Name, check.Equals, "test") assert.Equal(t, "test", user.Name)
users, err := db.ListUsers() users, err := db.ListUsers()
c.Assert(err, check.IsNil) require.NoError(t, err)
c.Assert(len(users), check.Equals, 1) assert.Len(t, users, 1)
err = db.DestroyUser(types.UserID(user.ID)) err = db.DestroyUser(types.UserID(user.ID))
c.Assert(err, check.IsNil) require.NoError(t, err)
_, err = db.GetUserByID(types.UserID(user.ID)) _, err = db.GetUserByID(types.UserID(user.ID))
c.Assert(err, check.NotNil) assert.Error(t, err)
} }
func (s *Suite) TestDestroyUserErrors(c *check.C) { func TestDestroyUserErrors(t *testing.T) {
err := db.DestroyUser(9998) tests := []struct {
c.Assert(err, check.Equals, ErrUserNotFound) name string
test func(*testing.T, *HSDatabase)
}{
{
name: "error_user_not_found",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
user := db.CreateUserForTest("test") err := db.DestroyUser(9998)
assert.ErrorIs(t, err, ErrUserNotFound)
},
},
{
name: "success_deletes_preauthkeys",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) user := db.CreateUserForTest("test")
c.Assert(err, check.IsNil)
err = db.DestroyUser(types.UserID(user.ID)) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) require.NoError(t, err)
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) err = db.DestroyUser(types.UserID(user.ID))
// destroying a user also deletes all associated preauthkeys require.NoError(t, err)
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
user, err = db.CreateUser(types.User{Name: "test"}) // Verify preauth key was deleted (need to search by prefix for new keys)
c.Assert(err, check.IsNil) var foundPak types.PreAuthKey
pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) result := db.DB.First(&foundPak, "id = ?", pak.ID)
c.Assert(err, check.IsNil) assert.ErrorIs(t, result.Error, gorm.ErrRecordNotFound)
},
},
{
name: "error_user_has_nodes",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
node := types.Node{ user, err := db.CreateUser(types.User{Name: "test"})
ID: 0, require.NoError(t, err)
Hostname: "testnode",
UserID: user.ID, pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
RegisterMethod: util.RegisterMethodAuthKey, require.NoError(t, err)
AuthKeyID: ptr.To(pak.ID),
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
require.NoError(t, trx.Error)
err = db.DestroyUser(types.UserID(user.ID))
assert.ErrorIs(t, err, ErrUserStillHasNodes)
},
},
} }
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
err = db.DestroyUser(types.UserID(user.ID)) for _, tt := range tests {
c.Assert(err, check.Equals, ErrUserStillHasNodes) t.Run(tt.name, func(t *testing.T) {
} db, err := newSQLiteTestDB()
require.NoError(t, err)
func (s *Suite) TestRenameUser(c *check.C) { tt.test(t, db)
userTest := db.CreateUserForTest("test") })
c.Assert(userTest.Name, check.Equals, "test")
users, err := db.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = db.RenameUser(types.UserID(userTest.ID), "test-renamed")
c.Assert(err, check.IsNil)
users, err = db.ListUsers(&types.User{Name: "test"})
c.Assert(err, check.Equals, nil)
c.Assert(len(users), check.Equals, 0)
users, err = db.ListUsers(&types.User{Name: "test-renamed"})
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = db.RenameUser(99988, "test")
c.Assert(err, check.Equals, ErrUserNotFound)
userTest2 := db.CreateUserForTest("test2")
c.Assert(userTest2.Name, check.Equals, "test2")
want := "UNIQUE constraint failed"
err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed")
if err == nil || !strings.Contains(err.Error(), want) {
c.Fatalf("expected failure with unique constraint, want: %q got: %q", want, err)
} }
} }
func (s *Suite) TestSetMachineUser(c *check.C) { func TestRenameUser(t *testing.T) {
oldUser := db.CreateUserForTest("old") tests := []struct {
newUser := db.CreateUserForTest("new") name string
test func(*testing.T, *HSDatabase)
}{
{
name: "success_rename",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) userTest := db.CreateUserForTest("test")
c.Assert(err, check.IsNil) assert.Equal(t, "test", userTest.Name)
node := types.Node{ users, err := db.ListUsers()
ID: 12, require.NoError(t, err)
Hostname: "testnode", assert.Len(t, users, 1)
UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey, err = db.RenameUser(types.UserID(userTest.ID), "test-renamed")
AuthKeyID: ptr.To(pak.ID), require.NoError(t, err)
users, err = db.ListUsers(&types.User{Name: "test"})
require.NoError(t, err)
assert.Empty(t, users)
users, err = db.ListUsers(&types.User{Name: "test-renamed"})
require.NoError(t, err)
assert.Len(t, users, 1)
},
},
{
name: "error_user_not_found",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
err := db.RenameUser(99988, "test")
assert.ErrorIs(t, err, ErrUserNotFound)
},
},
{
name: "error_duplicate_name",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
userTest := db.CreateUserForTest("test")
userTest2 := db.CreateUserForTest("test2")
assert.Equal(t, "test", userTest.Name)
assert.Equal(t, "test2", userTest2.Name)
err := db.RenameUser(types.UserID(userTest2.ID), "test")
require.Error(t, err)
assert.Contains(t, err.Error(), "UNIQUE constraint failed")
},
},
} }
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.Write(func(tx *gorm.DB) error { for _, tt := range tests {
return AssignNodeToUser(tx, 12, types.UserID(newUser.ID)) t.Run(tt.name, func(t *testing.T) {
}) db, err := newSQLiteTestDB()
c.Assert(err, check.IsNil) require.NoError(t, err)
// Reload node from database to see updated values
updatedNode, err := db.GetNodeByID(12)
c.Assert(err, check.IsNil)
c.Assert(updatedNode.UserID, check.Equals, newUser.ID)
c.Assert(updatedNode.User.Name, check.Equals, newUser.Name)
err = db.Write(func(tx *gorm.DB) error { tt.test(t, db)
return AssignNodeToUser(tx, 12, 9584849) })
}) }
c.Assert(err, check.Equals, ErrUserNotFound) }
err = db.Write(func(tx *gorm.DB) error { func TestAssignNodeToUser(t *testing.T) {
return AssignNodeToUser(tx, 12, types.UserID(newUser.ID)) tests := []struct {
}) name string
c.Assert(err, check.IsNil) test func(*testing.T, *HSDatabase)
// Reload node from database again to see updated values }{
finalNode, err := db.GetNodeByID(12) {
c.Assert(err, check.IsNil) name: "success_reassign_node",
c.Assert(finalNode.UserID, check.Equals, newUser.ID) test: func(t *testing.T, db *HSDatabase) {
c.Assert(finalNode.User.Name, check.Equals, newUser.Name) t.Helper()
oldUser := db.CreateUserForTest("old")
newUser := db.CreateUserForTest("new")
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
require.NoError(t, err)
node := types.Node{
ID: 12,
Hostname: "testnode",
UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
require.NoError(t, trx.Error)
assert.Equal(t, oldUser.ID, node.UserID)
err = db.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, 12, types.UserID(newUser.ID))
})
require.NoError(t, err)
// Reload node from database to see updated values
updatedNode, err := db.GetNodeByID(12)
require.NoError(t, err)
assert.Equal(t, newUser.ID, updatedNode.UserID)
assert.Equal(t, newUser.Name, updatedNode.User.Name)
},
},
{
name: "error_user_not_found",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
oldUser := db.CreateUserForTest("old")
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
require.NoError(t, err)
node := types.Node{
ID: 12,
Hostname: "testnode",
UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
require.NoError(t, trx.Error)
err = db.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, 12, 9584849)
})
assert.ErrorIs(t, err, ErrUserNotFound)
},
},
{
name: "success_reassign_to_same_user",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
user := db.CreateUserForTest("user")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
require.NoError(t, err)
node := types.Node{
ID: 12,
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
require.NoError(t, trx.Error)
err = db.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, 12, types.UserID(user.ID))
})
require.NoError(t, err)
// Reload node from database again to see updated values
finalNode, err := db.GetNodeByID(12)
require.NoError(t, err)
assert.Equal(t, user.ID, finalNode.UserID)
assert.Equal(t, user.Name, finalNode.User.Name)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
tt.test(t, db)
})
}
} }

View File

@@ -236,10 +236,18 @@ func (api headscaleV1APIServer) RegisterNode(
ctx context.Context, ctx context.Context,
request *v1.RegisterNodeRequest, request *v1.RegisterNodeRequest,
) (*v1.RegisterNodeResponse, error) { ) (*v1.RegisterNodeResponse, error) {
// Generate ephemeral registration key for tracking this registration flow in logs
registrationKey, err := util.GenerateRegistrationKey()
if err != nil {
log.Warn().Err(err).Msg("Failed to generate registration key")
registrationKey = "" // Continue without key if generation fails
}
log.Trace(). log.Trace().
Caller(). Caller().
Str("user", request.GetUser()). Str("user", request.GetUser()).
Str("registration_id", request.GetKey()). Str("registration_id", request.GetKey()).
Str("registration_key", registrationKey).
Msg("Registering node") Msg("Registering node")
registrationId, err := types.RegistrationIDFromString(request.GetKey()) registrationId, err := types.RegistrationIDFromString(request.GetKey())
@@ -259,9 +267,19 @@ func (api headscaleV1APIServer) RegisterNode(
util.RegisterMethodCLI, util.RegisterMethodCLI,
) )
if err != nil { if err != nil {
log.Error().
Str("registration_key", registrationKey).
Err(err).
Msg("Failed to register node")
return nil, err return nil, err
} }
log.Info().
Str("registration_key", registrationKey).
Str("node_id", fmt.Sprintf("%d", node.ID())).
Str("hostname", node.Hostname()).
Msg("Node registered successfully")
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here. // dependency here.
// Because the way the policy manager works, we need to have the node // Because the way the policy manager works, we need to have the node

View File

@@ -901,7 +901,14 @@ func (s *State) CreateAPIKey(expiration *time.Time) (string, *types.APIKey, erro
} }
// GetAPIKey retrieves an API key by its prefix. // GetAPIKey retrieves an API key by its prefix.
func (s *State) GetAPIKey(prefix string) (*types.APIKey, error) { // Accepts both display format (hskey-api-{12chars}-***) and database format ({12chars}).
func (s *State) GetAPIKey(displayPrefix string) (*types.APIKey, error) {
// Parse the display prefix to extract the database prefix
prefix, err := hsdb.ParseAPIKeyPrefix(displayPrefix)
if err != nil {
return nil, err
}
return s.db.GetAPIKey(prefix) return s.db.GetAPIKey(prefix)
} }
@@ -921,7 +928,7 @@ func (s *State) DestroyAPIKey(key types.APIKey) error {
} }
// CreatePreAuthKey generates a new pre-authentication key for a user. // CreatePreAuthKey generates a new pre-authentication key for a user.
func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKey, error) { func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) {
return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags) return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags)
} }

View File

@@ -7,6 +7,13 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
const (
// NewAPIKeyPrefixLength is the length of the prefix for new API keys.
NewAPIKeyPrefixLength = 12
// LegacyAPIKeyPrefixLength is the length of the prefix for legacy API keys.
LegacyAPIKeyPrefixLength = 7
)
// APIKey describes the datamodel for API keys used to remotely authenticate with // APIKey describes the datamodel for API keys used to remotely authenticate with
// headscale. // headscale.
type APIKey struct { type APIKey struct {
@@ -21,8 +28,16 @@ type APIKey struct {
func (key *APIKey) Proto() *v1.ApiKey { func (key *APIKey) Proto() *v1.ApiKey {
protoKey := v1.ApiKey{ protoKey := v1.ApiKey{
Id: key.ID, Id: key.ID,
Prefix: key.Prefix, }
// Show prefix format: distinguish between new (12-char) and legacy (7-char) keys
if len(key.Prefix) == NewAPIKeyPrefixLength {
// New format key (12-char prefix)
protoKey.Prefix = "hskey-api-" + key.Prefix + "-***"
} else {
// Legacy format key (7-char prefix) or fallback
protoKey.Prefix = key.Prefix + "***"
} }
if key.Expiration != nil { if key.Expiration != nil {

View File

@@ -14,8 +14,15 @@ func (e PAKError) Error() string { return string(e) }
// PreAuthKey describes a pre-authorization key usable in a particular user. // PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct { type PreAuthKey struct {
ID uint64 `gorm:"primary_key"` ID uint64 `gorm:"primary_key"`
Key string
// Legacy plaintext key (for backwards compatibility)
Key string
// New bcrypt-based authentication
Prefix string
Hash []byte // bcrypt
UserID uint UserID uint
User User `gorm:"constraint:OnDelete:SET NULL;"` User User `gorm:"constraint:OnDelete:SET NULL;"`
Reusable bool Reusable bool
@@ -32,17 +39,59 @@ type PreAuthKey struct {
Expiration *time.Time Expiration *time.Time
} }
// PreAuthKeyNew is returned once when the key is created.
type PreAuthKeyNew struct {
ID uint64 `gorm:"primary_key"`
Key string
Reusable bool
Ephemeral bool
Tags []string
Expiration *time.Time
CreatedAt *time.Time
User User
}
func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
Id: key.ID,
Key: key.Key,
User: key.User.Proto(),
Reusable: key.Reusable,
Ephemeral: key.Ephemeral,
AclTags: key.Tags,
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
return &protoKey
}
func (key *PreAuthKey) Proto() *v1.PreAuthKey { func (key *PreAuthKey) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{ protoKey := v1.PreAuthKey{
User: key.User.Proto(), User: key.User.Proto(),
Id: key.ID, Id: key.ID,
Key: key.Key,
Ephemeral: key.Ephemeral, Ephemeral: key.Ephemeral,
Reusable: key.Reusable, Reusable: key.Reusable,
Used: key.Used, Used: key.Used,
AclTags: key.Tags, AclTags: key.Tags,
} }
// For new keys (with prefix/hash), show the prefix so users can identify the key
// For legacy keys (with plaintext key), show the full key for backwards compatibility
if key.Prefix != "" {
protoKey.Key = "hskey-auth-" + key.Prefix + "-***"
} else if key.Key != "" {
// Legacy key - show full key for backwards compatibility
// TODO: Consider hiding this in a future major version
protoKey.Key = key.Key
}
if key.Expiration != nil { if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration) protoKey.Expiration = timestamppb.New(*key.Expiration)
} }

View File

@@ -110,6 +110,7 @@ func (src *PreAuthKey) Clone() *PreAuthKey {
} }
dst := new(PreAuthKey) dst := new(PreAuthKey)
*dst = *src *dst = *src
dst.Hash = append(src.Hash[:0:0], src.Hash...)
dst.Tags = append(src.Tags[:0:0], src.Tags...) dst.Tags = append(src.Tags[:0:0], src.Tags...)
if dst.CreatedAt != nil { if dst.CreatedAt != nil {
dst.CreatedAt = ptr.To(*src.CreatedAt) dst.CreatedAt = ptr.To(*src.CreatedAt)
@@ -124,6 +125,8 @@ func (src *PreAuthKey) Clone() *PreAuthKey {
var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct { var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct {
ID uint64 ID uint64
Key string Key string
Prefix string
Hash []byte
UserID uint UserID uint
User User User User
Reusable bool Reusable bool

View File

@@ -239,14 +239,16 @@ func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } func (v PreAuthKeyView) ID() uint64 { return v.ж.ID }
func (v PreAuthKeyView) Key() string { return v.ж.Key } func (v PreAuthKeyView) Key() string { return v.ж.Key }
func (v PreAuthKeyView) UserID() uint { return v.ж.UserID } func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix }
func (v PreAuthKeyView) User() User { return v.ж.User } func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) }
func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } func (v PreAuthKeyView) UserID() uint { return v.ж.UserID }
func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } func (v PreAuthKeyView) User() User { return v.ж.User }
func (v PreAuthKeyView) Used() bool { return v.ж.Used } func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral }
func (v PreAuthKeyView) Used() bool { return v.ж.Used }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] { func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.CreatedAt) return views.ValuePointerOf(v.ж.CreatedAt)
} }
@@ -259,6 +261,8 @@ func (v PreAuthKeyView) Expiration() views.ValuePointer[time.Time] {
var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct { var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct {
ID uint64 ID uint64
Key string Key string
Prefix string
Hash []byte
UserID uint UserID uint
User User User User
Reusable bool Reusable bool

View File

@@ -294,3 +294,20 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
return InvalidString() return InvalidString()
} }
// GenerateRegistrationKey generates a vanity key for tracking web authentication
// registration flows in logs. This key is NOT stored in the database and does NOT use bcrypt -
// it's purely for observability and correlating log entries during the registration process.
func GenerateRegistrationKey() (string, error) {
const (
registerKeyPrefix = "hskey-reg-" //nolint:gosec // This is a vanity key for logging, not a credential
registerKeyLength = 64
)
randomPart, err := GenerateRandomStringURLSafe(registerKeyLength)
if err != nil {
return "", fmt.Errorf("generating registration key: %w", err)
}
return registerKeyPrefix + randomPart, nil
}

View File

@@ -1288,3 +1288,99 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
} }
} }
func TestGenerateRegistrationKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
test func(*testing.T)
}{
{
name: "generates_key_with_correct_prefix",
test: func(t *testing.T) {
t.Helper()
key, err := GenerateRegistrationKey()
if err != nil {
t.Errorf("GenerateRegistrationKey() error = %v", err)
}
if !strings.HasPrefix(key, "hskey-reg-") {
t.Errorf("key does not have expected prefix: %s", key)
}
},
},
{
name: "generates_key_with_correct_length",
test: func(t *testing.T) {
t.Helper()
key, err := GenerateRegistrationKey()
if err != nil {
t.Errorf("GenerateRegistrationKey() error = %v", err)
}
// Expected format: hskey-reg-{64-char-random}
// Total length: 10 (prefix) + 64 (random) = 74
if len(key) != 74 {
t.Errorf("key length = %d, want 74", len(key))
}
},
},
{
name: "generates_unique_keys",
test: func(t *testing.T) {
t.Helper()
key1, err := GenerateRegistrationKey()
if err != nil {
t.Errorf("GenerateRegistrationKey() error = %v", err)
}
key2, err := GenerateRegistrationKey()
if err != nil {
t.Errorf("GenerateRegistrationKey() error = %v", err)
}
if key1 == key2 {
t.Error("generated keys should be unique")
}
},
},
{
name: "key_contains_only_valid_chars",
test: func(t *testing.T) {
t.Helper()
key, err := GenerateRegistrationKey()
if err != nil {
t.Errorf("GenerateRegistrationKey() error = %v", err)
}
// Remove prefix
_, randomPart, found := strings.Cut(key, "hskey-reg-")
if !found {
t.Error("key does not contain expected prefix")
}
// Verify base64 URL-safe characters (A-Za-z0-9_-)
for _, ch := range randomPart {
if (ch < 'A' || ch > 'Z') &&
(ch < 'a' || ch > 'z') &&
(ch < '0' || ch > '9') &&
ch != '_' && ch != '-' {
t.Errorf("key contains invalid character: %c", ch)
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tt.test(t)
})
}
}

View File

@@ -337,9 +337,10 @@ func TestPreAuthKeyCommand(t *testing.T) {
}, },
) )
assert.NotEmpty(t, listedPreAuthKeys[1].GetKey()) // New keys show prefix after listing, so check the created keys instead
assert.NotEmpty(t, listedPreAuthKeys[2].GetKey()) assert.NotEmpty(t, keys[0].GetKey())
assert.NotEmpty(t, listedPreAuthKeys[3].GetKey()) assert.NotEmpty(t, keys[1].GetKey())
assert.NotEmpty(t, keys[2].GetKey())
assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now()))
assert.True(t, listedPreAuthKeys[2].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeys[2].GetExpiration().AsTime().After(time.Now()))
@@ -370,7 +371,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
) )
} }
// Test key expiry // Test key expiry - use the full key from creation, not the masked one from listing
_, err = headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
"headscale", "headscale",
@@ -378,7 +379,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
"--user", "--user",
"1", "1",
"expire", "expire",
listedPreAuthKeys[1].GetKey(), keys[0].GetKey(),
}, },
) )
require.NoError(t, err) require.NoError(t, err)