Split up MapResponse

This commits extends the mapper with functions for creating "delta"
MapResponses for different purposes (peer changed, peer removed, derp).

This wires up the new state management with a new StateUpdate struct
letting the poll worker know what kind of update to send to the
connected nodes.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-29 11:20:22 +01:00 committed by Kristoffer Dalby
parent 66ff1fcd40
commit 4b65cf48d0
8 changed files with 284 additions and 115 deletions

View File

@ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
} }
h.nodeNotifier.NotifyAll() h.nodeNotifier.NotifyAll(types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: *h.DERPMap,
})
} }
} }
} }
@ -721,7 +724,9 @@ func (h *Headscale) Serve() error {
Str("path", aclPath). Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
h.nodeNotifier.NotifyAll() h.nodeNotifier.NotifyAll(types.StateUpdate{
Type: types.StateFullUpdate,
})
} }
default: default:

View File

@ -13,6 +13,7 @@ import (
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags(
} }
machine.ForcedTags = newTags machine.ForcedTags = newTags
hsdb.notifier.NotifyWithIgnore(machine.MachineKey) hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to update tags for machine in the database: %w", err) return fmt.Errorf("failed to update tags for machine in the database: %w", err)
@ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
now := time.Now() now := time.Now()
machine.Expiry = &now machine.Expiry = &now
hsdb.notifier.NotifyWithIgnore(machine.MachineKey) hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to expire machine in the database: %w", err) return fmt.Errorf("failed to expire machine in the database: %w", err)
@ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
} }
machine.GivenName = newName machine.GivenName = newName
hsdb.notifier.NotifyWithIgnore(machine.MachineKey) hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err) return fmt.Errorf("failed to rename machine in the database: %w", err)
@ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
machine.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry machine.Expiry = &expiry
hsdb.notifier.NotifyWithIgnore(machine.MachineKey) hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf( return fmt.Errorf(
@ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string)
return false return false
} }
func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool {
ret := make(map[tailcfg.NodeID]bool)
for _, peer := range peers {
ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline()
}
return ret
}
func (hsdb *HSDatabase) ListOnlineMachines(
machine *types.Machine,
) (map[tailcfg.NodeID]bool, error) {
peers, err := hsdb.ListPeers(machine)
if err != nil {
return nil, err
}
return OnlineMachineMap(peers), nil
}
// enableRoutes enables new routes based on a list of new routes. // enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error {
newRoutes := make([]netip.Prefix, len(routeStrs)) newRoutes := make([]netip.Prefix, len(routeStrs))
@ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
} }
} }
hsdb.notifier.NotifyWithIgnore(machine.MachineKey) hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
return nil return nil
} }
@ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
return return
} }
expiredFound := false expired := make([]tailcfg.NodeID, 0)
for idx, machine := range machines { for idx, machine := range machines {
if machine.IsEphemeral() && machine.LastSeen != nil && if machine.IsEphemeral() && machine.LastSeen != nil &&
time.Now(). time.Now().
After(machine.LastSeen.Add(inactivityThreshhold)) { After(machine.LastSeen.Add(inactivityThreshhold)) {
expiredFound = true expired = append(expired, tailcfg.NodeID(machine.ID))
log.Info(). log.Info().
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Ephemeral client removed from database") Msg("Ephemeral client removed from database")
@ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
} }
} }
if expiredFound { if len(expired) > 0 {
hsdb.notifier.NotifyAll() hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
} }
} }
} }
@ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
return time.Unix(0, 0) return time.Unix(0, 0)
} }
expiredFound := false expired := make([]tailcfg.NodeID, 0)
for index, machine := range machines { for index, machine := range machines {
if machine.IsExpired() && if machine.IsExpired() &&
machine.Expiry.After(lastCheck) { machine.Expiry.After(lastCheck) {
expiredFound = true expired = append(expired, tailcfg.NodeID(machine.ID))
err := hsdb.ExpireMachine(&machines[index]) err := hsdb.ExpireMachine(&machines[index])
if err != nil { if err != nil {
@ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
} }
} }
if expiredFound { if len(expired) > 0 {
hsdb.notifier.NotifyAll() hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
} }
} }

View File

@ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
log.Error().Err(err).Msg("error getting routes") log.Error().Err(err).Msg("error getting routes")
} }
routesChanged := false changedMachines := make([]uint64, 0)
for pos, route := range routes { for pos, route := range routes {
if route.IsExitRoute() { if route.IsExitRoute() {
continue continue
@ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
return err return err
} }
routesChanged = true changedMachines = append(changedMachines, route.MachineID)
continue continue
} }
@ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
return err return err
} }
routesChanged = true changedMachines = append(changedMachines, route.MachineID)
} }
} }
if routesChanged { if len(changedMachines) > 0 {
hsdb.notifier.NotifyAll() hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: changedMachines,
})
} }
return nil return nil

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"sort"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -129,45 +130,35 @@ func fullMapResponse(
return nil, err return nil, err
} }
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
now := time.Now() now := time.Now()
resp := tailcfg.MapResponse{ resp := tailcfg.MapResponse{
KeepAlive: false,
Node: tailnode, Node: tailnode,
// TODO: Only send if updated
DERPMap: derpMap,
// TODO: Only send if updated
Peers: tailPeers, Peers: tailPeers,
// TODO(kradalby): Implement: DERPMap: derpMap,
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
// PeersChanged
// PeersRemoved
// PeersChangedPatch
// PeerSeenChange
// OnlineChange
// TODO: Only send if updated
DNSConfig: dnsConfig, DNSConfig: dnsConfig,
// TODO: Only send if updated
Domain: baseDomain, Domain: baseDomain,
// Do not instruct clients to collect services, we do not // Do not instruct clients to collect services we do not
// support or do anything with them // support or do anything with them
CollectServices: "false", CollectServices: "false",
// TODO: Only send if updated
PacketFilter: policy.ReduceFilterRules(machine, rules), PacketFilter: policy.ReduceFilterRules(machine, rules),
UserProfiles: profiles, UserProfiles: profiles,
// TODO: Only send if updated
SSHPolicy: sshPolicy, SSHPolicy: sshPolicy,
ControlTime: &now, ControlTime: &now,
KeepAlive: false,
OnlineChange: db.OnlineMachineMap(peers),
Debug: &tailcfg.Debug{ Debug: &tailcfg.Debug{
DisableLogTail: !logtail, DisableLogTail: !logtail,
@ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
} }
} }
// CreateMapResponse returns a MapResponse for the given machine. // FullMapResponse returns a MapResponse for the given machine.
func (m Mapper) CreateMapResponse( func (m Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *types.Machine, machine *types.Machine,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
@ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse(
} }
if m.isNoise { if m.isNoise {
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
} }
var machineKey key.MachinePublic return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
} }
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) func (m Mapper) KeepAliveResponse(
}
func (m Mapper) CreateKeepAliveResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *types.Machine, machine *types.Machine,
) ([]byte, error) { ) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{ resp := m.baseMapResponse(machine)
KeepAlive: true, resp.KeepAlive = true
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
} }
if m.isNoise { func (m Mapper) DERPMapResponse(
return m.marshalMapResponse( mapRequest tailcfg.MapRequest,
keepAliveResponse, machine *types.Machine,
key.MachinePublic{}, derpMap tailcfg.DERPMap,
mapRequest.Compress, ) ([]byte, error) {
resp := m.baseMapResponse(machine)
resp.DERPMap = &derpMap
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
machineKeys []uint64,
pol *policy.ACLPolicy,
) ([]byte, error) {
var err error
changed := make(types.Machines, len(machineKeys))
lastSeen := make(map[tailcfg.NodeID]bool)
for idx, machineKey := range machineKeys {
peer, err := m.db.GetMachineByID(machineKey)
if err != nil {
return nil, err
}
changed[idx] = *peer
// We have just seen the node, let the peers update their list.
lastSeen[tailcfg.NodeID(peer.ID)] = true
}
rules, _, err := policy.GenerateFilterAndSSHRules(
pol,
machine,
changed,
) )
if err != nil {
return nil, err
} }
// Filter out peers that have expired.
changed = lo.Filter(changed, func(item types.Machine, index int) bool {
return !item.IsExpired()
})
// If there are filter rules present, see if there are any machines that cannot
// access eachother at all and remove them from the changed.
if len(rules) > 0 {
changed = policy.FilterMachinesByACL(machine, changed, rules)
}
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
resp := m.baseMapResponse(machine)
resp.PeersChanged = tailPeers
resp.PeerSeenChange = lastSeen
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) PeerRemovedResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
removed []tailcfg.NodeID,
) ([]byte, error) {
resp := m.baseMapResponse(machine)
resp.PeersRemoved = removed
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) marshalMapResponse(
resp *tailcfg.MapResponse,
machine *types.Machine,
compression string,
) ([]byte, error) {
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil { if err != nil {
@ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse(
return nil, err return nil, err
} }
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
}
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
// If isNoise is set, then the JSON body will be returned
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
func MarshalResponse(
resp interface{},
isNoise bool,
privateKey2019 *key.MachinePrivate,
machineKey key.MachinePublic,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if !isNoise && privateKey2019 != nil {
return privateKey2019.SealTo(machineKey, jsonBody), nil
}
return jsonBody, nil
}
func (m Mapper) marshalMapResponse(
resp interface{},
machineKey key.MachinePublic,
compression string,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp) jsonBody, err := json.Marshal(resp)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse(
return data, nil return data, nil
} }
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
// If isNoise is set, then the JSON body will be returned
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
func MarshalResponse(
resp interface{},
isNoise bool,
privateKey2019 *key.MachinePrivate,
machineKey key.MachinePublic,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if !isNoise && privateKey2019 != nil {
return privateKey2019.SealTo(machineKey, jsonBody), nil
}
return jsonBody, nil
}
func zstdEncode(in []byte) []byte { func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok { if !ok {
@ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{
return encoder return encoder
}, },
} }
func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse {
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
}
online, err := m.db.ListOnlineMachines(machine)
if err == nil {
resp.OnlineChange = online
}
return resp
}

View File

@ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{}, DNSConfig: &tailcfg.DNSConfig{},
Domain: "", Domain: "",
CollectServices: "false", CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
PacketFilter: []tailcfg.FilterRule{}, PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
@ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{}, DNSConfig: &tailcfg.DNSConfig{},
Domain: "", Domain: "",
CollectServices: "false", CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
PacketFilter: []tailcfg.FilterRule{ PacketFilter: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.2/32"}, SrcIPs: []string{"100.64.0.2/32"},

View File

@ -3,24 +3,25 @@ package notifier
import ( import (
"sync" "sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
) )
type Notifier struct { type Notifier struct {
l sync.RWMutex l sync.RWMutex
nodes map[string]chan<- struct{} nodes map[string]chan<- types.StateUpdate
} }
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
return &Notifier{} return &Notifier{}
} }
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) { func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
n.l.Lock() n.l.Lock()
defer n.l.Unlock() defer n.l.Unlock()
if n.nodes == nil { if n.nodes == nil {
n.nodes = make(map[string]chan<- struct{}) n.nodes = make(map[string]chan<- types.StateUpdate)
} }
n.nodes[machineKey] = c n.nodes[machineKey] = c
@ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) {
delete(n.nodes, machineKey) delete(n.nodes, machineKey)
} }
func (n *Notifier) NotifyAll() { func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore() n.NotifyWithIgnore(update)
} }
func (n *Notifier) NotifyWithIgnore(ignore ...string) { func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
@ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) {
continue continue
} }
c <- struct{}{} c <- update
} }
} }

View File

@ -116,7 +116,7 @@ func (h *Headscale) handlePoll(
return return
} }
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
if err != nil { if err != nil {
logErr(err, "Failed to create MapResponse") logErr(err, "Failed to create MapResponse")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
@ -163,7 +163,12 @@ func (h *Headscale) handlePoll(
Inc() Inc()
// Tell all the other nodes about the new endpoint, but dont update ourselves. // Tell all the other nodes about the new endpoint, but dont update ourselves.
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey) h.nodeNotifier.NotifyWithIgnore(
types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
},
machine.MachineKey)
return return
} else if mapRequest.OmitPeers && mapRequest.Stream { } else if mapRequest.OmitPeers && mapRequest.Stream {
@ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream(
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)
const chanSize = 8 const chanSize = 8
updateChan := make(chan struct{}, chanSize) updateChan := make(chan types.StateUpdate, chanSize)
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
@ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream(
for { for {
select { select {
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine) data, err := mapp.KeepAliveResponse(mapRequest, machine)
if err != nil { if err != nil {
logErr(err, "Error generating the keep alive msg") logErr(err, "Error generating the keep alive msg")
@ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream(
return return
} }
case <-updateChan: case update := <-updateChan:
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) var data []byte
var err error
switch update.Type {
case types.StateFullUpdate:
data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
case types.StatePeerChanged:
data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy)
case types.StatePeerRemoved:
data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed)
case types.StateDERPUpdated:
data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap)
}
if err != nil { if err != nil {
logErr(err, "Could not get the map update") logErr(err, "Could not get the create map update")
return return
} }
@ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream(
} }
} }
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) {
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine). Str("machine", machine).

View File

@ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) {
return string(bytes), err return string(bytes), err
} }
type StateUpdateType int
const (
StateFullUpdate StateUpdateType = iota
StatePeerChanged
StatePeerRemoved
StateDERPUpdated
)
// StateUpdate is an internal message containing information about
// a state change that has happened to the network.
type StateUpdate struct {
// The type of update
Type StateUpdateType
// Changed must be set when Type is StatePeerChanged and
// contain the Machine IDs of machines that has changed.
Changed []uint64
// Removed must be set when Type is StatePeerRemoved and
// contain a list of the nodes that has been removed from
// the network.
Removed []tailcfg.NodeID
// DERPMap must be set when Type is StateDERPUpdated and
// contain the new DERP Map.
DERPMap tailcfg.DERPMap
}