227 lines
5.4 KiB
Go
Raw Normal View History

package db
2021-04-23 00:25:01 +02:00
import (
"crypto/rand"
"encoding/hex"
2021-06-24 15:44:19 +02:00
"errors"
2022-05-30 15:31:06 +02:00
"fmt"
"strings"
2021-04-23 00:25:01 +02:00
"time"
2021-06-24 15:44:19 +02:00
"github.com/juanfont/headscale/hscontrol/types"
2021-06-24 15:44:19 +02:00
"gorm.io/gorm"
2021-04-23 00:25:01 +02:00
)
var (
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
ErrUserMismatch = errors.New("user mismatch")
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
)
2021-05-05 23:00:04 +02:00
func (hsdb *HSDatabase) CreatePreAuthKey(
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
})
}
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
2021-04-23 00:25:01 +02:00
if err != nil {
return nil, err
}
for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf(
"%w: '%s' did not begin with 'tag:'",
ErrPreAuthKeyACLTagInvalid,
tag,
)
}
}
2021-04-23 00:25:01 +02:00
now := time.Now().UTC()
kstr, err := generateKey()
2021-04-23 00:25:01 +02:00
if err != nil {
return nil, err
}
key := types.PreAuthKey{
Key: kstr,
UserID: user.ID,
User: *user,
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
2021-04-23 00:25:01 +02:00
}
2022-05-30 15:31:06 +02:00
if err := tx.Save(&key).Error; err != nil {
return nil, fmt.Errorf("failed to create key in the database: %w", err)
}
if len(aclTags) > 0 {
seenTags := map[string]bool{}
for _, tag := range aclTags {
if !seenTags[tag] {
if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return nil, fmt.Errorf(
"failed to ceate key tag in the database: %w",
err,
)
}
seenTags[tag] = true
}
}
}
if err != nil {
return nil, err
2022-05-30 15:31:06 +02:00
}
2021-04-23 00:25:01 +02:00
2021-11-15 16:15:50 +00:00
return &key, nil
2021-04-23 00:25:01 +02:00
}
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeys(rx, userName)
})
}
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
2021-04-23 00:25:01 +02:00
if err != nil {
return nil, err
}
keys := []types.PreAuthKey{}
if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
2021-04-23 00:25:01 +02:00
return nil, err
}
2021-11-14 16:46:09 +01:00
return keys, nil
2021-04-23 00:25:01 +02:00
}
2021-11-13 08:39:04 +00:00
// GetPreAuthKey returns a PreAuthKey for a given key.
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
pak, err := ValidatePreAuthKey(tx, key)
if err != nil {
return nil, err
}
if pak.User.Name != user {
return nil, ErrUserMismatch
}
return pak, nil
}
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
return tx.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
return result.Error
}
if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
})
}
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
})
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
2021-08-07 23:57:52 +02:00
return err
}
2021-11-14 16:46:09 +01:00
2021-08-07 23:57:52 +02:00
return nil
}
// UsePreAuthKey marks a PreAuthKey as used.
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
if err := tx.Save(k).Error; err != nil {
2022-05-30 15:31:06 +02:00
return fmt.Errorf("failed to update key used status in the database: %w", err)
}
return nil
}
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
}
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
2021-11-13 08:36:45 +00:00
result.Error,
gorm.ErrRecordNotFound,
) {
2022-07-29 17:35:21 +02:00
return nil, ErrPreAuthKeyNotFound
2021-05-05 23:00:04 +02:00
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
2022-07-29 17:35:21 +02:00
return nil, ErrPreAuthKeyExpired
2021-05-05 23:00:04 +02:00
}
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
return &pak, nil
}
2023-09-24 13:42:05 +02:00
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
2023-09-24 13:42:05 +02:00
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
Find(&nodes).Error; err != nil {
return nil, err
}
2023-09-24 13:42:05 +02:00
if len(nodes) != 0 || pak.Used {
2022-07-29 17:35:21 +02:00
return nil, ErrSingleUseAuthKeyHasBeenUsed
}
2021-05-05 23:00:04 +02:00
return &pak, nil
}
func generateKey() (string, error) {
2021-04-23 00:25:01 +02:00
size := 24
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
2021-11-14 16:46:09 +01:00
2021-04-23 00:25:01 +02:00
return hex.EncodeToString(bytes), nil
}