use helper function for constructing state updates (#2410)

This helps preventing messages being sent with the wrong update type
and payload combination, and it is shorter/neater.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-07 13:49:59 +01:00 committed by GitHub
parent b92bd3d27e
commit 1f0110fe06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 56 additions and 107 deletions

View File

@ -307,11 +307,9 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
h.cfg.TailcfgDNSConfig.ExtraRecords = records h.cfg.TailcfgDNSConfig.ExtraRecords = records
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all") ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ // TODO(kradalby): We can probably do better than sending a full update here,
// TODO(kradalby): We can probably do better than sending a full update here, // but for now this will ensure that all of the nodes get the new records.
// but for now this will ensure that all of the nodes get the new records. h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
Type: types.StateFullUpdate,
})
} }
} }
} }
@ -511,9 +509,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
if changed { if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{ notif.NotifyAll(ctx, types.UpdateFull())
Type: types.StateFullUpdate,
})
} }
return nil return nil
@ -535,9 +531,7 @@ func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
if filterChanged { if filterChanged {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{ notif.NotifyAll(ctx, types.UpdateFull())
Type: types.StateFullUpdate,
})
return true, nil return true, nil
} }
@ -872,9 +866,7 @@ func (h *Headscale) Serve() error {
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
Type: types.StateFullUpdate,
})
} }
default: default:
info := func(msg string) { log.Info().Msg(msg) } info := func(msg string) { log.Info().Msg(msg) }

View File

@ -93,15 +93,9 @@ func (h *Headscale) handleExistingNode(
} }
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
if changedNodes != nil { if changedNodes != nil {
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
} }
} }
@ -114,7 +108,7 @@ func (h *Headscale) handleExistingNode(
} }
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID) h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
} }
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
@ -249,7 +243,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
if !updateSent { if !updateSent {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname) ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID)) h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
} }
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{

View File

@ -17,6 +17,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
const ( const (
@ -626,11 +627,7 @@ func enableRoutes(tx *gorm.DB,
node.Routes = nRoutes node.Routes = nRoutes
return &types.StateUpdate{ return ptr.To(types.UpdatePeerChanged(node.ID)), nil
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "created in db.enableRoutes",
}, nil
} }
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
@ -717,10 +714,7 @@ func ExpireExpiredNodes(tx *gorm.DB,
} }
if len(expired) > 0 { if len(expired) > 0 {
return started, types.StateUpdate{ return started, types.UpdatePeerPatch(expired...), true
Type: types.StatePeerChangedPatch,
ChangePatches: expired,
}, true
} }
return started, types.StateUpdate{}, false return started, types.StateUpdate{}, false

View File

@ -12,6 +12,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/types/ptr"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
@ -470,11 +471,7 @@ nodeRouteLoop:
}) })
if len(changedNodes) != 0 { if len(changedNodes) != 0 {
return &types.StateUpdate{ return ptr.To(types.UpdatePeerChanged(chng...)), nil
Type: types.StatePeerChanged,
ChangeNodes: chng,
Message: "called from db.FailoverNodeRoutesIfNecessary",
}, nil
} }
return nil, nil return nil, nil

View File

@ -266,10 +266,7 @@ func (api headscaleV1APIServer) RegisterNode(
} }
if !updateSent { if !updateSent {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname) ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
})
} }
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
@ -319,11 +316,7 @@ func (api headscaleV1APIServer) SetTags(
} }
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.SetTags",
}, node.ID)
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -364,16 +357,10 @@ func (api headscaleV1APIServer) DeleteNode(
} }
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
if changedNodes != nil { if changedNodes != nil {
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
} }
return &v1.DeleteNodeResponse{}, nil return &v1.DeleteNodeResponse{}, nil
@ -401,14 +388,11 @@ func (api headscaleV1APIServer) ExpireNode(
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID( api.h.nodeNotifier.NotifyByNodeID(
ctx, ctx,
types.StateUpdate{ types.UpdateSelf(node.ID),
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID) node.ID)
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID) api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -439,11 +423,7 @@ func (api headscaleV1APIServer) RenameNode(
} }
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.RenameNode",
}, node.ID)
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -602,10 +582,7 @@ func (api headscaleV1APIServer) DisableRoute(
if update != nil { if update != nil {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown") ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
Type: types.StatePeerChanged,
ChangeNodes: update,
})
} }
return &v1.DisableRouteResponse{}, nil return &v1.DisableRouteResponse{}, nil
@ -644,10 +621,7 @@ func (api headscaleV1APIServer) DeleteRoute(
if update != nil { if update != nil {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown") ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
Type: types.StatePeerChanged,
ChangeNodes: update,
})
} }
return &v1.DeleteRouteResponse{}, nil return &v1.DeleteRouteResponse{}, nil
@ -809,9 +783,7 @@ func (api headscaleV1APIServer) SetPolicy(
// Only send update if the packet filter has changed. // Only send update if the packet filter has changed.
if changed { if changed {
ctx := types.NotifyCtx(context.Background(), "acl-update", "na") ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
Type: types.StateFullUpdate,
})
} }
response := &v1.SetPolicyResponse{ response := &v1.SetPolicyResponse{

View File

@ -388,19 +388,13 @@ func (b *batcher) flush() {
}) })
if b.changedNodeIDs.Slice().Len() > 0 { if b.changedNodeIDs.Slice().Len() > 0 {
update := types.StateUpdate{ update := types.UpdatePeerChanged(changedNodes...)
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
}
b.n.sendAll(update) b.n.sendAll(update)
} }
if len(patches) > 0 { if len(patches) > 0 {
patchUpdate := types.StateUpdate{ patchUpdate := types.UpdatePeerPatch(patches...)
Type: types.StatePeerChangedPatch,
ChangePatches: patches,
}
b.n.sendAll(patchUpdate) b.n.sendAll(patchUpdate)
} }

View File

@ -494,12 +494,12 @@ func (a *AuthProviderOIDC) handleRegistration(
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID( a.notifier.NotifyByNodeID(
ctx, ctx,
types.StateSelf(node.ID), types.UpdateSelf(node.ID),
node.ID, node.ID,
) )
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID) a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
} }
return newNode, nil return newNode, nil

View File

@ -68,9 +68,7 @@ func (h *Headscale) newMapSession(
// to receive a message to make sure we dont block the entire // to receive a message to make sure we dont block the entire
// notifier. // notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.StateUpdate{ updateChan <- types.UpdateFull()
Type: types.StateFullUpdate,
}
} }
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
@ -428,12 +426,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
} }
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerPatch(change), node.ID)
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
change,
},
}, node.ID)
} }
func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleEndpointUpdate() {
@ -506,10 +499,7 @@ func (m *mapSession) handleEndpointUpdate() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID( m.h.nodeNotifier.NotifyByNodeID(
ctx, ctx,
types.StateUpdate{ types.UpdateSelf(m.node.ID),
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{m.node.ID},
},
m.node.ID) m.node.ID)
} }
@ -530,11 +520,7 @@ func (m *mapSession) handleEndpointUpdate() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore( m.h.nodeNotifier.NotifyWithIgnore(
ctx, ctx,
types.StateUpdate{ types.UpdatePeerChanged(m.node.ID),
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{m.node.ID},
Message: "called from handlePoll -> update",
},
m.node.ID, m.node.ID,
) )

View File

@ -102,21 +102,41 @@ func (su *StateUpdate) Empty() bool {
return false return false
} }
func StateSelf(nodeID NodeID) StateUpdate { func UpdateFull() StateUpdate {
return StateUpdate{
Type: StateFullUpdate,
}
}
func UpdateSelf(nodeID NodeID) StateUpdate {
return StateUpdate{ return StateUpdate{
Type: StateSelfUpdate, Type: StateSelfUpdate,
ChangeNodes: []NodeID{nodeID}, ChangeNodes: []NodeID{nodeID},
} }
} }
func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate { func UpdatePeerChanged(nodeIDs ...NodeID) StateUpdate {
return StateUpdate{ return StateUpdate{
Type: StatePeerChanged, Type: StatePeerChanged,
ChangeNodes: nodeIDs, ChangeNodes: nodeIDs,
} }
} }
func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { func UpdatePeerPatch(changes ...*tailcfg.PeerChange) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: changes,
}
}
func UpdatePeerRemoved(nodeIDs ...NodeID) StateUpdate {
return StateUpdate{
Type: StatePeerRemoved,
Removed: nodeIDs,
}
}
func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
return StateUpdate{ return StateUpdate{
Type: StatePeerChangedPatch, Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{ ChangePatches: []*tailcfg.PeerChange{