use helper function for constructing state updates (#2410)

This helps preventing messages being sent with the wrong update type
and payload combination, and it is shorter/neater.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-07 13:49:59 +01:00 committed by GitHub
parent b92bd3d27e
commit 1f0110fe06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 56 additions and 107 deletions

View File

@ -307,11 +307,9 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
h.cfg.TailcfgDNSConfig.ExtraRecords = records
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
Type: types.StateFullUpdate,
})
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
}
@ -511,9 +509,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
notif.NotifyAll(ctx, types.UpdateFull())
}
return nil
@ -535,9 +531,7 @@ func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
if filterChanged {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
notif.NotifyAll(ctx, types.UpdateFull())
return true, nil
}
@ -872,9 +866,7 @@ func (h *Headscale) Serve() error {
Msg("ACL policy successfully reloaded, notifying nodes of change")
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
default:
info := func(msg string) { log.Info().Msg(msg) }

View File

@ -93,15 +93,9 @@ func (h *Headscale) handleExistingNode(
}
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
if changedNodes != nil {
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
}
}
@ -114,7 +108,7 @@ func (h *Headscale) handleExistingNode(
}
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID)
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
}
return &tailcfg.RegisterResponse{
@ -249,7 +243,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
if !updateSent {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID))
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}
return &tailcfg.RegisterResponse{

View File

@ -17,6 +17,7 @@ import (
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
const (
@ -626,11 +627,7 @@ func enableRoutes(tx *gorm.DB,
node.Routes = nRoutes
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "created in db.enableRoutes",
}, nil
return ptr.To(types.UpdatePeerChanged(node.ID)), nil
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
@ -717,10 +714,7 @@ func ExpireExpiredNodes(tx *gorm.DB,
}
if len(expired) > 0 {
return started, types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: expired,
}, true
return started, types.UpdatePeerPatch(expired...), true
}
return started, types.StateUpdate{}, false

View File

@ -12,6 +12,7 @@ import (
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/ptr"
"tailscale.com/util/set"
)
@ -470,11 +471,7 @@ nodeRouteLoop:
})
if len(changedNodes) != 0 {
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: chng,
Message: "called from db.FailoverNodeRoutesIfNecessary",
}, nil
return ptr.To(types.UpdatePeerChanged(chng...)), nil
}
return nil, nil

View File

@ -266,10 +266,7 @@ func (api headscaleV1APIServer) RegisterNode(
}
if !updateSent {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
@ -319,11 +316,7 @@ func (api headscaleV1APIServer) SetTags(
}
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.SetTags",
}, node.ID)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().
Str("node", node.Hostname).
@ -364,16 +357,10 @@ func (api headscaleV1APIServer) DeleteNode(
}
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
if changedNodes != nil {
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
}
return &v1.DeleteNodeResponse{}, nil
@ -401,14 +388,11 @@ func (api headscaleV1APIServer) ExpireNode(
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
log.Trace().
Str("node", node.Hostname).
@ -439,11 +423,7 @@ func (api headscaleV1APIServer) RenameNode(
}
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.RenameNode",
}, node.ID)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().
Str("node", node.Hostname).
@ -602,10 +582,7 @@ func (api headscaleV1APIServer) DisableRoute(
if update != nil {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: update,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
}
return &v1.DisableRouteResponse{}, nil
@ -644,10 +621,7 @@ func (api headscaleV1APIServer) DeleteRoute(
if update != nil {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: update,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
}
return &v1.DeleteRouteResponse{}, nil
@ -809,9 +783,7 @@ func (api headscaleV1APIServer) SetPolicy(
// Only send update if the packet filter has changed.
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
response := &v1.SetPolicyResponse{

View File

@ -388,19 +388,13 @@ func (b *batcher) flush() {
})
if b.changedNodeIDs.Slice().Len() > 0 {
update := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
}
update := types.UpdatePeerChanged(changedNodes...)
b.n.sendAll(update)
}
if len(patches) > 0 {
patchUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: patches,
}
patchUpdate := types.UpdatePeerPatch(patches...)
b.n.sendAll(patchUpdate)
}

View File

@ -494,12 +494,12 @@ func (a *AuthProviderOIDC) handleRegistration(
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateSelf(node.ID),
types.UpdateSelf(node.ID),
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID)
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
}
return newNode, nil

View File

@ -68,9 +68,7 @@ func (h *Headscale) newMapSession(
// to receive a message to make sure we dont block the entire
// notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.StateUpdate{
Type: types.StateFullUpdate,
}
updateChan <- types.UpdateFull()
}
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
@ -428,12 +426,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
change,
},
}, node.ID)
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerPatch(change), node.ID)
}
func (m *mapSession) handleEndpointUpdate() {
@ -506,10 +499,7 @@ func (m *mapSession) handleEndpointUpdate() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{m.node.ID},
},
types.UpdateSelf(m.node.ID),
m.node.ID)
}
@ -530,11 +520,7 @@ func (m *mapSession) handleEndpointUpdate() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,
types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{m.node.ID},
Message: "called from handlePoll -> update",
},
types.UpdatePeerChanged(m.node.ID),
m.node.ID,
)

View File

@ -102,21 +102,41 @@ func (su *StateUpdate) Empty() bool {
return false
}
func StateSelf(nodeID NodeID) StateUpdate {
func UpdateFull() StateUpdate {
return StateUpdate{
Type: StateFullUpdate,
}
}
func UpdateSelf(nodeID NodeID) StateUpdate {
return StateUpdate{
Type: StateSelfUpdate,
ChangeNodes: []NodeID{nodeID},
}
}
func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate {
func UpdatePeerChanged(nodeIDs ...NodeID) StateUpdate {
return StateUpdate{
Type: StatePeerChanged,
ChangeNodes: nodeIDs,
}
}
func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
func UpdatePeerPatch(changes ...*tailcfg.PeerChange) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: changes,
}
}
func UpdatePeerRemoved(nodeIDs ...NodeID) StateUpdate {
return StateUpdate{
Type: StatePeerRemoved,
Removed: nodeIDs,
}
}
func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{