all: use immutable node view in read path

This commit changes most of our (*)types.Node to
types.NodeView, which is a readonly version of the
underlying node ensuring that there is no mutations
happening in the read path.

Based on the migration, there didnt seem to be any, but the
idea here is to prevent it in the future and simplify other
new implementations.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-07-05 23:31:13 +02:00
committed by Kristoffer Dalby
parent 5ba7120418
commit 73023c2ec3
24 changed files with 866 additions and 196 deletions

View File

@@ -168,6 +168,10 @@ func (m *mapSession) serve() {
func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll()
// For now, mapSession uses a normal node, but since serveLongPoll is a read operation,
// convert the node to a view at the beginning.
nv := m.node.View()
// Clean up the session when the client disconnects
defer func() {
m.cancelChMu.Lock()
@@ -179,16 +183,16 @@ func (m *mapSession) serveLongPoll() {
// in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
if m.h.nodeNotifier.RemoveNode(nv.ID(), m.ch) {
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
change, err := m.h.state.Disconnect(m.node)
change, err := m.h.state.Disconnect(nv)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
m.errf(err, "Failed to disconnect node %s", nv.Hostname())
}
if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
@@ -201,8 +205,8 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
if m.h.state.Connect(m.node) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
if m.h.state.Connect(nv) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
@@ -213,17 +217,17 @@ func (m *mapSession) serveLongPoll() {
// so it needs to be disabled.
rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, nv.Hostname()))
defer cancel()
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(m.node.ID, m.ch)
m.h.nodeNotifier.AddNode(nv.ID(), m.ch)
go func() {
changed := m.h.state.Connect(m.node)
changed := m.h.state.Connect(nv)
if changed {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", nv.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}()
@@ -253,7 +257,7 @@ func (m *mapSession) serveLongPoll() {
}
// If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, m.node.ID) {
if slices.Contains(update.Removed, nv.ID()) {
m.tracef("node removed, closing stream")
return
}
@@ -268,18 +272,22 @@ func (m *mapSession) serveLongPoll() {
// Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
m.node, err = m.h.state.GetNodeByID(m.node.ID)
m.node, err = m.h.state.GetNodeByID(nv.ID())
if err != nil {
m.errf(err, "Could not get machine from db")
return
}
// Update the node view to reflect the latest node state
// TODO(kradalby): This should become a full read only path, with no update for the node view
// in the new mapper model.
nv = m.node.View()
updateType := "full"
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
data, err = m.mapper.FullMapResponse(m.req, nv, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
@@ -289,12 +297,12 @@ func (m *mapSession) serveLongPoll() {
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage)
updateType = "change"
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
data, err = m.mapper.PeerChangedPatchResponse(m.req, nv, update.ChangePatches)
updateType = "patch"
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
@@ -303,17 +311,17 @@ func (m *mapSession) serveLongPoll() {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, nv, changed, update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, nv, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
data, err = m.mapper.DERPMapResponse(m.req, nv, m.h.state.DERPMap())
updateType = "derp"
}
@@ -340,10 +348,10 @@ func (m *mapSession) serveLongPoll() {
return
}
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
log.Trace().Str("node", nv.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", nv.MachineKey().String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID.String()).Set(float64(time.Now().Unix()))
mapResponseLastSentSeconds.WithLabelValues(updateType, nv.ID().String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent")
@@ -351,7 +359,7 @@ func (m *mapSession) serveLongPoll() {
}
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
data, err := m.mapper.KeepAliveResponse(m.req, nv)
if err != nil {
m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
@@ -371,7 +379,7 @@ func (m *mapSession) serveLongPoll() {
}
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
mapResponseLastSentSeconds.WithLabelValues("keepalive", nv.ID().String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
}
@@ -490,7 +498,7 @@ func (m *mapSession) handleEndpointUpdate() {
func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node.View())
if err != nil {
m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError)