diff --git a/CHANGELOG.md b/CHANGELOG.md index c5d5f36c..2a322dcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,7 @@ The new policy can be used by setting the environment variable [#2493](https://github.com/juanfont/headscale/pull/2493) - If a OIDC provider doesn't include the `email_verified` claim in its ID tokens, Headscale will attempt to get it from the UserInfo endpoint. +- Improve performance by only querying relevant nodes from the database for node updates [#2509](https://github.com/juanfont/headscale/pull/2509) ## 0.25.1 (2025-02-25) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index f36f66b7..6aa75018 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -35,21 +35,26 @@ var ( ) ) -func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) { +// ListPeers returns peers of node, regardless of any Policy or if the node is expired. +// If no peer IDs are given, all peers are returned. +// If at least one peer ID is given, only these peer nodes will be returned. +func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListPeers(rx, nodeID) + return ListPeers(rx, nodeID, peerIDs...) }) } -// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. -func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { +// ListPeers returns peers of node, regardless of any Policy or if the node is expired. +// If no peer IDs are given, all peers are returned. +// If at least one peer ID is given, only these peer nodes will be returned. +func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Where("id <> ?", - nodeID).Find(&nodes).Error; err != nil { + Where("id <> ?", nodeID). + Where(peerIDs).Find(&nodes).Error; err != nil { return types.Nodes{}, err } @@ -58,19 +63,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter +func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListNodes(rx) + return ListNodes(rx, nodeIDs...) }) } -func ListNodes(tx *gorm.DB) (types.Nodes, error) { +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter +func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Find(&nodes).Error; err != nil { + Where(nodeIDs).Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e5f0661c..fd9313e1 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -747,3 +747,174 @@ func TestRenameNode(t *testing.T) { }) assert.ErrorContains(t, err, "name is not unique") } + +func TestListPeers(t *testing.T) { + // Setup test database + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + user2, err := db.CreateUser(types.User{Name: "user2"}) + require.NoError(t, err) + + node1 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test1", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test2", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + err = db.DB.Save(&node1).Error + require.NoError(t, err) + + err = db.DB.Save(&node2).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNode(tx, node1, nil, nil) + if err != nil { + return err + } + _, err = RegisterNode(tx, node2, nil, nil) + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 2) + + // No parameter means no filter, should return all peers + nodes, err = db.ListPeers(1) + require.NoError(t, err) + assert.Equal(t, len(nodes), 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Empty node list should return all peers + nodes, err = db.ListPeers(1, types.NodeIDs{}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // No match in IDs should return empty list and no error + nodes, err = db.ListPeers(1, types.NodeIDs{3, 4, 5}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 0) + + // Partial match in IDs + nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Several matched IDs, but node ID is still filtered out + nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 1) + assert.Equal(t, "test2", nodes[0].Hostname) +} + +func TestListNodes(t *testing.T) { + // Setup test database + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + user2, err := db.CreateUser(types.User{Name: "user2"}) + require.NoError(t, err) + + node1 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test1", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test2", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + err = db.DB.Save(&node1).Error + require.NoError(t, err) + + err = db.DB.Save(&node2).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNode(tx, node1, nil, nil) + if err != nil { + return err + } + _, err = RegisterNode(tx, node2, nil, nil) + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 2) + + // No parameter means no filter, should return all nodes + nodes, err = db.ListNodes() + require.NoError(t, err) + assert.Equal(t, len(nodes), 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) + + // Empty node list should return all nodes + nodes, err = db.ListNodes(types.NodeIDs{}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) + + // No match in IDs should return empty list and no error + nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 0) + + // Partial match in IDs + nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Several matched IDs + nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) + require.NoError(t, err) + assert.Equal(t, len(nodes), 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 7a297bd3..b85bf3b0 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -255,27 +255,25 @@ func (m *Mapper) PeerChangedResponse( patches []*tailcfg.PeerChange, messages ...string, ) ([]byte, error) { + var err error resp := m.baseMapResponse() - peers, err := m.ListPeers(node.ID) - if err != nil { - return nil, err - } - var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { if nodeChanged { - changedIDs = append(changedIDs, nodeID) + if nodeID != node.ID { + changedIDs = append(changedIDs, nodeID) + } } else { removedIDs = append(removedIDs, nodeID.NodeID()) } } - - changedNodes := make(types.Nodes, 0, len(changedIDs)) - for _, peer := range peers { - if slices.Contains(changedIDs, peer.ID) { - changedNodes = append(changedNodes, peer) + changedNodes := types.Nodes{} + if len(changedIDs) > 0 { + changedNodes, err = m.ListNodes(changedIDs...) + if err != nil { + return nil, err } } @@ -482,8 +480,11 @@ func (m *Mapper) baseWithConfigMapResponse( return &resp, nil } -func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { - peers, err := m.db.ListPeers(nodeID) +// ListPeers returns peers of node, regardless of any Policy or if the node is expired. +// If no peer IDs are given, all peers are returned. +// If at least one peer ID is given, only these peer nodes will be returned. +func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { + peers, err := m.db.ListPeers(nodeID, peerIDs...) if err != nil { return nil, err } @@ -496,6 +497,22 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { return peers, nil } +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter +func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { + nodes, err := m.db.ListNodes(nodeIDs...) + if err != nil { + return nil, err + } + + for _, node := range nodes { + online := m.notif.IsLikelyConnected(node.ID) + node.IsOnline = &online + } + + return nodes, nil +} + func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { ret := make(types.Nodes, 0)