From 0f2d805740353d4f685c6a83c69e5a1a9b7fa843 Mon Sep 17 00:00:00 2001 From: bingoohuang Date: Thu, 24 Jun 2021 10:20:40 +0800 Subject: [PATCH] force tag to force zero when updates --- schema/field.go | 11 ++++++----- tests/update_test.go | 12 ++++++++++++ utils/tests/models.go | 5 +++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index 9efaa44a..071833bb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -401,17 +401,18 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { + _, force := field.TagSettings["FORCE"] // ValueOf switch { case len(field.StructField.Index) == 1: field.ValueOf = func(value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() + return fieldValue.Interface(), !force && fieldValue.IsZero() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: field.ValueOf = func(value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - return fieldValue.Interface(), fieldValue.IsZero() + return fieldValue.Interface(), !force && fieldValue.IsZero() } default: field.ValueOf = func(value reflect.Value) (interface{}, bool) { @@ -424,17 +425,17 @@ func (field *Field) setupValuerAndSetter() { v = v.Field(-idx - 1) if v.Type().Elem().Kind() != reflect.Struct { - return nil, true + return nil, !force } if !v.IsNil() { v = v.Elem() } else { - return nil, true + return nil, !force } } } - return v.Interface(), v.IsZero() + return v.Interface(), !force && v.IsZero() } } diff --git a/tests/update_test.go b/tests/update_test.go index 5ad1bb39..9eb31cf9 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -13,6 +13,18 @@ import ( . "gorm.io/gorm/utils/tests" ) +func TestUpdateZero(t *testing.T) { + c := &Company{ + ID: 0, + Name: "bingoo", + Foreign: true, + } + + DB.Create(c) + c.Foreign = false + DB.Updates(c) +} + func TestUpdate(t *testing.T) { var ( users = []*User{ diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..71497944 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -50,8 +50,9 @@ type Toy struct { } type Company struct { - ID int - Name string + ID int + Name string + Foreign bool `gorm:"force"` } type Language struct {