mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-25 03:46:06 -05:00
types: make pre auth key use bcrypt (#2853)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
@@ -10,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
@@ -28,12 +28,18 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
) (*types.PreAuthKeyNew, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) {
|
||||
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
|
||||
})
|
||||
}
|
||||
|
||||
const (
|
||||
authKeyPrefix = "hskey-auth-"
|
||||
authKeyPrefixLength = 12
|
||||
authKeyLength = 64
|
||||
)
|
||||
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func CreatePreAuthKey(
|
||||
tx *gorm.DB,
|
||||
@@ -42,7 +48,7 @@ func CreatePreAuthKey(
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
) (*types.PreAuthKeyNew, error) {
|
||||
user, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,14 +71,43 @@ func CreatePreAuthKey(
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
// TODO(kradalby): unify the key generations spread all over the code.
|
||||
kstr, err := generateKey()
|
||||
|
||||
prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate generated prefix (should always be valid, but be defensive)
|
||||
if len(prefix) != authKeyPrefixLength {
|
||||
return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix))
|
||||
}
|
||||
|
||||
if !isValidBase64URLSafe(prefix) {
|
||||
return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse)
|
||||
}
|
||||
|
||||
toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate generated hash (should always be valid, but be defensive)
|
||||
if len(toBeHashed) != authKeyLength {
|
||||
return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed))
|
||||
}
|
||||
|
||||
if !isValidBase64URLSafe(toBeHashed) {
|
||||
return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse)
|
||||
}
|
||||
|
||||
keyStr := authKeyPrefix + prefix + "-" + toBeHashed
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := types.PreAuthKey{
|
||||
Key: kstr,
|
||||
UserID: user.ID,
|
||||
User: *user,
|
||||
Reusable: reusable,
|
||||
@@ -80,13 +115,24 @@ func CreatePreAuthKey(
|
||||
CreatedAt: &now,
|
||||
Expiration: expiration,
|
||||
Tags: aclTags,
|
||||
Prefix: prefix, // Store prefix
|
||||
Hash: hash, // Store hash
|
||||
}
|
||||
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
return &types.PreAuthKeyNew{
|
||||
ID: key.ID,
|
||||
Key: keyStr,
|
||||
Reusable: key.Reusable,
|
||||
Ephemeral: key.Ephemeral,
|
||||
Tags: key.Tags,
|
||||
Expiration: key.Expiration,
|
||||
CreatedAt: key.CreatedAt,
|
||||
User: key.User,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
||||
@@ -110,6 +156,107 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey")
|
||||
|
||||
func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) {
|
||||
var pak types.PreAuthKey
|
||||
|
||||
// Validate input is not empty
|
||||
if keyStr == "" {
|
||||
return nil, ErrPreAuthKeyFailedToParse
|
||||
}
|
||||
|
||||
_, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix)
|
||||
|
||||
if !found {
|
||||
// Legacy format (plaintext) - backwards compatibility
|
||||
err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error
|
||||
if err != nil {
|
||||
return nil, ErrPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
// New format: hskey-auth-{12-char-prefix}-{64-char-hash}
|
||||
// Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77
|
||||
const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength
|
||||
if len(prefixAndHash) < expectedMinLength {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: key too short, expected at least %d chars after prefix, got %d",
|
||||
ErrPreAuthKeyFailedToParse,
|
||||
expectedMinLength,
|
||||
len(prefixAndHash),
|
||||
)
|
||||
}
|
||||
|
||||
// Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe
|
||||
prefix := prefixAndHash[:authKeyPrefixLength]
|
||||
|
||||
// Validate separator at expected position
|
||||
if prefixAndHash[authKeyPrefixLength] != '-' {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: expected separator '-' at position %d, got '%c'",
|
||||
ErrPreAuthKeyFailedToParse,
|
||||
authKeyPrefixLength,
|
||||
prefixAndHash[authKeyPrefixLength],
|
||||
)
|
||||
}
|
||||
|
||||
hash := prefixAndHash[authKeyPrefixLength+1:]
|
||||
|
||||
// Validate hash length
|
||||
if len(hash) != authKeyLength {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: hash length mismatch, expected %d chars, got %d",
|
||||
ErrPreAuthKeyFailedToParse,
|
||||
authKeyLength,
|
||||
len(hash),
|
||||
)
|
||||
}
|
||||
|
||||
// Validate prefix contains only base64 URL-safe characters
|
||||
if !isValidBase64URLSafe(prefix) {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
|
||||
ErrPreAuthKeyFailedToParse,
|
||||
)
|
||||
}
|
||||
|
||||
// Validate hash contains only base64 URL-safe characters
|
||||
if !isValidBase64URLSafe(hash) {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
|
||||
ErrPreAuthKeyFailedToParse,
|
||||
)
|
||||
}
|
||||
|
||||
// Look up key by prefix
|
||||
err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error
|
||||
if err != nil {
|
||||
return nil, ErrPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
// Verify hash matches
|
||||
err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid auth key: %w", err)
|
||||
}
|
||||
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
// isValidBase64URLSafe checks if a string contains only base64 URL-safe characters.
|
||||
func isValidBase64URLSafe(s string) bool {
|
||||
for _, c := range s {
|
||||
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||
return GetPreAuthKey(hsdb.DB, key)
|
||||
}
|
||||
@@ -117,12 +264,7 @@ func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
||||
// for checking if the key is usable (expired or used).
|
||||
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
|
||||
return nil, ErrPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
return &pak, nil
|
||||
return findAuthKey(tx, key)
|
||||
}
|
||||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
@@ -159,13 +301,3 @@ func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||
}
|
||||
|
||||
func generateKey() (string, error) {
|
||||
size := 24
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user