diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 70d3afaf..060196a9 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -452,13 +452,6 @@ func NodeSetMachineKey( }).Error } -// NodeSave saves a node object to the database, prefer to use a specific save method rather -// than this. It is intended to be used when we are changing or. -// TODO(kradalby): Remove this func, just use Save. -func NodeSave(tx *gorm.DB, node *types.Node) error { - return tx.Save(node).Error -} - func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { // Strip invalid DNS characters for givenName suppliedName = strings.ToLower(suppliedName) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index a36c1f13..94575269 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -145,11 +145,12 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { // UsePreAuthKey marks a PreAuthKey as used. func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { - k.Used = true - if err := tx.Save(k).Error; err != nil { + err := tx.Model(k).Update("used", true).Error + if err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } + k.Used = true return nil } diff --git a/hscontrol/db/user_update_test.go b/hscontrol/db/user_update_test.go new file mode 100644 index 00000000..180481e7 --- /dev/null +++ b/hscontrol/db/user_update_test.go @@ -0,0 +1,134 @@ +package db + +import ( + "database/sql" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +// TestUserUpdatePreservesUnchangedFields verifies that updating a user +// preserves fields that aren't modified. This test validates the fix +// for using Updates() instead of Save() in UpdateUser-like operations. +func TestUserUpdatePreservesUnchangedFields(t *testing.T) { + database := dbForTest(t) + + // Create a user with all fields set + initialUser := types.User{ + Name: "testuser", + DisplayName: "Test User Display", + Email: "test@example.com", + ProviderIdentifier: sql.NullString{ + String: "provider-123", + Valid: true, + }, + } + + createdUser, err := database.CreateUser(initialUser) + require.NoError(t, err) + require.NotNil(t, createdUser) + + // Verify initial state + assert.Equal(t, "testuser", createdUser.Name) + assert.Equal(t, "Test User Display", createdUser.DisplayName) + assert.Equal(t, "test@example.com", createdUser.Email) + assert.True(t, createdUser.ProviderIdentifier.Valid) + assert.Equal(t, "provider-123", createdUser.ProviderIdentifier.String) + + // Simulate what UpdateUser does: load user, modify one field, save + _, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := GetUserByID(tx, types.UserID(createdUser.ID)) + if err != nil { + return nil, err + } + + // Modify ONLY DisplayName + user.DisplayName = "Updated Display Name" + + // This is the line being tested - currently uses Save() which writes ALL fields, potentially overwriting unchanged ones + err = tx.Save(user).Error + if err != nil { + return nil, err + } + + return user, nil + }) + require.NoError(t, err) + + // Read user back from database + updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByID(rx, types.UserID(createdUser.ID)) + }) + require.NoError(t, err) + + // Verify that DisplayName was updated + assert.Equal(t, "Updated Display Name", updatedUser.DisplayName) + + // CRITICAL: Verify that other fields were NOT overwritten + // With Save(), these assertions should pass because the user object + // was loaded from DB and has all fields populated. + // But if Updates() is used, these will also pass (and it's safer). + assert.Equal(t, "testuser", updatedUser.Name, "Name should be preserved") + assert.Equal(t, "test@example.com", updatedUser.Email, "Email should be preserved") + assert.True(t, updatedUser.ProviderIdentifier.Valid, "ProviderIdentifier should be preserved") + assert.Equal(t, "provider-123", updatedUser.ProviderIdentifier.String, "ProviderIdentifier value should be preserved") +} + +// TestUserUpdateWithUpdatesMethod tests that using Updates() instead of Save() +// works correctly and only updates modified fields. +func TestUserUpdateWithUpdatesMethod(t *testing.T) { + database := dbForTest(t) + + // Create a user + initialUser := types.User{ + Name: "testuser", + DisplayName: "Original Display", + Email: "original@example.com", + ProviderIdentifier: sql.NullString{ + String: "provider-abc", + Valid: true, + }, + } + + createdUser, err := database.CreateUser(initialUser) + require.NoError(t, err) + + // Update using Updates() method + _, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := GetUserByID(tx, types.UserID(createdUser.ID)) + if err != nil { + return nil, err + } + + // Modify multiple fields + user.DisplayName = "New Display" + user.Email = "new@example.com" + + // Use Updates() instead of Save() + err = tx.Updates(user).Error + if err != nil { + return nil, err + } + + return user, nil + }) + require.NoError(t, err) + + // Verify changes + updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByID(rx, types.UserID(createdUser.ID)) + }) + require.NoError(t, err) + + // Verify updated fields + assert.Equal(t, "New Display", updatedUser.DisplayName) + assert.Equal(t, "new@example.com", updatedUser.Email) + + // Verify preserved fields + assert.Equal(t, "testuser", updatedUser.Name) + assert.True(t, updatedUser.ProviderIdentifier.Valid) + assert.Equal(t, "provider-abc", updatedUser.ProviderIdentifier.String) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index ff876024..297004fc 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -300,7 +300,9 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return nil, err } - if err := tx.Save(user).Error; err != nil { + // Use Updates() to only update modified fields, preserving unchanged values. + err = tx.Updates(user).Error + if err != nil { return nil, fmt.Errorf("updating user: %w", err) } @@ -1191,9 +1193,10 @@ func (s *State) HandleNodeFromAuthPath( return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) } - // Use the node from UpdateNode to save to database _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { + // Use Updates() to preserve fields not modified by UpdateNode. + err := tx.Updates(updatedNodeView.AsStruct()).Error + if err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } return nil, nil @@ -1410,9 +1413,10 @@ func (s *State) HandleNodeFromPreAuthKey( return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) } - // Use the node from UpdateNode to save to database _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { + // Use Updates() to preserve fields not modified by UpdateNode. + err := tx.Updates(updatedNodeView.AsStruct()).Error + if err != nil { return nil, fmt.Errorf("failed to save node: %w", err) }