diff --git a/callbacks/update.go b/callbacks/update.go index db5b52fb..5cf3fdcf 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -3,6 +3,7 @@ package callbacks import ( "reflect" "sort" + "unsafe" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -163,10 +164,46 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - switch value := updatingValue.Interface().(type) { - case map[string]interface{}: - set = make([]clause.Assignment, 0, len(value)) + switch updatingValue.Kind() { + case reflect.Struct: + set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } + isZero = false + } + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + case reflect.Map: + value, ok := convertMap(updatingValue) + if !ok { + stmt.AddError(gorm.ErrInvalidData) + return + } + + set = make([]clause.Assignment, 0, len(value)) keys := make([]string, 0, len(value)) for k := range value { keys = append(keys, k) @@ -220,42 +257,20 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: - switch updatingValue.Kind() { - case reflect.Struct: - set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) - for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.LookUpField(dbName) - if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) - if !stmt.SkipHooks && field.AutoUpdateTime > 0 { - if field.AutoUpdateTime == schema.UnixNanosecond { - value = stmt.DB.NowFunc().UnixNano() - } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { - value = stmt.DB.NowFunc().Unix() - } - isZero = false - } - - if ok || !isZero { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) - } - } - } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) - } - } - } - default: - stmt.AddError(gorm.ErrInvalidData) - } + stmt.AddError(gorm.ErrInvalidData) } - return } + +func convertMap(src reflect.Value) (map[string]interface{}, bool) { + mt := reflect.TypeOf(map[string]interface{}{}) + if !src.Type().AssignableTo(mt) { + return nil, false + } + + srcData := src.Interface() + // get interface pointer and convert to map + v := (unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(&srcData)) + unsafe.Sizeof(int(0)))) + dst := (*map[string]interface{})(v) + return *dst, true +}