diff --git a/association.go b/association.go index db77cc4e..140ae6ac 100644 --- a/association.go +++ b/association.go @@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error + association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error + association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error } } diff --git a/callbacks/create.go b/callbacks/create.go index c59b14b5..5de19d35 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -319,7 +319,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } if stmt.UpdatingColumn { - if stmt.Schema != nil { + if stmt.Schema != nil && len(values.Columns) > 1 { columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { diff --git a/finisher_api.go b/finisher_api.go index 2cde3c31..a205b859 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -32,26 +32,29 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Create().Execute(tx) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { - where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - for idx, pf := range tx.Statement.Schema.PrimaryFields { - if pv, isZero := pf.ValueOf(reflectValue); isZero { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) return - } else { - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} } } - - tx.Statement.AddClause(where) } fallthrough default: - if len(tx.Statement.Selects) == 0 { + selectedUpdate := len(tx.Statement.Selects) != 0 + // when updating, use all fields including those zero-value fields + if !selectedUpdate { tx.Statement.Selects = append(tx.Statement.Selects, "*") } tx.callbacks.Update().Execute(tx) + + if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { + if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { + return tx.Create(value) + } + } } return diff --git a/statement.go b/statement.go index fba1991d..d72a086f 100644 --- a/statement.go +++ b/statement.go @@ -86,7 +86,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } if v.Alias != "" { - writer.WriteString(" AS ") + writer.WriteByte(' ') stmt.DB.Dialector.QuoteTo(writer, v.Alias) } case clause.Column: diff --git a/tests/go.mod b/tests/go.mod index 1a6fe7a8..c09747ab 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v1.0.0 gorm.io/driver/postgres v1.0.0 gorm.io/driver/sqlite v1.1.0 - gorm.io/driver/sqlserver v1.0.0 + gorm.io/driver/sqlserver v1.0.1 gorm.io/gorm v1.9.19 ) diff --git a/tests/update_test.go b/tests/update_test.go index e52dc652..1944ed3f 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "regexp" "sort" "strings" "testing" @@ -586,3 +587,46 @@ func TestUpdateFromSubQuery(t *testing.T) { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } } + +func TestSave(t *testing.T) { + user := *GetUser("save", Config{}) + DB.Create(&user) + + if err := DB.First(&User{}, "name = ?", "save").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user.Name = "save2" + DB.Save(&user) + + var result User + if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID { + t.Fatalf("failed to find updated user") + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Save(&user).Statement + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } +} + +func TestSaveWithPrimaryValue(t *testing.T) { + lang := Language{Code: "save", Name: "save"} + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should create language, rows affected: %v", result.RowsAffected) + } + + var result Language + DB.First(&result, "code = ?", "save") + AssertEqual(t, result, lang) + + lang.Name = "save name2" + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should update language") + } + + var result2 Language + DB.First(&result2, "code = ?", "save") + AssertEqual(t, result2, lang) +}