Refactor setter, valuer

This commit is contained in:
Jinzhu 2022-02-14 16:28:47 +08:00
parent 69b851ed95
commit fb52b97363
3 changed files with 47 additions and 88 deletions

14
scan.go
View File

@ -10,6 +10,7 @@ import (
"gorm.io/gorm/schema"
)
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
@ -99,14 +100,17 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
@ -138,7 +142,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}, []map[string]interface{}:
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
@ -149,11 +153,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
if values, ok := dest.([]map[string]interface{}); ok {
values = append(values, mapValue)
} else if values, ok := dest.(*[]map[string]interface{}); ok {
*values = append(*values, mapValue)
}
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
@ -174,7 +174,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
for reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}

View File

@ -167,6 +167,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if _, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
field.DataType = String
} else if _, ok := field.TagSettings["SERIALIZER"]; ok {
field.DataType = String
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
}
@ -406,95 +412,52 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
return field
}
type GormFieldValuer interface {
GormFieldValue(context.Context, *Field) (interface{}, bool)
}
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
// ValueOf
switch {
case len(field.StructField.Index) == 1:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
fv, zero := fieldValue.Interface(), fieldValue.IsZero()
if vr, ok := fv.(GormFieldValuer); ok {
fv, zero = vr.GormFieldValue(ctx, field)
}
return fv, zero
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
fv, zero := fieldValue.Interface(), fieldValue.IsZero()
if vr, ok := fv.(GormFieldValuer); ok {
fv, zero = vr.GormFieldValue(ctx, field)
}
return fv, zero
}
default:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value)
// ValueOf returns field's value and if it is zero
for _, idx := range field.StructField.Index {
if idx >= 0 {
v = v.Field(idx)
// if vr, ok := fv.(GormFieldValuer); ok {
// fv, zero = vr.GormFieldValue(ctx, field)
// }
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if !v.IsNil() {
v = v.Elem()
} else {
v = v.Field(-idx - 1)
if v.Type().Elem().Kind() != reflect.Struct {
return nil, true
}
if !v.IsNil() {
v = v.Elem()
} else {
return nil, true
}
return nil, true
}
}
fv, zero := v.Interface(), v.IsZero()
if vr, ok := fv.(GormFieldValuer); ok {
fv, zero = vr.GormFieldValue(ctx, field)
}
return fv, zero
}
fv, zero := v.Interface(), v.IsZero()
return fv, zero
}
// ReflectValueOf
switch {
case len(field.StructField.Index) == 1:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0])
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
}
default:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
v := reflect.Indirect(value)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
// ReflectValueOf returns field's reflect value
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Kind() == reflect.Ptr {
if v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
}
if idx < len(field.StructField.Index)-1 {
v = v.Elem()
}
if idx < len(field.StructField.Index)-1 {
v = v.Elem()
}
}
return v
}
return v
}
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
@ -565,11 +528,7 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(ctx, value).SetBool(false)
}
case int64:
if data > 0 {
field.ReflectValueOf(ctx, value).SetBool(true)
} else {
field.ReflectValueOf(ctx, value).SetBool(false)
}
field.ReflectValueOf(ctx, value).SetBool(data > 0)
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(ctx, value).SetBool(b)

View File

@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) {
{"name": "create_from_map_3", "Age": 20},
}
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
if err := DB.Model(&User{}).Create(&datas).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}