Refactor setter, valuer

This commit is contained in:
Jinzhu 2022-02-14 16:28:47 +08:00
parent 69b851ed95
commit fb52b97363
3 changed files with 47 additions and 88 deletions

14
scan.go
View File

@ -10,6 +10,7 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
for idx, name := range columns { for idx, name := range columns {
@ -99,14 +100,17 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
} }
} }
// ScanMode scan data mode
type ScanMode uint8 type ScanMode uint8
// scan modes
const ( const (
ScanInitialized ScanMode = 1 << 0 // 1 ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2 ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
) )
// Scan scan rows into db statement
func Scan(rows *sql.Rows, db *DB, mode ScanMode) { func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
var ( var (
columns, _ = rows.Columns() columns, _ = rows.Columns()
@ -138,7 +142,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
} }
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
} }
case *[]map[string]interface{}, []map[string]interface{}: case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes() columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() { for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns) prepareValues(values, db, columnTypes, columns)
@ -149,11 +153,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
mapValue := map[string]interface{}{} mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
if values, ok := dest.([]map[string]interface{}); ok { *dest = append(*dest, mapValue)
values = append(values, mapValue)
} else if values, ok := dest.(*[]map[string]interface{}); ok {
*values = append(*values, mapValue)
}
} }
case *int, *int8, *int16, *int32, *int64, case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
@ -174,7 +174,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
reflectValue = db.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
) )
if reflectValue.Kind() == reflect.Interface { for reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem() reflectValue = reflectValue.Elem()
} }

View File

@ -167,6 +167,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if _, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
field.DataType = String
} else if _, ok := field.TagSettings["SERIALIZER"]; ok {
field.DataType = String
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
} }
@ -406,45 +412,20 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
return field return field
} }
type GormFieldValuer interface {
GormFieldValue(context.Context, *Field) (interface{}, bool)
}
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter() {
// ValueOf // ValueOf returns field's value and if it is zero
switch {
case len(field.StructField.Index) == 1:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
fv, zero := fieldValue.Interface(), fieldValue.IsZero()
if vr, ok := fv.(GormFieldValuer); ok {
fv, zero = vr.GormFieldValue(ctx, field)
}
return fv, zero
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
fv, zero := fieldValue.Interface(), fieldValue.IsZero()
if vr, ok := fv.(GormFieldValuer); ok {
fv, zero = vr.GormFieldValue(ctx, field)
}
return fv, zero
}
default:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value)
for _, idx := range field.StructField.Index { // if vr, ok := fv.(GormFieldValuer); ok {
if idx >= 0 { // fv, zero = vr.GormFieldValue(ctx, field)
v = v.Field(idx) // }
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else { } else {
v = v.Field(-idx - 1) v = v.Field(-fieldIdx - 1)
if v.Type().Elem().Kind() != reflect.Struct {
return nil, true
}
if !v.IsNil() { if !v.IsNil() {
v = v.Elem() v = v.Elem()
@ -453,40 +434,23 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
} }
fv, zero := v.Interface(), v.IsZero() fv, zero := v.Interface(), v.IsZero()
if vr, ok := fv.(GormFieldValuer); ok {
fv, zero = vr.GormFieldValue(ctx, field)
}
return fv, zero return fv, zero
} }
}
// ReflectValueOf // ReflectValueOf returns field's reflect value
switch { field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
case len(field.StructField.Index) == 1: v = reflect.Indirect(v)
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0])
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
}
default:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
v := reflect.Indirect(value)
for idx, fieldIdx := range field.StructField.Index { for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 { if fieldIdx >= 0 {
v = v.Field(fieldIdx) v = v.Field(fieldIdx)
} else { } else {
v = v.Field(-fieldIdx - 1) v = v.Field(-fieldIdx - 1)
}
if v.Kind() == reflect.Ptr {
if v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
}
if idx < len(field.StructField.Index)-1 { if idx < len(field.StructField.Index)-1 {
v = v.Elem() v = v.Elem()
@ -495,7 +459,6 @@ func (field *Field) setupValuerAndSetter() {
} }
return v return v
} }
}
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
if v == nil { if v == nil {
@ -565,11 +528,7 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(ctx, value).SetBool(false) field.ReflectValueOf(ctx, value).SetBool(false)
} }
case int64: case int64:
if data > 0 { field.ReflectValueOf(ctx, value).SetBool(data > 0)
field.ReflectValueOf(ctx, value).SetBool(true)
} else {
field.ReflectValueOf(ctx, value).SetBool(false)
}
case string: case string:
b, _ := strconv.ParseBool(data) b, _ := strconv.ParseBool(data)
field.ReflectValueOf(ctx, value).SetBool(b) field.ReflectValueOf(ctx, value).SetBool(b)

View File

@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) {
{"name": "create_from_map_3", "Age": 20}, {"name": "create_from_map_3", "Age": 20},
} }
if err := DB.Model(&User{}).Create(datas).Error; err != nil { if err := DB.Model(&User{}).Create(&datas).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err) t.Fatalf("failed to create data from slice of map, got error: %v", err)
} }