diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2e07fd81 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +gorm.test \ No newline at end of file diff --git a/callback_update.go b/callback_update.go index c59bcf1a..7ff264be 100644 --- a/callback_update.go +++ b/callback_update.go @@ -7,7 +7,7 @@ import ( func AssignUpdateAttributes(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if maps := convertInterfaceToMap(attrs); len(maps) > 0 { + if maps := convertInterfaceToMap(attrs, false); len(maps) > 0 { protected, ok := scope.Get("gorm:ignore_protected_attrs") _, updateColumn := scope.Get("gorm:update_column") updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) diff --git a/main.go b/main.go index af2c26c4..9ab700a4 100644 --- a/main.go +++ b/main.go @@ -200,7 +200,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize() } else { - c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.AssignAttrs), false) + c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.AssignAttrs, false), false) } return c } @@ -223,6 +223,12 @@ func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } +func (s *DB) UpdateAll(value interface{}) *DB { + return s.clone().NewScope(value). + InstanceSet("gorm:update_interface", convertInterfaceToMap(value, true)). + callCallbacks(s.parent.callback.updates).db +} + func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { return s.clone().NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). diff --git a/scope_private.go b/scope_private.go index b6949db3..ab800e6d 100644 --- a/scope_private.go +++ b/scope_private.go @@ -375,10 +375,10 @@ func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) initialize() *Scope { for _, clause := range scope.Search.WhereConditions { - scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"], false), false) } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs, false), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs, false), false) return scope } diff --git a/update_test.go b/update_test.go index 7302bb8a..8d5b8507 100644 --- a/update_test.go +++ b/update_test.go @@ -158,3 +158,44 @@ func TestUpdateColumn(t *testing.T) { t.Errorf("updatedAt should not be updated with update column") } } + +func TestAlwaysUpdate(t *testing.T) { + type Always struct { + Id int64 + Name string + Code string + Price int64 + IsActive bool + } + + DB.DropTable(&Always{}) + DB.CreateTable(&Always{}) + + obj1 := Always{Name: "obj1", Code: "code_1", Price: 10, IsActive: true} + obj2 := Always{Name: "obj2", Code: "code_2", Price: 20, IsActive: true} + obj3 := Always{Name: "obj3", Code: "code_10", Price: 100, IsActive: true} + + // save initial + DB.Save(&obj1).Save(&obj2).Save(&obj3) + + // now update via struct price should change to zero + obj2.Price = 0 + DB.UpdateAll(obj2) + + var obj2_1 Always + DB.First(&obj2_1, obj2.Id) + if obj2_1.Price != 0 { + t.Errorf("UpdateAll did not update Price for obj2: %#v", obj2_1) + } + + // test bool + obj3.IsActive = false + DB.UpdateAll(obj3) + + var obj3_1 Always + DB.First(&obj3_1, obj3.Id) + if obj3_1.IsActive { + t.Errorf("UpdateAll did not update IsActive for obj3: %#v", obj3_1) + } + +} diff --git a/utils_private.go b/utils_private.go index cf941f61..e1ac22f5 100644 --- a/utils_private.go +++ b/utils_private.go @@ -40,7 +40,7 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) { return } -func convertInterfaceToMap(values interface{}) map[string]interface{} { +func convertInterfaceToMap(values interface{}, skipBlankCheck bool) map[string]interface{} { attrs := map[string]interface{}{} switch value := values.(type) { @@ -50,7 +50,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { } case []interface{}: for _, v := range value { - for key, value := range convertInterfaceToMap(v) { + for key, value := range convertInterfaceToMap(v, skipBlankCheck) { attrs[key] = value } } @@ -65,7 +65,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { default: scope := Scope{Value: values} for _, field := range scope.Fields() { - if !field.IsBlank { + if skipBlankCheck || !field.IsBlank { attrs[field.DBName] = field.Field.Interface() } }