mirror of
https://github.com/juanfont/headscale.git
synced 2025-04-15 08:45:41 -04:00
Only read relevant nodes from database in PeerChangedResponse (#2509)
* Only read relevant nodes from database in PeerChangedResponse * Rework to ensure transactional consistency in PeerChangedResponse again * An empty nodeIDs list should return an empty nodes list * Add test to ListNodesSubset * Link PR in CHANGELOG.md * combine ListNodes and ListNodesSubset into one function * query for all nodes in ListNodes if no parameter is given * also add optional filtering for relevant nodes to ListPeers
This commit is contained in:
parent
d2a6356d89
commit
0d3134720b
@ -87,6 +87,7 @@ The new policy can be used by setting the environment variable
|
|||||||
[#2493](https://github.com/juanfont/headscale/pull/2493)
|
[#2493](https://github.com/juanfont/headscale/pull/2493)
|
||||||
- If a OIDC provider doesn't include the `email_verified` claim in its ID
|
- If a OIDC provider doesn't include the `email_verified` claim in its ID
|
||||||
tokens, Headscale will attempt to get it from the UserInfo endpoint.
|
tokens, Headscale will attempt to get it from the UserInfo endpoint.
|
||||||
|
- Improve performance by only querying relevant nodes from the database for node updates [#2509](https://github.com/juanfont/headscale/pull/2509)
|
||||||
|
|
||||||
## 0.25.1 (2025-02-25)
|
## 0.25.1 (2025-02-25)
|
||||||
|
|
||||||
|
@ -35,21 +35,26 @@ var (
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||||
|
// If no peer IDs are given, all peers are returned.
|
||||||
|
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||||
|
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return ListPeers(rx, nodeID)
|
return ListPeers(rx, nodeID, peerIDs...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||||
func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
|
// If no peer IDs are given, all peers are returned.
|
||||||
|
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||||
|
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Where("id <> ?",
|
Where("id <> ?", nodeID).
|
||||||
nodeID).Find(&nodes).Error; err != nil {
|
Where(peerIDs).Find(&nodes).Error; err != nil {
|
||||||
return types.Nodes{}, err
|
return types.Nodes{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,19 +63,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
|
|||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
|
// or for the given nodes if at least one node ID is given as parameter
|
||||||
|
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return ListNodes(rx)
|
return ListNodes(rx, nodeIDs...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
|
// or for the given nodes if at least one node ID is given as parameter
|
||||||
|
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Find(&nodes).Error; err != nil {
|
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -747,3 +747,174 @@ func TestRenameNode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.ErrorContains(t, err, "name is not unique")
|
assert.ErrorContains(t, err, "name is not unique")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListPeers(t *testing.T) {
|
||||||
|
// Setup test database
|
||||||
|
db, err := newSQLiteTestDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := db.CreateUser(types.User{Name: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user2, err := db.CreateUser(types.User{Name: "user2"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
node1 := types.Node{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "test1",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
}
|
||||||
|
|
||||||
|
node2 := types.Node{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "test2",
|
||||||
|
UserID: user2.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.DB.Save(&node1).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.DB.Save(&node2).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
_, err := RegisterNode(tx, node1, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = RegisterNode(tx, node2, nil, nil)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
nodes, err := db.ListNodes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, nodes, 2)
|
||||||
|
|
||||||
|
// No parameter means no filter, should return all peers
|
||||||
|
nodes, err = db.ListPeers(1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 1)
|
||||||
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
|
// Empty node list should return all peers
|
||||||
|
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 1)
|
||||||
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
|
// No match in IDs should return empty list and no error
|
||||||
|
nodes, err = db.ListPeers(1, types.NodeIDs{3, 4, 5}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 0)
|
||||||
|
|
||||||
|
// Partial match in IDs
|
||||||
|
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 1)
|
||||||
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
|
// Several matched IDs, but node ID is still filtered out
|
||||||
|
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 1)
|
||||||
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListNodes(t *testing.T) {
|
||||||
|
// Setup test database
|
||||||
|
db, err := newSQLiteTestDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := db.CreateUser(types.User{Name: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user2, err := db.CreateUser(types.User{Name: "user2"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
node1 := types.Node{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "test1",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
}
|
||||||
|
|
||||||
|
node2 := types.Node{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "test2",
|
||||||
|
UserID: user2.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.DB.Save(&node1).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.DB.Save(&node2).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
_, err := RegisterNode(tx, node1, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = RegisterNode(tx, node2, nil, nil)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
nodes, err := db.ListNodes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, nodes, 2)
|
||||||
|
|
||||||
|
// No parameter means no filter, should return all nodes
|
||||||
|
nodes, err = db.ListNodes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 2)
|
||||||
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
|
// Empty node list should return all nodes
|
||||||
|
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 2)
|
||||||
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
|
// No match in IDs should return empty list and no error
|
||||||
|
nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 0)
|
||||||
|
|
||||||
|
// Partial match in IDs
|
||||||
|
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 1)
|
||||||
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
|
// Several matched IDs
|
||||||
|
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(nodes), 2)
|
||||||
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
}
|
||||||
|
@ -255,27 +255,25 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
patches []*tailcfg.PeerChange,
|
patches []*tailcfg.PeerChange,
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
|
var err error
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
peers, err := m.ListPeers(node.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var removedIDs []tailcfg.NodeID
|
var removedIDs []tailcfg.NodeID
|
||||||
var changedIDs []types.NodeID
|
var changedIDs []types.NodeID
|
||||||
for nodeID, nodeChanged := range changed {
|
for nodeID, nodeChanged := range changed {
|
||||||
if nodeChanged {
|
if nodeChanged {
|
||||||
changedIDs = append(changedIDs, nodeID)
|
if nodeID != node.ID {
|
||||||
|
changedIDs = append(changedIDs, nodeID)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
changedNodes := types.Nodes{}
|
||||||
changedNodes := make(types.Nodes, 0, len(changedIDs))
|
if len(changedIDs) > 0 {
|
||||||
for _, peer := range peers {
|
changedNodes, err = m.ListNodes(changedIDs...)
|
||||||
if slices.Contains(changedIDs, peer.ID) {
|
if err != nil {
|
||||||
changedNodes = append(changedNodes, peer)
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -482,8 +480,11 @@ func (m *Mapper) baseWithConfigMapResponse(
|
|||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||||
peers, err := m.db.ListPeers(nodeID)
|
// If no peer IDs are given, all peers are returned.
|
||||||
|
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||||
|
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
|
peers, err := m.db.ListPeers(nodeID, peerIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -496,6 +497,22 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
|||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
|
// or for the given nodes if at least one node ID is given as parameter
|
||||||
|
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
|
nodes, err := m.db.ListNodes(nodeIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
online := m.notif.IsLikelyConnected(node.ID)
|
||||||
|
node.IsOnline = &online
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
|
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
|
||||||
ret := make(types.Nodes, 0)
|
ret := make(types.Nodes, 0)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user