From 386d981a98f0fe49b1f0e29cdc40d2e17649681a Mon Sep 17 00:00:00 2001 From: yiranzai Date: Sun, 18 Apr 2021 16:18:05 +0800 Subject: [PATCH] CanZero specifies fields that you can be set to zero value when creating and updating. --- callbacks/create.go | 12 +++++++++++- callbacks/update.go | 9 +++++++++ chainable_api.go | 13 +++++++++++++ statement.go | 42 ++++++++++++++++++++++++++++++++++++++++++ tests/update_test.go | 29 +++++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 909d984a..ebf5de43 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -328,9 +328,19 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } + canZeroColumns := stmt.CanZeroColumns(true, false) for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + v, isZero := field.ValueOf(stmt.ReflectValue) + + // can zero + if isZero { + if v, ok := canZeroColumns[field.DBName]; ok && v { + isZero = false + } + } + + if !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], v) } diff --git a/callbacks/update.go b/callbacks/update.go index db5b52fb..8e6f233a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -222,12 +222,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { default: switch updatingValue.Kind() { case reflect.Struct: + canZeroColumns := stmt.CanZeroColumns(false, true) set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { value, isZero := field.ValueOf(updatingValue) + + // can zero + if isZero { + if v, ok := canZeroColumns[field.DBName]; ok && v { + isZero = false + } + } + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() diff --git a/chainable_api.go b/chainable_api.go index e17d9bb2..15e74874 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -142,6 +142,19 @@ func (db *DB) Omit(columns ...string) (tx *DB) { return } +// CanZero specifies fields that you can be set to zero value when creating and updating. +// Priority is lower than Select and Omit +func (db *DB) CanZero(columns ...string) (tx *DB) { + tx = db.getInstance() + + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { + tx.Statement.CanZeros = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + } else { + tx.Statement.CanZeros = columns + } + return +} + // Where add conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/statement.go b/statement.go index 099c66d2..938b0a1f 100644 --- a/statement.go +++ b/statement.go @@ -30,6 +30,7 @@ type Statement struct { Distinct bool Selects []string // selected columns Omits []string // omit columns + CanZeros []string // can zero columns, priority is lower than selected columns and omit columns Joins []join Preloads map[string][]interface{} Settings sync.Map @@ -666,3 +667,44 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( return results, !notRestricted && len(stmt.Selects) > 0 } + +// CanZeroColumns get can zero columns +func (stmt *Statement) CanZeroColumns(requireCreate, requireUpdate bool) map[string]bool { + results := map[string]bool{} + + // can zero columns + for _, canZero := range stmt.CanZeros { + if stmt.Schema == nil { + results[canZero] = true + } else if canZero == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + } else if canZero == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(canZero); field != nil && field.DBName != "" { + results[field.DBName] = true + } else { + results[canZero] = true + } + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByName { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results +} diff --git a/tests/update_test.go b/tests/update_test.go index 5ad1bb39..1b8d5bae 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -685,3 +685,32 @@ func TestSaveWithPrimaryValue(t *testing.T) { t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) } } + +func TestCanZeroWithUpdate(t *testing.T) { + user := *GetUser("can_zero_update", Config{}) + user.Active = true + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("can_zero_update_new", Config{}) + result.Name = user2.Name + result.Active = false + result.Age = 0 + + DB.Model(User{}).Where("ID", user.ID).CanZero("Age").Updates(User{ + Name: user2.Name, + Active: false, + Age: 0, + }) + + var result2 User + DB.First(&result2, user.ID) + + AssertObjEqual(t, result2, result, "Name", "Age") + + if !result2.Active || result.Active { + t.Fatalf("Update struct should only update can zero columns, was %+v, got %+v", result2.Active, result.Active) + } +}