diff --git a/association.go b/association.go index d1984229..2df571f5 100644 --- a/association.go +++ b/association.go @@ -177,7 +177,7 @@ func (association *Association) Delete(values ...interface{}) *Association { modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap, false) + scope.updatedAttrsWithValues(foreignKeyMap) } } else { association.setErr(results.Error) diff --git a/callback_update.go b/callback_update.go index 44f9a143..b71a47b4 100644 --- a/callback_update.go +++ b/callback_update.go @@ -22,17 +22,10 @@ func init() { func assignUpdatingAttributesCallback(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if maps := convertInterfaceToMap(attrs); len(maps) > 0 { - protected, ok := scope.Get("gorm:ignore_protected_attrs") - _, updateColumn := scope.Get("gorm:update_column") - updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) - - if updateColumn { - scope.InstanceSet("gorm:update_attrs", maps) - } else if len(updateAttrs) > 0 { - scope.InstanceSet("gorm:update_attrs", updateAttrs) - } else if !hasUpdate { + if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate { + scope.InstanceSet("gorm:update_attrs", updateMaps) + } else { scope.SkipLeft() - return } } } @@ -64,13 +57,7 @@ func updateCallback(scope *Scope) { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { for column, value := range updateAttrs.(map[string]interface{}) { - if field, ok := scope.FieldByName(column); ok { - if scope.changeableField(field) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(value))) - } - } else { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { fields := scope.Fields() diff --git a/main.go b/main.go index b7f0d2aa..cfa71b60 100644 --- a/main.go +++ b/main.go @@ -258,7 +258,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize() } else { - c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false) + c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs)) } return c } diff --git a/scope.go b/scope.go index 6d9303ec..c84a8179 100644 --- a/scope.go +++ b/scope.go @@ -154,20 +154,29 @@ func (scope *Scope) HasColumn(column string) bool { // SetColumn to set the column's value func (scope *Scope) SetColumn(column interface{}, value interface{}) error { + var updateAttrs = map[string]interface{}{} + if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { + updateAttrs = attrs.(map[string]interface{}) + defer scope.InstanceSet("gorm:update_attrs", updateAttrs) + } + if field, ok := column.(*Field); ok { + updateAttrs[field.DBName] = value return field.Set(value) } else if name, ok := column.(string); ok { - if field, ok := scope.Fields()[name]; ok { + updateAttrs[field.DBName] = value return field.Set(value) } dbName := ToDBName(name) if field, ok := scope.Fields()[dbName]; ok { + updateAttrs[field.DBName] = value return field.Set(value) } if field, ok := scope.FieldByName(name); ok { + updateAttrs[field.DBName] = value return field.Set(value) } } diff --git a/scope_private.go b/scope_private.go index 6b34a4b3..9b01dcb9 100644 --- a/scope_private.go +++ b/scope_private.go @@ -319,38 +319,30 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } -func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { - if !scope.IndirectValue().CanAddr() { +func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) { + if scope.IndirectValue().Kind() != reflect.Struct { return values, true } - var hasExpr bool + results = map[string]interface{}{} for key, value := range values { - if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() { + if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { - if _, ok := value.(*expr); ok { - hasExpr = true - } else if !equalAsString(field.Field.Interface(), value) { - hasUpdate = true + if field.IsNormal { + if _, ok := value.(*expr); ok { + hasUpdate = true + results[field.DBName] = value + } else if !equalAsString(field.Field.Interface(), value) { + hasUpdate = true + field.Set(value) + results[field.DBName] = field.Field.Interface() + } + } else { field.Set(value) } } } } - - if hasExpr { - var updateMap = map[string]interface{}{} - for key, field := range scope.Fields() { - if field.IsNormal { - if v, ok := values[key]; ok { - updateMap[key] = v - } else { - updateMap[key] = field.Field.Interface() - } - } - } - return updateMap, true - } return } @@ -370,10 +362,10 @@ func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) initialize() *Scope { for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"])) } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs)) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs)) return scope } diff --git a/update_test.go b/update_test.go index 417463bb..fd193ada 100644 --- a/update_test.go +++ b/update_test.go @@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) { } DB.First(&product4, product4.Id) + updatedAt4 := product4.UpdatedAt DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) var product5 Product DB.First(&product5, product4.Id) if product5.Price != product4.Price+100-50 { t.Errorf("Update with expression") } - if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { + if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { t.Errorf("Update with expression should update UpdatedAt") } } @@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) { t.Errorf("product2's code should be updated") } + updatedAt4 := product4.UpdatedAt DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) var product5 Product DB.First(&product5, product4.Id) if product5.Price != product4.Price+100 { t.Errorf("Updates with expression") } - if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { + // product4's UpdatedAt will be reset when updating + if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { t.Errorf("Updates with expression should update UpdatedAt") } } @@ -421,8 +424,6 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) { } func TestUpdatesWithBlankValues(t *testing.T) { - t.Skip("not implemented") - product := Product{Code: "product1", Price: 10} DB.Save(&product) diff --git a/utils.go b/utils.go index bfdaf9f7..55d75619 100644 --- a/utils.go +++ b/utils.go @@ -192,9 +192,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { switch value := values.(type) { case map[string]interface{}: - for k, v := range value { - attrs[k] = v - } + return value case []interface{}: for _, v := range value { for key, value := range convertInterfaceToMap(v) {