From fb52b97363d527f262a125ebce9cfcac9608b58b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Feb 2022 16:28:47 +0800 Subject: [PATCH] Refactor setter, valuer --- scan.go | 14 ++--- schema/field.go | 119 ++++++++++++++----------------------------- tests/create_test.go | 2 +- 3 files changed, 47 insertions(+), 88 deletions(-) diff --git a/scan.go b/scan.go index 64ea8dbd..e8ab805e 100644 --- a/scan.go +++ b/scan.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/schema" ) +// prepareValues prepare values slice func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { @@ -99,14 +100,17 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re } } +// ScanMode scan data mode type ScanMode uint8 +// scan modes const ( ScanInitialized ScanMode = 1 << 0 // 1 ScanUpdate ScanMode = 1 << 1 // 2 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) +// Scan scan rows into db statement func Scan(rows *sql.Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() @@ -138,7 +142,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } scanIntoMap(mapValue, values, columns) } - case *[]map[string]interface{}, []map[string]interface{}: + case *[]map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) @@ -149,11 +153,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) - if values, ok := dest.([]map[string]interface{}); ok { - values = append(values, mapValue) - } else if values, ok := dest.(*[]map[string]interface{}); ok { - *values = append(*values, mapValue) - } + *dest = append(*dest, mapValue) } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, @@ -174,7 +174,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - if reflectValue.Kind() == reflect.Interface { + for reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } diff --git a/schema/field.go b/schema/field.go index 773e5414..85f737b5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -167,6 +167,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if _, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String + } else if _, ok := field.TagSettings["SERIALIZER"]; ok { + field.DataType = String + } + if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) } @@ -406,95 +412,52 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } -type GormFieldValuer interface { - GormFieldValue(context.Context, *Field) (interface{}, bool) -} - // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // ValueOf - switch { - case len(field.StructField.Index) == 1: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - fv, zero := fieldValue.Interface(), fieldValue.IsZero() - if vr, ok := fv.(GormFieldValuer); ok { - fv, zero = vr.GormFieldValue(ctx, field) - } - return fv, zero - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - fv, zero := fieldValue.Interface(), fieldValue.IsZero() - if vr, ok := fv.(GormFieldValuer); ok { - fv, zero = vr.GormFieldValue(ctx, field) - } - return fv, zero - } - default: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - v := reflect.Indirect(value) + // ValueOf returns field's value and if it is zero - for _, idx := range field.StructField.Index { - if idx >= 0 { - v = v.Field(idx) + // if vr, ok := fv.(GormFieldValuer); ok { + // fv, zero = vr.GormFieldValue(ctx, field) + // } + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() } else { - v = v.Field(-idx - 1) - - if v.Type().Elem().Kind() != reflect.Struct { - return nil, true - } - - if !v.IsNil() { - v = v.Elem() - } else { - return nil, true - } + return nil, true } } - fv, zero := v.Interface(), v.IsZero() - if vr, ok := fv.(GormFieldValuer); ok { - fv, zero = vr.GormFieldValue(ctx, field) - } - return fv, zero } + + fv, zero := v.Interface(), v.IsZero() + return fv, zero } - // ReflectValueOf - switch { - case len(field.StructField.Index) == 1: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]) - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - } - default: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - v := reflect.Indirect(value) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) + // ReflectValueOf returns field's reflect value + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - if v.Kind() == reflect.Ptr { - if v.Type().Elem().Kind() == reflect.Struct { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - } - - if idx < len(field.StructField.Index)-1 { - v = v.Elem() - } + if idx < len(field.StructField.Index)-1 { + v = v.Elem() } } - return v } + return v } fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { @@ -565,11 +528,7 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(ctx, value).SetBool(false) } case int64: - if data > 0 { - field.ReflectValueOf(ctx, value).SetBool(true) - } else { - field.ReflectValueOf(ctx, value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: b, _ := strconv.ParseBool(data) field.ReflectValueOf(ctx, value).SetBool(b) diff --git a/tests/create_test.go b/tests/create_test.go index af2abdb0..2b23d440 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) { {"name": "create_from_map_3", "Age": 20}, } - if err := DB.Model(&User{}).Create(datas).Error; err != nil { + if err := DB.Model(&User{}).Create(&datas).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) }