poll: use nodeview everywhere

There was a bug in HA subnet router handover where we used stale node data
from the longpoll session that we handed to Connect. This meant that we got
some odd behaviour where routes would not be deactivated correctly.

This commit changes to the nodeview is used through out, and we load the
current node to be updated in the write path and then handle it all there
to be consistent.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-07-08 09:49:05 +02:00
committed by Kristoffer Dalby
parent 4a8d2d9ed3
commit b904276f2b
4 changed files with 176 additions and 117 deletions

View File

@@ -42,7 +42,7 @@ type mapSession struct {
keepAlive time.Duration
keepAliveTicker *time.Ticker
node *types.Node
node types.NodeView
w http.ResponseWriter
warnf func(string, ...any)
@@ -55,9 +55,9 @@ func (h *Headscale) newMapSession(
ctx context.Context,
req tailcfg.MapRequest,
w http.ResponseWriter,
node *types.Node,
nv types.NodeView,
) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node)
warnf, infof, tracef, errf := logPollFuncView(req, nv)
var updateChan chan types.StateUpdate
if req.Stream {
@@ -75,7 +75,7 @@ func (h *Headscale) newMapSession(
ctx: ctx,
req: req,
w: w,
node: node,
node: nv,
capVer: req.Version,
mapper: h.mapper,
@@ -112,13 +112,13 @@ func (m *mapSession) resetKeepAlive() {
func (m *mapSession) beforeServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Cancel(m.node.ID)
m.h.ephemeralGC.Cancel(m.node.ID())
}
}
func (m *mapSession) afterServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
}
}
@@ -168,10 +168,6 @@ func (m *mapSession) serve() {
func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll()
// For now, mapSession uses a normal node, but since serveLongPoll is a read operation,
// convert the node to a view at the beginning.
nv := m.node.View()
// Clean up the session when the client disconnects
defer func() {
m.cancelChMu.Lock()
@@ -183,16 +179,16 @@ func (m *mapSession) serveLongPoll() {
// in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(nv.ID(), m.ch) {
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
change, err := m.h.state.Disconnect(nv)
change, err := m.h.state.Disconnect(m.node.ID())
if err != nil {
m.errf(err, "Failed to disconnect node %s", nv.Hostname())
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
}
if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
@@ -205,10 +201,7 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
if m.h.state.Connect(nv) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
m.h.state.Connect(m.node.ID())
// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
@@ -217,20 +210,12 @@ func (m *mapSession) serveLongPoll() {
// so it needs to be disabled.
rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, nv.Hostname()))
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
defer cancel()
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(nv.ID(), m.ch)
go func() {
changed := m.h.state.Connect(nv)
if changed {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}()
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@@ -257,7 +242,7 @@ func (m *mapSession) serveLongPoll() {
}
// If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, nv.ID()) {
if slices.Contains(update.Removed, m.node.ID()) {
m.tracef("node removed, closing stream")
return
}
@@ -269,25 +254,21 @@ func (m *mapSession) serveLongPoll() {
var err error
var lastMessage string
// Ensure the node object is updated, for example, there
// Ensure the node view is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
m.node, err = m.h.state.GetNodeByID(nv.ID())
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
if err != nil {
m.errf(err, "Could not get machine from db")
return
}
// Update the node view to reflect the latest node state
// TODO(kradalby): This should become a full read only path, with no update for the node view
// in the new mapper model.
nv = m.node.View()
updateType := "full"
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, nv, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
@@ -297,12 +278,12 @@ func (m *mapSession) serveLongPoll() {
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change"
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, nv, update.ChangePatches)
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch"
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
@@ -311,17 +292,17 @@ func (m *mapSession) serveLongPoll() {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, nv, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, nv, m.h.state.DERPMap())
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
updateType = "derp"
}
@@ -348,10 +329,10 @@ func (m *mapSession) serveLongPoll() {
return
}
log.Trace().Str("node", nv.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", nv.MachineKey().String()).Msg("finished writing mapresp to node")
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, nv.ID().String()).Set(float64(time.Now().Unix()))
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent")
@@ -359,7 +340,7 @@ func (m *mapSession) serveLongPoll() {
}
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, nv)
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil {
m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
@@ -379,7 +360,7 @@ func (m *mapSession) serveLongPoll() {
}
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", nv.ID().String()).Set(float64(time.Now().Unix()))
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
}
@@ -389,14 +370,23 @@ func (m *mapSession) serveLongPoll() {
func (m *mapSession) handleEndpointUpdate() {
m.tracef("received endpoint update")
// Get fresh node state from database for accurate route calculations
node, err := m.h.state.GetNodeByID(m.node.ID())
if err != nil {
m.errf(err, "Failed to get fresh node from database for endpoint update")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
change := m.node.PeerChangeFromMapRequest(m.req)
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID)
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
change.Online = &online
m.node.ApplyPeerChange(&change)
node.ApplyPeerChange(&change)
sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo)
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
// The node might not set NetInfo if it has not changed and if
// the full HostInfo object is overwritten, the information is lost.
@@ -405,12 +395,12 @@ func (m *mapSession) handleEndpointUpdate() {
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
// TODO(kradalby): evaluate if we need better comparing of hostinfo
// before we take the changes.
if m.req.Hostinfo.NetInfo == nil && m.node.Hostinfo != nil {
m.req.Hostinfo.NetInfo = m.node.Hostinfo.NetInfo
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
}
m.node.Hostinfo = m.req.Hostinfo
node.Hostinfo = m.req.Hostinfo
logTracePeerChange(m.node.Hostname, sendUpdate, &change)
logTracePeerChange(node.Hostname, sendUpdate, &change)
// If there is no changes and nothing to save,
// return early.
@@ -419,47 +409,40 @@ func (m *mapSession) handleEndpointUpdate() {
return
}
// Check if the Hostinfo of the node has changed.
// If it has changed, check if there has been a change to
// the routable IPs of the host and update them in
// the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the route change.
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
routesAutoApproved := m.h.state.AutoApproveRoutes(m.node)
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
// Update the routes of the given node in the route manager to
// see if an update needs to be sent.
if m.h.state.SetNodeRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(m.node.ID), m.node.ID)
// Always update routes for connected nodes to handle reconnection scenarios
// where routes need to be restored to the primary routes system
routesToSet := node.SubnetRoutes()
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(m.node.ID),
m.node.ID)
}
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else if routesChanged {
// Only send peer changed notification if routes actually changed
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// If routes were auto-approved, we need to save the node to persist the changes
if routesAutoApproved {
if _, _, err := m.h.state.SaveNode(m.node); err != nil {
m.errf(err, "Failed to save auto-approved routes to node")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
}
// If routes were auto-approved, we need to save the node to persist the changes
if routesAutoApproved {
if _, _, err := m.h.state.SaveNode(node); err != nil {
m.errf(err, "Failed to save auto-approved routes to node")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
}
@@ -467,9 +450,9 @@ func (m *mapSession) handleEndpointUpdate() {
// in the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the hostname change.
m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
_, policyChanged, err := m.h.state.SaveNode(m.node)
_, policyChanged, err := m.h.state.SaveNode(node)
if err != nil {
m.errf(err, "Failed to persist/update node in the database")
http.Error(m.w, "", http.StatusInternalServerError)
@@ -480,15 +463,15 @@ func (m *mapSession) handleEndpointUpdate() {
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", m.node.Hostname)
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname)
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,
types.UpdatePeerChanged(m.node.ID),
m.node.ID,
types.UpdatePeerChanged(node.ID),
node.ID,
)
m.w.WriteHeader(http.StatusOK)
@@ -498,7 +481,7 @@ func (m *mapSession) handleEndpointUpdate() {
func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node.View())
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil {
m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError)
@@ -611,6 +594,53 @@ func logPollFunc(
}
}
func logPollFuncView(
mapRequest tailcfg.MapRequest,
nodeView types.NodeView,
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Err(err).
Msgf(msg, a...)
}
}
// hostInfoChanged reports if hostInfo has changed in two ways,
// - first bool reports if an update needs to be sent to nodes
// - second reports if there has been changes to routes