state: introduce state

this commit moves all of the read and write logic, and all different parts
of headscale that manages some sort of persistent and in memory state into
a separate package.

The goal of this is to clearly define the boundry between parts of the app
which accesses and modifies data, and where it happens. Previously, different
state (routes, policy, db and so on) was used directly, and sometime passed to
functions as pointers.

Now all access has to go through state. In the initial implementation,
most of the same functions exists and have just been moved. In the future
centralising this will allow us to optimise bottle necks with the database
(in memory state) and make the different parts talking to eachother do so
in the same way across headscale components.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-05-27 16:27:16 +02:00
committed by Kristoffer Dalby
parent a975b6a8b1
commit 1553f0ab53
17 changed files with 1390 additions and 1067 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
_ "net/http/pprof" // nolint
@@ -30,8 +29,7 @@ import (
"github.com/juanfont/headscale/hscontrol/dns"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
@@ -49,13 +47,11 @@ import (
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
"tailscale.com/util/dnsname"
zcache "zgo.at/zcache/v2"
)
var (
@@ -73,33 +69,22 @@ const (
updateInterval = 5 * time.Second
privateKeyFileMode = 0o600
headscaleDirPerm = 0o700
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
)
// Headscale represents the base app of the service.
type Headscale struct {
cfg *types.Config
db *db.HSDatabase
ipAlloc *db.IPAllocator
state *state.State
noisePrivateKey *key.MachinePrivate
ephemeralGC *db.EphemeralGarbageCollector
DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer
polManOnce sync.Once
polMan policy.PolicyManager
// Things that generate changes
extraRecordMan *dns.ExtraRecordsMan
primaryRoutes *routes.PrimaryRoutes
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
authProvider AuthProvider
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
authProvider AuthProvider
pollNetMapStreamWG sync.WaitGroup
}
@@ -124,43 +109,42 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
s, err := state.NewState(cfg)
if err != nil {
return nil, fmt.Errorf("init state: %w", err)
}
app := Headscale{
cfg: cfg,
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(cfg),
primaryRoutes: routes.New(),
state: s,
}
app.db, err = db.NewHeadscaleDatabase(
cfg.Database,
cfg.BaseDomain,
registrationCache,
)
if err != nil {
return nil, fmt.Errorf("new database: %w", err)
}
app.ipAlloc, err = db.NewIPAllocator(app.db, cfg.PrefixV4, cfg.PrefixV6, cfg.IPAllocation)
if err != nil {
return nil, err
}
app.ephemeralGC = db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
if err := app.db.DeleteEphemeralNode(ni); err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
// Initialize ephemeral garbage collector
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
node, err := app.state.GetNodeByID(ni)
if err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
return
}
})
if err = app.loadPolicyManager(); err != nil {
return nil, fmt.Errorf("loading ACL policy: %w", err)
}
policyChanged, err := app.state.DeleteNode(node)
if err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
return
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
})
app.ephemeralGC = ephemeralGC
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
@@ -171,10 +155,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
ctx,
cfg.ServerURL,
&cfg.OIDC,
app.db,
app.state,
app.nodeNotifier,
app.ipAlloc,
app.polMan,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@@ -283,14 +265,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
var update types.StateUpdate
var changed bool
if err := h.db.Write(func(tx *gorm.DB) error {
lastExpiryCheck, update, changed = db.ExpireExpiredNodes(tx, lastExpiryCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
@@ -301,16 +276,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
case <-derpTickerChan:
log.Info().Msg("Fetching DERPMap updates")
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
derpMap := derp.GetDERPMap(h.cfg.DERP)
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion()
h.DERPMap.Regions[region.RegionID] = &region
derpMap.Regions[region.RegionID] = &region
}
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: h.DERPMap,
DERPMap: derpMap,
})
case records, ok := <-extraRecordsUpdate:
@@ -369,7 +344,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
)
}
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
if err != nil {
return ctx, status.Error(codes.Internal, "failed to validate token")
}
@@ -414,7 +389,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
return
}
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Error().
Caller().
@@ -497,7 +472,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.DERPMap))
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
}
apiRouter := router.PathPrefix("/api").Subrouter()
@@ -509,57 +484,57 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
return err
}
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
// users, err := db.ListUsers()
// if err != nil {
// return err
// }
changed, err := polMan.SetUsers(users)
if err != nil {
return err
}
// changed, err := polMan.SetUsers(users)
// if err != nil {
// return err
// }
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.UpdateFull())
}
// if changed {
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
// }
return nil
}
// return nil
// }
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes
func nodesChangedHook(
db *db.HSDatabase,
polMan policy.PolicyManager,
notif *notifier.Notifier,
) (bool, error) {
nodes, err := db.ListNodes()
if err != nil {
return false, err
}
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func nodesChangedHook(
// db *db.HSDatabase,
// polMan policy.PolicyManager,
// notif *notifier.Notifier,
// ) (bool, error) {
// nodes, err := db.ListNodes()
// if err != nil {
// return false, err
// }
filterChanged, err := polMan.SetNodes(nodes)
if err != nil {
return false, err
}
// filterChanged, err := polMan.SetNodes(nodes)
// if err != nil {
// return false, err
// }
if filterChanged {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.UpdateFull())
// if filterChanged {
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
return true, nil
}
// return true, nil
// }
return false, nil
}
// return false, nil
// }
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
@@ -588,9 +563,9 @@ func (h *Headscale) Serve() error {
Msg("Clients with a lower minimum version will be rejected")
// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan, h.primaryRoutes)
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
// TODO(kradalby): fix state part.
if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server
if h.cfg.DERP.STUNAddr == "" {
@@ -603,13 +578,13 @@ func (h *Headscale) Serve() error {
}
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
h.DERPMap.Regions[region.RegionID] = &region
h.state.DERPMap().Regions[region.RegionID] = &region
}
go h.DERPServer.ServeSTUN()
}
if len(h.DERPMap.Regions) == 0 {
if len(h.state.DERPMap().Regions) == 0 {
return errEmptyInitialDERPMap
}
@@ -618,7 +593,7 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will
// be cancelled.
go h.ephemeralGC.Start()
ephmNodes, err := h.db.ListEphemeralNodes()
ephmNodes, err := h.state.ListEphemeralNodes()
if err != nil {
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
}
@@ -853,29 +828,16 @@ func (h *Headscale) Serve() error {
continue
}
if err := h.loadPolicyManager(); err != nil {
log.Error().Err(err).Msg("failed to reload Policy")
}
pol, err := h.policyBytes()
changed, err := h.state.ReloadPolicy()
if err != nil {
log.Error().Err(err).Msg("failed to get policy blob")
}
changed, err := h.polMan.SetPolicy(pol)
if err != nil {
log.Error().Err(err).Msg("failed to set new policy")
log.Error().Err(err).Msgf("reloading policy")
continue
}
if changed {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
err = h.autoApproveNodes()
if err != nil {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
@@ -934,7 +896,7 @@ func (h *Headscale) Serve() error {
// Close db connections
info("closing database connection")
err = h.db.Close()
err = h.state.Close()
if err != nil {
log.Error().Err(err).Msg("failed to close db")
}
@@ -1085,124 +1047,3 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil
}
// policyBytes returns the appropriate policy for the
// current configuration as a []byte array.
func (h *Headscale) policyBytes() ([]byte, error) {
switch h.cfg.Policy.Mode {
case types.PolicyModeFile:
path := h.cfg.Policy.Path
// It is fine to start headscale without a policy file.
if len(path) == 0 {
return nil, nil
}
absPath := util.AbsolutePathFromConfigPath(path)
policyFile, err := os.Open(absPath)
if err != nil {
return nil, err
}
defer policyFile.Close()
return io.ReadAll(policyFile)
case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
}
if p.Data == "" {
return nil, nil
}
return []byte(p.Data), err
}
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
}
func (h *Headscale) loadPolicyManager() error {
var errOut error
h.polManOnce.Do(func() {
// Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes
// yet, but it should help for the general case and for hot reloading
// configurations.
// Note that this check is only done for file-based policies in this function
// as the database-based policies are checked in the gRPC API where it is not
// allowed to be written to the database.
nodes, err := h.db.ListNodes()
if err != nil {
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
return
}
users, err := h.db.ListUsers()
if err != nil {
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
return
}
pol, err := h.policyBytes()
if err != nil {
errOut = fmt.Errorf("loading policy bytes: %w", err)
return
}
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
if err != nil {
errOut = fmt.Errorf("creating policy manager: %w", err)
return
}
log.Info().Msgf("Using policy manager version: %d", h.polMan.Version())
if len(nodes) > 0 {
_, err = h.polMan.SSHPolicy(nodes[0])
if err != nil {
errOut = fmt.Errorf("verifying SSH rules: %w", err)
return
}
}
})
return errOut
}
// autoApproveNodes mass approves routes on all nodes. It is _only_ intended for
// use when the policy is replaced. It is not sending or reporting any changes
// or updates as we send full updates after replacing the policy.
// TODO(kradalby): This is kind of messy, maybe this is another +1
// for an event bus. See example comments here.
func (h *Headscale) autoApproveNodes() error {
err := h.db.Write(func(tx *gorm.DB) error {
nodes, err := db.ListNodes(tx)
if err != nil {
return err
}
for _, node := range nodes {
changed := policy.AutoApproveRoutes(h.polMan, node)
if changed {
err = tx.Save(node).Error
if err != nil {
return err
}
h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
}
}
return nil
})
if err != nil {
return fmt.Errorf("auto approving routes for nodes: %w", err)
}
return nil
}