mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-09 19:33:20 -05:00
8571513e3c
* reformat code This is mostly an automated change with `make lint`. I had to manually please golangci-lint in routes_test because of a short variable name. * fix start -> strategy which was wrongly corrected by linter
517 lines
9.4 KiB
Go
517 lines
9.4 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"net/netip"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/davecgh/go-spew/spew"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
)
|
|
|
|
var mpp = func(pref string) *netip.Prefix {
|
|
p := netip.MustParsePrefix(pref)
|
|
return &p
|
|
}
|
|
|
|
var na = func(pref string) netip.Addr {
|
|
return netip.MustParseAddr(pref)
|
|
}
|
|
|
|
var nap = func(pref string) *netip.Addr {
|
|
n := na(pref)
|
|
return &n
|
|
}
|
|
|
|
func TestIPAllocatorSequential(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
dbFunc func() *HSDatabase
|
|
|
|
prefix4 *netip.Prefix
|
|
prefix6 *netip.Prefix
|
|
getCount int
|
|
want4 []netip.Addr
|
|
want6 []netip.Addr
|
|
}{
|
|
{
|
|
name: "simple",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 1,
|
|
|
|
want4: []netip.Addr{
|
|
na("100.64.0.1"),
|
|
},
|
|
want6: []netip.Addr{
|
|
na("fd7a:115c:a1e0::1"),
|
|
},
|
|
},
|
|
{
|
|
name: "simple-v4",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
|
|
getCount: 1,
|
|
|
|
want4: []netip.Addr{
|
|
na("100.64.0.1"),
|
|
},
|
|
},
|
|
{
|
|
name: "simple-v6",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 1,
|
|
|
|
want6: []netip.Addr{
|
|
na("fd7a:115c:a1e0::1"),
|
|
},
|
|
},
|
|
{
|
|
name: "simple-with-db",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "simple-with-db")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.1"),
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 1,
|
|
|
|
want4: []netip.Addr{
|
|
na("100.64.0.2"),
|
|
},
|
|
want6: []netip.Addr{
|
|
na("fd7a:115c:a1e0::2"),
|
|
},
|
|
},
|
|
{
|
|
name: "before-after-free-middle-in-db",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "before-after-free-middle-in-db")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.2"),
|
|
IPv6: nap("fd7a:115c:a1e0::2"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 2,
|
|
|
|
want4: []netip.Addr{
|
|
na("100.64.0.1"),
|
|
na("100.64.0.3"),
|
|
},
|
|
want6: []netip.Addr{
|
|
na("fd7a:115c:a1e0::1"),
|
|
na("fd7a:115c:a1e0::3"),
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := tt.dbFunc()
|
|
|
|
alloc, _ := NewIPAllocator(
|
|
db,
|
|
tt.prefix4,
|
|
tt.prefix6,
|
|
types.IPAllocationStrategySequential,
|
|
)
|
|
|
|
spew.Dump(alloc)
|
|
|
|
var got4s []netip.Addr
|
|
var got6s []netip.Addr
|
|
|
|
for range tt.getCount {
|
|
got4, got6, err := alloc.Next()
|
|
if err != nil {
|
|
t.Fatalf("allocating next IP: %s", err)
|
|
}
|
|
|
|
if got4 != nil {
|
|
got4s = append(got4s, *got4)
|
|
}
|
|
|
|
if got6 != nil {
|
|
got6s = append(got6s, *got6)
|
|
}
|
|
}
|
|
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
|
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
|
|
if diff := cmp.Diff(tt.want6, got6s, util.Comparers...); diff != "" {
|
|
t.Errorf("IPAllocator 6s unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIPAllocatorRandom(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
dbFunc func() *HSDatabase
|
|
|
|
getCount int
|
|
|
|
prefix4 *netip.Prefix
|
|
prefix6 *netip.Prefix
|
|
want4 bool
|
|
want6 bool
|
|
}{
|
|
{
|
|
name: "simple",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 1,
|
|
|
|
want4: true,
|
|
want6: true,
|
|
},
|
|
{
|
|
name: "simple-v4",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
|
|
getCount: 1,
|
|
|
|
want4: true,
|
|
want6: false,
|
|
},
|
|
{
|
|
name: "simple-v6",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 1,
|
|
|
|
want4: false,
|
|
want6: true,
|
|
},
|
|
{
|
|
name: "generate-lots-of-random",
|
|
dbFunc: func() *HSDatabase {
|
|
return nil
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
getCount: 1000,
|
|
|
|
want4: true,
|
|
want6: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := tt.dbFunc()
|
|
|
|
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom)
|
|
|
|
spew.Dump(alloc)
|
|
|
|
for range tt.getCount {
|
|
got4, got6, err := alloc.Next()
|
|
if err != nil {
|
|
t.Fatalf("allocating next IP: %s", err)
|
|
}
|
|
|
|
t.Logf("addrs ipv4: %v, ipv6: %v", got4, got6)
|
|
|
|
if tt.want4 {
|
|
if got4 == nil {
|
|
t.Fatalf("expected ipv4 addr, got nil")
|
|
}
|
|
}
|
|
|
|
if tt.want6 {
|
|
if got6 == nil {
|
|
t.Fatalf("expected ipv4 addr, got nil")
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBackfillIPAddresses(t *testing.T) {
|
|
fullNodeP := func(i int) *types.Node {
|
|
v4 := fmt.Sprintf("100.64.0.%d", i)
|
|
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
|
return &types.Node{
|
|
IPv4DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: v4,
|
|
},
|
|
IPv4: nap(v4),
|
|
IPv6DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: v6,
|
|
},
|
|
IPv6: nap(v6),
|
|
}
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
dbFunc func() *HSDatabase
|
|
|
|
prefix4 *netip.Prefix
|
|
prefix6 *netip.Prefix
|
|
want types.Nodes
|
|
}{
|
|
{
|
|
name: "simple-backfill-ipv6",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "simple-backfill-ipv6")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.1"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
want: types.Nodes{
|
|
&types.Node{
|
|
IPv4DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: "100.64.0.1",
|
|
},
|
|
IPv4: nap("100.64.0.1"),
|
|
IPv6DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: "fd7a:115c:a1e0::1",
|
|
},
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "simple-backfill-ipv4",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "simple-backfill-ipv4")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
want: types.Nodes{
|
|
&types.Node{
|
|
IPv4DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: "100.64.0.1",
|
|
},
|
|
IPv4: nap("100.64.0.1"),
|
|
IPv6DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: "fd7a:115c:a1e0::1",
|
|
},
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "simple-backfill-remove-ipv6",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "simple-backfill-remove-ipv6")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.1"),
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
|
|
want: types.Nodes{
|
|
&types.Node{
|
|
IPv4DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: "100.64.0.1",
|
|
},
|
|
IPv4: nap("100.64.0.1"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "simple-backfill-remove-ipv4",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "simple-backfill-remove-ipv4")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.1"),
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
want: types.Nodes{
|
|
&types.Node{
|
|
IPv6DatabaseField: sql.NullString{
|
|
Valid: true,
|
|
String: "fd7a:115c:a1e0::1",
|
|
},
|
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "multi-backfill-ipv6",
|
|
dbFunc: func() *HSDatabase {
|
|
db := dbForTest(t, "simple-backfill-ipv6")
|
|
user := types.User{Name: ""}
|
|
db.DB.Save(&user)
|
|
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.1"),
|
|
})
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.2"),
|
|
})
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.3"),
|
|
})
|
|
db.DB.Save(&types.Node{
|
|
User: user,
|
|
IPv4: nap("100.64.0.4"),
|
|
})
|
|
|
|
return db
|
|
},
|
|
|
|
prefix4: mpp("100.64.0.0/10"),
|
|
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
|
|
|
want: types.Nodes{
|
|
fullNodeP(1),
|
|
fullNodeP(2),
|
|
fullNodeP(3),
|
|
fullNodeP(4),
|
|
},
|
|
},
|
|
}
|
|
|
|
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
|
|
"ID",
|
|
"MachineKeyDatabaseField",
|
|
"NodeKeyDatabaseField",
|
|
"DiscoKeyDatabaseField",
|
|
"User",
|
|
"UserID",
|
|
"Endpoints",
|
|
"HostinfoDatabaseField",
|
|
"Hostinfo",
|
|
"Routes",
|
|
"CreatedAt",
|
|
"UpdatedAt",
|
|
))
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := tt.dbFunc()
|
|
|
|
alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategySequential)
|
|
if err != nil {
|
|
t.Fatalf("failed to set up ip alloc: %s", err)
|
|
}
|
|
|
|
logs, err := db.BackfillNodeIPs(alloc)
|
|
if err != nil {
|
|
t.Fatalf("failed to backfill: %s", err)
|
|
}
|
|
|
|
t.Logf("backfill log: \n%s", strings.Join(logs, "\n"))
|
|
|
|
got, err := db.ListNodes()
|
|
if err != nil {
|
|
t.Fatalf("failed to get nodes: %s", err)
|
|
}
|
|
|
|
if diff := cmp.Diff(tt.want, got, comps...); diff != "" {
|
|
t.Errorf("Backfill unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|