package db import ( "errors" "fmt" "slices" "strings" "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" "tailscale.com/util/set" ) var ( ErrPreAuthKeyNotFound = errors.New("AuthKey not found") ErrPreAuthKeyExpired = errors.New("AuthKey expired") ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") ErrUserMismatch = errors.New("user mismatch") ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) func (hsdb *HSDatabase) CreatePreAuthKey( uid types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKeyNew, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) { return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) }) } const ( authKeyPrefix = "hskey-auth-" authKeyPrefixLength = 12 authKeyLength = 64 ) // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func CreatePreAuthKey( tx *gorm.DB, uid types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKeyNew, error) { user, err := GetUserByID(tx, uid) if err != nil { return nil, err } // Remove duplicates and sort for consistency aclTags = set.SetOf(aclTags).Slice() slices.Sort(aclTags) // TODO(kradalby): factor out and create a reusable tag validation, // check if there is one in Tailscale's lib. for _, tag := range aclTags { if !strings.HasPrefix(tag, "tag:") { return nil, fmt.Errorf( "%w: '%s' did not begin with 'tag:'", ErrPreAuthKeyACLTagInvalid, tag, ) } } now := time.Now().UTC() prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength) if err != nil { return nil, err } // Validate generated prefix (should always be valid, but be defensive) if len(prefix) != authKeyPrefixLength { return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix)) } if !isValidBase64URLSafe(prefix) { return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse) } toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength) if err != nil { return nil, err } // Validate generated hash (should always be valid, but be defensive) if len(toBeHashed) != authKeyLength { return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed)) } if !isValidBase64URLSafe(toBeHashed) { return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse) } keyStr := authKeyPrefix + prefix + "-" + toBeHashed hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) if err != nil { return nil, err } key := types.PreAuthKey{ UserID: user.ID, User: *user, Reusable: reusable, Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, Tags: aclTags, Prefix: prefix, // Store prefix Hash: hash, // Store hash } if err := tx.Save(&key).Error; err != nil { return nil, fmt.Errorf("failed to create key in the database: %w", err) } return &types.PreAuthKeyNew{ ID: key.ID, Key: keyStr, Reusable: key.Reusable, Ephemeral: key.Ephemeral, Tags: key.Tags, Expiration: key.Expiration, CreatedAt: key.CreatedAt, User: key.User, }, nil } func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { return ListPreAuthKeysByUser(rx, uid) }) } // ListPreAuthKeysByUser returns the list of PreAuthKeys for a user. func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) { user, err := GetUserByID(tx, uid) if err != nil { return nil, err } keys := []types.PreAuthKey{} if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } return keys, nil } var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey") func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { var pak types.PreAuthKey // Validate input is not empty if keyStr == "" { return nil, ErrPreAuthKeyFailedToParse } _, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix) if !found { // Legacy format (plaintext) - backwards compatibility err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error if err != nil { return nil, ErrPreAuthKeyNotFound } return &pak, nil } // New format: hskey-auth-{12-char-prefix}-{64-char-hash} // Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77 const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength if len(prefixAndHash) < expectedMinLength { return nil, fmt.Errorf( "%w: key too short, expected at least %d chars after prefix, got %d", ErrPreAuthKeyFailedToParse, expectedMinLength, len(prefixAndHash), ) } // Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe prefix := prefixAndHash[:authKeyPrefixLength] // Validate separator at expected position if prefixAndHash[authKeyPrefixLength] != '-' { return nil, fmt.Errorf( "%w: expected separator '-' at position %d, got '%c'", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, prefixAndHash[authKeyPrefixLength], ) } hash := prefixAndHash[authKeyPrefixLength+1:] // Validate hash length if len(hash) != authKeyLength { return nil, fmt.Errorf( "%w: hash length mismatch, expected %d chars, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(hash), ) } // Validate prefix contains only base64 URL-safe characters if !isValidBase64URLSafe(prefix) { return nil, fmt.Errorf( "%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", ErrPreAuthKeyFailedToParse, ) } // Validate hash contains only base64 URL-safe characters if !isValidBase64URLSafe(hash) { return nil, fmt.Errorf( "%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", ErrPreAuthKeyFailedToParse, ) } // Look up key by prefix err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error if err != nil { return nil, ErrPreAuthKeyNotFound } // Verify hash matches err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash)) if err != nil { return nil, fmt.Errorf("invalid auth key: %w", err) } return &pak, nil } // isValidBase64URLSafe checks if a string contains only base64 URL-safe characters. func isValidBase64URLSafe(s string) bool { for _, c := range s { if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' { return false } } return true } func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { return GetPreAuthKey(hsdb.DB, key) } // GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible // for checking if the key is usable (expired or used). func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { return findAuthKey(tx, key) } // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error { return tx.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Delete(pak); result.Error != nil { return result.Error } return nil }) } func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { return hsdb.Write(func(tx *gorm.DB) error { return ExpirePreAuthKey(tx, k) }) } // UsePreAuthKey marks a PreAuthKey as used. func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { err := tx.Model(k).Update("used", true).Error if err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } k.Used = true return nil } // MarkExpirePreAuthKey marks a PreAuthKey as expired. func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { now := time.Now() return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error }