From 62be27d3cafd48d3dcb348bd1d17a5be31867f13 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 16 Nov 2020 20:22:08 +0800 Subject: [PATCH] Add OnConflict UpdateAll support --- callbacks/create.go | 33 ++++++++++++++++++--------------- clause/on_conflict.go | 1 + finisher_api.go | 4 +++- tests/upsert_test.go | 10 ++++++++++ 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 67f3ab14..ad91ebc3 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -329,26 +329,29 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - if stmt.UpdatingColumn { - if stmt.Schema != nil && len(values.Columns) > 1 { - columns := make([]string, 0, len(values.Columns)-1) - for _, column := range values.Columns { - if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) > 1 { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } } } - } - onConflict := clause.OnConflict{ - Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), - DoUpdates: clause.AssignmentColumns(columns), - } + onConflict := clause.OnConflict{ + Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), + DoUpdates: clause.AssignmentColumns(columns), + } - for idx, field := range stmt.Schema.PrimaryFields { - onConflict.Columns[idx] = clause.Column{Name: field.DBName} + for idx, field := range stmt.Schema.PrimaryFields { + onConflict.Columns[idx] = clause.Column{Name: field.DBName} + } + + stmt.AddClause(onConflict) } - stmt.AddClause(onConflict) } } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 47f69fc9..47fe169c 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -5,6 +5,7 @@ type OnConflict struct { Where Where DoNothing bool DoUpdates Set + UpdateAll bool } func (OnConflict) Name() string { diff --git a/finisher_api.go b/finisher_api.go index 2e7e5f4e..67423b23 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -29,7 +29,9 @@ func (db *DB) Save(value interface{}) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - tx.Statement.UpdatingColumn = true + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } tx.callbacks.Create().Execute(tx) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index ba7c1a9d..0ba8b9f0 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -41,6 +41,16 @@ func TestUpsert(t *testing.T) { } else if langs[0].Name != "upsert-new" { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } + + lang = Language{Code: "upsert", Name: "Upsert-Newname"} + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var result Language + if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { + t.Fatalf("failed to upsert, got name %v", result.Name) + } } func TestUpsertSlice(t *testing.T) {