mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-20 01:40:21 -05:00
hscontrol: use Updates() instead of Save() for partial updates
Changed UpdateUser and re-registration flows to use Updates() which only writes modified fields, preventing unintended overwrites of unchanged fields. Also updated UsePreAuthKey to use Model().Update() for single field updates and removed unused NodeSave wrapper.
This commit is contained in:
committed by
Kristoffer Dalby
parent
4a8dc2d445
commit
ddd31ba774
@@ -452,13 +452,6 @@ func NodeSetMachineKey(
|
||||
}).Error
|
||||
}
|
||||
|
||||
// NodeSave saves a node object to the database, prefer to use a specific save method rather
|
||||
// than this. It is intended to be used when we are changing or.
|
||||
// TODO(kradalby): Remove this func, just use Save.
|
||||
func NodeSave(tx *gorm.DB, node *types.Node) error {
|
||||
return tx.Save(node).Error
|
||||
}
|
||||
|
||||
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
// Strip invalid DNS characters for givenName
|
||||
suppliedName = strings.ToLower(suppliedName)
|
||||
|
||||
@@ -145,11 +145,12 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
k.Used = true
|
||||
if err := tx.Save(k).Error; err != nil {
|
||||
err := tx.Model(k).Update("used", true).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||
}
|
||||
|
||||
k.Used = true
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
134
hscontrol/db/user_update_test.go
Normal file
134
hscontrol/db/user_update_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TestUserUpdatePreservesUnchangedFields verifies that updating a user
|
||||
// preserves fields that aren't modified. This test validates the fix
|
||||
// for using Updates() instead of Save() in UpdateUser-like operations.
|
||||
func TestUserUpdatePreservesUnchangedFields(t *testing.T) {
|
||||
database := dbForTest(t)
|
||||
|
||||
// Create a user with all fields set
|
||||
initialUser := types.User{
|
||||
Name: "testuser",
|
||||
DisplayName: "Test User Display",
|
||||
Email: "test@example.com",
|
||||
ProviderIdentifier: sql.NullString{
|
||||
String: "provider-123",
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
createdUser, err := database.CreateUser(initialUser)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdUser)
|
||||
|
||||
// Verify initial state
|
||||
assert.Equal(t, "testuser", createdUser.Name)
|
||||
assert.Equal(t, "Test User Display", createdUser.DisplayName)
|
||||
assert.Equal(t, "test@example.com", createdUser.Email)
|
||||
assert.True(t, createdUser.ProviderIdentifier.Valid)
|
||||
assert.Equal(t, "provider-123", createdUser.ProviderIdentifier.String)
|
||||
|
||||
// Simulate what UpdateUser does: load user, modify one field, save
|
||||
_, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) {
|
||||
user, err := GetUserByID(tx, types.UserID(createdUser.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Modify ONLY DisplayName
|
||||
user.DisplayName = "Updated Display Name"
|
||||
|
||||
// This is the line being tested - currently uses Save() which writes ALL fields, potentially overwriting unchanged ones
|
||||
err = tx.Save(user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read user back from database
|
||||
updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUserByID(rx, types.UserID(createdUser.ID))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that DisplayName was updated
|
||||
assert.Equal(t, "Updated Display Name", updatedUser.DisplayName)
|
||||
|
||||
// CRITICAL: Verify that other fields were NOT overwritten
|
||||
// With Save(), these assertions should pass because the user object
|
||||
// was loaded from DB and has all fields populated.
|
||||
// But if Updates() is used, these will also pass (and it's safer).
|
||||
assert.Equal(t, "testuser", updatedUser.Name, "Name should be preserved")
|
||||
assert.Equal(t, "test@example.com", updatedUser.Email, "Email should be preserved")
|
||||
assert.True(t, updatedUser.ProviderIdentifier.Valid, "ProviderIdentifier should be preserved")
|
||||
assert.Equal(t, "provider-123", updatedUser.ProviderIdentifier.String, "ProviderIdentifier value should be preserved")
|
||||
}
|
||||
|
||||
// TestUserUpdateWithUpdatesMethod tests that using Updates() instead of Save()
|
||||
// works correctly and only updates modified fields.
|
||||
func TestUserUpdateWithUpdatesMethod(t *testing.T) {
|
||||
database := dbForTest(t)
|
||||
|
||||
// Create a user
|
||||
initialUser := types.User{
|
||||
Name: "testuser",
|
||||
DisplayName: "Original Display",
|
||||
Email: "original@example.com",
|
||||
ProviderIdentifier: sql.NullString{
|
||||
String: "provider-abc",
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
createdUser, err := database.CreateUser(initialUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update using Updates() method
|
||||
_, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) {
|
||||
user, err := GetUserByID(tx, types.UserID(createdUser.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Modify multiple fields
|
||||
user.DisplayName = "New Display"
|
||||
user.Email = "new@example.com"
|
||||
|
||||
// Use Updates() instead of Save()
|
||||
err = tx.Updates(user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify changes
|
||||
updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUserByID(rx, types.UserID(createdUser.ID))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated fields
|
||||
assert.Equal(t, "New Display", updatedUser.DisplayName)
|
||||
assert.Equal(t, "new@example.com", updatedUser.Email)
|
||||
|
||||
// Verify preserved fields
|
||||
assert.Equal(t, "testuser", updatedUser.Name)
|
||||
assert.True(t, updatedUser.ProviderIdentifier.Valid)
|
||||
assert.Equal(t, "provider-abc", updatedUser.ProviderIdentifier.String)
|
||||
}
|
||||
@@ -300,7 +300,9 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Save(user).Error; err != nil {
|
||||
// Use Updates() to only update modified fields, preserving unchanged values.
|
||||
err = tx.Updates(user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("updating user: %w", err)
|
||||
}
|
||||
|
||||
@@ -1191,9 +1193,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
|
||||
}
|
||||
|
||||
// Use the node from UpdateNode to save to database
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil {
|
||||
// Use Updates() to preserve fields not modified by UpdateNode.
|
||||
err := tx.Updates(updatedNodeView.AsStruct()).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
return nil, nil
|
||||
@@ -1410,9 +1413,10 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
|
||||
}
|
||||
|
||||
// Use the node from UpdateNode to save to database
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil {
|
||||
// Use Updates() to preserve fields not modified by UpdateNode.
|
||||
err := tx.Updates(updatedNodeView.AsStruct()).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user