poll: use nodeview everywhere

There was a bug in HA subnet router handover where we used stale node data
from the longpoll session that we handed to Connect. This meant that we got
some odd behaviour where routes would not be deactivated correctly.

This commit changes to the nodeview is used through out, and we load the
current node to be updated in the write path and then handle it all there
to be consistent.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-07-08 09:49:05 +02:00 committed by Kristoffer Dalby
parent 4a8d2d9ed3
commit b904276f2b
4 changed files with 176 additions and 117 deletions

View File

@ -206,6 +206,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
} else if changed { } else if changed {
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname) ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
// Existing node re-registering without route changes
// Still need to notify peers about the node being active again
// Use UpdateFull to ensure all peers get complete peer maps
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{

View File

@ -213,15 +213,15 @@ func (ns *noiseServer) NoisePollNetMapHandler(
return return
} }
node, err := ns.getAndValidateNode(mapRequest) nv, err := ns.getAndValidateNode(mapRequest)
if err != nil { if err != nil {
httpError(writer, err) httpError(writer, err)
return return
} }
ns.nodeKey = node.NodeKey ns.nodeKey = nv.NodeKey()
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node) sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
sess.tracef("a node sending a MapRequest with Noise protocol") sess.tracef("a node sending a MapRequest with Noise protocol")
if !sess.isStreaming() { if !sess.isStreaming() {
sess.serve() sess.serve()
@ -292,19 +292,19 @@ func (ns *noiseServer) NoiseRegistrationHandler(
// getAndValidateNode retrieves the node from the database using the NodeKey // getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session. // and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (*types.Node, error) { func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusNotFound, "node not found", nil) return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
} }
return nil, err return types.NodeView{}, err
} }
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
if ns.machineKey != node.MachineKey { if ns.machineKey != nv.MachineKey() {
return nil, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
} }
return node, nil return nv, nil
} }

View File

@ -42,7 +42,7 @@ type mapSession struct {
keepAlive time.Duration keepAlive time.Duration
keepAliveTicker *time.Ticker keepAliveTicker *time.Ticker
node *types.Node node types.NodeView
w http.ResponseWriter w http.ResponseWriter
warnf func(string, ...any) warnf func(string, ...any)
@ -55,9 +55,9 @@ func (h *Headscale) newMapSession(
ctx context.Context, ctx context.Context,
req tailcfg.MapRequest, req tailcfg.MapRequest,
w http.ResponseWriter, w http.ResponseWriter,
node *types.Node, nv types.NodeView,
) *mapSession { ) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node) warnf, infof, tracef, errf := logPollFuncView(req, nv)
var updateChan chan types.StateUpdate var updateChan chan types.StateUpdate
if req.Stream { if req.Stream {
@ -75,7 +75,7 @@ func (h *Headscale) newMapSession(
ctx: ctx, ctx: ctx,
req: req, req: req,
w: w, w: w,
node: node, node: nv,
capVer: req.Version, capVer: req.Version,
mapper: h.mapper, mapper: h.mapper,
@ -112,13 +112,13 @@ func (m *mapSession) resetKeepAlive() {
func (m *mapSession) beforeServeLongPoll() { func (m *mapSession) beforeServeLongPoll() {
if m.node.IsEphemeral() { if m.node.IsEphemeral() {
m.h.ephemeralGC.Cancel(m.node.ID) m.h.ephemeralGC.Cancel(m.node.ID())
} }
} }
func (m *mapSession) afterServeLongPoll() { func (m *mapSession) afterServeLongPoll() {
if m.node.IsEphemeral() { if m.node.IsEphemeral() {
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout) m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
} }
} }
@ -168,10 +168,6 @@ func (m *mapSession) serve() {
func (m *mapSession) serveLongPoll() { func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll() m.beforeServeLongPoll()
// For now, mapSession uses a normal node, but since serveLongPoll is a read operation,
// convert the node to a view at the beginning.
nv := m.node.View()
// Clean up the session when the client disconnects // Clean up the session when the client disconnects
defer func() { defer func() {
m.cancelChMu.Lock() m.cancelChMu.Lock()
@ -183,16 +179,16 @@ func (m *mapSession) serveLongPoll() {
// in principal, it will be removed, but the client rapidly // in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection. // reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online. // In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(nv.ID(), m.ch) { if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
// TODO(kradalby): This can likely be made more effective, but likely most // TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal. // nodes has access to the same routes, so it might not be a big deal.
change, err := m.h.state.Disconnect(nv) change, err := m.h.state.Disconnect(m.node.ID())
if err != nil { if err != nil {
m.errf(err, "Failed to disconnect node %s", nv.Hostname()) m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
} }
if change { if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname()) ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
} }
@ -205,10 +201,7 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1) m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done() defer m.h.pollNetMapStreamWG.Done()
if m.h.state.Connect(nv) { m.h.state.Connect(m.node.ID())
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
// Upgrade the writer to a ResponseController // Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w) rc := http.NewResponseController(m.w)
@ -217,20 +210,12 @@ func (m *mapSession) serveLongPoll() {
// so it needs to be disabled. // so it needs to be disabled.
rc.SetWriteDeadline(time.Time{}) rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, nv.Hostname())) ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
defer cancel() defer cancel()
m.keepAliveTicker = time.NewTicker(m.keepAlive) m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(nv.ID(), m.ch) m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
go func() {
changed := m.h.state.Connect(nv)
if changed {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}()
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -257,7 +242,7 @@ func (m *mapSession) serveLongPoll() {
} }
// If the node has been removed from headscale, close the stream // If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, nv.ID()) { if slices.Contains(update.Removed, m.node.ID()) {
m.tracef("node removed, closing stream") m.tracef("node removed, closing stream")
return return
} }
@ -269,25 +254,21 @@ func (m *mapSession) serveLongPoll() {
var err error var err error
var lastMessage string var lastMessage string
// Ensure the node object is updated, for example, there // Ensure the node view is updated, for example, there
// might have been a hostinfo update in a sidechannel // might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response. // which contains data needed to generate a map response.
m.node, err = m.h.state.GetNodeByID(nv.ID()) m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
if err != nil { if err != nil {
m.errf(err, "Could not get machine from db") m.errf(err, "Could not get machine from db")
return return
} }
// Update the node view to reflect the latest node state
// TODO(kradalby): This should become a full read only path, with no update for the node view
// in the new mapper model.
nv = m.node.View()
updateType := "full" updateType := "full"
switch update.Type { switch update.Type {
case types.StateFullUpdate: case types.StateFullUpdate:
m.tracef("Sending Full MapResponse") m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, nv, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged: case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
@ -297,12 +278,12 @@ func (m *mapSession) serveLongPoll() {
lastMessage = update.Message lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change" updateType = "change"
case types.StatePeerChangedPatch: case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, nv, update.ChangePatches) data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch" updateType = "patch"
case types.StatePeerRemoved: case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed)) changed := make(map[types.NodeID]bool, len(update.Removed))
@ -311,17 +292,17 @@ func (m *mapSession) serveLongPoll() {
changed[nodeID] = false changed[nodeID] = false
} }
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove" updateType = "remove"
case types.StateSelfUpdate: case types.StateSelfUpdate:
lastMessage = update.Message lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent // create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, nv, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove" updateType = "remove"
case types.StateDERPUpdated: case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse") m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, nv, m.h.state.DERPMap()) data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
updateType = "derp" updateType = "derp"
} }
@ -348,10 +329,10 @@ func (m *mapSession) serveLongPoll() {
return return
} }
log.Trace().Str("node", nv.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", nv.MachineKey().String()).Msg("finished writing mapresp to node") log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics { if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, nv.ID().String()).Set(float64(time.Now().Unix())) mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
} }
mapResponseSent.WithLabelValues("ok", updateType).Inc() mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent") m.tracef("update sent")
@ -359,7 +340,7 @@ func (m *mapSession) serveLongPoll() {
} }
case <-m.keepAliveTicker.C: case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, nv) data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil { if err != nil {
m.errf(err, "Error generating the keep alive msg") m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc() mapResponseSent.WithLabelValues("error", "keepalive").Inc()
@ -379,7 +360,7 @@ func (m *mapSession) serveLongPoll() {
} }
if debugHighCardinalityMetrics { if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", nv.ID().String()).Set(float64(time.Now().Unix())) mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
} }
mapResponseSent.WithLabelValues("ok", "keepalive").Inc() mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
} }
@ -389,14 +370,23 @@ func (m *mapSession) serveLongPoll() {
func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleEndpointUpdate() {
m.tracef("received endpoint update") m.tracef("received endpoint update")
// Get fresh node state from database for accurate route calculations
node, err := m.h.state.GetNodeByID(m.node.ID())
if err != nil {
m.errf(err, "Failed to get fresh node from database for endpoint update")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
change := m.node.PeerChangeFromMapRequest(m.req) change := m.node.PeerChangeFromMapRequest(m.req)
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID) online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
change.Online = &online change.Online = &online
m.node.ApplyPeerChange(&change) node.ApplyPeerChange(&change)
sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
// The node might not set NetInfo if it has not changed and if // The node might not set NetInfo if it has not changed and if
// the full HostInfo object is overwritten, the information is lost. // the full HostInfo object is overwritten, the information is lost.
@ -405,12 +395,12 @@ func (m *mapSession) handleEndpointUpdate() {
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
// TODO(kradalby): evaluate if we need better comparing of hostinfo // TODO(kradalby): evaluate if we need better comparing of hostinfo
// before we take the changes. // before we take the changes.
if m.req.Hostinfo.NetInfo == nil && m.node.Hostinfo != nil { if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
m.req.Hostinfo.NetInfo = m.node.Hostinfo.NetInfo m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
} }
m.node.Hostinfo = m.req.Hostinfo node.Hostinfo = m.req.Hostinfo
logTracePeerChange(m.node.Hostname, sendUpdate, &change) logTracePeerChange(node.Hostname, sendUpdate, &change)
// If there is no changes and nothing to save, // If there is no changes and nothing to save,
// return early. // return early.
@ -419,47 +409,40 @@ func (m *mapSession) handleEndpointUpdate() {
return return
} }
// Check if the Hostinfo of the node has changed. // Auto approve any routes that have been defined in policy as
// If it has changed, check if there has been a change to // auto approved. Check if this actually changed the node.
// the routable IPs of the host and update them in routesAutoApproved := m.h.state.AutoApproveRoutes(node)
// the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the route change.
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
routesAutoApproved := m.h.state.AutoApproveRoutes(m.node)
// Update the routes of the given node in the route manager to // Always update routes for connected nodes to handle reconnection scenarios
// see if an update needs to be sent. // where routes need to be restored to the primary routes system
if m.h.state.SetNodeRoutes(m.node.ID, m.node.SubnetRoutes()...) { routesToSet := node.SubnetRoutes()
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(m.node.ID), m.node.ID)
// TODO(kradalby): I am not sure if we need this? if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
// Send an update to the node itself with to ensure it ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
// has an updated packetfilter allowing the new route m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
// if it is defined in the ACL. } else if routesChanged {
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", m.node.Hostname) // Only send peer changed notification if routes actually changed
m.h.nodeNotifier.NotifyByNodeID( ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
ctx, m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
types.UpdateSelf(m.node.ID),
m.node.ID)
}
// If routes were auto-approved, we need to save the node to persist the changes // TODO(kradalby): I am not sure if we need this?
if routesAutoApproved { // Send an update to the node itself with to ensure it
if _, _, err := m.h.state.SaveNode(m.node); err != nil { // has an updated packetfilter allowing the new route
m.errf(err, "Failed to save auto-approved routes to node") // if it is defined in the ACL.
http.Error(m.w, "", http.StatusInternalServerError) ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
mapResponseEndpointUpdates.WithLabelValues("error").Inc() m.h.nodeNotifier.NotifyByNodeID(
return ctx,
} types.UpdateSelf(node.ID),
node.ID)
}
// If routes were auto-approved, we need to save the node to persist the changes
if routesAutoApproved {
if _, _, err := m.h.state.SaveNode(node); err != nil {
m.errf(err, "Failed to save auto-approved routes to node")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
} }
} }
@ -467,9 +450,9 @@ func (m *mapSession) handleEndpointUpdate() {
// in the database. Then send a Changed update // in the database. Then send a Changed update
// (containing the whole node object) to peers to inform about // (containing the whole node object) to peers to inform about
// the hostname change. // the hostname change.
m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo) node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
_, policyChanged, err := m.h.state.SaveNode(m.node) _, policyChanged, err := m.h.state.SaveNode(node)
if err != nil { if err != nil {
m.errf(err, "Failed to persist/update node in the database") m.errf(err, "Failed to persist/update node in the database")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)
@ -480,15 +463,15 @@ func (m *mapSession) handleEndpointUpdate() {
// Send policy update notifications if needed // Send policy update notifications if needed
if policyChanged { if policyChanged {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} }
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore( m.h.nodeNotifier.NotifyWithIgnore(
ctx, ctx,
types.UpdatePeerChanged(m.node.ID), types.UpdatePeerChanged(node.ID),
m.node.ID, node.ID,
) )
m.w.WriteHeader(http.StatusOK) m.w.WriteHeader(http.StatusOK)
@ -498,7 +481,7 @@ func (m *mapSession) handleEndpointUpdate() {
func (m *mapSession) handleReadOnlyRequest() { func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers") m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node.View()) mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil { if err != nil {
m.errf(err, "Failed to create MapResponse") m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)
@ -611,6 +594,53 @@ func logPollFunc(
} }
} }
func logPollFuncView(
mapRequest tailcfg.MapRequest,
nodeView types.NodeView,
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Err(err).
Msgf(msg, a...)
}
}
// hostInfoChanged reports if hostInfo has changed in two ways, // hostInfoChanged reports if hostInfo has changed in two ways,
// - first bool reports if an update needs to be sent to nodes // - first bool reports if an update needs to be sent to nodes
// - second reports if there has been changes to routes // - second reports if there has been changes to routes

View File

@ -18,6 +18,7 @@ import (
"github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock" "github.com/sasha-s/go-deadlock"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -400,22 +401,17 @@ func (s *State) DeleteNode(node *types.Node) (bool, error) {
return policyChanged, nil return policyChanged, nil
} }
func (s *State) Connect(node types.NodeView) bool { func (s *State) Connect(id types.NodeID) {
changed := s.primaryRoutes.SetRoutes(node.ID(), node.SubnetRoutes()...)
// TODO(kradalby): this should be more granular, allowing us to
// only send a online update change.
return changed
} }
func (s *State) Disconnect(node types.NodeView) (bool, error) { func (s *State) Disconnect(id types.NodeID) (bool, error) {
// TODO(kradalby): This node should update the in memory state // TODO(kradalby): This node should update the in memory state
_, polChanged, err := s.SetLastSeen(node.ID(), time.Now()) _, polChanged, err := s.SetLastSeen(id, time.Now())
if err != nil { if err != nil {
return false, fmt.Errorf("disconnecting node: %w", err) return false, fmt.Errorf("disconnecting node: %w", err)
} }
changed := s.primaryRoutes.SetRoutes(node.ID()) changed := s.primaryRoutes.SetRoutes(id)
// TODO(kradalby): the returned change should be more nuanced allowing us to // TODO(kradalby): the returned change should be more nuanced allowing us to
// send more directed updates. // send more directed updates.
@ -427,11 +423,29 @@ func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) {
return s.db.GetNodeByID(nodeID) return s.db.GetNodeByID(nodeID)
} }
// GetNodeViewByID retrieves a node view by ID.
func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
node, err := s.db.GetNodeByID(nodeID)
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
// GetNodeByNodeKey retrieves a node by its Tailscale public key. // GetNodeByNodeKey retrieves a node by its Tailscale public key.
func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return s.db.GetNodeByNodeKey(nodeKey) return s.db.GetNodeByNodeKey(nodeKey)
} }
// GetNodeViewByNodeKey retrieves a node view by its Tailscale public key.
func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) {
node, err := s.db.GetNodeByNodeKey(nodeKey)
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) == 0 { if len(nodeIDs) == 0 {
@ -682,8 +696,17 @@ func (s *State) HandleNodeFromPreAuthKey(
AuthKeyID: &pak.ID, AuthKeyID: &pak.ID,
} }
if !regReq.Expiry.IsZero() { // For auth key registration, ensure we don't keep an expired node
// This is especially important for re-registration after logout
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
nodeToRegister.Expiry = &regReq.Expiry nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired
log.Debug().
Time("requested_expiry", regReq.Expiry).
Str("node", regReq.Hostinfo.Hostname).
Msg("Ignoring expired expiry time from auth key registration")
} }
ipv4, ipv6, err := s.ipAlloc.Next() ipv4, ipv6, err := s.ipAlloc.Next()