330 lines
8.1 KiB
Go
330 lines
8.1 KiB
Go
package db
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"net/netip"
|
|
"sync"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"github.com/rs/zerolog/log"
|
|
"go4.org/netipx"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// IPAllocator is a singleton responsible for allocating
|
|
// IP addresses for nodes and making sure the same
|
|
// address is not handed out twice. There can only be one
|
|
// and it needs to be created before any other database
|
|
// writes occur.
|
|
type IPAllocator struct {
|
|
mu sync.Mutex
|
|
|
|
prefix4 *netip.Prefix
|
|
prefix6 *netip.Prefix
|
|
|
|
// Previous IPs handed out
|
|
prev4 netip.Addr
|
|
prev6 netip.Addr
|
|
|
|
// strategy used for handing out IP addresses.
|
|
strategy types.IPAllocationStrategy
|
|
|
|
// Set of all IPs handed out.
|
|
// This might not be in sync with the database,
|
|
// but it is more conservative. If saves to the
|
|
// database fails, the IP will be allocated here
|
|
// until the next restart of Headscale.
|
|
usedIPs netipx.IPSetBuilder
|
|
}
|
|
|
|
// NewIPAllocator returns a new IPAllocator singleton which
|
|
// can be used to hand out unique IP addresses within the
|
|
// provided IPv4 and IPv6 prefix. It needs to be created
|
|
// when headscale starts and needs to finish its read
|
|
// transaction before any writes to the database occur.
|
|
func NewIPAllocator(
|
|
db *HSDatabase,
|
|
prefix4, prefix6 *netip.Prefix,
|
|
strategy types.IPAllocationStrategy,
|
|
) (*IPAllocator, error) {
|
|
ret := IPAllocator{
|
|
prefix4: prefix4,
|
|
prefix6: prefix6,
|
|
|
|
strategy: strategy,
|
|
}
|
|
|
|
var v4s []sql.NullString
|
|
var v6s []sql.NullString
|
|
|
|
if db != nil {
|
|
err := db.Read(func(rx *gorm.DB) error {
|
|
return rx.Model(&types.Node{}).Pluck("ipv4", &v4s).Error
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading IPv4 addresses from database: %w", err)
|
|
}
|
|
|
|
err = db.Read(func(rx *gorm.DB) error {
|
|
return rx.Model(&types.Node{}).Pluck("ipv6", &v6s).Error
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading IPv6 addresses from database: %w", err)
|
|
}
|
|
}
|
|
|
|
var ips netipx.IPSetBuilder
|
|
|
|
// Add network and broadcast addrs to used pool so they
|
|
// are not handed out to nodes.
|
|
if prefix4 != nil {
|
|
network4, broadcast4 := util.GetIPPrefixEndpoints(*prefix4)
|
|
ips.Add(network4)
|
|
ips.Add(broadcast4)
|
|
|
|
// Use network as starting point, it will be used to call .Next()
|
|
// TODO(kradalby): Could potentially take all the IPs loaded from
|
|
// the database into account to start at a more "educated" location.
|
|
ret.prev4 = network4
|
|
}
|
|
|
|
if prefix6 != nil {
|
|
network6, broadcast6 := util.GetIPPrefixEndpoints(*prefix6)
|
|
ips.Add(network6)
|
|
ips.Add(broadcast6)
|
|
|
|
ret.prev6 = network6
|
|
}
|
|
|
|
// Fetch all the IP Addresses currently handed out from the Database
|
|
// and add them to the used IP set.
|
|
for _, addrStr := range append(v4s, v6s...) {
|
|
if addrStr.Valid {
|
|
addr, err := netip.ParseAddr(addrStr.String)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing IP address from database: %w", err)
|
|
}
|
|
|
|
ips.Add(addr)
|
|
}
|
|
}
|
|
|
|
// Build the initial IPSet to validate that we can use it.
|
|
_, err := ips.IPSet()
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"building initial IP Set: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
ret.usedIPs = ips
|
|
|
|
return &ret, nil
|
|
}
|
|
|
|
func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
|
i.mu.Lock()
|
|
defer i.mu.Unlock()
|
|
|
|
var err error
|
|
var ret4 *netip.Addr
|
|
var ret6 *netip.Addr
|
|
|
|
if i.prefix4 != nil {
|
|
ret4, err = i.next(i.prev4, i.prefix4)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
|
}
|
|
i.prev4 = *ret4
|
|
}
|
|
|
|
if i.prefix6 != nil {
|
|
ret6, err = i.next(i.prev6, i.prefix6)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
|
}
|
|
i.prev6 = *ret6
|
|
}
|
|
|
|
return ret4, ret6, nil
|
|
}
|
|
|
|
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
|
|
|
|
func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
|
i.mu.Lock()
|
|
defer i.mu.Unlock()
|
|
|
|
return i.next(prev, prefix)
|
|
}
|
|
|
|
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
|
var err error
|
|
var ip netip.Addr
|
|
|
|
switch i.strategy {
|
|
case types.IPAllocationStrategySequential:
|
|
// Get the first IP in our prefix
|
|
ip = prev.Next()
|
|
case types.IPAllocationStrategyRandom:
|
|
ip, err = randomNext(*prefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting random IP: %w", err)
|
|
}
|
|
}
|
|
|
|
// TODO(kradalby): maybe this can be done less often.
|
|
set, err := i.usedIPs.IPSet()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for {
|
|
if !prefix.Contains(ip) {
|
|
return nil, ErrCouldNotAllocateIP
|
|
}
|
|
|
|
// Check if the IP has already been allocated.
|
|
if set.Contains(ip) {
|
|
switch i.strategy {
|
|
case types.IPAllocationStrategySequential:
|
|
ip = ip.Next()
|
|
case types.IPAllocationStrategyRandom:
|
|
ip, err = randomNext(*prefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting random IP: %w", err)
|
|
}
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
i.usedIPs.Add(ip)
|
|
|
|
return &ip, nil
|
|
}
|
|
}
|
|
|
|
func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
|
rang := netipx.RangeOfPrefix(pfx)
|
|
fromIP, toIP := rang.From(), rang.To()
|
|
|
|
var from, to big.Int
|
|
|
|
from.SetBytes(fromIP.AsSlice())
|
|
to.SetBytes(toIP.AsSlice())
|
|
|
|
// Find the max, this is how we can do "random range",
|
|
// get the "max" as 0 -> to - from and then add back from
|
|
// after.
|
|
tempMax := big.NewInt(0).Sub(&to, &from)
|
|
|
|
out, err := rand.Int(rand.Reader, tempMax)
|
|
if err != nil {
|
|
return netip.Addr{}, fmt.Errorf("generating random IP: %w", err)
|
|
}
|
|
|
|
valInRange := big.NewInt(0).Add(&from, out)
|
|
|
|
ip, ok := netip.AddrFromSlice(valInRange.Bytes())
|
|
if !ok {
|
|
return netip.Addr{}, fmt.Errorf("generated ip bytes are invalid ip")
|
|
}
|
|
|
|
if !pfx.Contains(ip) {
|
|
return netip.Addr{}, fmt.Errorf(
|
|
"generated ip(%s) not in prefix(%s)",
|
|
ip.String(),
|
|
pfx.String(),
|
|
)
|
|
}
|
|
|
|
return ip, nil
|
|
}
|
|
|
|
// BackfillNodeIPs will take a database transaction, and
|
|
// iterate through all of the current nodes in headscale
|
|
// and ensure it has IP addresses according to the current
|
|
// configuration.
|
|
// This means that if both IPv4 and IPv6 is set in the
|
|
// config, and some nodes are missing that type of IP,
|
|
// it will be added.
|
|
// If a prefix type has been removed (IPv4 or IPv6), it
|
|
// will remove the IPs in that family from the node.
|
|
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
|
var err error
|
|
var ret []string
|
|
err = db.Write(func(tx *gorm.DB) error {
|
|
if i == nil {
|
|
return errors.New("backfilling IPs: ip allocator was nil")
|
|
}
|
|
|
|
log.Trace().Msgf("starting to backfill IPs")
|
|
|
|
nodes, err := ListNodes(tx)
|
|
if err != nil {
|
|
return fmt.Errorf("listing nodes to backfill IPs: %w", err)
|
|
}
|
|
|
|
for _, node := range nodes {
|
|
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill")
|
|
|
|
changed := false
|
|
// IPv4 prefix is set, but node ip is missing, alloc
|
|
if i.prefix4 != nil && node.IPv4 == nil {
|
|
ret4, err := i.nextLocked(i.prev4, i.prefix4)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err)
|
|
}
|
|
|
|
node.IPv4 = ret4
|
|
changed = true
|
|
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
|
}
|
|
|
|
// IPv6 prefix is set, but node ip is missing, alloc
|
|
if i.prefix6 != nil && node.IPv6 == nil {
|
|
ret6, err := i.nextLocked(i.prev6, i.prefix6)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err)
|
|
}
|
|
|
|
node.IPv6 = ret6
|
|
changed = true
|
|
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
|
}
|
|
|
|
// IPv4 prefix is not set, but node has IP, remove
|
|
if i.prefix4 == nil && node.IPv4 != nil {
|
|
ret = append(ret, fmt.Sprintf("removing IPv4 %q from Node(%d) %q", node.IPv4.String(), node.ID, node.Hostname))
|
|
node.IPv4 = nil
|
|
changed = true
|
|
}
|
|
|
|
// IPv6 prefix is not set, but node has IP, remove
|
|
if i.prefix6 == nil && node.IPv6 != nil {
|
|
ret = append(ret, fmt.Sprintf("removing IPv6 %q from Node(%d) %q", node.IPv6.String(), node.ID, node.Hostname))
|
|
node.IPv6 = nil
|
|
changed = true
|
|
}
|
|
|
|
if changed {
|
|
err := tx.Save(node).Error
|
|
if err != nil {
|
|
return fmt.Errorf("saving node(%d) after adding IPs: %w", node.ID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return ret, err
|
|
}
|