mirror of
https://github.com/juanfont/headscale.git
synced 2025-07-30 10:40:59 -04:00
361 lines
9.2 KiB
Go
361 lines
9.2 KiB
Go
package hscontrol
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/rand/v2"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/sasha-s/go-deadlock"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/util/zstdframe"
|
|
)
|
|
|
|
const (
|
|
keepAliveInterval = 50 * time.Second
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const nodeNameContextKey = contextKey("nodeName")
|
|
|
|
type mapSession struct {
|
|
h *Headscale
|
|
req tailcfg.MapRequest
|
|
ctx context.Context
|
|
capVer tailcfg.CapabilityVersion
|
|
|
|
cancelChMu deadlock.Mutex
|
|
|
|
ch chan *tailcfg.MapResponse
|
|
cancelCh chan struct{}
|
|
cancelChOpen bool
|
|
|
|
keepAlive time.Duration
|
|
keepAliveTicker *time.Ticker
|
|
|
|
node *types.Node
|
|
w http.ResponseWriter
|
|
|
|
warnf func(string, ...any)
|
|
infof func(string, ...any)
|
|
tracef func(string, ...any)
|
|
errf func(error, string, ...any)
|
|
}
|
|
|
|
func (h *Headscale) newMapSession(
|
|
ctx context.Context,
|
|
req tailcfg.MapRequest,
|
|
w http.ResponseWriter,
|
|
node *types.Node,
|
|
) *mapSession {
|
|
warnf, infof, tracef, errf := logPollFunc(req, node)
|
|
|
|
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
|
|
|
return &mapSession{
|
|
h: h,
|
|
ctx: ctx,
|
|
req: req,
|
|
w: w,
|
|
node: node,
|
|
capVer: req.Version,
|
|
|
|
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
|
|
cancelCh: make(chan struct{}),
|
|
cancelChOpen: true,
|
|
|
|
keepAlive: ka,
|
|
keepAliveTicker: nil,
|
|
|
|
// Loggers
|
|
warnf: warnf,
|
|
infof: infof,
|
|
tracef: tracef,
|
|
errf: errf,
|
|
}
|
|
}
|
|
|
|
func (m *mapSession) isStreaming() bool {
|
|
return m.req.Stream
|
|
}
|
|
|
|
func (m *mapSession) isEndpointUpdate() bool {
|
|
return !m.req.Stream && m.req.OmitPeers
|
|
}
|
|
|
|
func (m *mapSession) resetKeepAlive() {
|
|
m.keepAliveTicker.Reset(m.keepAlive)
|
|
}
|
|
|
|
func (m *mapSession) beforeServeLongPoll() {
|
|
if m.node.IsEphemeral() {
|
|
m.h.ephemeralGC.Cancel(m.node.ID)
|
|
}
|
|
}
|
|
|
|
func (m *mapSession) afterServeLongPoll() {
|
|
if m.node.IsEphemeral() {
|
|
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
|
|
}
|
|
}
|
|
|
|
// serve handles non-streaming requests.
|
|
func (m *mapSession) serve() {
|
|
// This is the mechanism where the node gives us information about its
|
|
// current configuration.
|
|
//
|
|
// If OmitPeers is true and Stream is false
|
|
// then the server will let clients update their endpoints without
|
|
// breaking existing long-polling (Stream == true) connections.
|
|
// In this case, the server can omit the entire response; the client
|
|
// only checks the HTTP response status code.
|
|
//
|
|
// This is what Tailscale calls a Lite update, the client ignores
|
|
// the response and just wants a 200.
|
|
// !req.stream && req.OmitPeers
|
|
if m.isEndpointUpdate() {
|
|
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
|
if err != nil {
|
|
httpError(m.w, err)
|
|
return
|
|
}
|
|
|
|
m.h.Change(c)
|
|
|
|
m.w.WriteHeader(http.StatusOK)
|
|
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
|
}
|
|
}
|
|
|
|
// serveLongPoll ensures the node gets the appropriate updates from either
|
|
// polling or immediate responses.
|
|
//
|
|
//nolint:gocyclo
|
|
func (m *mapSession) serveLongPoll() {
|
|
m.beforeServeLongPoll()
|
|
|
|
// Clean up the session when the client disconnects
|
|
defer func() {
|
|
m.cancelChMu.Lock()
|
|
m.cancelChOpen = false
|
|
close(m.cancelCh)
|
|
m.cancelChMu.Unlock()
|
|
|
|
// TODO(kradalby): This can likely be made more effective, but likely most
|
|
// nodes has access to the same routes, so it might not be a big deal.
|
|
disconnectChange, err := m.h.state.Disconnect(m.node)
|
|
if err != nil {
|
|
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
|
}
|
|
m.h.Change(disconnectChange)
|
|
|
|
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
|
|
|
m.afterServeLongPoll()
|
|
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
|
}()
|
|
|
|
// Set up the client stream
|
|
m.h.pollNetMapStreamWG.Add(1)
|
|
defer m.h.pollNetMapStreamWG.Done()
|
|
|
|
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
|
|
defer cancel()
|
|
|
|
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
|
|
|
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
|
// where the change is sent before the node is in the batcher's node map
|
|
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
|
m.errf(err, "failed to add node to batcher")
|
|
// Send empty response to client to fail fast for invalid/non-existent nodes
|
|
select {
|
|
case m.ch <- &tailcfg.MapResponse{}:
|
|
default:
|
|
// Channel might be closed
|
|
}
|
|
return
|
|
}
|
|
|
|
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
|
// but we still need to update routes and other state-level changes
|
|
connectChange := m.h.state.Connect(m.node)
|
|
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
|
m.h.Change(connectChange)
|
|
}
|
|
|
|
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
|
|
|
// Loop through updates and continuously send them to the
|
|
// client.
|
|
for {
|
|
// consume channels with update, keep alives or "batch" blocking signals
|
|
select {
|
|
case <-m.cancelCh:
|
|
m.tracef("poll cancelled received")
|
|
mapResponseEnded.WithLabelValues("cancelled").Inc()
|
|
return
|
|
|
|
case <-ctx.Done():
|
|
m.tracef("poll context done")
|
|
mapResponseEnded.WithLabelValues("done").Inc()
|
|
return
|
|
|
|
// Consume updates sent to node
|
|
case update, ok := <-m.ch:
|
|
m.tracef("received update from channel, ok: %t", ok)
|
|
if !ok {
|
|
m.tracef("update channel closed, streaming session is likely being replaced")
|
|
return
|
|
}
|
|
|
|
if err := m.writeMap(update); err != nil {
|
|
m.errf(err, "cannot write update to client")
|
|
return
|
|
}
|
|
|
|
m.tracef("update sent")
|
|
m.resetKeepAlive()
|
|
|
|
case <-m.keepAliveTicker.C:
|
|
if err := m.writeMap(&keepAlive); err != nil {
|
|
m.errf(err, "cannot write keep alive")
|
|
return
|
|
}
|
|
|
|
if debugHighCardinalityMetrics {
|
|
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
|
}
|
|
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
|
}
|
|
}
|
|
}
|
|
|
|
// writeMap writes the map response to the client.
|
|
// It handles compression if requested and any headers that need to be set.
|
|
// It also handles flushing the response if the ResponseWriter
|
|
// implements http.Flusher.
|
|
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
|
jsonBody, err := json.Marshal(msg)
|
|
if err != nil {
|
|
return fmt.Errorf("marshalling map response: %w", err)
|
|
}
|
|
|
|
if m.req.Compress == util.ZstdCompression {
|
|
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
|
}
|
|
|
|
data := make([]byte, reservedResponseHeaderSize)
|
|
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
|
data = append(data, jsonBody...)
|
|
|
|
startWrite := time.Now()
|
|
|
|
_, err = m.w.Write(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if m.isStreaming() {
|
|
if f, ok := m.w.(http.Flusher); ok {
|
|
f.Flush()
|
|
} else {
|
|
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
|
|
}
|
|
}
|
|
|
|
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
|
|
|
return nil
|
|
}
|
|
|
|
var keepAlive = tailcfg.MapResponse{
|
|
KeepAlive: true,
|
|
}
|
|
|
|
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
|
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
|
|
|
if peerChange.Key != nil {
|
|
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
|
}
|
|
|
|
if peerChange.DiscoKey != nil {
|
|
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
|
}
|
|
|
|
if peerChange.Online != nil {
|
|
trace = trace.Bool("online", *peerChange.Online)
|
|
}
|
|
|
|
if peerChange.Endpoints != nil {
|
|
eps := make([]string, len(peerChange.Endpoints))
|
|
for idx, ep := range peerChange.Endpoints {
|
|
eps[idx] = ep.String()
|
|
}
|
|
|
|
trace = trace.Strs("endpoints", eps)
|
|
}
|
|
|
|
if hostinfoChange {
|
|
trace = trace.Bool("hostinfo_changed", hostinfoChange)
|
|
}
|
|
|
|
if peerChange.DERPRegion != 0 {
|
|
trace = trace.Int("derp_region", peerChange.DERPRegion)
|
|
}
|
|
|
|
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
|
|
}
|
|
|
|
func logPollFunc(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
|
|
return func(msg string, a ...any) {
|
|
log.Warn().
|
|
Caller().
|
|
Bool("omitPeers", mapRequest.OmitPeers).
|
|
Bool("stream", mapRequest.Stream).
|
|
Uint64("node.id", node.ID.Uint64()).
|
|
Str("node", node.Hostname).
|
|
Msgf(msg, a...)
|
|
},
|
|
func(msg string, a ...any) {
|
|
log.Info().
|
|
Caller().
|
|
Bool("omitPeers", mapRequest.OmitPeers).
|
|
Bool("stream", mapRequest.Stream).
|
|
Uint64("node.id", node.ID.Uint64()).
|
|
Str("node", node.Hostname).
|
|
Msgf(msg, a...)
|
|
},
|
|
func(msg string, a ...any) {
|
|
log.Trace().
|
|
Caller().
|
|
Bool("omitPeers", mapRequest.OmitPeers).
|
|
Bool("stream", mapRequest.Stream).
|
|
Uint64("node.id", node.ID.Uint64()).
|
|
Str("node", node.Hostname).
|
|
Msgf(msg, a...)
|
|
},
|
|
func(err error, msg string, a ...any) {
|
|
log.Error().
|
|
Caller().
|
|
Bool("omitPeers", mapRequest.OmitPeers).
|
|
Bool("stream", mapRequest.Stream).
|
|
Uint64("node.id", node.ID.Uint64()).
|
|
Str("node", node.Hostname).
|
|
Err(err).
|
|
Msgf(msg, a...)
|
|
}
|
|
}
|