From ddd31ba774a78eaae845c52eae0260692d8e31c4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 10 Nov 2025 19:15:05 +0100 Subject: [PATCH] hscontrol: use Updates() instead of Save() for partial updates Changed UpdateUser and re-registration flows to use Updates() which only writes modified fields, preventing unintended overwrites of unchanged fields. Also updated UsePreAuthKey to use Model().Update() for single field updates and removed unused NodeSave wrapper. --- hscontrol/db/node.go | 7 -- hscontrol/db/preauth_keys.go | 5 +- hscontrol/db/user_update_test.go | 134 +++++++++++++++++++++++++++++++ hscontrol/state/state.go | 14 ++-- 4 files changed, 146 insertions(+), 14 deletions(-) create mode 100644 hscontrol/db/user_update_test.go diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 70d3afaf..060196a9 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -452,13 +452,6 @@ func NodeSetMachineKey( }).Error } -// NodeSave saves a node object to the database, prefer to use a specific save method rather -// than this. It is intended to be used when we are changing or. -// TODO(kradalby): Remove this func, just use Save. -func NodeSave(tx *gorm.DB, node *types.Node) error { - return tx.Save(node).Error -} - func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { // Strip invalid DNS characters for givenName suppliedName = strings.ToLower(suppliedName) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index a36c1f13..94575269 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -145,11 +145,12 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { // UsePreAuthKey marks a PreAuthKey as used. func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { - k.Used = true - if err := tx.Save(k).Error; err != nil { + err := tx.Model(k).Update("used", true).Error + if err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } + k.Used = true return nil } diff --git a/hscontrol/db/user_update_test.go b/hscontrol/db/user_update_test.go new file mode 100644 index 00000000..180481e7 --- /dev/null +++ b/hscontrol/db/user_update_test.go @@ -0,0 +1,134 @@ +package db + +import ( + "database/sql" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +// TestUserUpdatePreservesUnchangedFields verifies that updating a user +// preserves fields that aren't modified. This test validates the fix +// for using Updates() instead of Save() in UpdateUser-like operations. +func TestUserUpdatePreservesUnchangedFields(t *testing.T) { + database := dbForTest(t) + + // Create a user with all fields set + initialUser := types.User{ + Name: "testuser", + DisplayName: "Test User Display", + Email: "test@example.com", + ProviderIdentifier: sql.NullString{ + String: "provider-123", + Valid: true, + }, + } + + createdUser, err := database.CreateUser(initialUser) + require.NoError(t, err) + require.NotNil(t, createdUser) + + // Verify initial state + assert.Equal(t, "testuser", createdUser.Name) + assert.Equal(t, "Test User Display", createdUser.DisplayName) + assert.Equal(t, "test@example.com", createdUser.Email) + assert.True(t, createdUser.ProviderIdentifier.Valid) + assert.Equal(t, "provider-123", createdUser.ProviderIdentifier.String) + + // Simulate what UpdateUser does: load user, modify one field, save + _, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := GetUserByID(tx, types.UserID(createdUser.ID)) + if err != nil { + return nil, err + } + + // Modify ONLY DisplayName + user.DisplayName = "Updated Display Name" + + // This is the line being tested - currently uses Save() which writes ALL fields, potentially overwriting unchanged ones + err = tx.Save(user).Error + if err != nil { + return nil, err + } + + return user, nil + }) + require.NoError(t, err) + + // Read user back from database + updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByID(rx, types.UserID(createdUser.ID)) + }) + require.NoError(t, err) + + // Verify that DisplayName was updated + assert.Equal(t, "Updated Display Name", updatedUser.DisplayName) + + // CRITICAL: Verify that other fields were NOT overwritten + // With Save(), these assertions should pass because the user object + // was loaded from DB and has all fields populated. + // But if Updates() is used, these will also pass (and it's safer). + assert.Equal(t, "testuser", updatedUser.Name, "Name should be preserved") + assert.Equal(t, "test@example.com", updatedUser.Email, "Email should be preserved") + assert.True(t, updatedUser.ProviderIdentifier.Valid, "ProviderIdentifier should be preserved") + assert.Equal(t, "provider-123", updatedUser.ProviderIdentifier.String, "ProviderIdentifier value should be preserved") +} + +// TestUserUpdateWithUpdatesMethod tests that using Updates() instead of Save() +// works correctly and only updates modified fields. +func TestUserUpdateWithUpdatesMethod(t *testing.T) { + database := dbForTest(t) + + // Create a user + initialUser := types.User{ + Name: "testuser", + DisplayName: "Original Display", + Email: "original@example.com", + ProviderIdentifier: sql.NullString{ + String: "provider-abc", + Valid: true, + }, + } + + createdUser, err := database.CreateUser(initialUser) + require.NoError(t, err) + + // Update using Updates() method + _, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := GetUserByID(tx, types.UserID(createdUser.ID)) + if err != nil { + return nil, err + } + + // Modify multiple fields + user.DisplayName = "New Display" + user.Email = "new@example.com" + + // Use Updates() instead of Save() + err = tx.Updates(user).Error + if err != nil { + return nil, err + } + + return user, nil + }) + require.NoError(t, err) + + // Verify changes + updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByID(rx, types.UserID(createdUser.ID)) + }) + require.NoError(t, err) + + // Verify updated fields + assert.Equal(t, "New Display", updatedUser.DisplayName) + assert.Equal(t, "new@example.com", updatedUser.Email) + + // Verify preserved fields + assert.Equal(t, "testuser", updatedUser.Name) + assert.True(t, updatedUser.ProviderIdentifier.Valid) + assert.Equal(t, "provider-abc", updatedUser.ProviderIdentifier.String) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index ff876024..297004fc 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -300,7 +300,9 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return nil, err } - if err := tx.Save(user).Error; err != nil { + // Use Updates() to only update modified fields, preserving unchanged values. + err = tx.Updates(user).Error + if err != nil { return nil, fmt.Errorf("updating user: %w", err) } @@ -1191,9 +1193,10 @@ func (s *State) HandleNodeFromAuthPath( return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) } - // Use the node from UpdateNode to save to database _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { + // Use Updates() to preserve fields not modified by UpdateNode. + err := tx.Updates(updatedNodeView.AsStruct()).Error + if err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } return nil, nil @@ -1410,9 +1413,10 @@ func (s *State) HandleNodeFromPreAuthKey( return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) } - // Use the node from UpdateNode to save to database _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { + // Use Updates() to preserve fields not modified by UpdateNode. + err := tx.Updates(updatedNodeView.AsStruct()).Error + if err != nil { return nil, fmt.Errorf("failed to save node: %w", err) }