From 6ede807f8dba479f1ca6cba59ef357165fbd15cb Mon Sep 17 00:00:00 2001 From: Franco Liberali Date: Thu, 7 Sep 2023 19:08:12 +0200 Subject: [PATCH] update returning preload --- callbacks/callbacks.go | 1 + callbacks/query.go | 4 +++- tests/update_test.go | 13 +++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..63c90249 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -68,6 +68,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:update", Update(config)) + updateCallback.Register("gorm:preload", Preload) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/query.go b/callbacks/query.go index b71ff5f5..510ff79f 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -134,7 +134,9 @@ func Preload(db *gorm.DB) { }) if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { - return + if err := preloadDB.Statement.Parse(db.Statement.Model); err != nil { + return + } } preloadDB.Statement.ReflectValue = db.Statement.ReflectValue preloadDB.Statement.Unscoped = db.Statement.Unscoped diff --git a/tests/update_test.go b/tests/update_test.go index 9eb9dbfc..a18e68bc 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -775,6 +775,7 @@ func TestUpdateReturning(t *testing.T) { GetUser("update-returning-1", Config{}), GetUser("update-returning-2", Config{}), GetUser("update-returning-3", Config{}), + GetUser("update-returning-4", Config{Pets: 1}), } DB.Create(&users) @@ -795,6 +796,18 @@ func TestUpdateReturning(t *testing.T) { if results[1].Age-results[0].Age != 100 { t.Errorf("failed to return updated age column") } + + var result User + DB.Model(&result).Where("name = ?", users[3].Name).Clauses(clause.Returning{}).Preload("Pets").Update("age", 38) + if result.Age != 38 { + t.Errorf("failed to return updated data, got %v", results) + } + + if len(result.Pets) != 1 { + t.Errorf("failed to preload pets, got %v", result.Pets) + } + + CheckPet(t, *result.Pets[0], *users[3].Pets[0]) } func TestUpdateWithDiffSchema(t *testing.T) {