From eb90a02a07a950340d9412dfd3a77a74714aaa81 Mon Sep 17 00:00:00 2001 From: Phongphan Phuttha Date: Tue, 29 Jul 2025 18:06:13 +0700 Subject: [PATCH] fix: returning all columns with "on conflict do update" must considered as ScanUpdate (#7534) --- callbacks/create.go | 5 ++++- tests/upsert_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index cb8429b3..e5929adb 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -80,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index e84dc14a..860b87fd 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -135,6 +135,53 @@ func TestUpsertSlice(t *testing.T) { } } +func TestUpsertSliceWithReturning(t *testing.T) { + langs := []Language{ + {Code: "upsert-slice1", Name: "Upsert-slice1"}, + {Code: "upsert-slice2", Name: "Upsert-slice2"}, + {Code: "upsert-slice3", Name: "Upsert-slice3"}, + } + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + + var langs2 []Language + if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs2) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs2) + } + + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + var langs3 []Language + if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs3) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs3) + } + + for idx, lang := range langs { + lang.Name = lang.Name + "_new" + langs[idx] = lang + } + + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.AssignmentColumns([]string{"name"}), + }, clause.Returning{}).CreateInBatches(&langs, len(langs)).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + for _, lang := range langs { + var results []Language + if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(results) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if results[0].Name != lang.Name { + t.Errorf("should update name on conflict, but got name %+v", results[0].Name) + } + } +} + func TestUpsertWithSave(t *testing.T) { langs := []Language{ {Code: "upsert-save-1", Name: "Upsert-save-1"},