Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
a9650e563f
@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||
|
||||
// TODO support save slice data, sql with case?
|
||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error
|
||||
}
|
||||
case reflect.Struct:
|
||||
// clear old data
|
||||
@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error
|
||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,7 +319,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
|
||||
if stmt.UpdatingColumn {
|
||||
if stmt.Schema != nil {
|
||||
if stmt.Schema != nil && len(values.Columns) > 1 {
|
||||
columns := make([]string, 0, len(values.Columns)-1)
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
|
@ -32,26 +32,29 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if pv, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if _, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
return
|
||||
} else {
|
||||
where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClause(where)
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
selectedUpdate := len(tx.Statement.Selects) != 0
|
||||
// when updating, use all fields including those zero-value fields
|
||||
if !selectedUpdate {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
return tx.Create(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -86,7 +86,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
}
|
||||
|
||||
if v.Alias != "" {
|
||||
writer.WriteString(" AS ")
|
||||
writer.WriteByte(' ')
|
||||
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
|
||||
}
|
||||
case clause.Column:
|
||||
|
@ -9,7 +9,7 @@ require (
|
||||
gorm.io/driver/mysql v1.0.0
|
||||
gorm.io/driver/postgres v1.0.0
|
||||
gorm.io/driver/sqlite v1.1.0
|
||||
gorm.io/driver/sqlserver v1.0.0
|
||||
gorm.io/driver/sqlserver v1.0.1
|
||||
gorm.io/gorm v1.9.19
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@ package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -586,3 +587,46 @@ func TestUpdateFromSubQuery(t *testing.T) {
|
||||
t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
user := *GetUser("save", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "save").Error; err != nil {
|
||||
t.Fatalf("failed to find created user")
|
||||
}
|
||||
|
||||
user.Name = "save2"
|
||||
DB.Save(&user)
|
||||
|
||||
var result User
|
||||
if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID {
|
||||
t.Fatalf("failed to find updated user")
|
||||
}
|
||||
|
||||
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
stmt := dryDB.Save(&user).Statement
|
||||
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) {
|
||||
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveWithPrimaryValue(t *testing.T) {
|
||||
lang := Language{Code: "save", Name: "save"}
|
||||
if result := DB.Save(&lang); result.RowsAffected != 1 {
|
||||
t.Errorf("should create language, rows affected: %v", result.RowsAffected)
|
||||
}
|
||||
|
||||
var result Language
|
||||
DB.First(&result, "code = ?", "save")
|
||||
AssertEqual(t, result, lang)
|
||||
|
||||
lang.Name = "save name2"
|
||||
if result := DB.Save(&lang); result.RowsAffected != 1 {
|
||||
t.Errorf("should update language")
|
||||
}
|
||||
|
||||
var result2 Language
|
||||
DB.First(&result2, "code = ?", "save")
|
||||
AssertEqual(t, result2, lang)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user