From 0923458e528845f4ce40ec81c607c9e594f80c03 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Sat, 21 Nov 2015 15:35:31 +0100 Subject: [PATCH 1/4] Fix update of fields with a default value --- callback.go | 6 ++++++ callback_create.go | 10 ++-------- callback_update.go | 9 +++++++-- scope.go | 2 +- update_test.go | 9 +++------ 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/callback.go b/callback.go index 603e5111..d6bad42c 100644 --- a/callback.go +++ b/callback.go @@ -197,4 +197,10 @@ func (c *callback) sort() { c.rowQueries = sortProcessors(rowQueries) } +func ForceReload(scope *Scope) { + if _, ok := scope.InstanceGet("gorm:force_reload"); ok { + scope.DB().New().First(scope.Value) + } +} + var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_create.go b/callback_create.go index 71db4ef0..3f61b785 100644 --- a/callback_create.go +++ b/callback_create.go @@ -33,7 +33,7 @@ func Create(scope *Scope) { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } else if field.HasDefaultValue { - scope.InstanceSet("gorm:force_reload_after_create", true) + scope.InstanceSet("gorm:force_reload", true) } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { @@ -97,12 +97,6 @@ func Create(scope *Scope) { } } -func ForceReloadAfterCreate(scope *Scope) { - if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok { - scope.DB().New().First(scope.Value) - } -} - func AfterCreate(scope *Scope) { scope.CallMethodWithErrorCheck("AfterCreate") scope.CallMethodWithErrorCheck("AfterSave") @@ -114,7 +108,7 @@ func init() { DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) DefaultCallback.Create().Register("gorm:create", Create) - DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) + DefaultCallback.Create().Register("gorm:force_reload", ForceReload) DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("gorm:after_create", AfterCreate) DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callback_update.go b/callback_update.go index 4c9952d2..b710f40e 100644 --- a/callback_update.go +++ b/callback_update.go @@ -43,7 +43,7 @@ func Update(scope *Scope) { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { for key, value := range updateAttrs.(map[string]interface{}) { - if scope.changeableDBColumn(key) { + if scope.isChangeableDBColumn(key) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) } } @@ -51,7 +51,11 @@ func Update(scope *Scope) { fields := scope.Fields() for _, field := range fields { if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.HasDefaultValue || !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } else if field.HasDefaultValue { + scope.InstanceSet("gorm:force_reload", true) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, dbName := range relationship.ForeignDBNames { if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { @@ -92,4 +96,5 @@ func init() { DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) DefaultCallback.Update().Register("gorm:after_update", AfterUpdate) DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + DefaultCallback.Update().Register("gorm:force_reload", ForceReload) } diff --git a/scope.go b/scope.go index 0003d575..3c7e7109 100644 --- a/scope.go +++ b/scope.go @@ -411,7 +411,7 @@ func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } -func (scope *Scope) changeableDBColumn(column string) bool { +func (scope *Scope) isChangeableDBColumn(column string) bool { selectAttrs := scope.SelectAttrs() omitAttrs := scope.OmitAttrs() diff --git a/update_test.go b/update_test.go index 75877488..71aa7c43 100644 --- a/update_test.go +++ b/update_test.go @@ -101,7 +101,6 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) if animal.Name != "galeone" { t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) } @@ -109,17 +108,15 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { // When changing a field with a default value, the change must occur animal.Name = "amazing horse" DB.Save(&animal) - DB.First(&animal, animal.Counter) if animal.Name != "amazing horse" { t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) } - // When changing a field with a default value with blank value + // When changing a field with a default value with blank value, the DBMS should insert the default value. Not the empty one. animal.Name = "" DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + if animal.Name == "" { + t.Errorf("Update a filed with an associated default value should not occur when trying to insert an empty field. The default one should be inserted\n") } } From 40b51709ff396efe40d2cd884c88e2b60cd5f03e Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Sun, 22 Nov 2015 21:48:32 +0100 Subject: [PATCH 2/4] Use the sql default value (of the tag field) when updating a field to a blank value --- callback_update.go | 12 +++++++++--- structs_test.go | 1 + update_test.go | 18 ++++++++++++++++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/callback_update.go b/callback_update.go index b710f40e..e45a3a84 100644 --- a/callback_update.go +++ b/callback_update.go @@ -51,10 +51,16 @@ func Update(scope *Scope) { fields := scope.Fields() for _, field := range fields { if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { - if !field.HasDefaultValue || !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if field.HasDefaultValue { + if field.HasDefaultValue { + if field.IsBlank { + defaultValue := strings.Trim(parseTagSetting(field.Tag.Get("sql"))["DEFAULT"], "'") + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(defaultValue))) + } else { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } scope.InstanceSet("gorm:force_reload", true) + } else { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, dbName := range relationship.ForeignDBNames { diff --git a/structs_test.go b/structs_test.go index 9a9b23d1..b52e64ef 100644 --- a/structs_test.go +++ b/structs_test.go @@ -134,6 +134,7 @@ type Animal struct { unexported string // unexported value CreatedAt time.Time UpdatedAt time.Time + Cool bool `sql:"default:false"` } type JoinTable struct { diff --git a/update_test.go b/update_test.go index 71aa7c43..244e0dec 100644 --- a/update_test.go +++ b/update_test.go @@ -87,10 +87,11 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { DB.Save(&animal) updatedAt1 := animal.UpdatedAt + // Sleep for a second and than update a field + time.Sleep(1000 * time.Millisecond) DB.Save(&animal).Update("name", "Francis") - if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") + t.Errorf("updatedAt should be updated when changing a field") } var animals []Animal @@ -118,6 +119,19 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { if animal.Name == "" { t.Errorf("Update a filed with an associated default value should not occur when trying to insert an empty field. The default one should be inserted\n") } + + // Animal.Cool has a default value thats equal to the Zero of its type. (false) I have to update this field to true and false without problems + animal.Cool = true + DB.Save(&animal) + if !animal.Cool { + t.Errorf("I should update a field with a default value to someother value") + } + + animal.Cool = false + DB.Save(&animal) + if animal.Cool { + t.Errorf("I should update a field with an associated blank value to its blank value") + } } func TestUpdates(t *testing.T) { From 5399fd879f634f880648cd96427891a843ff7e11 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Sun, 22 Nov 2015 23:22:34 +0100 Subject: [PATCH 3/4] Add support for sql expressions in struct tag default fileds, on update. Fixes create_test.go time comparison --- callback_update.go | 12 ++++++++++-- create_test.go | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/callback_update.go b/callback_update.go index e45a3a84..6c960723 100644 --- a/callback_update.go +++ b/callback_update.go @@ -37,6 +37,14 @@ func UpdateTimeStampWhenUpdate(scope *Scope) { } } +func escapeIfNeeded(scope *Scope, value string) string { + // default:'string value' OR sql expression, like: default:"(now() at timezone 'utc')" + if (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) || (strings.HasPrefix(value, "(") && strings.HasSuffix(value, ")")) { + return value + } + return scope.AddToVars(value) // default:'something' like:default:'false' should be between quotes (what AddToVars do) +} + func Update(scope *Scope) { if !scope.HasError() { var sqls []string @@ -53,8 +61,8 @@ func Update(scope *Scope) { if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { if field.HasDefaultValue { if field.IsBlank { - defaultValue := strings.Trim(parseTagSetting(field.Tag.Get("sql"))["DEFAULT"], "'") - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(defaultValue))) + defaultValue := parseTagSetting(field.Tag.Get("sql"))["DEFAULT"] + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), escapeIfNeeded(scope, defaultValue))) } else { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } diff --git a/create_test.go b/create_test.go index 97175980..56ac19c2 100644 --- a/create_test.go +++ b/create_test.go @@ -51,7 +51,7 @@ func TestCreate(t *testing.T) { DB.Model(user).Update("name", "create_user_new_name") DB.First(&user, user.Id) - if user.CreatedAt != newUser.CreatedAt { + if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { t.Errorf("CreatedAt should not be changed after update") } } From 2d2d926b6e93260451ccee79425feb9d53616482 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Mon, 23 Nov 2015 16:41:29 +0100 Subject: [PATCH 4/4] Added support for default fields on create. Unified logics of default values between update and create --- callback.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++ callback_create.go | 7 ++++-- callback_update.go | 15 +------------ create_test.go | 22 +++++++++++++++++++ 4 files changed, 81 insertions(+), 16 deletions(-) diff --git a/callback.go b/callback.go index d6bad42c..d2f8e2bd 100644 --- a/callback.go +++ b/callback.go @@ -2,6 +2,9 @@ package gorm import ( "fmt" + "reflect" + "strconv" + "strings" ) type callback struct { @@ -203,4 +206,54 @@ func ForceReload(scope *Scope) { } } +func escapeIfNeeded(scope *Scope, value string) string { + trimmed := strings.TrimSpace(value) + // default:'string value' OR + if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) || + strings.HasSuffix(trimmed, ")") { //sql expression, like: default:"(now() at timezone 'utc') or now() or user_defined_function(parameters.. ) + return trimmed + } + + lowered := strings.ToLower(trimmed) + if lowered == "null" || strings.HasPrefix(lowered, "current_") { // null and other sql reserved keyworks (used a default values) can't be placed between apices + return lowered + } + return scope.AddToVars(trimmed) // default:'something' like:default:'false' should be between quotes (what AddToVars do) +} + +func handleDefaultValue(scope *Scope, field *Field) string { + if field.IsBlank { + defaultValue := strings.TrimSpace(parseTagSetting(field.Tag.Get("sql"))["DEFAULT"]) + switch field.Field.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if numericValue, err := strconv.ParseInt(defaultValue, 10, 64); err == nil { + if numericValue != field.Field.Int() { + return escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int())) + } else { + return escapeIfNeeded(scope, defaultValue) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if numericValue, err := strconv.ParseUint(defaultValue, 10, 64); err == nil { + if numericValue != field.Field.Uint() { + return escapeIfNeeded(scope, escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int()))) + } else { + return escapeIfNeeded(scope, defaultValue) + } + } + case reflect.Bool: + if boolValue, err := strconv.ParseBool(defaultValue); err == nil { + if boolValue != field.Field.Bool() { + return escapeIfNeeded(scope, fmt.Sprintf("%t", field.Field.Bool())) + } else { + return escapeIfNeeded(scope, defaultValue) + } + } + default: + return escapeIfNeeded(scope, defaultValue) + } + } + return scope.AddToVars(field.Field.Interface()) +} + var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_create.go b/callback_create.go index 3f61b785..9a74db0d 100644 --- a/callback_create.go +++ b/callback_create.go @@ -29,12 +29,15 @@ func Create(scope *Scope) { if scope.changeableField(field) { if field.IsNormal { if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { - if !field.IsBlank || !field.HasDefaultValue { + if !field.HasDefaultValue { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) - } else if field.HasDefaultValue { + } else { + columns = append(columns, scope.Quote(field.DBName)) + sqls = append(sqls, handleDefaultValue(scope, field)) scope.InstanceSet("gorm:force_reload", true) } + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, dbName := range relationship.ForeignDBNames { diff --git a/callback_update.go b/callback_update.go index 6c960723..38353cc5 100644 --- a/callback_update.go +++ b/callback_update.go @@ -37,14 +37,6 @@ func UpdateTimeStampWhenUpdate(scope *Scope) { } } -func escapeIfNeeded(scope *Scope, value string) string { - // default:'string value' OR sql expression, like: default:"(now() at timezone 'utc')" - if (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) || (strings.HasPrefix(value, "(") && strings.HasSuffix(value, ")")) { - return value - } - return scope.AddToVars(value) // default:'something' like:default:'false' should be between quotes (what AddToVars do) -} - func Update(scope *Scope) { if !scope.HasError() { var sqls []string @@ -60,12 +52,7 @@ func Update(scope *Scope) { for _, field := range fields { if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { if field.HasDefaultValue { - if field.IsBlank { - defaultValue := parseTagSetting(field.Tag.Get("sql"))["DEFAULT"] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), escapeIfNeeded(scope, defaultValue))) - } else { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), handleDefaultValue(scope, field))) scope.InstanceSet("gorm:force_reload", true) } else { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) diff --git a/create_test.go b/create_test.go index 56ac19c2..4a5245e6 100644 --- a/create_test.go +++ b/create_test.go @@ -157,3 +157,25 @@ func TestOmitWithCreate(t *testing.T) { t.Errorf("Should not create omited relationships") } } + +// Test from: https://github.com/jinzhu/gorm/issues/689 +func TestCreateWithBoolDefaultValue(t *testing.T) { + type Data struct { + ID int `gorm:"column:id;primary_key" json:"id"` + Name string `sql:"type:varchar(100);not null;unique" json:"name"` + DeleteAllowed bool `sql:"not null;DEFAULT:true" json:"delete_allowed"` + } + + DB.AutoMigrate(&Data{}) + + data := Data{ + Name: "test", + DeleteAllowed: false, + } + + DB.Create(&data) + + if data.DeleteAllowed { + t.Error("Test failed") + } +}