refactor: support map[string]interface{} alias

This commit is contained in:
Harman 2021-04-21 16:55:07 +07:00
parent 7701c88507
commit e246b39d89

View File

@ -3,6 +3,7 @@ package callbacks
import ( import (
"reflect" "reflect"
"sort" "sort"
"unsafe"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -163,10 +164,46 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
switch value := updatingValue.Interface().(type) { switch updatingValue.Kind() {
case map[string]interface{}: case reflect.Struct:
set = make([]clause.Assignment, 0, len(value)) set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
}
isZero = false
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value)
}
}
} else {
if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
case reflect.Map:
value, ok := convertMap(updatingValue)
if !ok {
stmt.AddError(gorm.ErrInvalidData)
return
}
set = make([]clause.Assignment, 0, len(value))
keys := make([]string, 0, len(value)) keys := make([]string, 0, len(value))
for k := range value { for k := range value {
keys = append(keys, k) keys = append(keys, k)
@ -220,42 +257,20 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
default: default:
switch updatingValue.Kind() { stmt.AddError(gorm.ErrInvalidData)
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
}
isZero = false
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value)
}
}
} else {
if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
} }
return return
} }
func convertMap(src reflect.Value) (map[string]interface{}, bool) {
mt := reflect.TypeOf(map[string]interface{}{})
if !src.Type().AssignableTo(mt) {
return nil, false
}
srcData := src.Interface()
// get interface pointer and convert to map
v := (unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(&srcData)) + unsafe.Sizeof(int(0))))
dst := (*map[string]interface{})(v)
return *dst, true
}