152 lines
4.4 KiB
Go

package policy
import (
"net/netip"
"slices"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/types/views"
)
// ReduceNodes returns the list of peers authorized to be accessed from a given node.
func ReduceNodes(
node types.NodeView,
nodes views.Slice[types.NodeView],
matchers []matcher.Match,
) views.Slice[types.NodeView] {
var result []types.NodeView
for _, peer := range nodes.All() {
if peer.ID() == node.ID() {
continue
}
if node.CanAccess(matchers, peer) || peer.CanAccess(matchers, node) {
result = append(result, peer)
}
}
return views.SliceOf(result)
}
// ReduceRoutes returns a reduced list of routes for a given node that it can access.
func ReduceRoutes(
node types.NodeView,
routes []netip.Prefix,
matchers []matcher.Match,
) []netip.Prefix {
var result []netip.Prefix
for _, route := range routes {
if node.CanAccessRoute(matchers, route) {
result = append(result, route)
}
}
return result
}
// BuildPeerMap builds a map of all peers that can be accessed by each node.
func BuildPeerMap(
nodes views.Slice[types.NodeView],
matchers []matcher.Match,
) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
// Build the map of all peers according to the matchers.
// Compared to ReduceNodes, which builds the list per node, we end up with doing
// the full work for every node (On^2), while this will reduce the list as we see
// relationships while building the map, making it O(n^2/2) in the end, but with less work per node.
for i := range nodes.Len() {
for j := i + 1; j < nodes.Len(); j++ {
if nodes.At(i).ID() == nodes.At(j).ID() {
continue
}
if nodes.At(i).CanAccess(matchers, nodes.At(j)) || nodes.At(j).CanAccess(matchers, nodes.At(i)) {
ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j))
ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i))
}
}
}
return ret
}
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
// and returns the new list of approved routes.
// The approved routes will include:
// 1. ALL previously approved routes (regardless of whether they're still advertised)
// 2. New routes from announcedRoutes that can be auto-approved by policy
// This ensures that:
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
// - New routes can be auto-approved according to policy
// - Routes can only be removed by explicit admin action (not by auto-approval).
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
if pm == nil {
return currentApproved, false
}
// Start with ALL currently approved routes - we never remove approved routes
newApproved := make([]netip.Prefix, len(currentApproved))
copy(newApproved, currentApproved)
// Then, check for new routes that can be auto-approved
for _, route := range announcedRoutes {
// Skip if already approved
if slices.Contains(newApproved, route) {
continue
}
// Check if this new route can be auto-approved by policy
canApprove := pm.NodeCanApproveRoute(nv, route)
if canApprove {
newApproved = append(newApproved, route)
}
}
// Sort and deduplicate
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
// Sort the current approved for comparison
sortedCurrent := make([]netip.Prefix, len(currentApproved))
copy(sortedCurrent, currentApproved)
tsaddr.SortPrefixes(sortedCurrent)
// Only update if the routes actually changed
if !slices.Equal(sortedCurrent, newApproved) {
// Log what changed
var added, kept []netip.Prefix
for _, route := range newApproved {
if !slices.Contains(sortedCurrent, route) {
added = append(added, route)
} else {
kept = append(kept, route)
}
}
if len(added) > 0 {
log.Debug().
Uint64("node.id", nv.ID().Uint64()).
Str("node.name", nv.Hostname()).
Strs("routes.added", util.PrefixesToString(added)).
Strs("routes.kept", util.PrefixesToString(kept)).
Int("routes.total", len(newApproved)).
Msg("Routes auto-approved by policy")
}
return newApproved, true
}
return newApproved, false
}