headscale/hscontrol/db/ip_test.go

152 lines
2.7 KiB
Go

package db
import (
"net/netip"
"os"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestIPAllocator(t *testing.T) {
mpp := func(pref string) netip.Prefix {
return netip.MustParsePrefix(pref)
}
na := func(pref string) netip.Addr {
return netip.MustParseAddr(pref)
}
newDb := func() *HSDatabase {
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
t.Fatalf("creating temp dir: %s", err)
}
db, _ = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
"",
)
return db
}
tests := []struct {
name string
dbFunc func() *HSDatabase
prefix4 netip.Prefix
prefix6 netip.Prefix
getCount int
want []types.NodeAddresses
}{
{
name: "simple",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
},
},
{
name: "simple-with-db",
dbFunc: func() *HSDatabase {
db := newDb()
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
},
},
{
name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase {
db := newDb()
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 2,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
{
na("100.64.0.3"),
na("fd7a:115c:a1e0::3"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc()
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
spew.Dump(alloc)
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
var got []types.NodeAddresses
for range tt.getCount {
gotSet, err := alloc.Next()
if err != nil {
t.Fatalf("allocating next IP: %s", err)
}
got = append(got, gotSet)
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
}
})
}
}