refactor: support map[string]interface{} alias
This commit is contained in:
parent
7701c88507
commit
e246b39d89
@ -3,6 +3,7 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -163,10 +164,46 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch value := updatingValue.Interface().(type) {
|
switch updatingValue.Kind() {
|
||||||
case map[string]interface{}:
|
case reflect.Struct:
|
||||||
set = make([]clause.Assignment, 0, len(value))
|
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||||
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
|
field := stmt.Schema.LookUpField(dbName)
|
||||||
|
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
||||||
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||||
|
value, isZero := field.ValueOf(updatingValue)
|
||||||
|
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||||
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
|
value = stmt.DB.NowFunc().UnixNano()
|
||||||
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
|
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||||
|
} else if field.GORMDataType == schema.Time {
|
||||||
|
value = stmt.DB.NowFunc()
|
||||||
|
} else {
|
||||||
|
value = stmt.DB.NowFunc().Unix()
|
||||||
|
}
|
||||||
|
isZero = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok || !isZero {
|
||||||
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||||
|
assignValue(field, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
value, ok := convertMap(updatingValue)
|
||||||
|
if !ok {
|
||||||
|
stmt.AddError(gorm.ErrInvalidData)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
set = make([]clause.Assignment, 0, len(value))
|
||||||
keys := make([]string, 0, len(value))
|
keys := make([]string, 0, len(value))
|
||||||
for k := range value {
|
for k := range value {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
@ -220,42 +257,20 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
switch updatingValue.Kind() {
|
stmt.AddError(gorm.ErrInvalidData)
|
||||||
case reflect.Struct:
|
|
||||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
|
||||||
field := stmt.Schema.LookUpField(dbName)
|
|
||||||
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
|
||||||
value, isZero := field.ValueOf(updatingValue)
|
|
||||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
|
||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
|
||||||
value = stmt.DB.NowFunc().UnixNano()
|
|
||||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
|
||||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
|
||||||
} else if field.GORMDataType == schema.Time {
|
|
||||||
value = stmt.DB.NowFunc()
|
|
||||||
} else {
|
|
||||||
value = stmt.DB.NowFunc().Unix()
|
|
||||||
}
|
|
||||||
isZero = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok || !isZero {
|
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
|
||||||
assignValue(field, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
stmt.AddError(gorm.ErrInvalidData)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertMap(src reflect.Value) (map[string]interface{}, bool) {
|
||||||
|
mt := reflect.TypeOf(map[string]interface{}{})
|
||||||
|
if !src.Type().AssignableTo(mt) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcData := src.Interface()
|
||||||
|
// get interface pointer and convert to map
|
||||||
|
v := (unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(&srcData)) + unsafe.Sizeof(int(0))))
|
||||||
|
dst := (*map[string]interface{})(v)
|
||||||
|
return *dst, true
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user