mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-20 09:46:01 -05:00
304 lines
8.2 KiB
Go
304 lines
8.2 KiB
Go
package db
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
"tailscale.com/util/set"
|
|
)
|
|
|
|
var (
|
|
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
|
|
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
|
|
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
|
|
ErrUserMismatch = errors.New("user mismatch")
|
|
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
|
)
|
|
|
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
|
uid types.UserID,
|
|
reusable bool,
|
|
ephemeral bool,
|
|
expiration *time.Time,
|
|
aclTags []string,
|
|
) (*types.PreAuthKeyNew, error) {
|
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) {
|
|
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
|
|
})
|
|
}
|
|
|
|
const (
|
|
authKeyPrefix = "hskey-auth-"
|
|
authKeyPrefixLength = 12
|
|
authKeyLength = 64
|
|
)
|
|
|
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
|
func CreatePreAuthKey(
|
|
tx *gorm.DB,
|
|
uid types.UserID,
|
|
reusable bool,
|
|
ephemeral bool,
|
|
expiration *time.Time,
|
|
aclTags []string,
|
|
) (*types.PreAuthKeyNew, error) {
|
|
user, err := GetUserByID(tx, uid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Remove duplicates and sort for consistency
|
|
aclTags = set.SetOf(aclTags).Slice()
|
|
slices.Sort(aclTags)
|
|
|
|
// TODO(kradalby): factor out and create a reusable tag validation,
|
|
// check if there is one in Tailscale's lib.
|
|
for _, tag := range aclTags {
|
|
if !strings.HasPrefix(tag, "tag:") {
|
|
return nil, fmt.Errorf(
|
|
"%w: '%s' did not begin with 'tag:'",
|
|
ErrPreAuthKeyACLTagInvalid,
|
|
tag,
|
|
)
|
|
}
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Validate generated prefix (should always be valid, but be defensive)
|
|
if len(prefix) != authKeyPrefixLength {
|
|
return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix))
|
|
}
|
|
|
|
if !isValidBase64URLSafe(prefix) {
|
|
return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse)
|
|
}
|
|
|
|
toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Validate generated hash (should always be valid, but be defensive)
|
|
if len(toBeHashed) != authKeyLength {
|
|
return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed))
|
|
}
|
|
|
|
if !isValidBase64URLSafe(toBeHashed) {
|
|
return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse)
|
|
}
|
|
|
|
keyStr := authKeyPrefix + prefix + "-" + toBeHashed
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
key := types.PreAuthKey{
|
|
UserID: user.ID,
|
|
User: *user,
|
|
Reusable: reusable,
|
|
Ephemeral: ephemeral,
|
|
CreatedAt: &now,
|
|
Expiration: expiration,
|
|
Tags: aclTags,
|
|
Prefix: prefix, // Store prefix
|
|
Hash: hash, // Store hash
|
|
}
|
|
|
|
if err := tx.Save(&key).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
|
}
|
|
|
|
return &types.PreAuthKeyNew{
|
|
ID: key.ID,
|
|
Key: keyStr,
|
|
Reusable: key.Reusable,
|
|
Ephemeral: key.Ephemeral,
|
|
Tags: key.Tags,
|
|
Expiration: key.Expiration,
|
|
CreatedAt: key.CreatedAt,
|
|
User: key.User,
|
|
}, nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
|
return ListPreAuthKeysByUser(rx, uid)
|
|
})
|
|
}
|
|
|
|
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
|
|
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
|
|
user, err := GetUserByID(tx, uid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keys := []types.PreAuthKey{}
|
|
if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey")
|
|
|
|
func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) {
|
|
var pak types.PreAuthKey
|
|
|
|
// Validate input is not empty
|
|
if keyStr == "" {
|
|
return nil, ErrPreAuthKeyFailedToParse
|
|
}
|
|
|
|
_, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix)
|
|
|
|
if !found {
|
|
// Legacy format (plaintext) - backwards compatibility
|
|
err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error
|
|
if err != nil {
|
|
return nil, ErrPreAuthKeyNotFound
|
|
}
|
|
|
|
return &pak, nil
|
|
}
|
|
|
|
// New format: hskey-auth-{12-char-prefix}-{64-char-hash}
|
|
// Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77
|
|
const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength
|
|
if len(prefixAndHash) < expectedMinLength {
|
|
return nil, fmt.Errorf(
|
|
"%w: key too short, expected at least %d chars after prefix, got %d",
|
|
ErrPreAuthKeyFailedToParse,
|
|
expectedMinLength,
|
|
len(prefixAndHash),
|
|
)
|
|
}
|
|
|
|
// Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe
|
|
prefix := prefixAndHash[:authKeyPrefixLength]
|
|
|
|
// Validate separator at expected position
|
|
if prefixAndHash[authKeyPrefixLength] != '-' {
|
|
return nil, fmt.Errorf(
|
|
"%w: expected separator '-' at position %d, got '%c'",
|
|
ErrPreAuthKeyFailedToParse,
|
|
authKeyPrefixLength,
|
|
prefixAndHash[authKeyPrefixLength],
|
|
)
|
|
}
|
|
|
|
hash := prefixAndHash[authKeyPrefixLength+1:]
|
|
|
|
// Validate hash length
|
|
if len(hash) != authKeyLength {
|
|
return nil, fmt.Errorf(
|
|
"%w: hash length mismatch, expected %d chars, got %d",
|
|
ErrPreAuthKeyFailedToParse,
|
|
authKeyLength,
|
|
len(hash),
|
|
)
|
|
}
|
|
|
|
// Validate prefix contains only base64 URL-safe characters
|
|
if !isValidBase64URLSafe(prefix) {
|
|
return nil, fmt.Errorf(
|
|
"%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
|
|
ErrPreAuthKeyFailedToParse,
|
|
)
|
|
}
|
|
|
|
// Validate hash contains only base64 URL-safe characters
|
|
if !isValidBase64URLSafe(hash) {
|
|
return nil, fmt.Errorf(
|
|
"%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
|
|
ErrPreAuthKeyFailedToParse,
|
|
)
|
|
}
|
|
|
|
// Look up key by prefix
|
|
err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error
|
|
if err != nil {
|
|
return nil, ErrPreAuthKeyNotFound
|
|
}
|
|
|
|
// Verify hash matches
|
|
err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid auth key: %w", err)
|
|
}
|
|
|
|
return &pak, nil
|
|
}
|
|
|
|
// isValidBase64URLSafe checks if a string contains only base64 URL-safe characters.
|
|
func isValidBase64URLSafe(s string) bool {
|
|
for _, c := range s {
|
|
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
|
return GetPreAuthKey(hsdb.DB, key)
|
|
}
|
|
|
|
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
|
// for checking if the key is usable (expired or used).
|
|
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
|
|
return findAuthKey(tx, key)
|
|
}
|
|
|
|
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
|
// does not exist.
|
|
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
|
|
return tx.Transaction(func(db *gorm.DB) error {
|
|
if result := db.Unscoped().Delete(pak); result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
|
return hsdb.Write(func(tx *gorm.DB) error {
|
|
return ExpirePreAuthKey(tx, k)
|
|
})
|
|
}
|
|
|
|
// UsePreAuthKey marks a PreAuthKey as used.
|
|
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
|
err := tx.Model(k).Update("used", true).Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
|
}
|
|
|
|
k.Used = true
|
|
return nil
|
|
}
|
|
|
|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
|
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
|
now := time.Now()
|
|
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
|
}
|