fix: returning all columns with "on conflict do update" must considered as ScanUpdate (#7534)
This commit is contained in:
parent
22d5239dec
commit
eb90a02a07
@ -80,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
ok, mode := hasReturning(db, supportReturning)
|
ok, mode := hasReturning(db, supportReturning)
|
||||||
if ok {
|
if ok {
|
||||||
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
|
onConflict, _ := c.Expression.(clause.OnConflict)
|
||||||
|
if onConflict.DoNothing {
|
||||||
mode |= gorm.ScanOnConflictDoNothing
|
mode |= gorm.ScanOnConflictDoNothing
|
||||||
|
} else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll {
|
||||||
|
mode |= gorm.ScanUpdate
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +135,53 @@ func TestUpsertSlice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpsertSliceWithReturning(t *testing.T) {
|
||||||
|
langs := []Language{
|
||||||
|
{Code: "upsert-slice1", Name: "Upsert-slice1"},
|
||||||
|
{Code: "upsert-slice2", Name: "Upsert-slice2"},
|
||||||
|
{Code: "upsert-slice3", Name: "Upsert-slice3"},
|
||||||
|
}
|
||||||
|
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
|
||||||
|
|
||||||
|
var langs2 []Language
|
||||||
|
if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs2) != 3 {
|
||||||
|
t.Errorf("should only find only 3 languages, but got %+v", langs2)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
|
||||||
|
var langs3 []Language
|
||||||
|
if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs3) != 3 {
|
||||||
|
t.Errorf("should only find only 3 languages, but got %+v", langs3)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, lang := range langs {
|
||||||
|
lang.Name = lang.Name + "_new"
|
||||||
|
langs[idx] = lang
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "code"}},
|
||||||
|
DoUpdates: clause.AssignmentColumns([]string{"name"}),
|
||||||
|
}, clause.Returning{}).CreateInBatches(&langs, len(langs)).Error; err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, lang := range langs {
|
||||||
|
var results []Language
|
||||||
|
if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(results) != 1 {
|
||||||
|
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||||
|
} else if results[0].Name != lang.Name {
|
||||||
|
t.Errorf("should update name on conflict, but got name %+v", results[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpsertWithSave(t *testing.T) {
|
func TestUpsertWithSave(t *testing.T) {
|
||||||
langs := []Language{
|
langs := []Language{
|
||||||
{Code: "upsert-save-1", Name: "Upsert-save-1"},
|
{Code: "upsert-save-1", Name: "Upsert-save-1"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user