From 0923458e528845f4ce40ec81c607c9e594f80c03 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Sat, 21 Nov 2015 15:35:31 +0100 Subject: [PATCH] Fix update of fields with a default value --- callback.go | 6 ++++++ callback_create.go | 10 ++-------- callback_update.go | 9 +++++++-- scope.go | 2 +- update_test.go | 9 +++------ 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/callback.go b/callback.go index 603e5111..d6bad42c 100644 --- a/callback.go +++ b/callback.go @@ -197,4 +197,10 @@ func (c *callback) sort() { c.rowQueries = sortProcessors(rowQueries) } +func ForceReload(scope *Scope) { + if _, ok := scope.InstanceGet("gorm:force_reload"); ok { + scope.DB().New().First(scope.Value) + } +} + var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_create.go b/callback_create.go index 71db4ef0..3f61b785 100644 --- a/callback_create.go +++ b/callback_create.go @@ -33,7 +33,7 @@ func Create(scope *Scope) { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } else if field.HasDefaultValue { - scope.InstanceSet("gorm:force_reload_after_create", true) + scope.InstanceSet("gorm:force_reload", true) } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { @@ -97,12 +97,6 @@ func Create(scope *Scope) { } } -func ForceReloadAfterCreate(scope *Scope) { - if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok { - scope.DB().New().First(scope.Value) - } -} - func AfterCreate(scope *Scope) { scope.CallMethodWithErrorCheck("AfterCreate") scope.CallMethodWithErrorCheck("AfterSave") @@ -114,7 +108,7 @@ func init() { DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) DefaultCallback.Create().Register("gorm:create", Create) - DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) + DefaultCallback.Create().Register("gorm:force_reload", ForceReload) DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("gorm:after_create", AfterCreate) DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callback_update.go b/callback_update.go index 4c9952d2..b710f40e 100644 --- a/callback_update.go +++ b/callback_update.go @@ -43,7 +43,7 @@ func Update(scope *Scope) { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { for key, value := range updateAttrs.(map[string]interface{}) { - if scope.changeableDBColumn(key) { + if scope.isChangeableDBColumn(key) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) } } @@ -51,7 +51,11 @@ func Update(scope *Scope) { fields := scope.Fields() for _, field := range fields { if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.HasDefaultValue || !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } else if field.HasDefaultValue { + scope.InstanceSet("gorm:force_reload", true) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, dbName := range relationship.ForeignDBNames { if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { @@ -92,4 +96,5 @@ func init() { DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) DefaultCallback.Update().Register("gorm:after_update", AfterUpdate) DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + DefaultCallback.Update().Register("gorm:force_reload", ForceReload) } diff --git a/scope.go b/scope.go index 0003d575..3c7e7109 100644 --- a/scope.go +++ b/scope.go @@ -411,7 +411,7 @@ func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } -func (scope *Scope) changeableDBColumn(column string) bool { +func (scope *Scope) isChangeableDBColumn(column string) bool { selectAttrs := scope.SelectAttrs() omitAttrs := scope.OmitAttrs() diff --git a/update_test.go b/update_test.go index 75877488..71aa7c43 100644 --- a/update_test.go +++ b/update_test.go @@ -101,7 +101,6 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) if animal.Name != "galeone" { t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) } @@ -109,17 +108,15 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { // When changing a field with a default value, the change must occur animal.Name = "amazing horse" DB.Save(&animal) - DB.First(&animal, animal.Counter) if animal.Name != "amazing horse" { t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) } - // When changing a field with a default value with blank value + // When changing a field with a default value with blank value, the DBMS should insert the default value. Not the empty one. animal.Name = "" DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + if animal.Name == "" { + t.Errorf("Update a filed with an associated default value should not occur when trying to insert an empty field. The default one should be inserted\n") } }