diff --git a/routes.go b/routes.go index 079c055c..3ba710be 100644 --- a/routes.go +++ b/routes.go @@ -23,109 +23,82 @@ type Route struct { IsPrimary bool } -// Deprecated: use machine function instead -// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by -// namespace and node name). -func (h *Headscale) GetAdvertisedNodeRoutes( - namespace string, - nodeName string, -) (*[]netip.Prefix, error) { - machine, err := h.GetMachine(namespace, nodeName) +type Routes []Route + +func (r *Route) String() string { + return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) +} + +func (rs Routes) toPrefixes() []netip.Prefix { + prefixes := make([]netip.Prefix, len(rs)) + for i, r := range rs { + prefixes[i] = netip.Prefix(r.Prefix) + } + return prefixes +} + +// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) +// Exit nodes are not considered for this, as they are never marked as Primary +func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { + var routes []Route + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). + Find(&routes).Error if err != nil { return nil, err } - return &machine.HostInfo.RoutableIPs, nil + return routes, nil } -// Deprecated: use machine function instead -// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by -// namespace and node name). -func (h *Headscale) GetEnabledNodeRoutes( - namespace string, - nodeName string, -) ([]netip.Prefix, error) { - machine, err := h.GetMachine(namespace, nodeName) - if err != nil { - return nil, err - } - - return machine.EnabledRoutes, nil -} - -// Deprecated: use machine function instead -// IsNodeRouteEnabled checks if a certain route has been enabled. -func (h *Headscale) IsNodeRouteEnabled( - namespace string, - nodeName string, - routeStr string, -) bool { - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return false - } - - enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName) - if err != nil { - return false - } - - for _, enabledRoute := range enabledRoutes { - if route == enabledRoute { - return true - } - } - - return false -} - -// Deprecated: use EnableRoute in machine.go -// EnableNodeRoute enables a subnet route advertised by a node (identified by -// namespace and node name). -func (h *Headscale) EnableNodeRoute( - namespace string, - nodeName string, - routeStr string, -) error { - machine, err := h.GetMachine(namespace, nodeName) +func (h *Headscale) processMachineRoutes(machine *Machine) error { + currentRoutes := []Route{} + err := h.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { return err } - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return err + advertisedRoutes := map[netip.Prefix]bool{} + for _, prefix := range machine.HostInfo.RoutableIPs { + advertisedRoutes[prefix] = false } - availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName) - if err != nil { - return err - } - - enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName) - if err != nil { - return err - } - - available := false - for _, availableRoute := range *availableRoutes { - // If the route is available, and not yet enabled, add it to the new routing table - if route == availableRoute { - available = true - if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) { - enabledRoutes = append(enabledRoutes, route) + for _, route := range currentRoutes { + if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { + if !route.Advertised { + route.Advertised = true + err := h.db.Save(&route).Error + if err != nil { + return err + } + } + advertisedRoutes[netip.Prefix(route.Prefix)] = true + } else { + if route.Advertised { + route.Advertised = false + route.Enabled = false + err := h.db.Save(&route).Error + if err != nil { + return err + } } } } - if !available { - return ErrRouteIsNotAvailable - } - - machine.EnabledRoutes = enabledRoutes - - if err := h.db.Save(&machine).Error; err != nil { - return fmt.Errorf("failed to update node routes in the database: %w", err) + for prefix, exists := range advertisedRoutes { + if !exists { + route := Route{ + MachineID: machine.ID, + Prefix: IPPrefix(prefix), + Advertised: true, + Enabled: false, + } + err := h.db.Create(&route).Error + if err != nil { + return err + } + } } return nil