mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-09 21:49:39 -05:00
mapper: produce map before poll (#2628)
This commit is contained in:
155
hscontrol/mapper/batcher.go
Normal file
155
hscontrol/mapper/batcher.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
|
||||
// Batcher defines the common interface for all batcher implementations.
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
}
|
||||
|
||||
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
||||
return &LockFreeBatcher{
|
||||
mapper: mapper,
|
||||
workers: workers,
|
||||
tick: time.NewTicker(batchTime),
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBatcherAndMapper creates a Batcher implementation.
|
||||
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
return b
|
||||
}
|
||||
|
||||
// nodeConnection interface for different connection implementations.
|
||||
type nodeConnection interface {
|
||||
nodeID() types.NodeID
|
||||
version() tailcfg.CapabilityVersion
|
||||
send(data *tailcfg.MapResponse) error
|
||||
}
|
||||
|
||||
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
|
||||
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
if c.Empty() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Validate inputs before processing
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
mapResp, err = mapper.derpMapResponse(nodeID)
|
||||
|
||||
case change.NodeCameOnline, change.NodeWentOffline:
|
||||
if c.IsSubnetRouter {
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case change.NodeNewOrUpdate:
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
|
||||
case change.NodeRemove:
|
||||
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
|
||||
|
||||
default:
|
||||
// The following will always hit this:
|
||||
// change.Full, change.Policy
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this necessary?
|
||||
// Validate the generated map response - only check for nil response
|
||||
// Note: mapResp.Node can be nil for peer updates, which is valid
|
||||
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
|
||||
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
|
||||
}
|
||||
|
||||
return mapResp, nil
|
||||
}
|
||||
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
// No data to send is valid for some change types
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// workResult represents the result of processing a change.
|
||||
type workResult struct {
|
||||
mapResponse *tailcfg.MapResponse
|
||||
err error
|
||||
}
|
||||
|
||||
// work represents a unit of work to be processed by workers.
|
||||
type work struct {
|
||||
c change.ChangeSet
|
||||
nodeID types.NodeID
|
||||
resultCh chan<- workResult // optional channel for synchronous operations
|
||||
}
|
||||
491
hscontrol/mapper/batcher_lockfree.go
Normal file
491
hscontrol/mapper/batcher_lockfree.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
totalUpdates atomic.Int64
|
||||
workQueuedCount atomic.Int64
|
||||
workProcessed atomic.Int64
|
||||
workErrors atomic.Int64
|
||||
}
|
||||
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
|
||||
var conn *nodeConn
|
||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||
// Update existing connection
|
||||
existing.updateConnection(c, version)
|
||||
conn = existing
|
||||
} else {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
if err := conn.send(initialMap); err != nil {
|
||||
// Clean up the connection state on send failure
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-b.workCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
// where there is a blocking function waiting for the map that
|
||||
// is being generated.
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Msg("node not found for synchronous work")
|
||||
}
|
||||
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If resultCh is nil, this is an asynchronous work request
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Check if this connection is still active before processing
|
||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
continue
|
||||
}
|
||||
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.c.NodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow asynchronous work processing")
|
||||
}
|
||||
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
// For critical changes that need immediate processing, send directly
|
||||
if b.shouldProcessImmediately(c) {
|
||||
if c.SelfUpdateOnly {
|
||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||
return
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For non-critical changes, add to batch
|
||||
b.addToBatch(c)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.ctx.Done():
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(c.NodeID, changes)
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process all pending changes
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
|
||||
if len(changes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Send all batched changes for this node
|
||||
for _, c := range changes {
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// nil means connected
|
||||
ret.Store(id, val == nil)
|
||||
return true
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||
// This allows synchronous map generation using the same worker pool.
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
resultCh := make(chan workResult, 1)
|
||||
|
||||
// Queue the work with a result channel using the safe queueing method
|
||||
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
|
||||
|
||||
// Wait for the result
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.ctx.Done():
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
1977
hscontrol/mapper/batcher_test.go
Normal file
1977
hscontrol/mapper/batcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
259
hscontrol/mapper/builder.go
Normal file
259
hscontrol/mapper/builder.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
},
|
||||
mapper: m,
|
||||
nodeID: nodeID,
|
||||
errs: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
filter, _ := b.mapper.state.Filter()
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.nodeID)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
}
|
||||
347
hscontrol/mapper/builder_test.go
Normal file
347
hscontrol/mapper/builder_test.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
BaseDomain: "example.com",
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
|
||||
// Test basic builder creation
|
||||
assert.NotNil(t, builder)
|
||||
assert.Equal(t, nodeID, builder.nodeID)
|
||||
assert.NotNil(t, builder.resp)
|
||||
assert.False(t, builder.resp.KeepAlive)
|
||||
assert.NotNil(t, builder.resp.ControlTime)
|
||||
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(42)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer)
|
||||
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||
domain := "test.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://test.example.com",
|
||||
BaseDomain: domain,
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDomain()
|
||||
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logTailEnabled bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "LogTail enabled",
|
||||
logTailEnabled: true,
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
},
|
||||
{
|
||||
name: "LogTail disabled",
|
||||
logTailEnabled: false,
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: tt.logTailEnabled,
|
||||
},
|
||||
}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugConfig()
|
||||
|
||||
require.NotNil(t, builder.resp.Debug)
|
||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
changes := []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 123,
|
||||
DERPRegion: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 456,
|
||||
DERPRegion: 2,
|
||||
},
|
||||
}
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changes)
|
||||
|
||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(123)
|
||||
removedID2 := types.NodeID(456)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1, removedID2)
|
||||
|
||||
expected := []tailcfg.NodeID{
|
||||
removedID1.NodeID(),
|
||||
removedID2.NodeID(),
|
||||
}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Simulate an error in the builder
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
|
||||
// All subsequent calls should continue to work and accumulate errors
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 1)
|
||||
assert.Equal(t, assert.AnError, result.errs[0])
|
||||
|
||||
// Build should return the error
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||
domain := "chained.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://chained.example.com",
|
||||
BaseDomain: domain,
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(99)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.NotNil(t, builder.resp.Debug)
|
||||
assert.True(t, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(100)
|
||||
removedID2 := types.NodeID(200)
|
||||
|
||||
// Test calling WithPeersRemoved multiple times
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1).
|
||||
WithPeersRemoved(removedID2)
|
||||
|
||||
// Second call should overwrite the first
|
||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||
|
||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(nil)
|
||||
|
||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Create a builder and add multiple errors
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(nil) // This should be ignored
|
||||
|
||||
// All subsequent calls should continue to work
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
@@ -10,31 +9,21 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
reservedResponseHeaderSize = 4
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||
@@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
|
||||
// - Create a "minifier" that removes info not needed for the node
|
||||
// - some sort of batching, wait for 5 or 60 seconds before sending
|
||||
|
||||
type Mapper struct {
|
||||
type mapper struct {
|
||||
// Configuration
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
notif *notifier.Notifier
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
batcher Batcher
|
||||
|
||||
uid string
|
||||
created time.Time
|
||||
seq uint64
|
||||
}
|
||||
|
||||
type patch struct {
|
||||
@@ -66,41 +53,31 @@ type patch struct {
|
||||
change *tailcfg.PeerChange
|
||||
}
|
||||
|
||||
func NewMapper(
|
||||
state *state.State,
|
||||
func newMapper(
|
||||
cfg *types.Config,
|
||||
notif *notifier.Notifier,
|
||||
) *Mapper {
|
||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &Mapper{
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
notif: notif,
|
||||
|
||||
uid: uid,
|
||||
created: time.Now(),
|
||||
seq: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mapper) String() string {
|
||||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, peers.Len()+1)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@@ -117,7 +94,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node types.NodeView,
|
||||
node *types.Node,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@@ -137,17 +114,16 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
}
|
||||
|
||||
nodeIPs := node.IPs()
|
||||
if len(nodeIPs) > 0 {
|
||||
attrs.Add("device_ip", nodeIPs[0].String())
|
||||
if len(node.IPs()) > 0 {
|
||||
attrs.Add("device_ip", node.IPs()[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
@@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
}
|
||||
|
||||
// fullMapResponse creates a complete MapResponse for a node.
|
||||
// It is a separate function to make testing easier.
|
||||
func (m *Mapper) fullMapResponse(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
peers, err := m.listPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
resp,
|
||||
true, // full change
|
||||
m.state,
|
||||
node,
|
||||
capVer,
|
||||
peers,
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
Build(messages...)
|
||||
}
|
||||
|
||||
// FullMapResponse returns a MapResponse for the given node.
|
||||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
peers, err := m.ListPeers(node.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
// ReadOnlyMapResponse returns a MapResponse for the given node.
|
||||
// Lite means that the peers has been omitted, this is intended
|
||||
// to be used to answer MapRequests with OmitPeers set to true.
|
||||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.KeepAlive = true
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) DERPMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.DERPMap = derpMap
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) PeerChangedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
changed map[types.NodeID]bool,
|
||||
patches []*tailcfg.PeerChange,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
var err error
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
var removedIDs []tailcfg.NodeID
|
||||
var changedIDs []types.NodeID
|
||||
for nodeID, nodeChanged := range changed {
|
||||
if nodeChanged {
|
||||
if nodeID != node.ID() {
|
||||
changedIDs = append(changedIDs, nodeID)
|
||||
}
|
||||
} else {
|
||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||
}
|
||||
}
|
||||
changedNodes := types.Nodes{}
|
||||
if len(changedIDs) > 0 {
|
||||
changedNodes, err = m.ListNodes(changedIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
&resp,
|
||||
false, // partial change
|
||||
m.state,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
changedNodes.ViewSlice(),
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.PeersRemoved = removedIDs
|
||||
|
||||
// Sending patches as a part of a PeersChanged response
|
||||
// is technically not suppose to be done, but they are
|
||||
// applied after the PeersChanged. The patch list
|
||||
// should _only_ contain Nodes that are not in the
|
||||
// PeersChanged or PeersRemoved list and the caller
|
||||
// should filter them out.
|
||||
//
|
||||
// From tailcfg docs:
|
||||
// These are applied after Peers* above, but in practice the
|
||||
// control server should only send these on their own, without
|
||||
// the Peers* fields also set.
|
||||
if patches != nil {
|
||||
resp.PeersChangedPatch = patches
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
// Add the node itself, it might have changed, and particularly
|
||||
// if there are no patches or changes, this is a self update.
|
||||
tailnode, err := tailNode(
|
||||
node, mapRequest.Version, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
||||
func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
|
||||
// PeerChangedPatchResponse creates a patch MapResponse with
|
||||
// incoming update from a state change.
|
||||
func (m *Mapper) PeerChangedPatchResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
func (m *mapper) peerChangedPatchResponse(
|
||||
nodeID types.NodeID,
|
||||
changed []*tailcfg.PeerChange,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.PeersChangedPatch = changed
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) marshalMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
resp *tailcfg.MapResponse,
|
||||
node types.NodeView,
|
||||
compression string,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
atomic.AddUint64(&m.seq, 1)
|
||||
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapRequest": mapRequest,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
panic("invalid type in sync pool")
|
||||
}
|
||||
out := encoder.EncodeAll(in, nil)
|
||||
_ = encoder.Close()
|
||||
zstdEncoderPool.Put(encoder)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var zstdEncoderPool = &sync.Pool{
|
||||
New: func() any {
|
||||
encoder, err := smallzstd.NewEncoder(
|
||||
nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoder
|
||||
},
|
||||
}
|
||||
|
||||
// baseMapResponse returns a tailcfg.MapResponse with
|
||||
// KeepAlive false and ControlTime set to now.
|
||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
// TODO(kradalby): Implement PingRequest?
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
||||
// with the basic configuration from headscale set.
|
||||
// It is used in for bigger updates, such as full and lite, not
|
||||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node, capVer, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
// peerChangeResponse returns a MapResponse with changed or added nodes.
|
||||
func (m *mapper) peerChangeResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
resp.DERPMap = m.state.DERPMap()
|
||||
|
||||
resp.Domain = m.cfg.Domain()
|
||||
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
resp.CollectServices = "false"
|
||||
|
||||
resp.KeepAlive = false
|
||||
|
||||
resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !m.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
Build()
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
|
||||
func (m *mapper) peerRemovedResponse(
|
||||
nodeID types.NodeID,
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
nodeID types.NodeID,
|
||||
messages ...string,
|
||||
) {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s.json", now, responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.notif.IsLikelyConnected(peer.ID)
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
online := m.notif.IsLikelyConnected(node.ID)
|
||||
node.IsOnline = &online
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
||||
|
||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
||||
// necessary changes when peers have changed.
|
||||
func appendPeerChanges(
|
||||
resp *tailcfg.MapResponse,
|
||||
|
||||
fullChange bool,
|
||||
state *state.State,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changed views.Slice[types.NodeView],
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
filter, matchers := state.Filter()
|
||||
|
||||
sshPolicy, err := state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var reducedChanged views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
reducedChanged = policy.ReduceNodes(node, changed, matchers)
|
||||
} else {
|
||||
reducedChanged = changed
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(node, reducedChanged)
|
||||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
reducedChanged, capVer, state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
if fullChange {
|
||||
resp.Peers = tailPeers
|
||||
} else {
|
||||
resp.PeersChanged = tailPeers
|
||||
}
|
||||
resp.DNSConfig = dnsConfig
|
||||
resp.UserProfiles = profiles
|
||||
resp.SSHPolicy = sshPolicy
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1.View(),
|
||||
nodeInShared1,
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
@@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
for _, id := range peerIDs {
|
||||
if peer.ID == id {
|
||||
filtered = append(filtered, peer)
|
||||
break
|
||||
}
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
for _, id := range nodeIDs {
|
||||
if node.ID == id {
|
||||
filtered = append(filtered, node)
|
||||
break
|
||||
}
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
47
hscontrol/mapper/utils.go
Normal file
47
hscontrol/mapper/utils.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mapper
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
// mergePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user