diff --git a/callback_shared.go b/callback_shared.go index f13cec9d..729a0461 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -16,7 +16,7 @@ func CommitOrRollbackTransaction(scope *Scope) { func SaveBeforeAssociations(scope *Scope) { for _, field := range scope.Fields() { - if !field.IsBlank && !field.IsIgnored { + if field.AlwaysUpdate || !field.IsBlank && !field.IsIgnored { relationship := field.Relationship if relationship != nil && relationship.Kind == "belongs_to" { value := field.Field @@ -42,7 +42,7 @@ func SaveBeforeAssociations(scope *Scope) { func SaveAfterAssociations(scope *Scope) { for _, field := range scope.Fields() { - if !field.IsBlank && !field.IsIgnored { + if field.AlwaysUpdate || !field.IsBlank && !field.IsIgnored { relationship := field.Relationship if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { diff --git a/field.go b/field.go index 7791d42a..550f365f 100644 --- a/field.go +++ b/field.go @@ -23,6 +23,7 @@ type Field struct { IsBlank bool IsIgnored bool IsPrimaryKey bool + AlwaysUpdate bool } func (field *Field) IsScanner() bool { diff --git a/scope.go b/scope.go index 9db9f259..87aaeda2 100644 --- a/scope.go +++ b/scope.go @@ -304,7 +304,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio if fieldStruct.Tag.Get(tagIdentifier) == "-" { field.IsIgnored = true } - + if _, ok := settings["UPDATE"]; ok { + field.AlwaysUpdate = true + } if !field.IsIgnored { // parse association if !indirectValue.IsValid() { diff --git a/update_test.go b/update_test.go index 7302bb8a..20d6b909 100644 --- a/update_test.go +++ b/update_test.go @@ -158,3 +158,55 @@ func TestUpdateColumn(t *testing.T) { t.Errorf("updatedAt should not be updated with update column") } } + +func TestAlwaysUpdate(t *testing.T) { + type Always struct { + Id int64 + Name string + Code string + Price int64 `gorm:"update"` + } + + DB.DropTable(&Always{}) + DB.CreateTable(&Always{}) + + obj1 := Always{Name: "obj1", Code: "code_1", Price: 10} + obj2 := Always{Name: "obj2", Code: "code_2", Price: 20} + + // save initial + DB.Save(&obj1).Save(&obj2).UpdateColumn(map[string]interface{}{"code": "columnUpdate2"}) + + // fetch and verify + var obj3, obj4 Always + DB.First(&obj3, obj1.Id) + if obj3.Code != "code_1" || obj3.Price != 10 { + t.Errorf("obj1 was not saved correctly: expected: %#v got: %#v", obj1, obj3) + } + DB.First(&obj4, obj2.Id) + if obj4.Code != "columnUpdate2" || obj4.Price != 20 { + t.Errorf("obj2 was not saved correctly: expected: %#v got: %#v", obj2, obj4) + } + + // now update via struct price should change to zero + obj5 := Always{Name: "obj2update", Code: "code_2"} + DB.Model(obj5).Updates(obj5) + + var obj6 Always + DB.First(&obj6, obj2.Id) + if obj6.Code != "code_2" || obj6.Name != "obj2update" || obj6.Price != 0 { + t.Errorf("obj2 was not saved correctly: got: %#v", obj6) + } + + var res []Always + DB.Find(&res) + if len(res) != 2 { + t.Error("Should have 2 objects") + } + + // test where clause + var res1 []Always + DB.Where(Always{}).Find(&res1) + if len(res1) != 2 { + t.Error("Where() Should have 2 returned objects") + } +} diff --git a/utils_private.go b/utils_private.go index cf941f61..2b0707f9 100644 --- a/utils_private.go +++ b/utils_private.go @@ -65,7 +65,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { default: scope := Scope{Value: values} for _, field := range scope.Fields() { - if !field.IsBlank { + if field.AlwaysUpdate || !field.IsBlank { attrs[field.DBName] = field.Field.Interface() } }