Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Steve Fan 2020-08-30 18:56:03 +08:00
commit a9650e563f
6 changed files with 60 additions and 13 deletions

View File

@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
// TODO support save slice data, sql with case?
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
if len(values) > 0 {
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error
}
}

View File

@ -319,7 +319,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
if stmt.UpdatingColumn {
if stmt.Schema != nil {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {

View File

@ -32,26 +32,29 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.callbacks.Create().Execute(tx)
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
for idx, pf := range tx.Statement.Schema.PrimaryFields {
if pv, isZero := pf.ValueOf(reflectValue); isZero {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx)
return
} else {
where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv}
}
}
tx.Statement.AddClause(where)
}
fallthrough
default:
if len(tx.Statement.Selects) == 0 {
selectedUpdate := len(tx.Statement.Selects) != 0
// when updating, use all fields including those zero-value fields
if !selectedUpdate {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
}
}
return

View File

@ -86,7 +86,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
}
if v.Alias != "" {
writer.WriteString(" AS ")
writer.WriteByte(' ')
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
}
case clause.Column:

View File

@ -9,7 +9,7 @@ require (
gorm.io/driver/mysql v1.0.0
gorm.io/driver/postgres v1.0.0
gorm.io/driver/sqlite v1.1.0
gorm.io/driver/sqlserver v1.0.0
gorm.io/driver/sqlserver v1.0.1
gorm.io/gorm v1.9.19
)

View File

@ -2,6 +2,7 @@ package tests_test
import (
"errors"
"regexp"
"sort"
"strings"
"testing"
@ -586,3 +587,46 @@ func TestUpdateFromSubQuery(t *testing.T) {
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
}
}
func TestSave(t *testing.T) {
user := *GetUser("save", Config{})
DB.Create(&user)
if err := DB.First(&User{}, "name = ?", "save").Error; err != nil {
t.Fatalf("failed to find created user")
}
user.Name = "save2"
DB.Save(&user)
var result User
if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID {
t.Fatalf("failed to find updated user")
}
dryDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryDB.Save(&user).Statement
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) {
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
}
}
func TestSaveWithPrimaryValue(t *testing.T) {
lang := Language{Code: "save", Name: "save"}
if result := DB.Save(&lang); result.RowsAffected != 1 {
t.Errorf("should create language, rows affected: %v", result.RowsAffected)
}
var result Language
DB.First(&result, "code = ?", "save")
AssertEqual(t, result, lang)
lang.Name = "save name2"
if result := DB.Save(&lang); result.RowsAffected != 1 {
t.Errorf("should update language")
}
var result2 Language
DB.First(&result2, "code = ?", "save")
AssertEqual(t, result2, lang)
}