Only read relevant nodes from database in PeerChangedResponse (#2509)

* Only read relevant nodes from database in PeerChangedResponse

* Rework to ensure transactional consistency in PeerChangedResponse again

* An empty nodeIDs list should return an empty nodes list

* Add test to ListNodesSubset

* Link PR in CHANGELOG.md

* combine ListNodes and ListNodesSubset into one function

* query for all nodes in ListNodes if no parameter is given

* also add optional filtering for relevant nodes to ListPeers
This commit is contained in:
Enkelmann 2025-04-08 14:56:44 +02:00 committed by GitHub
parent d2a6356d89
commit 0d3134720b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 221 additions and 23 deletions

View File

@ -87,6 +87,7 @@ The new policy can be used by setting the environment variable
[#2493](https://github.com/juanfont/headscale/pull/2493)
- If a OIDC provider doesn't include the `email_verified` claim in its ID
tokens, Headscale will attempt to get it from the UserInfo endpoint.
- Improve performance by only querying relevant nodes from the database for node updates [#2509](https://github.com/juanfont/headscale/pull/2509)
## 0.25.1 (2025-02-25)

View File

@ -35,21 +35,26 @@ var (
)
)
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, nodeID)
return ListPeers(rx, nodeID, peerIDs...)
})
}
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Where("id <> ?",
nodeID).Find(&nodes).Error; err != nil {
Where("id <> ?", nodeID).
Where(peerIDs).Find(&nodes).Error; err != nil {
return types.Nodes{}, err
}
@ -58,19 +63,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
return nodes, nil
}
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
return ListNodes(rx, nodeIDs...)
})
}
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Find(&nodes).Error; err != nil {
Where(nodeIDs).Find(&nodes).Error; err != nil {
return nil, err
}

View File

@ -747,3 +747,174 @@ func TestRenameNode(t *testing.T) {
})
assert.ErrorContains(t, err, "name is not unique")
}
func TestListPeers(t *testing.T) {
// Setup test database
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
user2, err := db.CreateUser(types.User{Name: "user2"})
require.NoError(t, err)
node1 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test1",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
node2 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test2",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
err = db.DB.Save(&node1).Error
require.NoError(t, err)
err = db.DB.Save(&node2).Error
require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node1, nil, nil)
if err != nil {
return err
}
_, err = RegisterNode(tx, node2, nil, nil)
return err
})
require.NoError(t, err)
nodes, err := db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
// No parameter means no filter, should return all peers
nodes, err = db.ListPeers(1)
require.NoError(t, err)
assert.Equal(t, len(nodes), 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// Empty node list should return all peers
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// No match in IDs should return empty list and no error
nodes, err = db.ListPeers(1, types.NodeIDs{3, 4, 5}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 0)
// Partial match in IDs
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs, but node ID is still filtered out
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 1)
assert.Equal(t, "test2", nodes[0].Hostname)
}
func TestListNodes(t *testing.T) {
// Setup test database
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
user2, err := db.CreateUser(types.User{Name: "user2"})
require.NoError(t, err)
node1 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test1",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
node2 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test2",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
err = db.DB.Save(&node1).Error
require.NoError(t, err)
err = db.DB.Save(&node2).Error
require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node1, nil, nil)
if err != nil {
return err
}
_, err = RegisterNode(tx, node2, nil, nil)
return err
})
require.NoError(t, err)
nodes, err := db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
// No parameter means no filter, should return all nodes
nodes, err = db.ListNodes()
require.NoError(t, err)
assert.Equal(t, len(nodes), 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
// Empty node list should return all nodes
nodes, err = db.ListNodes(types.NodeIDs{}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
// No match in IDs should return empty list and no error
nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 0)
// Partial match in IDs
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Equal(t, len(nodes), 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
}

View File

@ -255,27 +255,25 @@ func (m *Mapper) PeerChangedResponse(
patches []*tailcfg.PeerChange,
messages ...string,
) ([]byte, error) {
var err error
resp := m.baseMapResponse()
peers, err := m.ListPeers(node.ID)
if err != nil {
return nil, err
}
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
changedIDs = append(changedIDs, nodeID)
if nodeID != node.ID {
changedIDs = append(changedIDs, nodeID)
}
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := make(types.Nodes, 0, len(changedIDs))
for _, peer := range peers {
if slices.Contains(changedIDs, peer.ID) {
changedNodes = append(changedNodes, peer)
changedNodes := types.Nodes{}
if len(changedIDs) > 0 {
changedNodes, err = m.ListNodes(changedIDs...)
if err != nil {
return nil, err
}
}
@ -482,8 +480,11 @@ func (m *Mapper) baseWithConfigMapResponse(
return &resp, nil
}
func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
peers, err := m.db.ListPeers(nodeID)
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.db.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
@ -496,6 +497,22 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
return peers, nil
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.db.ListNodes(nodeIDs...)
if err != nil {
return nil, err
}
for _, node := range nodes {
online := m.notif.IsLikelyConnected(node.ID)
node.IsOnline = &online
}
return nodes, nil
}
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
ret := make(types.Nodes, 0)