mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-20 17:56:02 -05:00
235 lines
6.1 KiB
Go
235 lines
6.1 KiB
Go
package db
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gopkg.in/check.v1"
|
|
)
|
|
|
|
func (*Suite) TestCreateAPIKey(c *check.C) {
|
|
apiKeyStr, apiKey, err := db.CreateAPIKey(nil)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(apiKey, check.NotNil)
|
|
|
|
// Did we get a valid key?
|
|
c.Assert(apiKey.Prefix, check.NotNil)
|
|
c.Assert(apiKey.Hash, check.NotNil)
|
|
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
|
|
|
_, err = db.ListAPIKeys()
|
|
c.Assert(err, check.IsNil)
|
|
|
|
keys, err := db.ListAPIKeys()
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(len(keys), check.Equals, 1)
|
|
}
|
|
|
|
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
|
key, err := db.GetAPIKey("does-not-exist")
|
|
c.Assert(err, check.NotNil)
|
|
c.Assert(key, check.IsNil)
|
|
}
|
|
|
|
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
|
nowPlus2 := time.Now().Add(2 * time.Hour)
|
|
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(apiKey, check.NotNil)
|
|
|
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(valid, check.Equals, true)
|
|
}
|
|
|
|
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
|
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
|
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(apiKey, check.NotNil)
|
|
|
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(valid, check.Equals, false)
|
|
|
|
now := time.Now()
|
|
apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(apiKey, check.NotNil)
|
|
|
|
validNow, err := db.ValidateAPIKey(apiKeyStrNow)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(validNow, check.Equals, false)
|
|
|
|
validSilly, err := db.ValidateAPIKey("nota.validkey")
|
|
c.Assert(err, check.NotNil)
|
|
c.Assert(validSilly, check.Equals, false)
|
|
|
|
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
|
c.Assert(err, check.NotNil)
|
|
c.Assert(validWithErr, check.Equals, false)
|
|
}
|
|
|
|
func (*Suite) TestExpireAPIKey(c *check.C) {
|
|
nowPlus2 := time.Now().Add(2 * time.Hour)
|
|
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(apiKey, check.NotNil)
|
|
|
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(valid, check.Equals, true)
|
|
|
|
err = db.ExpireAPIKey(apiKey)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(apiKey.Expiration, check.NotNil)
|
|
|
|
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
|
c.Assert(err, check.IsNil)
|
|
c.Assert(notValid, check.Equals, false)
|
|
}
|
|
|
|
func TestAPIKeyWithPrefix(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
test func(*testing.T, *HSDatabase)
|
|
}{
|
|
{
|
|
name: "new_key_with_prefix",
|
|
test: func(t *testing.T, db *HSDatabase) {
|
|
t.Helper()
|
|
|
|
keyStr, apiKey, err := db.CreateAPIKey(nil)
|
|
require.NoError(t, err)
|
|
|
|
// Verify format: hskey-api-{12-char-prefix}-{64-char-secret}
|
|
assert.True(t, strings.HasPrefix(keyStr, "hskey-api-"))
|
|
|
|
_, prefixAndSecret, found := strings.Cut(keyStr, "hskey-api-")
|
|
assert.True(t, found)
|
|
assert.GreaterOrEqual(t, len(prefixAndSecret), 12+1+64)
|
|
|
|
prefix := prefixAndSecret[:12]
|
|
assert.Len(t, prefix, 12)
|
|
assert.Equal(t, byte('-'), prefixAndSecret[12])
|
|
secret := prefixAndSecret[13:]
|
|
assert.Len(t, secret, 64)
|
|
|
|
// Verify stored fields
|
|
assert.Len(t, apiKey.Prefix, types.NewAPIKeyPrefixLength)
|
|
assert.NotNil(t, apiKey.Hash)
|
|
},
|
|
},
|
|
{
|
|
name: "new_key_can_be_retrieved",
|
|
test: func(t *testing.T, db *HSDatabase) {
|
|
t.Helper()
|
|
|
|
keyStr, createdKey, err := db.CreateAPIKey(nil)
|
|
require.NoError(t, err)
|
|
|
|
// Validate the created key
|
|
valid, err := db.ValidateAPIKey(keyStr)
|
|
require.NoError(t, err)
|
|
assert.True(t, valid)
|
|
|
|
// Verify prefix is correct length
|
|
assert.Len(t, createdKey.Prefix, types.NewAPIKeyPrefixLength)
|
|
},
|
|
},
|
|
{
|
|
name: "invalid_key_format_rejected",
|
|
test: func(t *testing.T, db *HSDatabase) {
|
|
t.Helper()
|
|
|
|
invalidKeys := []string{
|
|
"",
|
|
"hskey-api-short",
|
|
"hskey-api-ABCDEFGHIJKL-tooshort",
|
|
"hskey-api-ABC$EFGHIJKL-" + strings.Repeat("a", 64),
|
|
"hskey-api-ABCDEFGHIJKL" + strings.Repeat("a", 64), // missing separator
|
|
}
|
|
|
|
for _, invalidKey := range invalidKeys {
|
|
valid, err := db.ValidateAPIKey(invalidKey)
|
|
require.Error(t, err, "key should be rejected: %s", invalidKey)
|
|
assert.False(t, valid)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "legacy_key_still_works",
|
|
test: func(t *testing.T, db *HSDatabase) {
|
|
t.Helper()
|
|
|
|
// Insert legacy API key directly (7-char prefix + 32-char secret)
|
|
legacyPrefix := "abcdefg"
|
|
legacySecret := strings.Repeat("x", 32)
|
|
legacyKey := legacyPrefix + "." + legacySecret
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(legacySecret), bcrypt.DefaultCost)
|
|
require.NoError(t, err)
|
|
|
|
now := time.Now()
|
|
err = db.DB.Exec(`
|
|
INSERT INTO api_keys (prefix, hash, created_at)
|
|
VALUES (?, ?, ?)
|
|
`, legacyPrefix, hash, now).Error
|
|
require.NoError(t, err)
|
|
|
|
// Validate legacy key
|
|
valid, err := db.ValidateAPIKey(legacyKey)
|
|
require.NoError(t, err)
|
|
assert.True(t, valid)
|
|
},
|
|
},
|
|
{
|
|
name: "wrong_secret_rejected",
|
|
test: func(t *testing.T, db *HSDatabase) {
|
|
t.Helper()
|
|
|
|
keyStr, _, err := db.CreateAPIKey(nil)
|
|
require.NoError(t, err)
|
|
|
|
// Tamper with the secret
|
|
_, prefixAndSecret, _ := strings.Cut(keyStr, "hskey-api-")
|
|
prefix := prefixAndSecret[:12]
|
|
tamperedKey := "hskey-api-" + prefix + "-" + strings.Repeat("x", 64)
|
|
|
|
valid, err := db.ValidateAPIKey(tamperedKey)
|
|
require.Error(t, err)
|
|
assert.False(t, valid)
|
|
},
|
|
},
|
|
{
|
|
name: "expired_key_rejected",
|
|
test: func(t *testing.T, db *HSDatabase) {
|
|
t.Helper()
|
|
|
|
// Create expired key
|
|
expired := time.Now().Add(-1 * time.Hour)
|
|
keyStr, _, err := db.CreateAPIKey(&expired)
|
|
require.NoError(t, err)
|
|
|
|
// Should fail validation
|
|
valid, err := db.ValidateAPIKey(keyStr)
|
|
require.NoError(t, err)
|
|
assert.False(t, valid)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db, err := newSQLiteTestDB()
|
|
require.NoError(t, err)
|
|
|
|
tt.test(t, db)
|
|
})
|
|
}
|
|
}
|