mapper: produce map before poll (#2628)

This commit is contained in:
Kristoffer Dalby
2025-07-28 11:15:53 +02:00
committed by GitHub
parent b2a18830ed
commit a058bf3cd3
70 changed files with 5771 additions and 2475 deletions

155
hscontrol/mapper/batcher.go Normal file
View File

@@ -0,0 +1,155 @@
package mapper
import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
// Batcher defines the common interface for all batcher implementations.
type Batcher interface {
Start()
Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
}
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
return &LockFreeBatcher{
mapper: mapper,
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
}
}
// NewBatcherAndMapper creates a Batcher implementation.
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
// nodeConnection interface for different connection implementations.
type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
if c.Empty() {
return nil, nil
}
// Validate inputs before processing
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
var mapResp *tailcfg.MapResponse
var err error
switch c.Change {
case change.DERP:
mapResp, err = mapper.derpMapResponse(nodeID)
case change.NodeCameOnline, change.NodeWentOffline:
if c.IsSubnetRouter {
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
Online: ptr.To(c.Change == change.NodeCameOnline),
},
})
}
case change.NodeNewOrUpdate:
mapResp, err = mapper.fullMapResponse(nodeID, version)
case change.NodeRemove:
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
default:
// The following will always hit this:
// change.Full, change.Policy
mapResp, err = mapper.fullMapResponse(nodeID, version)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
// TODO(kradalby): Is this necessary?
// Validate the generated map response - only check for nil response
// Note: mapResp.Node can be nil for peer updates, which is valid
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
if nc == nil {
return fmt.Errorf("nodeConnection is nil")
}
nodeID := nc.nodeID()
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some change types
return nil
}
// Send the map response
if err := nc.send(data); err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
return nil
}
// workResult represents the result of processing a change.
type workResult struct {
mapResponse *tailcfg.MapResponse
err error
}
// work represents a unit of work to be processed by workers.
type work struct {
c change.ChangeSet
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}

View File

@@ -0,0 +1,491 @@
package mapper
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
tick *time.Ticker
mapper *mapper
workers int
// Lock-free concurrent maps
nodes *xsync.Map[types.NodeID, *nodeConn]
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
ctx context.Context
cancel context.CancelFunc
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
batchMutex sync.RWMutex
// Metrics
totalNodes atomic.Int64
totalUpdates atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
}
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
// First validate that we can generate initial map before doing anything else
fullSelfChange := change.FullSelf(id)
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
// This currently means that the goroutine for the node connection will do the processing
// which means that we might have uncontrolled concurrency.
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
// it to be processed in a more controlled manner.
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
if err != nil {
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
// Only after validation succeeds, create or update node connection
newConn := newNodeConn(id, c, version, b.mapper)
var conn *nodeConn
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
// Update existing connection
existing.updateConnection(c, version)
conn = existing
} else {
b.totalNodes.Add(1)
conn = newConn
}
// Mark as connected only after validation succeeds
b.connected.Store(id, nil) // nil = connected
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
// Send the validated initial map
if initialMap != nil {
if err := conn.send(initialMap); err != nil {
// Clean up the connection state on send failure
b.nodes.Delete(id)
b.connected.Delete(id)
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
}
// Notify other nodes that this node came online
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
}
return nil
}
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches the current one, closes the connection,
// and notifies other nodes that this node has gone offline.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
// Check if this is the current connection and mark it as closed
if existing, ok := b.nodes.Load(id); ok {
if !existing.matchesChannel(c) {
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
return // Not the current connection, not an error
}
// Mark the connection as closed to prevent further sends
if connData := existing.connData.Load(); connData != nil {
connData.closed.Store(true)
}
}
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
// Remove node and mark disconnected atomically
b.nodes.Delete(id)
b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1)
// Notify other nodes that this node went offline
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
}
// AddWork queues a change to be processed by the batcher.
// Critical changes are processed immediately, while others are batched for efficiency.
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
b.addWork(c)
}
func (b *LockFreeBatcher) Start() {
b.ctx, b.cancel = context.WithCancel(context.Background())
go b.doWork()
}
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
}
close(b.workCh)
}
func (b *LockFreeBatcher) doWork() {
log.Debug().Msg("batcher doWork loop started")
defer log.Debug().Msg("batcher doWork loop stopped")
for i := range b.workers {
go b.worker(i + 1)
}
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-b.ctx.Done():
return
}
}
}
func (b *LockFreeBatcher) worker(workerID int) {
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
for {
select {
case w, ok := <-b.workCh:
if !ok {
return
}
startTime := time.Now()
b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request
// where there is a blocking function waiting for the map that
// is being generated.
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to generate map response for synchronous work")
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
// Send result
select {
case w.resultCh <- result:
case <-b.ctx.Done():
return
}
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow synchronous work processing")
}
continue
}
// If resultCh is nil, this is an asynchronous work request
// that should be processed and sent to the node instead of
// returned to the caller.
if nc, exists := b.nodes.Load(w.nodeID); exists {
// Check if this connection is still active before processing
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("skipping work for closed connection")
continue
}
err := nc.change(w.c)
if err != nil {
b.workErrors.Add(1)
log.Error().Err(err).
Int("workerID", workerID).
Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to apply change")
}
} else {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("node not found for asynchronous work - node may have disconnected")
}
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow asynchronous work processing")
}
case <-b.ctx.Done():
return
}
}
}
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
// For critical changes that need immediate processing, send directly
if b.shouldProcessImmediately(c) {
if c.SelfUpdateOnly {
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
return true
})
return
}
// For non-critical changes, add to batch
b.addToBatch(c)
}
// queueWork safely queues work
func (b *LockFreeBatcher) queueWork(w work) {
b.workQueuedCount.Add(1)
select {
case b.workCh <- w:
// Successfully queued
case <-b.ctx.Done():
// Batcher is shutting down
return
}
}
// shouldProcessImmediately determines if a change should bypass batching
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
// Process these changes immediately to avoid delaying critical functionality
switch c.Change {
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
return true
default:
return false
}
}
// addToBatch adds a change to the pending batch
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if c.SelfUpdateOnly {
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(c.NodeID, changes)
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(nodeID, changes)
return true
})
}
// processBatchedChanges processes all pending batched changes
func (b *LockFreeBatcher) processBatchedChanges() {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if b.pendingChanges == nil {
return
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
if len(changes) == 0 {
return true
}
// Send all batched changes for this node
for _, c := range changes {
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
}
// Clear the pending changes for this node
b.pendingChanges.Delete(nodeID)
return true
})
}
// IsConnected is lock-free read.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
if val, ok := b.connected.Load(id); ok {
// nil means connected
return val == nil
}
return false
}
// ConnectedMap returns a lock-free map of all connected nodes.
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]()
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// nil means connected
ret.Store(id, val == nil)
return true
})
return ret
}
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.ctx.Done():
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
}
}
// connectionData holds the channel and connection parameters.
type connectionData struct {
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
closed atomic.Bool // Track if this connection has been closed
}
// nodeConn described the node connection and its associated data.
type nodeConn struct {
id types.NodeID
mapper *mapper
// Atomic pointer to connection data - allows lock-free updates
connData atomic.Pointer[connectionData]
updateCount atomic.Int64
}
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
nc := &nodeConn{
id: id,
mapper: mapper,
}
// Initialize connection data
data := &connectionData{
c: c,
version: version,
}
nc.connData.Store(data)
return nc
}
// updateConnection atomically updates connection parameters.
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
newData := &connectionData{
c: c,
version: version,
}
nc.connData.Store(newData)
}
// matchesChannel checks if the given channel matches current connection.
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
data := nc.connData.Load()
if data == nil {
return false
}
// Compare channel pointers directly
return data.c == c
}
// compressAndVersion atomically reads connection settings.
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
data := nc.connData.Load()
if data == nil {
return 0
}
return data.version
}
func (nc *nodeConn) nodeID() types.NodeID {
return nc.id
}
func (nc *nodeConn) change(c change.ChangeSet) error {
return handleNodeChange(nc, nc.mapper, c)
}
// send sends data to the node's channel.
// The node will pick it up and send it to the HTTP handler.
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
connData := nc.connData.Load()
if connData == nil {
return fmt.Errorf("node %d: no connection data", nc.id)
}
// Check if connection has been closed
if connData.closed.Load() {
return fmt.Errorf("node %d: connection closed", nc.id)
}
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
// the channel. That might mean that we are sending to a node that has gone offline, but
// the channel is still open.
connData.c <- data
nc.updateCount.Add(1)
return nil
}

File diff suppressed because it is too large Load Diff

259
hscontrol/mapper/builder.go Normal file
View File

@@ -0,0 +1,259 @@
package mapper
import (
"net/netip"
"sort"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
"tailscale.com/util/multierr"
)
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
type MapResponseBuilder struct {
resp *tailcfg.MapResponse
mapper *mapper
nodeID types.NodeID
capVer tailcfg.CapabilityVersion
errs []error
}
// NewMapResponseBuilder creates a new builder with basic fields set
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
resp: &tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
},
mapper: m,
nodeID: nodeID,
errs: nil,
}
}
// addError adds an error to the builder's error list
func (b *MapResponseBuilder) addError(err error) {
if err != nil {
b.errs = append(b.errs, err)
}
}
// hasErrors returns true if the builder has accumulated any errors
func (b *MapResponseBuilder) hasErrors() bool {
return len(b.errs) > 0
}
// WithCapabilityVersion sets the capability version for the response
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
b.capVer = capVer
return b
}
// WithSelfNode adds the requesting node to the response
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
_, matchers := b.mapper.state.Filter()
tailnode, err := tailNode(
node.View(), b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
b.addError(err)
return b
}
b.resp.Node = tailnode
return b
}
// WithDERPMap adds the DERP map to the response
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
b.resp.DERPMap = b.mapper.state.DERPMap()
return b
}
// WithDomain adds the domain configuration
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
b.resp.Domain = b.mapper.cfg.Domain()
return b
}
// WithCollectServicesDisabled sets the collect services flag to false
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
b.resp.CollectServices.Set(false)
return b
}
// WithDebugConfig adds debug configuration
// It disables log tailing if the mapper's LogTail is not enabled
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
}
return b
}
// WithSSHPolicy adds SSH policy configuration for the requesting node
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
if err != nil {
b.addError(err)
return b
}
b.resp.SSHPolicy = sshPolicy
return b
}
// WithDNSConfig adds DNS configuration for the requesting node
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
return b
}
// WithUserProfiles adds user profiles for the requesting node and given peers
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
b.resp.UserProfiles = generateUserProfiles(node, peers)
return b
}
// WithPacketFilters adds packet filter rules based on policy
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
filter, _ := b.mapper.state.Filter()
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
// Currently, we do not send incremental package filters, however using the
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node.View(), filter),
}
return b
}
// WithPeers adds full peer list with policy filtering (for full map response)
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
return b
}
b.resp.Peers = tailPeers
return b
}
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
return b
}
b.resp.PeersChanged = tailPeers
return b
}
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
return nil, err
}
filter, matchers := b.mapper.state.Filter()
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView]
if len(filter) > 0 {
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
} else {
changedViews = peers.ViewSlice()
}
tailPeers, err := tailNodes(
changedViews, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
return tailPeers, nil
}
// WithPeerChangedPatch adds peer change patches
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
b.resp.PeersChangedPatch = changes
return b
}
// WithPeersRemoved adds removed peer IDs
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)
}
if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.nodeID)
}
return b.resp, nil
}

View File

@@ -0,0 +1,347 @@
package mapper
import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
func TestMapResponseBuilder_Basic(t *testing.T) {
cfg := &types.Config{
BaseDomain: "example.com",
LogTail: types.LogTailConfig{
Enabled: true,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID)
// Test basic builder creation
assert.NotNil(t, builder)
assert.Equal(t, nodeID, builder.nodeID)
assert.NotNil(t, builder.resp)
assert.False(t, builder.resp.KeepAlive)
assert.NotNil(t, builder.resp.ControlTime)
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
}
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(42)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer)
assert.Equal(t, capVer, builder.capVer)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithDomain(t *testing.T) {
domain := "test.example.com"
cfg := &types.Config{
ServerURL: "https://test.example.com",
BaseDomain: domain,
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDomain()
assert.Equal(t, domain, builder.resp.Domain)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithCollectServicesDisabled()
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
tests := []struct {
name string
logTailEnabled bool
expected bool
}{
{
name: "LogTail enabled",
logTailEnabled: true,
expected: false, // DisableLogTail should be false when LogTail is enabled
},
{
name: "LogTail disabled",
logTailEnabled: false,
expected: true, // DisableLogTail should be true when LogTail is disabled
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{
LogTail: types.LogTailConfig{
Enabled: tt.logTailEnabled,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDebugConfig()
require.NotNil(t, builder.resp.Debug)
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
})
}
}
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
changes := []*tailcfg.PeerChange{
{
NodeID: 123,
DERPRegion: 1,
},
{
NodeID: 456,
DERPRegion: 2,
},
}
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changes)
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(123)
removedID2 := types.NodeID(456)
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1, removedID2)
expected := []tailcfg.NodeID{
removedID1.NodeID(),
removedID2.NodeID(),
}
assert.Equal(t, expected, builder.resp.PeersRemoved)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Simulate an error in the builder
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
// All subsequent calls should continue to work and accumulate errors
result := builder.
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 1)
assert.Equal(t, assert.AnError, result.errs[0])
// Build should return the error
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
}
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
domain := "chained.example.com"
cfg := &types.Config{
ServerURL: "https://chained.example.com",
BaseDomain: domain,
LogTail: types.LogTailConfig{
Enabled: false,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(99)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
// Verify all fields are set correctly
assert.Equal(t, capVer, builder.capVer)
assert.Equal(t, domain, builder.resp.Domain)
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
assert.NotNil(t, builder.resp.Debug)
assert.True(t, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(100)
removedID2 := types.NodeID(200)
// Test calling WithPeersRemoved multiple times
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1).
WithPeersRemoved(removedID2)
// Second call should overwrite the first
expected := []tailcfg.NodeID{removedID2.NodeID()}
assert.Equal(t, expected, builder.resp.PeersRemoved)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch([]*tailcfg.PeerChange{})
assert.Empty(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(nil)
assert.Nil(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Create a builder and add multiple errors
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
builder.addError(assert.AnError)
builder.addError(nil) // This should be ignored
// All subsequent calls should continue to work
result := builder.
WithDomain().
WithCollectServicesDisabled()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 2) // nil error should be ignored
// Build should return a multierr
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
// The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors")
}

View File

@@ -1,7 +1,6 @@
package mapper
import (
"encoding/binary"
"encoding/json"
"fmt"
"io/fs"
@@ -10,31 +9,21 @@ import (
"os"
"path"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/envknob"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/views"
)
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
reservedResponseHeaderSize = 4
mapperIDLength = 8
debugMapResponsePerm = 0o755
nextDNSDoHPrefix = "https://dns.nextdns.io"
mapperIDLength = 8
debugMapResponsePerm = 0o755
)
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
@@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
// - Create a "minifier" that removes info not needed for the node
// - some sort of batching, wait for 5 or 60 seconds before sending
type Mapper struct {
type mapper struct {
// Configuration
state *state.State
cfg *types.Config
notif *notifier.Notifier
state *state.State
cfg *types.Config
batcher Batcher
uid string
created time.Time
seq uint64
}
type patch struct {
@@ -66,41 +53,31 @@ type patch struct {
change *tailcfg.PeerChange
}
func NewMapper(
state *state.State,
func newMapper(
cfg *types.Config,
notif *notifier.Notifier,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
state *state.State,
) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{
return &mapper{
state: state,
cfg: cfg,
notif: notif,
uid: uid,
created: time.Now(),
seq: 0,
}
}
func (m *Mapper) String() string {
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
}
func generateUserProfiles(
node types.NodeView,
peers views.Slice[types.NodeView],
node *types.Node,
peers types.Nodes,
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User)
ids := make([]uint, 0, peers.Len()+1)
user := node.User()
userMap[user.ID] = &user
ids = append(ids, user.ID)
for _, peer := range peers.All() {
peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
ids := make([]uint, 0, len(userMap))
userMap[node.User.ID] = &node.User
ids = append(ids, node.User.ID)
for _, peer := range peers {
userMap[peer.User.ID] = &peer.User
ids = append(ids, peer.User.ID)
}
slices.Sort(ids)
@@ -117,7 +94,7 @@ func generateUserProfiles(
func generateDNSConfig(
cfg *types.Config,
node types.NodeView,
node *types.Node,
) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil {
return nil
@@ -137,17 +114,16 @@ func generateDNSConfig(
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{node.Hostname()},
"device_model": []string{node.Hostinfo().OS()},
"device_name": []string{node.Hostname},
"device_model": []string{node.Hostinfo.OS},
}
nodeIPs := node.IPs()
if len(nodeIPs) > 0 {
attrs.Add("device_ip", nodeIPs[0].String())
if len(node.IPs()) > 0 {
attrs.Add("device_ip", node.IPs()[0].String())
}
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
@@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
}
}
// fullMapResponse creates a complete MapResponse for a node.
// It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse(
node types.NodeView,
peers views.Slice[types.NodeView],
// fullMapResponse returns a MapResponse for the given node.
func (m *mapper) fullMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer)
peers, err := m.listPeers(nodeID)
if err != nil {
return nil, err
}
err = appendPeerChanges(
resp,
true, // full change
m.state,
node,
capVer,
peers,
m.cfg,
)
if err != nil {
return nil, err
}
return resp, nil
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithDERPMap().
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig().
WithSSHPolicy().
WithDNSConfig().
WithUserProfiles(peers).
WithPacketFilters().
WithPeers(peers).
Build(messages...)
}
// FullMapResponse returns a MapResponse for the given node.
func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
messages ...string,
) ([]byte, error) {
peers, err := m.ListPeers(node.ID())
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// ReadOnlyMapResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
func (m *Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.KeepAlive = true
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
derpMap *tailcfg.DERPMap,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.DERPMap = derpMap
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange,
messages ...string,
) ([]byte, error) {
var err error
resp := m.baseMapResponse()
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
if nodeID != node.ID() {
changedIDs = append(changedIDs, nodeID)
}
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := types.Nodes{}
if len(changedIDs) > 0 {
changedNodes, err = m.ListNodes(changedIDs...)
if err != nil {
return nil, err
}
}
err = appendPeerChanges(
&resp,
false, // partial change
m.state,
node,
mapRequest.Version,
changedNodes.ViewSlice(),
m.cfg,
)
if err != nil {
return nil, err
}
resp.PeersRemoved = removedIDs
// Sending patches as a part of a PeersChanged response
// is technically not suppose to be done, but they are
// applied after the PeersChanged. The patch list
// should _only_ contain Nodes that are not in the
// PeersChanged or PeersRemoved list and the caller
// should filter them out.
//
// From tailcfg docs:
// These are applied after Peers* above, but in practice the
// control server should only send these on their own, without
// the Peers* fields also set.
if patches != nil {
resp.PeersChangedPatch = patches
}
_, matchers := m.state.Filter()
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(
node, mapRequest.Version, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil {
return nil, err
}
resp.Node = tailnode
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
func (m *mapper) derpMapResponse(
nodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDERPMap().
Build()
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
func (m *mapper) peerChangedPatchResponse(
nodeID types.NodeID,
changed []*tailcfg.PeerChange,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) marshalMapResponse(
mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse,
node types.NodeView,
compression string,
messages ...string,
) ([]byte, error) {
atomic.AddUint64(&m.seq, 1)
jsonBody, err := json.Marshal(resp)
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
}
if debugDumpMapResponsePath != "" {
data := map[string]any{
"Messages": messages,
"MapRequest": mapRequest,
"MapResponse": resp,
}
responseType := "keepalive"
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
}
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
}
now := time.Now().Format("2006-01-02T15-04-05.999999999")
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
}
}
var respBody []byte
if compression == util.ZstdCompression {
respBody = zstdEncode(jsonBody)
} else {
respBody = jsonBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
panic("invalid type in sync pool")
}
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(
nil,
zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}
// baseMapResponse returns a tailcfg.MapResponse with
// KeepAlive false and ControlTime set to now.
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
// TODO(kradalby): Implement PingRequest?
}
return resp
}
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
// with the basic configuration from headscale set.
// It is used in for bigger updates, such as full and lite, not
// incremental.
func (m *Mapper) baseWithConfigMapResponse(
node types.NodeView,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
return m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changed).
Build()
}
_, matchers := m.state.Filter()
tailnode, err := tailNode(
node, capVer, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
// peerChangeResponse returns a MapResponse with changed or added nodes.
func (m *mapper) peerChangeResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID, changedNodeID)
if err != nil {
return nil, err
}
resp.Node = tailnode
resp.DERPMap = m.state.DERPMap()
resp.Domain = m.cfg.Domain()
// Do not instruct clients to collect services we do not
// support or do anything with them
resp.CollectServices = "false"
resp.KeepAlive = false
resp.Debug = &tailcfg.Debug{
DisableLogTail: !m.cfg.LogTail.Enabled,
}
return &resp, nil
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithUserProfiles(peers).
WithPeerChanges(peers).
Build()
}
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
func (m *mapper) peerRemovedResponse(
nodeID types.NodeID,
removedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedNodeID).
Build()
}
func writeDebugMapResponse(
resp *tailcfg.MapResponse,
nodeID types.NodeID,
messages ...string,
) {
data := map[string]any{
"Messages": messages,
"MapResponse": resp,
}
responseType := "keepalive"
switch {
case len(resp.Peers) > 0:
responseType = "full"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case len(resp.PeersChanged) > 0:
responseType = "changed"
case len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil {
panic(err)
}
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
}
now := time.Now().Format("2006-01-02T15-04-05.999999999")
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s-%s.json", now, responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
}
}
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.state.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, peer := range peers {
online := m.notif.IsLikelyConnected(peer.ID)
online := m.batcher.IsConnected(peer.ID)
peer.IsOnline = &online
}
return peers, nil
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {
return nil, err
}
for _, node := range nodes {
online := m.notif.IsLikelyConnected(node.ID)
node.IsOnline = &online
}
return nodes, nil
}
// routeFilterFunc is a function that takes a node ID and returns a list of
// netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node.
type routeFilterFunc func(id types.NodeID) []netip.Prefix
// appendPeerChanges mutates a tailcfg.MapResponse with all the
// necessary changes when peers have changed.
func appendPeerChanges(
resp *tailcfg.MapResponse,
fullChange bool,
state *state.State,
node types.NodeView,
capVer tailcfg.CapabilityVersion,
changed views.Slice[types.NodeView],
cfg *types.Config,
) error {
filter, matchers := state.Filter()
sshPolicy, err := state.SSHPolicy(node)
if err != nil {
return err
}
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var reducedChanged views.Slice[types.NodeView]
if len(filter) > 0 {
reducedChanged = policy.ReduceNodes(node, changed, matchers)
} else {
reducedChanged = changed
}
profiles := generateUserProfiles(node, reducedChanged)
dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(
reducedChanged, capVer, state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
},
cfg)
if err != nil {
return err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
if fullChange {
resp.Peers = tailPeers
} else {
resp.PeersChanged = tailPeers
}
resp.DNSConfig = dnsConfig
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
// Currently, we do not send incremental package filters, however using the
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, filter),
}
return nil
}

View File

@@ -3,6 +3,7 @@ package mapper
import (
"fmt"
"net/netip"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
@@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
&types.Config{
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1.View(),
nodeInShared1,
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
@@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
for _, id := range peerIDs {
if peer.ID == id {
filtered = append(filtered, peer)
break
}
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
}
}
@@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
for _, id := range nodeIDs {
if node.ID == id {
filtered = append(filtered, node)
break
}
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)
}
}

47
hscontrol/mapper/utils.go Normal file
View File

@@ -0,0 +1,47 @@
package mapper
import "tailscale.com/tailcfg"
// mergePatch takes the current patch and a newer patch
// and override any field that has changed.
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
}