diff --git a/api_common.go b/api_common.go index 9e7ff480..1eaad57c 100644 --- a/api_common.go +++ b/api_common.go @@ -13,7 +13,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -37,7 +37,7 @@ func (h *Headscale) generateMapResponse( profiles := h.getMapResponseUserProfiles(*machine, peers) - nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). diff --git a/grpcv1.go b/grpcv1.go index 25ee7777..998b9902 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -374,7 +374,7 @@ func (api headscaleV1APIServer) GetMachineRoute( } return &v1.GetMachineRouteResponse{ - Routes: machine.RoutesToProto(), + Routes: api.h.RoutesToProto(machine), }, nil } @@ -393,7 +393,7 @@ func (api headscaleV1APIServer) EnableMachineRoutes( } return &v1.EnableMachineRoutesResponse{ - Routes: machine.RoutesToProto(), + Routes: api.h.RoutesToProto(machine), }, nil } diff --git a/machine.go b/machine.go index b688be65..d3d54802 100644 --- a/machine.go +++ b/machine.go @@ -13,6 +13,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -76,9 +77,8 @@ type Machine struct { LastSuccessfulUpdate *time.Time Expiry *time.Time - HostInfo HostInfo - Endpoints StringList - EnabledRoutes IPPrefixes + HostInfo HostInfo + Endpoints StringList CreatedAt time.Time UpdatedAt time.Time @@ -595,14 +595,15 @@ func (machines MachinesP) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (machines Machines) toNodes( +func (h *Headscale) toNodes( + machines Machines, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(machines)) for index, machine := range machines { - node, err := machine.toNode(baseDomain, dnsConfig) + node, err := h.toNode(machine, baseDomain, dnsConfig) if err != nil { return nil, err } @@ -615,7 +616,8 @@ func (machines Machines) toNodes( // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS. -func (machine Machine) toNode( +func (h *Headscale) toNode( + machine Machine, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) (*tailcfg.Node, error) { @@ -663,24 +665,19 @@ func (machine Machine) toNode( []netip.Prefix{}, addrs...) // we append the node own IP, as it is required by the clients - allowedIPs = append(allowedIPs, machine.EnabledRoutes...) - - // TODO(kradalby): This is kind of a hack where we say that - // all the announced routes (except exit), is presented as primary - // routes. This might be problematic if two nodes expose the same route. - // This was added to address an issue where subnet routers stopped working - // when we only populated AllowedIPs. - primaryRoutes := []netip.Prefix{} - if len(machine.EnabledRoutes) > 0 { - for _, route := range machine.EnabledRoutes { - if route == ExitRouteV4 || route == ExitRouteV6 { - continue - } - - primaryRoutes = append(primaryRoutes, route) - } + enabledRoutes, err := h.GetEnabledRoutes(&machine) + if err != nil { + return nil, err } + allowedIPs = append(allowedIPs, enabledRoutes...) + + primaryRoutes, err := h.getMachinePrimaryRoutes(&machine) + if err != nil { + return nil, err + } + primaryPrefixes := Routes(primaryRoutes).toPrefixes() + var derp string if machine.HostInfo.NetInfo != nil { derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP) @@ -733,7 +730,7 @@ func (machine Machine) toNode( DiscoKey: discoKey, Addresses: addrs, AllowedIPs: allowedIPs, - PrimaryRoutes: primaryRoutes, + PrimaryRoutes: primaryPrefixes, Endpoints: machine.Endpoints, DERP: derp, @@ -923,21 +920,66 @@ func (h *Headscale) RegisterMachine(machine Machine, return &machine, nil } -func (machine *Machine) GetAdvertisedRoutes() []netip.Prefix { - return machine.HostInfo.RoutableIPs +// GetAdvertisedRoutes returns the routes that are be advertised by the given machine. +func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { + routes := []Route{} + + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get advertised routes for machine") + return nil, err + } + + prefixes := []netip.Prefix{} + for _, route := range routes { + prefixes = append(prefixes, netip.Prefix(route.Prefix)) + } + + return prefixes, nil } -func (machine *Machine) GetEnabledRoutes() []netip.Prefix { - return machine.EnabledRoutes +// GetEnabledRoutes returns the routes that are enabled for the machine. +func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { + routes := []Route{} + + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true). + Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get enabled routes for machine") + return nil, err + } + + prefixes := []netip.Prefix{} + for _, route := range routes { + prefixes = append(prefixes, netip.Prefix(route.Prefix)) + } + + return prefixes, nil } -func (machine *Machine) IsRoutesEnabled(routeStr string) bool { +func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes := machine.GetEnabledRoutes() + enabledRoutes, err := h.GetEnabledRoutes(machine) + if err != nil { + log.Error().Err(err).Msg("Could not get enabled routes") + return false + } for _, enabledRoute := range enabledRoutes { if route == enabledRoute { @@ -948,8 +990,7 @@ func (machine *Machine) IsRoutesEnabled(routeStr string) bool { return false } -// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the -// previous list of routes. +// EnableRoutes enables new routes based on a list of new routes. func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { @@ -961,8 +1002,13 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { newRoutes[index] = route } + advertisedRoutes, err := h.GetAdvertisedRoutes(machine) + if err != nil { + return err + } + for _, newRoute := range newRoutes { - if !contains(machine.GetAdvertisedRoutes(), newRoute) { + if !contains(advertisedRoutes, newRoute) { return fmt.Errorf( "route (%s) is not available on node %s: %w", machine.Hostname, @@ -971,52 +1017,70 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { } } - machine.EnabledRoutes = newRoutes - - if err := h.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed enable routes for machine in the database: %w", err) + // Separate loop so we don't leave things in a half-updated state + for _, prefix := range newRoutes { + route := Route{} + err := h.db.Preload("Machine"). + Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). + First(&route).Error + if err == nil { + route.Enabled = true + err = h.db.Save(&route).Error + if err != nil { + return fmt.Errorf("failed to enable route: %w", err) + } + } else { + return fmt.Errorf("failed to find route: %w", err) + } } return nil } -// Enabled any routes advertised by a machine that match the ACL autoApprovers policy. -func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) { +// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. +func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { if len(machine.IPAddresses) == 0 { - return // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs + return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } - approvedRoutes := make([]netip.Prefix, 0, len(machine.HostInfo.RoutableIPs)) - thisMachine := []Machine{*machine} + routes := []Route{} + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID).Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get advertised routes for machine") - for _, advertisedRoute := range machine.HostInfo.RoutableIPs { - if contains(machine.EnabledRoutes, advertisedRoute) { - continue // Skip routes that are already enabled for the node - } + return err + } - routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers( - advertisedRoute, - ) + approvedRoutes := []Route{} + + for _, advertisedRoute := range routes { + routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers(netip.Prefix(advertisedRoute.Prefix)) if err != nil { log.Err(err). Str("advertisedRoute", advertisedRoute.String()). Uint64("machineId", machine.ID). Msg("Failed to resolve autoApprovers for advertised route") - return + return err } for _, approvedAlias := range routeApprovers { if approvedAlias == machine.Namespace.Name { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - approvedIps, err := expandAlias(thisMachine, *h.aclPolicy, approvedAlias, h.cfg.OIDC.StripEmaildomain) + approvedIps, err := expandAlias([]Machine{*machine}, *h.aclPolicy, approvedAlias, h.cfg.OIDC.StripEmaildomain) if err != nil { log.Err(err). Str("alias", approvedAlias). Msg("Failed to expand alias when processing autoApprovers policy") - return + return err } // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first @@ -1028,20 +1092,33 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) { } for _, approvedRoute := range approvedRoutes { - if !contains(machine.EnabledRoutes, approvedRoute) { - log.Info(). - Str("route", approvedRoute.String()). - Uint64("client", machine.ID). - Msg("Enabling autoApproved route for client") - machine.EnabledRoutes = append(machine.EnabledRoutes, approvedRoute) + approvedRoute.Enabled = true + err = h.db.Save(&approvedRoute).Error + if err != nil { + log.Err(err). + Str("approvedRoute", approvedRoute.String()). + Uint64("machineId", machine.ID). + Msg("Failed to enable approved route") + + return err } } + + return nil } -func (machine *Machine) RoutesToProto() *v1.Routes { - availableRoutes := machine.GetAdvertisedRoutes() +func (h *Headscale) RoutesToProto(machine *Machine) *v1.Routes { + availableRoutes, err := h.GetAdvertisedRoutes(machine) + if err != nil { + log.Error().Err(err).Msg("Could not get advertised routes") + return nil + } - enabledRoutes := machine.GetEnabledRoutes() + enabledRoutes, err := h.GetEnabledRoutes(machine) + if err != nil { + log.Error().Err(err).Msg("Could not get enabled routes") + return nil + } return &v1.Routes{ AdvertisedRoutes: ipPrefixToString(availableRoutes),