From 1f0110fe06fca32ccff677df9e2288a62c130c25 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 7 Feb 2025 13:49:59 +0100 Subject: [PATCH] use helper function for constructing state updates (#2410) This helps preventing messages being sent with the wrong update type and payload combination, and it is shorter/neater. Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 20 +++++--------- hscontrol/auth.go | 14 +++------- hscontrol/db/node.go | 12 +++------ hscontrol/db/routes.go | 7 ++--- hscontrol/grpcv1.go | 48 +++++++--------------------------- hscontrol/notifier/notifier.go | 10 ++----- hscontrol/oidc.go | 4 +-- hscontrol/poll.go | 22 +++------------- hscontrol/types/common.go | 26 +++++++++++++++--- 9 files changed, 56 insertions(+), 107 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 5623c76a..2f1cd4cd 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -307,11 +307,9 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { h.cfg.TailcfgDNSConfig.ExtraRecords = records ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - // TODO(kradalby): We can probably do better than sending a full update here, - // but for now this will ensure that all of the nodes get the new records. - Type: types.StateFullUpdate, - }) + // TODO(kradalby): We can probably do better than sending a full update here, + // but for now this will ensure that all of the nodes get the new records. + h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } } } @@ -511,9 +509,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not if changed { ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") - notif.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + notif.NotifyAll(ctx, types.UpdateFull()) } return nil @@ -535,9 +531,7 @@ func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not if filterChanged { ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") - notif.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + notif.NotifyAll(ctx, types.UpdateFull()) return true, nil } @@ -872,9 +866,7 @@ func (h *Headscale) Serve() error { Msg("ACL policy successfully reloaded, notifying nodes of change") ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } default: info := func(msg string) { log.Info().Msg(msg) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 7695f1ae..4cc7058b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -93,15 +93,9 @@ func (h *Headscale) handleExistingNode( } ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []types.NodeID{node.ID}, - }) + h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID)) if changedNodes != nil { - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changedNodes, - }) + h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...)) } } @@ -114,7 +108,7 @@ func (h *Headscale) handleExistingNode( } ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") - h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID) + h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID) } return &tailcfg.RegisterResponse{ @@ -249,7 +243,7 @@ func (h *Headscale) handleRegisterWithAuthKey( if !updateSent { ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname) - h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID)) + h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID)) } return &tailcfg.RegisterResponse{ diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 11a13056..0c167856 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -17,6 +17,7 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) const ( @@ -626,11 +627,7 @@ func enableRoutes(tx *gorm.DB, node.Routes = nRoutes - return &types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{node.ID}, - Message: "created in db.enableRoutes", - }, nil + return ptr.To(types.UpdatePeerChanged(node.ID)), nil } func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { @@ -717,10 +714,7 @@ func ExpireExpiredNodes(tx *gorm.DB, } if len(expired) > 0 { - return started, types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: expired, - }, true + return started, types.UpdatePeerPatch(expired...), true } return started, types.StateUpdate{}, false diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 8d86145a..b2bda26b 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -12,6 +12,7 @@ import ( "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/net/tsaddr" + "tailscale.com/types/ptr" "tailscale.com/util/set" ) @@ -470,11 +471,7 @@ nodeRouteLoop: }) if len(changedNodes) != 0 { - return &types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: chng, - Message: "called from db.FailoverNodeRoutesIfNecessary", - }, nil + return ptr.To(types.UpdatePeerChanged(chng...)), nil } return nil, nil diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 7eadd0a7..59fe4ebd 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -266,10 +266,7 @@ func (api headscaleV1APIServer) RegisterNode( } if !updateSent { ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{node.ID}, - }) + api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID)) } return &v1.RegisterNodeResponse{Node: node.Proto()}, nil @@ -319,11 +316,7 @@ func (api headscaleV1APIServer) SetTags( } ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{node.ID}, - Message: "called from api.SetTags", - }, node.ID) + api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) log.Trace(). Str("node", node.Hostname). @@ -364,16 +357,10 @@ func (api headscaleV1APIServer) DeleteNode( } ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []types.NodeID{node.ID}, - }) + api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID)) if changedNodes != nil { - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changedNodes, - }) + api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...)) } return &v1.DeleteNodeResponse{}, nil @@ -401,14 +388,11 @@ func (api headscaleV1APIServer) ExpireNode( ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) api.h.nodeNotifier.NotifyByNodeID( ctx, - types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: []types.NodeID{node.ID}, - }, + types.UpdateSelf(node.ID), node.ID) ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID) + api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID) log.Trace(). Str("node", node.Hostname). @@ -439,11 +423,7 @@ func (api headscaleV1APIServer) RenameNode( } ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) - api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{node.ID}, - Message: "called from api.RenameNode", - }, node.ID) + api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) log.Trace(). Str("node", node.Hostname). @@ -602,10 +582,7 @@ func (api headscaleV1APIServer) DisableRoute( if update != nil { ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: update, - }) + api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...)) } return &v1.DisableRouteResponse{}, nil @@ -644,10 +621,7 @@ func (api headscaleV1APIServer) DeleteRoute( if update != nil { ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: update, - }) + api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...)) } return &v1.DeleteRouteResponse{}, nil @@ -809,9 +783,7 @@ func (api headscaleV1APIServer) SetPolicy( // Only send update if the packet filter has changed. if changed { ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } response := &v1.SetPolicyResponse{ diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 166d572d..4d2e277b 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -388,19 +388,13 @@ func (b *batcher) flush() { }) if b.changedNodeIDs.Slice().Len() > 0 { - update := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changedNodes, - } + update := types.UpdatePeerChanged(changedNodes...) b.n.sendAll(update) } if len(patches) > 0 { - patchUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: patches, - } + patchUpdate := types.UpdatePeerPatch(patches...) b.n.sendAll(patchUpdate) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index d6a6d59f..d7a46a87 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -494,12 +494,12 @@ func (a *AuthProviderOIDC) handleRegistration( ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) a.notifier.NotifyByNodeID( ctx, - types.StateSelf(node.ID), + types.UpdateSelf(node.ID), node.ID, ) ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) - a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID) + a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) } return newNode, nil diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 88c6288b..2df35c36 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -68,9 +68,7 @@ func (h *Headscale) newMapSession( // to receive a message to make sure we dont block the entire // notifier. updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) - updateChan <- types.StateUpdate{ - Type: types.StateFullUpdate, - } + updateChan <- types.UpdateFull() } ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) @@ -428,12 +426,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { } ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) - h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - change, - }, - }, node.ID) + h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerPatch(change), node.ID) } func (m *mapSession) handleEndpointUpdate() { @@ -506,10 +499,7 @@ func (m *mapSession) handleEndpointUpdate() { ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) m.h.nodeNotifier.NotifyByNodeID( ctx, - types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: []types.NodeID{m.node.ID}, - }, + types.UpdateSelf(m.node.ID), m.node.ID) } @@ -530,11 +520,7 @@ func (m *mapSession) handleEndpointUpdate() { ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname) m.h.nodeNotifier.NotifyWithIgnore( ctx, - types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{m.node.ID}, - Message: "called from handlePoll -> update", - }, + types.UpdatePeerChanged(m.node.ID), m.node.ID, ) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index e5cef8fd..c8d696af 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -102,21 +102,41 @@ func (su *StateUpdate) Empty() bool { return false } -func StateSelf(nodeID NodeID) StateUpdate { +func UpdateFull() StateUpdate { + return StateUpdate{ + Type: StateFullUpdate, + } +} + +func UpdateSelf(nodeID NodeID) StateUpdate { return StateUpdate{ Type: StateSelfUpdate, ChangeNodes: []NodeID{nodeID}, } } -func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate { +func UpdatePeerChanged(nodeIDs ...NodeID) StateUpdate { return StateUpdate{ Type: StatePeerChanged, ChangeNodes: nodeIDs, } } -func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { +func UpdatePeerPatch(changes ...*tailcfg.PeerChange) StateUpdate { + return StateUpdate{ + Type: StatePeerChangedPatch, + ChangePatches: changes, + } +} + +func UpdatePeerRemoved(nodeIDs ...NodeID) StateUpdate { + return StateUpdate{ + Type: StatePeerRemoved, + Removed: nodeIDs, + } +} + +func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { return StateUpdate{ Type: StatePeerChangedPatch, ChangePatches: []*tailcfg.PeerChange{