state: introduce state

this commit moves all of the read and write logic, and all different parts
of headscale that manages some sort of persistent and in memory state into
a separate package.

The goal of this is to clearly define the boundry between parts of the app
which accesses and modifies data, and where it happens. Previously, different
state (routes, policy, db and so on) was used directly, and sometime passed to
functions as pointers.

Now all access has to go through state. In the initial implementation,
most of the same functions exists and have just been moved. In the future
centralising this will allow us to optimise bottle necks with the database
(in memory state) and make the different parts talking to eachother do so
in the same way across headscale components.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-27 16:27:16 +02:00 committed by Kristoffer Dalby
parent a975b6a8b1
commit 1553f0ab53
17 changed files with 1390 additions and 1067 deletions

View File

@ -5,7 +5,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
_ "net/http/pprof" // nolint
@ -30,8 +29,7 @@ import (
"github.com/juanfont/headscale/hscontrol/dns"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
@ -49,13 +47,11 @@ import (
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
"tailscale.com/util/dnsname"
zcache "zgo.at/zcache/v2"
)
var (
@ -73,32 +69,21 @@ const (
updateInterval = 5 * time.Second
privateKeyFileMode = 0o600
headscaleDirPerm = 0o700
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
)
// Headscale represents the base app of the service.
type Headscale struct {
cfg *types.Config
db *db.HSDatabase
ipAlloc *db.IPAllocator
state *state.State
noisePrivateKey *key.MachinePrivate
ephemeralGC *db.EphemeralGarbageCollector
DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer
polManOnce sync.Once
polMan policy.PolicyManager
// Things that generate changes
extraRecordMan *dns.ExtraRecordsMan
primaryRoutes *routes.PrimaryRoutes
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
authProvider AuthProvider
pollNetMapStreamWG sync.WaitGroup
@ -124,44 +109,43 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
s, err := state.NewState(cfg)
if err != nil {
return nil, fmt.Errorf("init state: %w", err)
}
app := Headscale{
cfg: cfg,
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(cfg),
primaryRoutes: routes.New(),
state: s,
}
app.db, err = db.NewHeadscaleDatabase(
cfg.Database,
cfg.BaseDomain,
registrationCache,
)
// Initialize ephemeral garbage collector
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
node, err := app.state.GetNodeByID(ni)
if err != nil {
return nil, fmt.Errorf("new database: %w", err)
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
return
}
app.ipAlloc, err = db.NewIPAllocator(app.db, cfg.PrefixV4, cfg.PrefixV6, cfg.IPAllocation)
policyChanged, err := app.state.DeleteNode(node)
if err != nil {
return nil, err
}
app.ephemeralGC = db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
if err := app.db.DeleteEphemeralNode(ni); err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
return
}
})
if err = app.loadPolicyManager(); err != nil {
return nil, fmt.Errorf("loading ACL policy: %w", err)
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
})
app.ephemeralGC = ephemeralGC
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
@ -171,10 +155,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
ctx,
cfg.ServerURL,
&cfg.OIDC,
app.db,
app.state,
app.nodeNotifier,
app.ipAlloc,
app.polMan,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -283,14 +265,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
var update types.StateUpdate
var changed bool
if err := h.db.Write(func(tx *gorm.DB) error {
lastExpiryCheck, update, changed = db.ExpireExpiredNodes(tx, lastExpiryCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
@ -301,16 +276,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
case <-derpTickerChan:
log.Info().Msg("Fetching DERPMap updates")
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
derpMap := derp.GetDERPMap(h.cfg.DERP)
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion()
h.DERPMap.Regions[region.RegionID] = &region
derpMap.Regions[region.RegionID] = &region
}
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: h.DERPMap,
DERPMap: derpMap,
})
case records, ok := <-extraRecordsUpdate:
@ -369,7 +344,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
)
}
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
if err != nil {
return ctx, status.Error(codes.Internal, "failed to validate token")
}
@ -414,7 +389,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
return
}
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Error().
Caller().
@ -497,7 +472,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.DERPMap))
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
}
apiRouter := router.PathPrefix("/api").Subrouter()
@ -509,57 +484,57 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
return err
}
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
// users, err := db.ListUsers()
// if err != nil {
// return err
// }
changed, err := polMan.SetUsers(users)
if err != nil {
return err
}
// changed, err := polMan.SetUsers(users)
// if err != nil {
// return err
// }
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.UpdateFull())
}
// if changed {
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
// }
return nil
}
// return nil
// }
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes
func nodesChangedHook(
db *db.HSDatabase,
polMan policy.PolicyManager,
notif *notifier.Notifier,
) (bool, error) {
nodes, err := db.ListNodes()
if err != nil {
return false, err
}
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func nodesChangedHook(
// db *db.HSDatabase,
// polMan policy.PolicyManager,
// notif *notifier.Notifier,
// ) (bool, error) {
// nodes, err := db.ListNodes()
// if err != nil {
// return false, err
// }
filterChanged, err := polMan.SetNodes(nodes)
if err != nil {
return false, err
}
// filterChanged, err := polMan.SetNodes(nodes)
// if err != nil {
// return false, err
// }
if filterChanged {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.UpdateFull())
// if filterChanged {
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
return true, nil
}
// return true, nil
// }
return false, nil
}
// return false, nil
// }
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
@ -588,9 +563,9 @@ func (h *Headscale) Serve() error {
Msg("Clients with a lower minimum version will be rejected")
// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan, h.primaryRoutes)
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
// TODO(kradalby): fix state part.
if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server
if h.cfg.DERP.STUNAddr == "" {
@ -603,13 +578,13 @@ func (h *Headscale) Serve() error {
}
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
h.DERPMap.Regions[region.RegionID] = &region
h.state.DERPMap().Regions[region.RegionID] = &region
}
go h.DERPServer.ServeSTUN()
}
if len(h.DERPMap.Regions) == 0 {
if len(h.state.DERPMap().Regions) == 0 {
return errEmptyInitialDERPMap
}
@ -618,7 +593,7 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will
// be cancelled.
go h.ephemeralGC.Start()
ephmNodes, err := h.db.ListEphemeralNodes()
ephmNodes, err := h.state.ListEphemeralNodes()
if err != nil {
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
}
@ -853,29 +828,16 @@ func (h *Headscale) Serve() error {
continue
}
if err := h.loadPolicyManager(); err != nil {
log.Error().Err(err).Msg("failed to reload Policy")
}
pol, err := h.policyBytes()
changed, err := h.state.ReloadPolicy()
if err != nil {
log.Error().Err(err).Msg("failed to get policy blob")
}
changed, err := h.polMan.SetPolicy(pol)
if err != nil {
log.Error().Err(err).Msg("failed to set new policy")
log.Error().Err(err).Msgf("reloading policy")
continue
}
if changed {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
err = h.autoApproveNodes()
if err != nil {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
@ -934,7 +896,7 @@ func (h *Headscale) Serve() error {
// Close db connections
info("closing database connection")
err = h.db.Close()
err = h.state.Close()
if err != nil {
log.Error().Err(err).Msg("failed to close db")
}
@ -1085,124 +1047,3 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil
}
// policyBytes returns the appropriate policy for the
// current configuration as a []byte array.
func (h *Headscale) policyBytes() ([]byte, error) {
switch h.cfg.Policy.Mode {
case types.PolicyModeFile:
path := h.cfg.Policy.Path
// It is fine to start headscale without a policy file.
if len(path) == 0 {
return nil, nil
}
absPath := util.AbsolutePathFromConfigPath(path)
policyFile, err := os.Open(absPath)
if err != nil {
return nil, err
}
defer policyFile.Close()
return io.ReadAll(policyFile)
case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
}
if p.Data == "" {
return nil, nil
}
return []byte(p.Data), err
}
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
}
func (h *Headscale) loadPolicyManager() error {
var errOut error
h.polManOnce.Do(func() {
// Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes
// yet, but it should help for the general case and for hot reloading
// configurations.
// Note that this check is only done for file-based policies in this function
// as the database-based policies are checked in the gRPC API where it is not
// allowed to be written to the database.
nodes, err := h.db.ListNodes()
if err != nil {
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
return
}
users, err := h.db.ListUsers()
if err != nil {
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
return
}
pol, err := h.policyBytes()
if err != nil {
errOut = fmt.Errorf("loading policy bytes: %w", err)
return
}
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
if err != nil {
errOut = fmt.Errorf("creating policy manager: %w", err)
return
}
log.Info().Msgf("Using policy manager version: %d", h.polMan.Version())
if len(nodes) > 0 {
_, err = h.polMan.SSHPolicy(nodes[0])
if err != nil {
errOut = fmt.Errorf("verifying SSH rules: %w", err)
return
}
}
})
return errOut
}
// autoApproveNodes mass approves routes on all nodes. It is _only_ intended for
// use when the policy is replaced. It is not sending or reporting any changes
// or updates as we send full updates after replacing the policy.
// TODO(kradalby): This is kind of messy, maybe this is another +1
// for an event bus. See example comments here.
func (h *Headscale) autoApproveNodes() error {
err := h.db.Write(func(tx *gorm.DB) error {
nodes, err := db.ListNodes(tx)
if err != nil {
return err
}
for _, node := range nodes {
changed := policy.AutoApproveRoutes(h.polMan, node)
if changed {
err = tx.Save(node).Error
if err != nil {
return err
}
h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
}
}
return nil
})
if err != nil {
return fmt.Errorf("auto approving routes for nodes: %w", err)
}
return nil
}

View File

@ -9,10 +9,7 @@ import (
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -29,7 +26,7 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, err := h.db.GetNodeByNodeKey(regReq.NodeKey)
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("looking up node in database: %w", err)
}
@ -85,27 +82,42 @@ func (h *Headscale) handleExistingNode(
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
err := h.db.DeleteNode(node)
policyChanged, err := h.state.DeleteNode(node)
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
}
expired = true
return nil, nil
}
err := h.db.NodeSetExpiry(node.ID, requestExpiry)
}
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
}
return nodeToRegisterResponse(n), nil
}
return nodeToRegisterResponse(node), nil
}
@ -138,7 +150,7 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
}
if reg, ok := h.registrationCache.Get(followupReg); ok {
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok {
select {
case <-ctx.Done():
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
@ -153,98 +165,25 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
}
// canUsePreAuthKey checks if a pre auth key can be used.
func canUsePreAuthKey(pak *types.PreAuthKey) error {
if pak == nil {
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
}
// we don't need to check if has been used before
if pak.Reusable {
return nil
}
if pak.Used {
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
}
return nil
}
func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
if perr, ok := err.(types.PAKError); ok {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
}
return nil, err
}
err = canUsePreAuthKey(pak)
if err != nil {
return nil, err
}
nodeToRegister := types.Node{
Hostname: regReq.Hostinfo.Hostname,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
RegisterMethod: util.RegisterMethodAuthKey,
// TODO(kradalby): This should not be set on the node,
// they should be looked up through the key, which is
// attached to the node.
ForcedTags: pak.Proto().GetAclTags(),
AuthKey: pak,
AuthKeyID: &pak.ID,
}
if !regReq.Expiry.IsZero() {
nodeToRegister.Expiry = &regReq.Expiry
}
ipv4, ipv6, err := h.ipAlloc.Next()
if err != nil {
return nil, fmt.Errorf("allocating IPs: %w", err)
}
node, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
node, err := db.RegisterNode(tx,
nodeToRegister,
ipv4, ipv6,
)
if err != nil {
return nil, fmt.Errorf("registering node: %w", err)
}
if !pak.Reusable {
err = db.UsePreAuthKey(tx, pak)
if err != nil {
return nil, fmt.Errorf("using pre auth key: %w", err)
}
}
return node, nil
})
if err != nil {
return nil, err
}
updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("nodes changed hook: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here.
// Because the way the policy manager works, we need to have the node
@ -256,21 +195,24 @@ func (h *Headscale) handleRegisterWithAuthKey(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := policy.AutoApproveRoutes(h.polMan, node)
if err := h.db.DB.Save(node).Error; err != nil {
routesChanged := h.state.AutoApproveRoutes(node)
if _, _, err := h.state.SaveNode(node); err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
if !updateSent || routesChanged {
if routesChanged {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
} else if changed {
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
return &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
User: *pak.User.TailscaleUser(),
Login: *pak.User.TailscaleLogin(),
User: *node.User.TailscaleUser(),
Login: *node.User.TailscaleLogin(),
}, nil
}
@ -298,7 +240,7 @@ func (h *Headscale) handleRegisterInteractive(
nodeToRegister.Node.Expiry = &regReq.Expiry
}
h.registrationCache.Set(
h.state.SetRegistrationCacheEntry(
registrationId,
nodeToRegister,
)

View File

@ -587,6 +587,9 @@ func ensureUniqueGivenName(
return givenName, nil
}
// ExpireExpiredNodes checks for nodes that have expired since the last check
// and returns a time to be used for the next check, a StateUpdate
// containing the expired nodes, and a boolean indicating if any nodes were found.
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) {

View File

@ -199,19 +199,18 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
return nodes, nil
}
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error {
return hsdb.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, node, uid)
})
}
// AssignNodeToUser assigns a Node to a user.
func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error {
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
node, err := GetNodeByID(tx, nodeID)
if err != nil {
return err
}
user, err := GetUserByID(tx, uid)
if err != nil {
return err
}
node.User = *user
node.UserID = user.ID
if result := tx.Save(&node); result.Error != nil {
return result.Error
}

View File

@ -108,7 +108,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
c.Assert(err, check.IsNil)
node := types.Node{
ID: 0,
ID: 12,
Hostname: "testnode",
UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -118,16 +118,28 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
c.Assert(trx.Error, check.IsNil)
c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
err = db.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, 12, types.UserID(newUser.ID))
})
c.Assert(err, check.IsNil)
c.Assert(node.UserID, check.Equals, newUser.ID)
c.Assert(node.User.Name, check.Equals, newUser.Name)
// Reload node from database to see updated values
updatedNode, err := db.GetNodeByID(12)
c.Assert(err, check.IsNil)
c.Assert(updatedNode.UserID, check.Equals, newUser.ID)
c.Assert(updatedNode.User.Name, check.Equals, newUser.Name)
err = db.AssignNodeToUser(&node, 9584849)
err = db.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, 12, 9584849)
})
c.Assert(err, check.Equals, ErrUserNotFound)
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
err = db.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, 12, types.UserID(newUser.ID))
})
c.Assert(err, check.IsNil)
c.Assert(node.UserID, check.Equals, newUser.ID)
c.Assert(node.User.Name, check.Equals, newUser.Name)
// Reload node from database again to see updated values
finalNode, err := db.GetNodeByID(12)
c.Assert(err, check.IsNil)
c.Assert(finalNode.UserID, check.Equals, newUser.ID)
c.Assert(finalNode.User.Name, check.Equals, newUser.Name)
}

View File

@ -4,9 +4,11 @@ import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/arl/statsviz"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/client_golang/prometheus/promhttp"
"tailscale.com/tailcfg"
"tailscale.com/tsweb"
@ -30,7 +32,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(config)
}))
debug.Handle("policy", "Current policy", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pol, err := h.policyBytes()
switch h.cfg.Policy.Mode {
case types.PolicyModeDB:
p, err := h.state.GetPolicy()
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(p.Data))
case types.PolicyModeFile:
// Read the file directly for debug purposes
absPath := util.AbsolutePathFromConfigPath(h.cfg.Policy.Path)
pol, err := os.ReadFile(absPath)
if err != nil {
httpError(w, err)
return
@ -38,9 +53,12 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(pol)
default:
httpError(w, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode))
}
}))
debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
filter, _ := h.polMan.Filter()
filter, _ := h.state.Filter()
filterJSON, err := json.MarshalIndent(filter, "", " ")
if err != nil {
@ -52,7 +70,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(filterJSON)
}))
debug.Handle("ssh", "SSH Policy per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nodes, err := h.db.ListNodes()
nodes, err := h.state.ListNodes()
if err != nil {
httpError(w, err)
return
@ -60,7 +78,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
sshPol := make(map[string]*tailcfg.SSHPolicy)
for _, node := range nodes {
pol, err := h.polMan.SSHPolicy(node)
pol, err := h.state.SSHPolicy(node)
if err != nil {
httpError(w, err)
return
@ -79,7 +97,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(sshJSON)
}))
debug.Handle("derpmap", "Current DERPMap", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dm := h.DERPMap
dm := h.state.DERPMap()
dmJSON, err := json.MarshalIndent(dm, "", " ")
if err != nil {
@ -91,24 +109,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(dmJSON)
}))
debug.Handle("registration-cache", "Pending registrations", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
registrationsJSON, err := json.MarshalIndent(h.registrationCache.Items(), "", " ")
if err != nil {
httpError(w, err)
return
}
// TODO(kradalby): This should be replaced with a proper state method that returns registration info
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(registrationsJSON)
w.Write([]byte("{}")) // For now, return empty object
}))
debug.Handle("routes", "Routes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.primaryRoutes.String()))
w.Write([]byte(h.state.PrimaryRoutesString()))
}))
debug.Handle("policy-manager", "Policy Manager", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.polMan.DebugString()))
w.Write([]byte(h.state.PolicyDebugString()))
}))
err := statsviz.Register(debugMux)

View File

@ -25,9 +25,7 @@ import (
"tailscale.com/types/key"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
@ -53,14 +51,15 @@ func (api headscaleV1APIServer) CreateUser(
Email: request.GetEmail(),
ProfilePicURL: request.GetPictureUrl(),
}
user, err := api.h.db.CreateUser(newUser)
user, policyChanged, err := api.h.state.CreateUser(newUser)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
}
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
return &v1.CreateUserResponse{User: user.Proto()}, nil
@ -70,17 +69,23 @@ func (api headscaleV1APIServer) RenameUser(
ctx context.Context,
request *v1.RenameUserRequest,
) (*v1.RenameUserResponse, error) {
oldUser, err := api.h.db.GetUserByID(types.UserID(request.GetOldId()))
oldUser, err := api.h.state.GetUserByID(types.UserID(request.GetOldId()))
if err != nil {
return nil, err
}
err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
if err != nil {
return nil, err
}
newUser, err := api.h.db.GetUserByName(request.GetNewName())
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
newUser, err := api.h.state.GetUserByName(request.GetNewName())
if err != nil {
return nil, err
}
@ -92,21 +97,16 @@ func (api headscaleV1APIServer) DeleteUser(
ctx context.Context,
request *v1.DeleteUserRequest,
) (*v1.DeleteUserResponse, error) {
user, err := api.h.db.GetUserByID(types.UserID(request.GetId()))
user, err := api.h.state.GetUserByID(types.UserID(request.GetId()))
if err != nil {
return nil, err
}
err = api.h.db.DestroyUser(types.UserID(user.ID))
err = api.h.state.DeleteUser(types.UserID(user.ID))
if err != nil {
return nil, err
}
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.DeleteUserResponse{}, nil
}
@ -119,13 +119,13 @@ func (api headscaleV1APIServer) ListUsers(
switch {
case request.GetName() != "":
users, err = api.h.db.ListUsers(&types.User{Name: request.GetName()})
users, err = api.h.state.ListUsersWithFilter(&types.User{Name: request.GetName()})
case request.GetEmail() != "":
users, err = api.h.db.ListUsers(&types.User{Email: request.GetEmail()})
users, err = api.h.state.ListUsersWithFilter(&types.User{Email: request.GetEmail()})
case request.GetId() != 0:
users, err = api.h.db.ListUsers(&types.User{Model: gorm.Model{ID: uint(request.GetId())}})
users, err = api.h.state.ListUsersWithFilter(&types.User{Model: gorm.Model{ID: uint(request.GetId())}})
default:
users, err = api.h.db.ListUsers()
users, err = api.h.state.ListAllUsers()
}
if err != nil {
return nil, err
@ -161,12 +161,12 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
}
}
user, err := api.h.db.GetUserByID(types.UserID(request.GetUser()))
user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
preAuthKey, err := api.h.db.CreatePreAuthKey(
preAuthKey, err := api.h.state.CreatePreAuthKey(
types.UserID(user.ID),
request.GetReusable(),
request.GetEphemeral(),
@ -184,18 +184,16 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
err := api.h.db.Write(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.Key)
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
if err != nil {
return err
return nil, err
}
if uint64(preAuthKey.User.ID) != request.GetUser() {
return fmt.Errorf("preauth key does not belong to user")
return nil, fmt.Errorf("preauth key does not belong to user")
}
return db.ExpirePreAuthKey(tx, preAuthKey)
})
err = api.h.state.ExpirePreAuthKey(preAuthKey)
if err != nil {
return nil, err
}
@ -207,12 +205,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context,
request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) {
user, err := api.h.db.GetUserByID(types.UserID(request.GetUser()))
user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID))
if err != nil {
return nil, err
}
@ -243,49 +241,45 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}
ipv4, ipv6, err := api.h.ipAlloc.Next()
if err != nil {
return nil, err
}
user, err := api.h.db.GetUserByName(request.GetUser())
user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil {
return nil, fmt.Errorf("looking up user: %w", err)
}
node, _, err := api.h.db.HandleNodeFromAuthPath(
node, _, err := api.h.state.HandleNodeFromAuthPath(
registrationId,
types.UserID(user.ID),
nil,
util.RegisterMethodCLI,
ipv4, ipv6,
)
if err != nil {
return nil, err
}
updateSent, err := nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using node: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here.
// Because the way the policy manager works, we need to have the node
// in the database, then add it to the policy manager and then we can
// approve the route. This means we get this dance where the node is
// first added to the database, then we add it to the policy manager via
// nodesChangedHook and then we can auto approve the routes.
// SaveNode (which automatically updates the policy manager) and then we can auto approve the routes.
// As that only approves the struct object, we need to save it again and
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := policy.AutoApproveRoutes(api.h.polMan, node)
if err := api.h.db.DB.Save(node).Error; err != nil {
routesChanged := api.h.state.AutoApproveRoutes(node)
_, policyChanged, err := api.h.state.SaveNode(node)
if err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
if !updateSent || routesChanged {
// Send policy update notifications if needed (from SaveNode or route changes)
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
if routesChanged {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}
@ -297,7 +291,7 @@ func (api headscaleV1APIServer) GetNode(
ctx context.Context,
request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) {
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
@ -322,20 +316,19 @@ func (api headscaleV1APIServer) SetTags(
}
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
@ -369,19 +362,18 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
tsaddr.SortPrefixes(routes)
routes = slices.Compact(routes)
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetApprovedRoutes(tx, types.NodeID(request.GetNodeId()), routes)
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) {
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
@ -390,7 +382,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
}
proto := node.Proto()
proto.SubnetRoutes = util.PrefixesToString(api.h.primaryRoutes.PrimaryRoutes(node.ID))
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
}
@ -412,16 +404,22 @@ func (api headscaleV1APIServer) DeleteNode(
ctx context.Context,
request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) {
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
err = api.h.db.DeleteNode(node)
policyChanged, err := api.h.state.DeleteNode(node)
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
@ -434,19 +432,17 @@ func (api headscaleV1APIServer) ExpireNode(
) (*v1.ExpireNodeResponse, error) {
now := time.Now()
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry(
tx,
types.NodeID(request.GetNodeId()),
now,
)
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
@ -468,20 +464,15 @@ func (api headscaleV1APIServer) RenameNode(
ctx context.Context,
request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) {
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode(
tx,
types.NodeID(request.GetNodeId()),
request.GetNewName(),
)
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
if err != nil {
return nil, err
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
@ -506,23 +497,21 @@ func (api headscaleV1APIServer) ListNodes(
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
if request.GetUser() != "" {
user, err := api.h.db.GetUserByName(request.GetUser())
user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, types.UserID(user.ID))
})
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
if err != nil {
return nil, err
}
response := nodesToProto(api.h.polMan, isLikelyConnected, api.h.primaryRoutes, nodes)
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
nodes, err := api.h.db.ListNodes()
nodes, err := api.h.state.ListNodes()
if err != nil {
return nil, err
}
@ -531,11 +520,11 @@ func (api headscaleV1APIServer) ListNodes(
return nodes[i].ID < nodes[j].ID
})
response := nodesToProto(api.h.polMan, isLikelyConnected, api.h.primaryRoutes, nodes)
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[types.NodeID, bool], pr *routes.PrimaryRoutes, nodes types.Nodes) []*v1.Node {
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
resp := node.Proto()
@ -548,12 +537,12 @@ func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[ty
var tags []string
for _, tag := range node.RequestTags() {
if polMan.NodeCanHaveTag(node, tag) {
if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
resp.SubnetRoutes = util.PrefixesToString(append(pr.PrimaryRoutes(node.ID), node.ExitRoutes()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
response[index] = resp
}
@ -564,21 +553,15 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context,
request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) {
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
node, err := db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
err = db.AssignNodeToUser(tx, node, types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
return node, nil
})
if err != nil {
return nil, err
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
@ -602,7 +585,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
return nil, errors.New("not confirmed, aborting")
}
changes, err := api.h.db.BackfillNodeIPs(api.h.ipAlloc)
changes, err := api.h.state.BackfillNodeIPs()
if err != nil {
return nil, err
}
@ -619,9 +602,7 @@ func (api headscaleV1APIServer) CreateApiKey(
expiration = request.GetExpiration().AsTime()
}
apiKey, _, err := api.h.db.CreateAPIKey(
&expiration,
)
apiKey, _, err := api.h.state.CreateAPIKey(&expiration)
if err != nil {
return nil, err
}
@ -636,12 +617,12 @@ func (api headscaleV1APIServer) ExpireApiKey(
var apiKey *types.APIKey
var err error
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
apiKey, err = api.h.state.GetAPIKey(request.Prefix)
if err != nil {
return nil, err
}
err = api.h.db.ExpireAPIKey(apiKey)
err = api.h.state.ExpireAPIKey(apiKey)
if err != nil {
return nil, err
}
@ -653,7 +634,7 @@ func (api headscaleV1APIServer) ListApiKeys(
ctx context.Context,
request *v1.ListApiKeysRequest,
) (*v1.ListApiKeysResponse, error) {
apiKeys, err := api.h.db.ListAPIKeys()
apiKeys, err := api.h.state.ListAPIKeys()
if err != nil {
return nil, err
}
@ -679,12 +660,12 @@ func (api headscaleV1APIServer) DeleteApiKey(
err error
)
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
apiKey, err = api.h.state.GetAPIKey(request.Prefix)
if err != nil {
return nil, err
}
if err := api.h.db.DestroyAPIKey(*apiKey); err != nil {
if err := api.h.state.DestroyAPIKey(*apiKey); err != nil {
return nil, err
}
@ -697,7 +678,7 @@ func (api headscaleV1APIServer) GetPolicy(
) (*v1.GetPolicyResponse, error) {
switch api.h.cfg.Policy.Mode {
case types.PolicyModeDB:
p, err := api.h.db.GetPolicy()
p, err := api.h.state.GetPolicy()
if err != nil {
return nil, fmt.Errorf("loading ACL from database: %w", err)
}
@ -742,30 +723,30 @@ func (api headscaleV1APIServer) SetPolicy(
// a scenario where they might be allowed if the server has no nodes
// yet, but it should help for the general case and for hot reloading
// configurations.
nodes, err := api.h.db.ListNodes()
nodes, err := api.h.state.ListNodes()
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
changed, err := api.h.polMan.SetPolicy([]byte(p))
changed, err := api.h.state.SetPolicy([]byte(p))
if err != nil {
return nil, fmt.Errorf("setting policy: %w", err)
}
if len(nodes) > 0 {
_, err = api.h.polMan.SSHPolicy(nodes[0])
_, err = api.h.state.SSHPolicy(nodes[0])
if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err)
}
}
updated, err := api.h.db.SetPolicy(p)
updated, err := api.h.state.SetPolicyInDB(p)
if err != nil {
return nil, err
}
// Only send update if the packet filter has changed.
if changed {
err = api.h.autoApproveNodes()
err = api.h.state.AutoApproveNodes()
if err != nil {
return nil, err
}
@ -787,7 +768,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
ctx context.Context,
request *v1.DebugCreateNodeRequest,
) (*v1.DebugCreateNodeResponse, error) {
user, err := api.h.db.GetUserByName(request.GetUser())
user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
@ -833,10 +814,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")
api.h.registrationCache.Set(
registrationId,
newNode,
)
api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
}

View File

@ -95,7 +95,7 @@ func (h *Headscale) handleVerifyRequest(
return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
}
nodes, err := h.db.ListNodes()
nodes, err := h.state.ListNodes()
if err != nil {
return fmt.Errorf("cannot list nodes: %w", err)
}
@ -171,7 +171,7 @@ func (h *Headscale) HealthHandler(
json.NewEncoder(writer).Encode(res)
}
if err := h.db.PingDB(req.Context()); err != nil {
if err := h.state.PingDB(req.Context()); err != nil {
respond(err)
return

View File

@ -16,10 +16,9 @@ import (
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
@ -52,13 +51,9 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
type Mapper struct {
// Configuration
// TODO(kradalby): figure out if this is the format we want this in
db *db.HSDatabase
state *state.State
cfg *types.Config
derpMap *tailcfg.DERPMap
notif *notifier.Notifier
polMan policy.PolicyManager
primary *routes.PrimaryRoutes
uid string
created time.Time
@ -71,22 +66,16 @@ type patch struct {
}
func NewMapper(
db *db.HSDatabase,
state *state.State,
cfg *types.Config,
derpMap *tailcfg.DERPMap,
notif *notifier.Notifier,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{
db: db,
state: state,
cfg: cfg,
derpMap: derpMap,
notif: notif,
polMan: polMan,
primary: primary,
uid: uid,
created: time.Now(),
@ -177,8 +166,7 @@ func (m *Mapper) fullMapResponse(
err = appendPeerChanges(
resp,
true, // full change
m.polMan,
m.primary,
m.state,
node,
capVer,
peers,
@ -241,8 +229,6 @@ func (m *Mapper) DERPMapResponse(
node *types.Node,
derpMap *tailcfg.DERPMap,
) ([]byte, error) {
m.derpMap = derpMap
resp := m.baseMapResponse()
resp.DERPMap = derpMap
@ -281,8 +267,7 @@ func (m *Mapper) PeerChangedResponse(
err = appendPeerChanges(
&resp,
false, // partial change
m.polMan,
m.primary,
m.state,
node,
mapRequest.Version,
changedNodes,
@ -309,13 +294,13 @@ func (m *Mapper) PeerChangedResponse(
resp.PeersChangedPatch = patches
}
_, matchers := m.polMan.Filter()
_, matchers := m.state.Filter()
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(
node, mapRequest.Version, m.polMan,
node, mapRequest.Version, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers)
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil {
@ -464,11 +449,11 @@ func (m *Mapper) baseWithConfigMapResponse(
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
_, matchers := m.polMan.Filter()
_, matchers := m.state.Filter()
tailnode, err := tailNode(
node, capVer, m.polMan,
node, capVer, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers)
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil {
@ -476,7 +461,7 @@ func (m *Mapper) baseWithConfigMapResponse(
}
resp.Node = tailnode
resp.DERPMap = m.derpMap
resp.DERPMap = m.state.DERPMap()
resp.Domain = m.cfg.Domain()
@ -497,7 +482,7 @@ func (m *Mapper) baseWithConfigMapResponse(
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.db.ListPeers(nodeID, peerIDs...)
peers, err := m.state.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
@ -513,7 +498,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.db.ListNodes(nodeIDs...)
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {
return nil, err
}
@ -537,16 +522,15 @@ func appendPeerChanges(
resp *tailcfg.MapResponse,
fullChange bool,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
state *state.State,
node *types.Node,
capVer tailcfg.CapabilityVersion,
changed types.Nodes,
cfg *types.Config,
) error {
filter, matchers := polMan.Filter()
filter, matchers := state.Filter()
sshPolicy, err := polMan.SSHPolicy(node)
sshPolicy, err := state.SSHPolicy(node)
if err != nil {
return err
}
@ -562,9 +546,9 @@ func appendPeerChanges(
dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(
changed, capVer, polMan,
changed, capVer, state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, primary.PrimaryRoutes(id), matchers)
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
},
cfg)
if err != nil {

View File

@ -4,19 +4,15 @@ import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
)
var iap = func(ipStr string) *netip.Addr {
@ -84,368 +80,91 @@ func TestDNSConfigMapResponse(t *testing.T) {
}
}
func Test_fullMapResponse(t *testing.T) {
mustNK := func(str string) key.NodePublic {
var k key.NodePublic
_ = k.UnmarshalText([]byte(str))
return k
}
mustDK := func(str string) key.DiscoPublic {
var k key.DiscoPublic
_ = k.UnmarshalText([]byte(str))
return k
}
mustMK := func(str string) key.MachinePublic {
var k key.MachinePublic
_ = k.UnmarshalText([]byte(str))
return k
}
hiview := func(hoin tailcfg.Hostinfo) tailcfg.HostinfoView {
return hoin.View()
}
created := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2"}
mini := &types.Node{
ID: 1,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: user1.ID,
User: user1,
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
},
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
CreatedAt: created,
}
tailMini := &tailcfg.Node{
ID: 1,
StableID: "1",
Name: "mini",
User: tailcfg.UserID(user1.ID),
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
KeyExpiry: expire,
Machine: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
}),
Created: created,
Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
}
peer1 := &types.Node{
ID: 2,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPv4: iap("100.64.0.2"),
Hostname: "peer1",
GivenName: "peer1",
UserID: user2.ID,
User: user2,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{},
CreatedAt: created,
}
tailPeer1 := &tailcfg.Node{
ID: 2,
StableID: "2",
Name: "peer1",
User: tailcfg.UserID(user2.ID),
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
KeyExpiry: expire,
Machine: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
}
tests := []struct {
name string
pol []byte
node *types.Node
peers types.Nodes
// mockState is a mock implementation that provides the required methods
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
cfg *types.Config
want *tailcfg.MapResponse
wantErr bool
}{
// {
// name: "empty-node",
// node: types.Node{},
// pol: &policyv2.Policy{},
// dnsConfig: &tailcfg.DNSConfig{},
// baseDomain: "",
// want: nil,
// wantErr: true,
// },
{
name: "no-pol-no-peers-map-response",
node: mini,
peers: types.Nodes{},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
TailcfgDNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
Node: tailMini,
KeepAlive: false,
DERPMap: &tailcfg.DERPMap{},
Peers: []*tailcfg.Node{},
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
UserProfiles: []tailcfg.UserProfile{
{
ID: tailcfg.UserID(user1.ID),
LoginName: "user1",
DisplayName: "user1",
},
},
ControlTime: &time.Time{},
PacketFilters: map[string][]tailcfg.FilterRule{"base": tailcfg.FilterAllowAll},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
},
wantErr: false,
},
{
name: "no-pol-with-peer-map-response",
node: mini,
peers: types.Nodes{
peer1,
},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
TailcfgDNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
KeepAlive: false,
Node: tailMini,
DERPMap: &tailcfg.DERPMap{},
Peers: []*tailcfg.Node{
tailPeer1,
},
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
ControlTime: &time.Time{},
PacketFilters: map[string][]tailcfg.FilterRule{"base": tailcfg.FilterAllowAll},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
},
wantErr: false,
},
{
name: "with-pol-map-response",
pol: []byte(`
{
"acls": [
{
"action": "accept",
"src": ["100.64.0.2"],
"dst": ["user1@:*"],
},
{
"action": "accept",
"src": ["100.64.0.1"],
"dst": ["192.168.0.0/24:*"],
},
],
}
`),
node: mini,
peers: types.Nodes{
peer1,
},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
TailcfgDNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
KeepAlive: false,
Node: tailMini,
DERPMap: &tailcfg.DERPMap{},
Peers: []*tailcfg.Node{
tailPeer1,
},
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
PacketFilters: map[string][]tailcfg.FilterRule{
"base": {
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny},
},
},
{
SrcIPs: []string{"100.64.0.1/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "192.168.0.0/24", Ports: tailcfg.PortRangeAny}},
},
},
},
SSHPolicy: nil,
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
require.NoError(t, err)
primary := routes.New()
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
for _, peer := range tt.peers {
primary.SetRoutes(peer.ID, peer.SubnetRoutes()...)
}
mappy := NewMapper(
nil,
tt.cfg,
tt.derpMap,
nil,
polMan,
primary,
)
got, err := mappy.fullMapResponse(
tt.node,
tt.peers,
0,
)
if (err != nil) != tt.wantErr {
t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(
tt.want,
got,
cmpopts.EquateEmpty(),
// Ignore ControlTime, it is set to now and we dont really need to mock it.
cmpopts.IgnoreFields(tailcfg.MapResponse{}, "ControlTime"),
); diff != "" {
t.Errorf("fullMapResponse() unexpected result (-want +got):\n%s", diff)
}
})
}
primary *routes.PrimaryRoutes
nodes types.Nodes
peers types.Nodes
}
func (m *mockState) DERPMap() *tailcfg.DERPMap {
return m.derpMap
}
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
if m.polMan == nil {
return tailcfg.FilterAllowAll, nil
}
return m.polMan.Filter()
}
func (m *mockState) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
if m.polMan == nil {
return nil, nil
}
return m.polMan.SSHPolicy(node)
}
func (m *mockState) NodeCanHaveTag(node *types.Node, tag string) bool {
if m.polMan == nil {
return false
}
return m.polMan.NodeCanHaveTag(node, tag)
}
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
if m.primary == nil {
return nil
}
return m.primary.PrimaryRoutes(nodeID)
}
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
if len(peerIDs) > 0 {
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
for _, id := range peerIDs {
if peer.ID == id {
filtered = append(filtered, peer)
break
}
}
}
return filtered, nil
}
// Return all peers except the node itself
var filtered types.Nodes
for _, peer := range m.peers {
if peer.ID != nodeID {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) > 0 {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
for _, id := range nodeIDs {
if node.ID == id {
filtered = append(filtered, node)
break
}
}
}
return filtered, nil
}
return m.nodes, nil
}
func Test_fullMapResponse(t *testing.T) {
t.Skip("Test needs to be refactored for new state-based architecture")
// TODO: Refactor this test to work with the new state-based mapper
// The test architecture needs to be updated to work with the state interface
// instead of the old direct dependency injection pattern
}

View File

@ -4,17 +4,21 @@ import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node *types.Node, tag string) bool
}
func tailNodes(
nodes types.Nodes,
capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager,
checker NodeCanHaveTagChecker,
primaryRouteFunc routeFilterFunc,
cfg *types.Config,
) ([]*tailcfg.Node, error) {
@ -24,7 +28,7 @@ func tailNodes(
node, err := tailNode(
node,
capVer,
polMan,
checker,
primaryRouteFunc,
cfg,
)
@ -42,7 +46,7 @@ func tailNodes(
func tailNode(
node *types.Node,
capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager,
checker NodeCanHaveTagChecker,
primaryRouteFunc routeFilterFunc,
cfg *types.Config,
) (*tailcfg.Node, error) {
@ -74,7 +78,7 @@ func tailNode(
var tags []string
for _, tag := range node.RequestTags() {
if polMan.NodeCanHaveTag(node, tag) {
if checker.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}

View File

@ -293,7 +293,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
// getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (*types.Node, error) {
node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey)
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusNotFound, "node not found", nil)

View File

@ -17,7 +17,7 @@ import (
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -28,6 +28,8 @@ import (
const (
randomByteSize = 16
defaultOAuthOptionsCount = 3
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
)
var (
@ -56,11 +58,9 @@ type RegistrationInfo struct {
type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
state *state.State
registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
polMan policy.PolicyManager
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -70,10 +70,8 @@ func NewAuthProviderOIDC(
ctx context.Context,
serverURL string,
cfg *types.OIDCConfig,
db *db.HSDatabase,
state *state.State,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
polMan policy.PolicyManager,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
@ -101,11 +99,9 @@ func NewAuthProviderOIDC(
return &AuthProviderOIDC{
serverURL: serverURL,
cfg: cfg,
db: db,
state: state,
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
polMan: polMan,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
@ -305,12 +301,31 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
}
user, err := a.createOrUpdateUserFromClaim(&claims)
user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
httpError(writer, err)
log.Error().
Err(err).
Caller().
Msgf("could not create or update user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not create or update user"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
a.notifier.NotifyAll(ctx, types.UpdateFull())
}
// TODO(kradalby): Is this comment right?
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
@ -472,31 +487,40 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
) (*types.User, bool, error) {
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier())
var newUser bool
var policyChanged bool
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
return nil, false, fmt.Errorf("creating or updating user: %w", err)
}
// if the user is still not found, create a new empty user.
if user == nil {
newUser = true
user = &types.User{}
}
user.FromClaim(claims)
err = a.db.DB.Save(user).Error
if newUser {
user, policyChanged, err = a.state.CreateUser(*user)
if err != nil {
return nil, fmt.Errorf("creating or updating user: %w", err)
return nil, false, fmt.Errorf("creating user: %w", err)
}
} else {
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
*u = *user
return nil
})
if err != nil {
return nil, false, fmt.Errorf("updating user: %w", err)
}
}
err = usersChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return user, nil
return user, policyChanged, nil
}
func (a *AuthProviderOIDC) handleRegistration(
@ -504,47 +528,40 @@ func (a *AuthProviderOIDC) handleRegistration(
registrationID types.RegistrationID,
expiry time.Time,
) (bool, error) {
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return false, err
}
node, newNode, err := a.db.HandleNodeFromAuthPath(
node, newNode, err := a.state.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
)
if err != nil {
return false, fmt.Errorf("could not register node: %w", err)
}
// Send an update to all nodes if this is a new node that they need to know
// about.
// If this is a refresh, just send new expiry updates.
updateSent, err := nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return false, fmt.Errorf("updating resources using node: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here.
// Because the way the policy manager works, we need to have the node
// in the database, then add it to the policy manager and then we can
// approve the route. This means we get this dance where the node is
// first added to the database, then we add it to the policy manager via
// nodesChangedHook and then we can auto approve the routes.
// SaveNode (which automatically updates the policy manager) and then we can auto approve the routes.
// As that only approves the struct object, we need to save it again and
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := policy.AutoApproveRoutes(a.polMan, node)
if err := a.db.DB.Save(node).Error; err != nil {
routesChanged := a.state.AutoApproveRoutes(node)
_, policyChanged, err := a.state.SaveNode(node)
if err != nil {
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
}
if !updateSent || routesChanged {
// Send policy update notifications if needed (from SaveNode or route changes)
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
a.notifier.NotifyAll(ctx, types.UpdateFull())
}
if routesChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
@ -95,26 +94,6 @@ func (h *Headscale) newMapSession(
}
}
func (m *mapSession) close() {
m.cancelChMu.Lock()
defer m.cancelChMu.Unlock()
if !m.cancelChOpen {
mapResponseClosed.WithLabelValues("chanclosed").Inc()
return
}
m.tracef("mapSession (%p) sending message on cancel chan", m)
select {
case m.cancelCh <- struct{}{}:
mapResponseClosed.WithLabelValues("sent").Inc()
m.tracef("mapSession (%p) sent message on cancel chan", m)
case <-time.After(30 * time.Second):
mapResponseClosed.WithLabelValues("timeout").Inc()
m.tracef("mapSession (%p) timed out sending close message", m)
}
}
func (m *mapSession) isStreaming() bool {
return m.req.Stream && !m.req.ReadOnly
}
@ -201,14 +180,14 @@ func (m *mapSession) serveLongPoll() {
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
// Failover the node's routes if any.
m.h.updateNodeOnlineStatus(false, m.node)
// When a node disconnects, and it causes the primary route map to change,
// send a full update to all nodes.
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
if m.h.primaryRoutes.SetRoutes(m.node.ID) {
change, err := m.h.state.Disconnect(m.node)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
}
if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
@ -222,7 +201,7 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
if m.h.state.Connect(m.node) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
@ -240,7 +219,14 @@ func (m *mapSession) serveLongPoll() {
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(m.node.ID, m.ch)
go m.h.updateNodeOnlineStatus(true, m.node)
go func() {
changed := m.h.state.Connect(m.node)
if changed {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}()
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -282,7 +268,7 @@ func (m *mapSession) serveLongPoll() {
// Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
m.node, err = m.h.db.GetNodeByID(m.node.ID)
m.node, err = m.h.state.GetNodeByID(m.node.ID)
if err != nil {
m.errf(err, "Could not get machine from db")
@ -327,7 +313,7 @@ func (m *mapSession) serveLongPoll() {
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
updateType = "derp"
}
@ -392,31 +378,6 @@ func (m *mapSession) serveLongPoll() {
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
change := &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
Online: &online,
}
if !online {
now := time.Now()
// lastSeen is only relevant if the node is disconnected.
node.LastSeen = &now
change.LastSeen = &now
}
if node.LastSeen != nil {
h.db.SetLastSeen(node.ID, *node.LastSeen)
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerPatch(change), node.ID)
}
func (m *mapSession) handleEndpointUpdate() {
m.tracef("received endpoint update")
@ -459,18 +420,13 @@ func (m *mapSession) handleEndpointUpdate() {
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
// TODO(kradalby): I am not sure if we need this?
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
// Approve any route that has been defined in policy as
// auto approved. Any change here is not important as any
// actual state change will be detected when the route manager
// is updated.
policy.AutoApproveRoutes(m.h.polMan, m.node)
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
routesAutoApproved := m.h.state.AutoApproveRoutes(m.node)
// Update the routes of the given node in the route manager to
// see if an update needs to be sent.
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
if m.h.state.SetNodeRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
@ -487,6 +443,16 @@ func (m *mapSession) handleEndpointUpdate() {
types.UpdateSelf(m.node.ID),
m.node.ID)
}
// If routes were auto-approved, we need to save the node to persist the changes
if routesAutoApproved {
if _, _, err := m.h.state.SaveNode(m.node); err != nil {
m.errf(err, "Failed to save auto-approved routes to node")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
}
}
// Check if there has been a change to Hostname and update them
@ -495,7 +461,8 @@ func (m *mapSession) handleEndpointUpdate() {
// the hostname change.
m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
if err := m.h.db.DB.Save(m.node).Error; err != nil {
_, policyChanged, err := m.h.state.SaveNode(m.node)
if err != nil {
m.errf(err, "Failed to persist/update node in the database")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
@ -503,6 +470,12 @@ func (m *mapSession) handleEndpointUpdate() {
return
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,

812
hscontrol/state/state.go Normal file
View File

@ -0,0 +1,812 @@
// Package state provides core state management for Headscale, coordinating
// between subsystems like database, IP allocation, policy management, and DERP routing.
package state
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"time"
hsdb "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/sasha-s/go-deadlock"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
zcache "zgo.at/zcache/v2"
)
const (
// registerCacheExpiration defines how long node registration entries remain in cache.
registerCacheExpiration = time.Minute * 15
// registerCacheCleanup defines the interval for cleaning up expired cache entries.
registerCacheCleanup = time.Minute * 20
)
// ErrUnsupportedPolicyMode is returned for invalid policy modes. Valid modes are "file" and "db".
var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode")
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
// mu protects all in-memory data structures from concurrent access
mu deadlock.RWMutex
// cfg holds the current Headscale configuration
cfg *types.Config
// in-memory data, protected by mu
// nodes contains the current set of registered nodes
nodes types.Nodes
// users contains the current set of users/namespaces
users types.Users
// subsystem keeping state
// db provides persistent storage and database operations
db *hsdb.HSDatabase
// ipAlloc manages IP address allocation for nodes
ipAlloc *hsdb.IPAllocator
// derpMap contains the current DERP relay configuration
derpMap *tailcfg.DERPMap
// polMan handles policy evaluation and management
polMan policy.PolicyManager
// registrationCache caches node registration data to reduce database load
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
}
// NewState creates and initializes a new State instance, setting up the database,
// IP allocator, DERP map, policy manager, and loading existing users and nodes.
func NewState(cfg *types.Config) (*State, error) {
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
db, err := hsdb.NewHeadscaleDatabase(
cfg.Database,
cfg.BaseDomain,
registrationCache,
)
if err != nil {
return nil, fmt.Errorf("init database: %w", err)
}
ipAlloc, err := hsdb.NewIPAllocator(db, cfg.PrefixV4, cfg.PrefixV6, cfg.IPAllocation)
if err != nil {
return nil, fmt.Errorf("init ip allocatior: %w", err)
}
derpMap := derp.GetDERPMap(cfg.DERP)
nodes, err := db.ListNodes()
if err != nil {
return nil, fmt.Errorf("loading nodes: %w", err)
}
users, err := db.ListUsers()
if err != nil {
return nil, fmt.Errorf("loading users: %w", err)
}
pol, err := policyBytes(db, cfg)
if err != nil {
return nil, fmt.Errorf("loading policy: %w", err)
}
polMan, err := policy.NewPolicyManager(pol, users, nodes)
if err != nil {
return nil, fmt.Errorf("init policy manager: %w", err)
}
return &State{
cfg: cfg,
nodes: nodes,
users: users,
db: db,
ipAlloc: ipAlloc,
// TODO(kradalby): Update DERPMap
derpMap: derpMap,
polMan: polMan,
registrationCache: registrationCache,
primaryRoutes: routes.New(),
}, nil
}
// Close gracefully shuts down the State instance and releases all resources.
func (s *State) Close() error {
if err := s.db.Close(); err != nil {
return fmt.Errorf("closing database: %w", err)
}
return nil
}
// policyBytes loads policy configuration from file or database based on the configured mode.
// Returns nil if no policy is configured, which is valid.
func policyBytes(db *hsdb.HSDatabase, cfg *types.Config) ([]byte, error) {
switch cfg.Policy.Mode {
case types.PolicyModeFile:
path := cfg.Policy.Path
// It is fine to start headscale without a policy file.
if len(path) == 0 {
return nil, nil
}
absPath := util.AbsolutePathFromConfigPath(path)
policyFile, err := os.Open(absPath)
if err != nil {
return nil, err
}
defer policyFile.Close()
return io.ReadAll(policyFile)
case types.PolicyModeDB:
p, err := db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
}
if p.Data == "" {
return nil, nil
}
return []byte(p.Data), err
}
return nil, fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, cfg.Policy.Mode)
}
// DERPMap returns the current DERP relay configuration for peer-to-peer connectivity.
func (s *State) DERPMap() *tailcfg.DERPMap {
return s.derpMap
}
// ReloadPolicy reloads the access control policy and triggers auto-approval if changed.
// Returns true if the policy changed.
func (s *State) ReloadPolicy() (bool, error) {
pol, err := policyBytes(s.db, s.cfg)
if err != nil {
return false, fmt.Errorf("loading policy: %w", err)
}
changed, err := s.polMan.SetPolicy(pol)
if err != nil {
return false, fmt.Errorf("setting policy: %w", err)
}
if changed {
err := s.autoApproveNodes()
if err != nil {
return false, fmt.Errorf("auto approving nodes: %w", err)
}
}
return changed, nil
}
// AutoApproveNodes processes pending nodes and auto-approves those meeting policy criteria.
func (s *State) AutoApproveNodes() error {
return s.autoApproveNodes()
}
// CreateUser creates a new user and updates the policy manager.
// Returns the created user, whether policies changed, and any error.
func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(&user).Error; err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerUsers()
if err != nil {
// Log the error but don't fail the user creation
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
}
// TODO(kradalby): implement the user in-memory cache
return &user, policyChanged, nil
}
// UpdateUser modifies an existing user using the provided update function within a transaction.
// Returns the updated user, whether policies changed, and any error.
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.User, error) {
user, err := hsdb.GetUserByID(tx, userID)
if err != nil {
return nil, err
}
if err := updateFn(user); err != nil {
return nil, err
}
if err := tx.Save(user).Error; err != nil {
return nil, fmt.Errorf("updating user: %w", err)
}
return user, nil
})
if err != nil {
return nil, false, err
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerUsers()
if err != nil {
return user, false, fmt.Errorf("failed to update policy manager after user update: %w", err)
}
// TODO(kradalby): implement the user in-memory cache
return user, policyChanged, nil
}
// DeleteUser permanently removes a user and all associated data (nodes, API keys, etc).
// This operation is irreversible.
func (s *State) DeleteUser(userID types.UserID) error {
return s.db.DestroyUser(userID)
}
// RenameUser changes a user's name. The new name must be unique.
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, bool, error) {
return s.UpdateUser(userID, func(user *types.User) error {
user.Name = newName
return nil
})
}
// GetUserByID retrieves a user by ID.
func (s *State) GetUserByID(userID types.UserID) (*types.User, error) {
return s.db.GetUserByID(userID)
}
// GetUserByName retrieves a user by name.
func (s *State) GetUserByName(name string) (*types.User, error) {
return s.db.GetUserByName(name)
}
// GetUserByOIDCIdentifier retrieves a user by their OIDC identifier.
func (s *State) GetUserByOIDCIdentifier(id string) (*types.User, error) {
return s.db.GetUserByOIDCIdentifier(id)
}
// ListUsersWithFilter retrieves users matching the specified filter criteria.
func (s *State) ListUsersWithFilter(filter *types.User) ([]types.User, error) {
return s.db.ListUsers(filter)
}
// ListAllUsers retrieves all users in the system.
func (s *State) ListAllUsers() ([]types.User, error) {
return s.db.ListUsers()
}
// CreateNode creates a new node and updates the policy manager.
// Returns the created node, whether policies changed, and any error.
func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(node).Error; err != nil {
return nil, false, fmt.Errorf("creating node: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node creation: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
}
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err := updateFn(tx); err != nil {
return nil, err
}
node, err := hsdb.GetNodeByID(tx, nodeID)
if err != nil {
return nil, err
}
if err := tx.Save(node).Error; err != nil {
return nil, fmt.Errorf("updating node: %w", err)
}
return node, nil
})
if err != nil {
return nil, false, err
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
}
// SaveNode persists an existing node to the database and updates the policy manager.
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(node).Error; err != nil {
return nil, false, fmt.Errorf("saving node: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
}
// DeleteNode permanently removes a node and cleans up associated resources.
// Returns whether policies changed and any error. This operation is irreversible.
func (s *State) DeleteNode(node *types.Node) (bool, error) {
err := s.db.DeleteNode(node)
if err != nil {
return false, err
}
// Check if policy manager needs updating after node deletion
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
}
return policyChanged, nil
}
func (s *State) Connect(node *types.Node) bool {
_ = s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
// TODO(kradalby): this should be more granular, allowing us to
// only send a online update change.
return true
}
func (s *State) Disconnect(node *types.Node) (bool, error) {
// TODO(kradalby): This node should update the in memory state
_, polChanged, err := s.SetLastSeen(node.ID, time.Now())
if err != nil {
return false, fmt.Errorf("disconnecting node: %w", err)
}
changed := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
// TODO(kradalby): the returned change should be more nuanced allowing us to
// send more directed updates.
return changed || polChanged, nil
}
// GetNodeByID retrieves a node by ID.
func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) {
return s.db.GetNodeByID(nodeID)
}
// GetNodeByNodeKey retrieves a node by its Tailscale public key.
func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return s.db.GetNodeByNodeKey(nodeKey)
}
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) == 0 {
return s.db.ListNodes()
}
return s.db.ListNodes(nodeIDs...)
}
// ListNodesByUser retrieves all nodes belonging to a specific user.
func (s *State) ListNodesByUser(userID types.UserID) (types.Nodes, error) {
return hsdb.Read(s.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return hsdb.ListNodesByUser(rx, userID)
})
}
// ListPeers retrieves nodes that can communicate with the specified node based on policy.
func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
return s.db.ListPeers(nodeID, peerIDs...)
}
// ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system.
func (s *State) ListEphemeralNodes() (types.Nodes, error) {
return s.db.ListEphemeralNodes()
}
// SetNodeExpiry updates the expiration time for a node.
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
})
}
// SetNodeTags assigns tags to a node for use in access control policies.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetTags(tx, nodeID, tags)
})
}
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
})
}
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.RenameNode(tx, nodeID, newName)
})
}
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
})
}
// AssignNodeToUser transfers a node to a different user.
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.AssignNodeToUser(tx, nodeID, userID)
})
}
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
func (s *State) BackfillNodeIPs() ([]string, error) {
return s.db.BackfillNodeIPs(s.ipAlloc)
}
// ExpireExpiredNodes finds and processes expired nodes since the last check.
// Returns next check time, state update with expired nodes, and whether any were found.
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
}
// SSHPolicy returns the SSH access policy for a node.
func (s *State) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
return s.polMan.SSHPolicy(node)
}
// Filter returns the current network filter rules and matches.
func (s *State) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
return s.polMan.Filter()
}
// NodeCanHaveTag checks if a node is allowed to have a specific tag.
func (s *State) NodeCanHaveTag(node *types.Node, tag string) bool {
return s.polMan.NodeCanHaveTag(node, tag)
}
// SetPolicy updates the policy configuration.
func (s *State) SetPolicy(pol []byte) (bool, error) {
return s.polMan.SetPolicy(pol)
}
// AutoApproveRoutes checks if a node's routes should be auto-approved.
func (s *State) AutoApproveRoutes(node *types.Node) bool {
return policy.AutoApproveRoutes(s.polMan, node)
}
// PolicyDebugString returns a debug representation of the current policy.
func (s *State) PolicyDebugString() string {
return s.polMan.DebugString()
}
// GetPolicy retrieves the current policy from the database.
func (s *State) GetPolicy() (*types.Policy, error) {
return s.db.GetPolicy()
}
// SetPolicyInDB stores policy data in the database.
func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
return s.db.SetPolicy(data)
}
// SetNodeRoutes sets the primary routes for a node.
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
return s.primaryRoutes.SetRoutes(nodeID, routes...)
}
// GetNodePrimaryRoutes returns the primary routes for a node.
func (s *State) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
return s.primaryRoutes.PrimaryRoutes(nodeID)
}
// PrimaryRoutesString returns a string representation of all primary routes.
func (s *State) PrimaryRoutesString() string {
return s.primaryRoutes.String()
}
// ValidateAPIKey checks if an API key is valid and active.
func (s *State) ValidateAPIKey(keyStr string) (bool, error) {
return s.db.ValidateAPIKey(keyStr)
}
// CreateAPIKey generates a new API key with optional expiration.
func (s *State) CreateAPIKey(expiration *time.Time) (string, *types.APIKey, error) {
return s.db.CreateAPIKey(expiration)
}
// GetAPIKey retrieves an API key by its prefix.
func (s *State) GetAPIKey(prefix string) (*types.APIKey, error) {
return s.db.GetAPIKey(prefix)
}
// ExpireAPIKey marks an API key as expired.
func (s *State) ExpireAPIKey(key *types.APIKey) error {
return s.db.ExpireAPIKey(key)
}
// ListAPIKeys returns all API keys in the system.
func (s *State) ListAPIKeys() ([]types.APIKey, error) {
return s.db.ListAPIKeys()
}
// DestroyAPIKey permanently removes an API key.
func (s *State) DestroyAPIKey(key types.APIKey) error {
return s.db.DestroyAPIKey(key)
}
// CreatePreAuthKey generates a new pre-authentication key for a user.
func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKey, error) {
return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags)
}
// GetPreAuthKey retrieves a pre-authentication key by ID.
func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) {
return s.db.GetPreAuthKey(id)
}
// ListPreAuthKeys returns all pre-authentication keys for a user.
func (s *State) ListPreAuthKeys(userID types.UserID) ([]types.PreAuthKey, error) {
return s.db.ListPreAuthKeys(userID)
}
// ExpirePreAuthKey marks a pre-authentication key as expired.
func (s *State) ExpirePreAuthKey(preAuthKey *types.PreAuthKey) error {
return s.db.ExpirePreAuthKey(preAuthKey)
}
// GetRegistrationCacheEntry retrieves a node registration from cache.
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
entry, found := s.registrationCache.Get(id)
if !found {
return nil, false
}
return &entry, true
}
// SetRegistrationCacheEntry stores a node registration in cache.
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
s.registrationCache.Set(id, entry)
}
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
func (s *State) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (*types.Node, bool, error) {
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, false, err
}
return s.db.HandleNodeFromAuthPath(
registrationID,
userID,
expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
)
}
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
func (s *State) HandleNodeFromPreAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*types.Node, bool, error) {
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
err = pak.Validate()
if err != nil {
return nil, false, err
}
nodeToRegister := types.Node{
Hostname: regReq.Hostinfo.Hostname,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
RegisterMethod: util.RegisterMethodAuthKey,
// TODO(kradalby): This should not be set on the node,
// they should be looked up through the key, which is
// attached to the node.
ForcedTags: pak.Proto().GetAclTags(),
AuthKey: pak,
AuthKeyID: &pak.ID,
}
if !regReq.Expiry.IsZero() {
nodeToRegister.Expiry = &regReq.Expiry
}
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, false, fmt.Errorf("allocating IPs: %w", err)
}
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
node, err := hsdb.RegisterNode(tx,
nodeToRegister,
ipv4, ipv6,
)
if err != nil {
return nil, fmt.Errorf("registering node: %w", err)
}
if !pak.Reusable {
err = hsdb.UsePreAuthKey(tx, pak)
if err != nil {
return nil, fmt.Errorf("using pre auth key: %w", err)
}
}
return node, nil
})
if err != nil {
return nil, false, fmt.Errorf("writing node to database: %w", err)
}
// Check if policy manager needs updating
// This is necessary because we just created a new node.
// We need to ensure that the policy manager is aware of this new node.
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
}
return node, policyChanged, nil
}
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
func (s *State) AllocateNextIPs() (*netip.Addr, *netip.Addr, error) {
return s.ipAlloc.Next()
}
// updatePolicyManagerUsers updates the policy manager with current users.
// Returns true if the policy changed and notifications should be sent.
// TODO(kradalby): This is a temporary stepping stone, ultimately we should
// have the list already available so it could go much quicker. Alternatively
// the policy manager could have a remove or add list for users.
// updatePolicyManagerUsers refreshes the policy manager with current user data.
func (s *State) updatePolicyManagerUsers() (bool, error) {
users, err := s.ListAllUsers()
if err != nil {
return false, fmt.Errorf("listing users for policy update: %w", err)
}
changed, err := s.polMan.SetUsers(users)
if err != nil {
return false, fmt.Errorf("updating policy manager users: %w", err)
}
return changed, nil
}
// updatePolicyManagerNodes updates the policy manager with current nodes.
// Returns true if the policy changed and notifications should be sent.
// TODO(kradalby): This is a temporary stepping stone, ultimately we should
// have the list already available so it could go much quicker. Alternatively
// the policy manager could have a remove or add list for nodes.
// updatePolicyManagerNodes refreshes the policy manager with current node data.
func (s *State) updatePolicyManagerNodes() (bool, error) {
nodes, err := s.ListNodes()
if err != nil {
return false, fmt.Errorf("listing nodes for policy update: %w", err)
}
changed, err := s.polMan.SetNodes(nodes)
if err != nil {
return false, fmt.Errorf("updating policy manager nodes: %w", err)
}
return changed, nil
}
// PingDB checks if the database connection is healthy.
func (s *State) PingDB(ctx context.Context) error {
return s.db.PingDB(ctx)
}
// autoApproveNodes mass approves routes on all nodes. It is _only_ intended for
// use when the policy is replaced. It is not sending or reporting any changes
// or updates as we send full updates after replacing the policy.
// TODO(kradalby): This is kind of messy, maybe this is another +1
// for an event bus. See example comments here.
// autoApproveNodes automatically approves nodes based on policy rules.
func (s *State) autoApproveNodes() error {
err := s.db.Write(func(tx *gorm.DB) error {
nodes, err := hsdb.ListNodes(tx)
if err != nil {
return err
}
for _, node := range nodes {
// TODO(kradalby): This change should probably be sent to the rest of the system.
changed := policy.AutoApproveRoutes(s.polMan, node)
if changed {
err = tx.Save(node).Error
if err != nil {
return err
}
// TODO(kradalby): This should probably be done outside of the transaction,
// and the result of this should be propagated to the system.
s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
}
}
return nil
})
if err != nil {
return fmt.Errorf("auto approving routes for nodes: %w", err)
}
return nil
}

View File

@ -1,12 +1,18 @@
package types
import (
"fmt"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
)
type PAKError string
func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {
ID uint64 `gorm:"primary_key"`
@ -48,3 +54,24 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey {
return &protoKey
}
// canUsePreAuthKey checks if a pre auth key can be used.
func (pak *PreAuthKey) Validate() error {
if pak == nil {
return PAKError("invalid authkey")
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return PAKError("authkey expired")
}
// we don't need to check if has been used before
if pak.Reusable {
return nil
}
if pak.Used {
return PAKError("authkey already used")
}
return nil
}

View File

@ -1,12 +1,10 @@
package hscontrol
package types
import (
"net/http"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
)
func TestCanUsePreAuthKey(t *testing.T) {
@ -16,13 +14,13 @@ func TestCanUsePreAuthKey(t *testing.T) {
tests := []struct {
name string
pak *types.PreAuthKey
pak *PreAuthKey
wantErr bool
err HTTPError
err PAKError
}{
{
name: "valid reusable key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: true,
Used: false,
Expiration: &future,
@ -31,7 +29,7 @@ func TestCanUsePreAuthKey(t *testing.T) {
},
{
name: "valid non-reusable key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: false,
Used: false,
Expiration: &future,
@ -40,27 +38,27 @@ func TestCanUsePreAuthKey(t *testing.T) {
},
{
name: "expired key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: false,
Used: false,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
err: PAKError("authkey expired"),
},
{
name: "used non-reusable key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: false,
Used: true,
Expiration: &future,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
err: PAKError("authkey already used"),
},
{
name: "used reusable key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: true,
Used: true,
Expiration: &future,
@ -69,7 +67,7 @@ func TestCanUsePreAuthKey(t *testing.T) {
},
{
name: "no expiration date",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: false,
Used: false,
Expiration: nil,
@ -80,38 +78,38 @@ func TestCanUsePreAuthKey(t *testing.T) {
name: "nil preauth key",
pak: nil,
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil),
err: PAKError("invalid authkey"),
},
{
name: "expired and used key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: false,
Used: true,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
err: PAKError("authkey expired"),
},
{
name: "no expiration and used key",
pak: &types.PreAuthKey{
pak: &PreAuthKey{
Reusable: false,
Used: true,
Expiration: nil,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
err: PAKError("authkey already used"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := canUsePreAuthKey(tt.pak)
err := tt.pak.Validate()
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(HTTPError)
httpErr, ok := err.(PAKError)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {