diff --git a/db.go b/db.go index fc539cc3..72386ca8 100644 --- a/db.go +++ b/db.go @@ -18,8 +18,10 @@ import ( ) const ( - dbVersion = "1" - errValueNotFound = Error("not found") + dbVersion = "1" + + errValueNotFound = Error("not found") + ErrCannotParsePrefix = Error("cannot parse prefix") ) // KV is a key-value store in a psql table. For future use... @@ -79,6 +81,65 @@ func (h *Headscale) initDB() error { } } + err = db.AutoMigrate(&Route{}) + if err != nil { + return err + } + + if db.Migrator().HasColumn(&Machine{}, "enabled_routes") { + log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") + + type MachineAux struct { + ID uint64 + EnabledRoutes IPPrefixes + } + + machinesAux := []MachineAux{} + err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error + if err != nil { + log.Fatal().Err(err).Msg("Error accessing db") + } + for _, machine := range machinesAux { + for _, prefix := range machine.EnabledRoutes { + if err != nil { + log.Error(). + Err(err). + Str("enabled_route", prefix.String()). + Msg("Error parsing enabled_route") + continue + } + + err = db.Preload("Machine").Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).First(&Route{}).Error + if err == nil { + log.Info(). + Str("enabled_route", prefix.String()). + Msg("Route already migrated to new table, skipping") + continue + } + + route := Route{ + MachineID: machine.ID, + Advertised: true, + Enabled: true, + Prefix: IPPrefix(prefix), + } + if err := h.db.Create(&route).Error; err != nil { + log.Error().Err(err).Msg("Error creating route") + } else { + log.Info(). + Uint64("machine_id", route.MachineID). + Str("prefix", prefix.String()). + Msg("Route migrated") + } + } + } + + err = db.Migrator().DropColumn(&Machine{}, "enabled_routes") + if err != nil { + log.Error().Err(err).Msg("Error dropping enabled_routes column") + } + } + err = db.AutoMigrate(&Machine{}) if err != nil { return err @@ -264,6 +325,28 @@ func (hi HostInfo) Value() (driver.Value, error) { return string(bytes), err } +type IPPrefix netip.Prefix + +func (i *IPPrefix) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + prefix, err := netip.ParsePrefix(value) + if err != nil { + return err + } + *i = IPPrefix(prefix) + return nil + default: + return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i IPPrefix) Value() (driver.Value, error) { + prefixStr := netip.Prefix(i).String() + return prefixStr, nil +} + type IPPrefixes []netip.Prefix func (i *IPPrefixes) Scan(destination interface{}) error { diff --git a/routes.go b/routes.go index 676c79b6..3ba710be 100644 --- a/routes.go +++ b/routes.go @@ -3,115 +3,102 @@ package headscale import ( "fmt" "net/netip" + + "gorm.io/gorm" ) const ( ErrRouteIsNotAvailable = Error("route is not available") ) -// Deprecated: use machine function instead -// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by -// namespace and node name). -func (h *Headscale) GetAdvertisedNodeRoutes( - namespace string, - nodeName string, -) (*[]netip.Prefix, error) { - machine, err := h.GetMachine(namespace, nodeName) +type Route struct { + gorm.Model + + MachineID uint64 + Machine Machine + Prefix IPPrefix + + Advertised bool + Enabled bool + IsPrimary bool +} + +type Routes []Route + +func (r *Route) String() string { + return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) +} + +func (rs Routes) toPrefixes() []netip.Prefix { + prefixes := make([]netip.Prefix, len(rs)) + for i, r := range rs { + prefixes[i] = netip.Prefix(r.Prefix) + } + return prefixes +} + +// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) +// Exit nodes are not considered for this, as they are never marked as Primary +func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { + var routes []Route + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). + Find(&routes).Error if err != nil { return nil, err } - return &machine.HostInfo.RoutableIPs, nil + return routes, nil } -// Deprecated: use machine function instead -// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by -// namespace and node name). -func (h *Headscale) GetEnabledNodeRoutes( - namespace string, - nodeName string, -) ([]netip.Prefix, error) { - machine, err := h.GetMachine(namespace, nodeName) - if err != nil { - return nil, err - } - - return machine.EnabledRoutes, nil -} - -// Deprecated: use machine function instead -// IsNodeRouteEnabled checks if a certain route has been enabled. -func (h *Headscale) IsNodeRouteEnabled( - namespace string, - nodeName string, - routeStr string, -) bool { - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return false - } - - enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName) - if err != nil { - return false - } - - for _, enabledRoute := range enabledRoutes { - if route == enabledRoute { - return true - } - } - - return false -} - -// Deprecated: use EnableRoute in machine.go -// EnableNodeRoute enables a subnet route advertised by a node (identified by -// namespace and node name). -func (h *Headscale) EnableNodeRoute( - namespace string, - nodeName string, - routeStr string, -) error { - machine, err := h.GetMachine(namespace, nodeName) +func (h *Headscale) processMachineRoutes(machine *Machine) error { + currentRoutes := []Route{} + err := h.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { return err } - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return err + advertisedRoutes := map[netip.Prefix]bool{} + for _, prefix := range machine.HostInfo.RoutableIPs { + advertisedRoutes[prefix] = false } - availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName) - if err != nil { - return err - } - - enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName) - if err != nil { - return err - } - - available := false - for _, availableRoute := range *availableRoutes { - // If the route is available, and not yet enabled, add it to the new routing table - if route == availableRoute { - available = true - if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) { - enabledRoutes = append(enabledRoutes, route) + for _, route := range currentRoutes { + if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { + if !route.Advertised { + route.Advertised = true + err := h.db.Save(&route).Error + if err != nil { + return err + } + } + advertisedRoutes[netip.Prefix(route.Prefix)] = true + } else { + if route.Advertised { + route.Advertised = false + route.Enabled = false + err := h.db.Save(&route).Error + if err != nil { + return err + } } } } - if !available { - return ErrRouteIsNotAvailable - } - - machine.EnabledRoutes = enabledRoutes - - if err := h.db.Save(&machine).Error; err != nil { - return fmt.Errorf("failed to update node routes in the database: %w", err) + for prefix, exists := range advertisedRoutes { + if !exists { + route := Route{ + MachineID: machine.ID, + Prefix: IPPrefix(prefix), + Advertised: true, + Enabled: false, + } + err := h.db.Create(&route).Error + if err != nil { + return err + } + } } return nil