diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index ed1d1221..bf55e2de 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -42,6 +42,7 @@ jobs: - TestPingAllByIPPublicDERP - TestAuthKeyLogoutAndRelogin - TestEphemeral + - TestEphemeral2006DeletedTooQuickly - TestPingAllByHostname - TestTaildrop - TestResolveMagicDNS diff --git a/hscontrol/app.go b/hscontrol/app.go index 726b9d0b..0a23f07d 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -91,6 +91,7 @@ type Headscale struct { db *db.HSDatabase ipAlloc *db.IPAllocator noisePrivateKey *key.MachinePrivate + ephemeralGC *db.EphemeralGarbageCollector DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer @@ -153,6 +154,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, err } + app.ephemeralGC = db.NewEphemeralGarbageCollector(func(ni types.NodeID) { + if err := app.db.DeleteEphemeralNode(ni); err != nil { + log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node") + } + }) + if cfg.OIDC.Issuer != "" { err = app.initOIDC() if err != nil { @@ -217,47 +224,6 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { http.Redirect(w, req, target, http.StatusFound) } -// deleteExpireEphemeralNodes deletes ephemeral node records that have not been -// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. -func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) { - ticker := time.NewTicker(every) - - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - var removed []types.NodeID - var changed []types.NodeID - if err := h.db.Write(func(tx *gorm.DB) error { - removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) - - return nil - }); err != nil { - log.Error().Err(err).Msg("database error while expiring ephemeral nodes") - continue - } - - if removed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: removed, - }) - } - - if changed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changed, - }) - } - } - } -} - // expireExpiredNodes expires nodes that have an explicit expiry set // after that expiry time has passed. func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) { @@ -557,9 +523,18 @@ func (h *Headscale) Serve() error { return errEmptyInitialDERPMap } - expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background()) - defer expireEphemeralCancel() - go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval) + // Start ephemeral node garbage collector and schedule all nodes + // that are already in the database and ephemeral. If they are still + // around between restarts, they will reconnect and the GC will + // be cancelled. + go h.ephemeralGC.Start() + ephmNodes, err := h.db.ListEphemeralNodes() + if err != nil { + return fmt.Errorf("failed to list ephemeral nodes: %w", err) + } + for _, node := range ephmNodes { + h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout) + } expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) defer expireNodeCancel() @@ -809,7 +784,7 @@ func (h *Headscale) Serve() error { Msg("Received signal to stop, shutting down gracefully") expireNodeCancel() - expireEphemeralCancel() + h.ephemeralGC.Close() trace("waiting for netmap stream to close") h.pollNetMapStreamWG.Wait() diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 5ee925a6..010d15a2 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -16,6 +16,7 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func logAuthFunc( @@ -314,9 +315,8 @@ func (h *Headscale) handleAuthKey( Msg("node was already registered before, refreshing with new auth key") node.NodeKey = nodeKey - pakID := uint(pak.ID) - if pakID != 0 { - node.AuthKeyID = &pakID + if pak.ID != 0 { + node.AuthKeyID = ptr.To(pak.ID) } node.Expiry = ®isterRequest.Expiry @@ -394,7 +394,7 @@ func (h *Headscale) handleAuthKey( pakID := uint(pak.ID) if pakID != 0 { - nodeToRegister.AuthKeyID = &pakID + nodeToRegister.AuthKeyID = ptr.To(pak.ID) } node, err = h.db.RegisterNode( nodeToRegister, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index e36d6ed1..a2515ebf 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -12,6 +12,7 @@ import ( "github.com/patrickmn/go-cache" "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" + "github.com/sasha-s/go-deadlock" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -78,6 +79,17 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) { return nodes, nil } +func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil { + return nil, err + } + + return nodes, nil + }) +} + func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { nodes := types.Nodes{} if err := tx. @@ -286,6 +298,20 @@ func DeleteNode(tx *gorm.DB, return changed, nil } +// DeleteEphemeralNode deletes a Node from the database, note that this method +// will remove it straight, and not notify any changes or consider any routes. +// It is intended for Ephemeral nodes. +func (hsdb *HSDatabase) DeleteEphemeralNode( + nodeID types.NodeID, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil { + return err + } + return nil + }) +} + // SetLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { @@ -660,51 +686,6 @@ func GenerateGivenName( return givenName, nil } -func DeleteExpiredEphemeralNodes(tx *gorm.DB, - inactivityThreshold time.Duration, -) ([]types.NodeID, []types.NodeID) { - users, err := ListUsers(tx) - if err != nil { - return nil, nil - } - - var expired []types.NodeID - var changedNodes []types.NodeID - for _, user := range users { - nodes, err := ListNodesByUser(tx, user.Name) - if err != nil { - return nil, nil - } - - for idx, node := range nodes { - if node.IsEphemeral() && node.LastSeen != nil && - time.Now(). - After(node.LastSeen.Add(inactivityThreshold)) { - expired = append(expired, node.ID) - - log.Info(). - Str("node", node.Hostname). - Msg("Ephemeral client removed from database") - - // empty isConnected map as ephemeral nodes are not routes - changed, err := DeleteNode(tx, nodes[idx], nil) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("🤮 Cannot delete ephemeral node from the database") - } - - changedNodes = append(changedNodes, changed...) - } - } - - // TODO(kradalby): needs to be moved out of transaction - } - - return expired, changedNodes -} - func ExpireExpiredNodes(tx *gorm.DB, lastCheck time.Time, ) (time.Time, types.StateUpdate, bool) { @@ -737,3 +718,78 @@ func ExpireExpiredNodes(tx *gorm.DB, return started, types.StateUpdate{}, false } + +// EphemeralGarbageCollector is a garbage collector that will delete nodes after +// a certain amount of time. +// It is used to delete ephemeral nodes that have disconnected and should be +// cleaned up. +type EphemeralGarbageCollector struct { + mu deadlock.Mutex + + deleteFunc func(types.NodeID) + toBeDeleted map[types.NodeID]*time.Timer + + deleteCh chan types.NodeID + cancelCh chan struct{} +} + +// NewEphemeralGarbageCollector creates a new EphemeralGarbageCollector, it takes +// a deleteFunc that will be called when a node is scheduled for deletion. +func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarbageCollector { + return &EphemeralGarbageCollector{ + toBeDeleted: make(map[types.NodeID]*time.Timer), + deleteCh: make(chan types.NodeID, 10), + cancelCh: make(chan struct{}), + deleteFunc: deleteFunc, + } +} + +// Close stops the garbage collector. +func (e *EphemeralGarbageCollector) Close() { + e.cancelCh <- struct{}{} +} + +// Schedule schedules a node for deletion after the expiry duration. +func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { + e.mu.Lock() + defer e.mu.Unlock() + + timer := time.NewTimer(expiry) + e.toBeDeleted[nodeID] = timer + + go func() { + select { + case _, ok := <-timer.C: + if ok { + e.deleteCh <- nodeID + } + } + }() +} + +// Cancel cancels the deletion of a node. +func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) { + e.mu.Lock() + defer e.mu.Unlock() + + if timer, ok := e.toBeDeleted[nodeID]; ok { + timer.Stop() + delete(e.toBeDeleted, nodeID) + } +} + +// Start starts the garbage collector. +func (e *EphemeralGarbageCollector) Start() { + for { + select { + case <-e.cancelCh: + return + case nodeID := <-e.deleteCh: + e.mu.Lock() + delete(e.toBeDeleted, nodeID) + e.mu.Unlock() + + go e.deleteFunc(nodeID) + } + } +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index f1762a44..d88d0458 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -1,17 +1,23 @@ package db import ( + "crypto/rand" "fmt" + "math/big" "net/netip" "regexp" "strconv" + "sync" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/puzpuzpuz/xsync/v3" + "github.com/stretchr/testify/assert" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" @@ -30,7 +36,6 @@ func (s *Suite) TestGetNode(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, @@ -39,7 +44,7 @@ func (s *Suite) TestGetNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(node) c.Assert(trx.Error, check.IsNil) @@ -61,7 +66,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -69,7 +73,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -93,7 +97,6 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { machineKey := key.NewMachine() - pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -101,7 +104,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -145,7 +148,6 @@ func (s *Suite) TestListPeers(c *check.C) { _, err = db.GetNodeByID(0) c.Assert(err, check.NotNil) - pakID := uint(pak.ID) for index := 0; index <= 10; index++ { nodeKey := key.NewNode() machineKey := key.NewMachine() @@ -157,7 +159,7 @@ func (s *Suite) TestListPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -197,7 +199,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for index := 0; index <= 10; index++ { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(stor[index%2].key.ID) v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) node := types.Node{ @@ -208,7 +209,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: stor[index%2].user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(stor[index%2].key.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -283,7 +284,6 @@ func (s *Suite) TestExpireNode(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, @@ -292,7 +292,7 @@ func (s *Suite) TestExpireNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Expiry: &time.Time{}, } db.DB.Save(node) @@ -328,7 +328,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { machineKey2 := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -337,7 +336,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { GivenName: "hostname-1", UserID: user1.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(node) @@ -372,7 +371,6 @@ func (s *Suite) TestSetTags(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -380,7 +378,7 @@ func (s *Suite) TestSetTags(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(node) @@ -566,7 +564,6 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { route2 := netip.MustParsePrefix("10.11.0.0/24") v4 := netip.MustParseAddr("100.64.0.1") - pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -574,7 +571,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { Hostname: "test", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:exit"}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, @@ -600,3 +597,121 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 4) } + +func TestEphemeralGarbageCollectorOrder(t *testing.T) { + want := []types.NodeID{1, 3} + got := []types.NodeID{} + + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + got = append(got, ni) + }) + go e.Start() + + e.Schedule(1, 1*time.Second) + e.Schedule(2, 2*time.Second) + e.Schedule(3, 3*time.Second) + e.Schedule(4, 4*time.Second) + e.Cancel(2) + e.Cancel(4) + + time.Sleep(6 * time.Second) + + e.Close() + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong nodes deleted, unexpected result (-want +got):\n%s", diff) + } +} + +func TestEphemeralGarbageCollectorLoads(t *testing.T) { + var got []types.NodeID + var mu sync.Mutex + + want := 1000 + + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + defer mu.Unlock() + mu.Lock() + + time.Sleep(time.Duration(generateRandomNumber(t, 3)) * time.Millisecond) + got = append(got, ni) + }) + go e.Start() + + for i := 0; i < want; i++ { + go e.Schedule(types.NodeID(i), 1*time.Second) + } + + time.Sleep(10 * time.Second) + + e.Close() + if len(got) != want { + t.Errorf("expected %d, got %d", want, len(got)) + } +} + +func generateRandomNumber(t *testing.T, max int64) int64 { + t.Helper() + maxB := big.NewInt(max) + n, err := rand.Int(rand.Reader, maxB) + if err != nil { + t.Fatalf("getting random number: %s", err) + } + return n.Int64() + 1 +} + +func TestListEphemeralNodes(t *testing.T) { + db, err := newTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser("test") + assert.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + assert.NoError(t, err) + + pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) + assert.NoError(t, err) + + node := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + + nodeEph := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "ephemeral", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pakEph.ID), + } + + err = db.DB.Save(&node).Error + assert.NoError(t, err) + + err = db.DB.Save(&nodeEph).Error + assert.NoError(t, err) + + nodes, err := db.ListNodes() + assert.NoError(t, err) + + ephemeralNodes, err := db.ListEphemeralNodes() + assert.NoError(t, err) + + assert.Len(t, nodes, 2) + assert.Len(t, ephemeralNodes, 1) + + assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID) + assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID) + assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) + assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) +} diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index adfd289a..5ea59a9c 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" + "tailscale.com/types/ptr" ) var ( @@ -197,10 +198,9 @@ func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { } nodes := types.Nodes{} - pakID := uint(pak.ID) if err := tx. Preload("AuthKey"). - Where(&types.Node{AuthKeyID: &pakID}). + Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}). Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9cdcba80..9dd5b199 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,7 +6,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" - "gorm.io/gorm" + "tailscale.com/types/ptr" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -76,13 +76,12 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -99,13 +98,12 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 1, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -127,77 +125,6 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { c.Assert(key.ID, check.Equals, pak.ID) } -func (*Suite) TestEphemeralKeyReusable(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, true, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - pakID := uint(pak.ID) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - - _, err = db.getNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.Write(func(tx *gorm.DB) error { - DeleteExpiredEphemeralNodes(tx, time.Second*20) - return nil - }) - - // The machine record should have been deleted - _, err = db.getNode("test7", "testest") - c.Assert(err, check.NotNil) -} - -func (*Suite) TestEphemeralKeyNotReusable(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - pakId := uint(pak.ID) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: &pakId, - } - db.DB.Save(&node) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.NotNil) - - _, err = db.getNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.Write(func(tx *gorm.DB) error { - DeleteExpiredEphemeralNodes(tx, time.Second*20) - return nil - }) - - // The machine record should have been deleted - _, err = db.getNode("test7", "testest") - c.Assert(err, check.NotNil) -} - func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 8bbc5948..122a7ff3 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -14,6 +14,7 @@ import ( "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] { @@ -43,13 +44,12 @@ func (s *Suite) TestGetRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route}, } - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "test_get_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo, } trx := db.DB.Save(&node) @@ -95,13 +95,12 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo, } trx := db.DB.Save(&node) @@ -169,13 +168,12 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { hostInfo1 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route, route2}, } - pakID := uint(pak.ID) node1 := types.Node{ ID: 1, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo1, } trx := db.DB.Save(&node1) @@ -199,7 +197,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo2, } db.DB.Save(&node2) @@ -253,13 +251,12 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } now := time.Now() - pakID := uint(pak.ID) node1 := types.Node{ ID: 1, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo1, LastSeen: &now, } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1b97ce06..d546b33d 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -36,10 +36,18 @@ func (s *Suite) ResetDB(c *check.C) { // } var err error - tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + db, err = newTestDB() if err != nil { c.Fatal(err) } +} + +func newTestDB() (*HSDatabase, error) { + var err error + tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + if err != nil { + return nil, err + } log.Printf("database path: %s", tmpDir+"/headscale_test.db") @@ -53,6 +61,8 @@ func (s *Suite) ResetDB(c *check.C) { "", ) if err != nil { - c.Fatal(err) + return nil, err } + + return db, nil } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 98dea6c0..0629480c 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -5,6 +5,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "gorm.io/gorm" + "tailscale.com/types/ptr" ) func (s *Suite) TestCreateAndDestroyUser(c *check.C) { @@ -46,13 +47,12 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -100,13 +100,12 @@ func (s *Suite) TestSetMachineUser(c *check.C) { pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testnode", UserID: oldUser.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index d3c82117..8122064b 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -135,6 +135,18 @@ func (m *mapSession) resetKeepAlive() { m.keepAliveTicker.Reset(m.keepAlive) } +func (m *mapSession) beforeServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Cancel(m.node.ID) + } +} + +func (m *mapSession) afterServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout) + } +} + // serve handles non-streaming requests. func (m *mapSession) serve() { // TODO(kradalby): A set todos to harden: @@ -180,6 +192,8 @@ func (m *mapSession) serve() { // //nolint:gocyclo func (m *mapSession) serveLongPoll() { + m.beforeServeLongPoll() + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -197,6 +211,7 @@ func (m *mapSession) serveLongPoll() { m.pollFailoverRoutes("node closing connection", m.node) } + m.afterServeLongPoll() m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) }() diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 19b287a1..24e36535 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -119,7 +119,7 @@ type Node struct { ForcedTags StringList // TODO(kradalby): This seems like irrelevant information? - AuthKeyID *uint `sql:"DEFAULT:NULL"` + AuthKeyID *uint64 `sql:"DEFAULT:NULL"` AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"` LastSeen *time.Time diff --git a/integration/general_test.go b/integration/general_test.go index 245e8f09..c17b977e 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -297,6 +297,122 @@ func TestEphemeral(t *testing.T) { } } +// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not +// deleted by accident if they are still online and active. +func TestEphemeral2006DeletedTooQuickly(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + headscale, err := scenario.Headscale( + hsic.WithTestName("ephemeral2006"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s", + }), + ) + assertNoErrHeadscaleEnv(t, err) + + for userName, clientCount := range spec { + err = scenario.CreateUser(userName) + if err != nil { + t.Fatalf("failed to create user %s: %s", userName, err) + } + + err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) + if err != nil { + t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) + } + + key, err := scenario.CreatePreAuthKey(userName, true, true) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + // All ephemeral nodes should be online and reachable. + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Take down all clients, this should start an expiry timer for each. + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // Wait a bit and bring up the clients again before the expiry + // time of the ephemeral nodes. + // Nodes should be able to reconnect and work fine. + time.Sleep(30 * time.Second) + + for _, client := range allClients { + err := client.Up() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + success = pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Take down all clients, this should start an expiry timer for each. + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // This time wait for all of the nodes to expire and check that they are no longer + // registered. + time.Sleep(3 * time.Minute) + + for userName := range spec { + nodes, err := headscale.ListNodesInUser(userName) + if err != nil { + log.Error(). + Err(err). + Str("user", userName). + Msg("Error listing nodes in user") + + return + } + + if len(nodes) != 0 { + t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName) + } + } +} + func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) t.Parallel()