From 20abf83a21af3f997bc3ca0af20335ce1bef4c47 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Feb 2022 15:08:43 +0800 Subject: [PATCH] Fix pool manager --- scan.go | 10 +++--- schema/field.go | 92 ++++++++++++++++++++++++++++--------------------- 2 files changed, 57 insertions(+), 45 deletions(-) diff --git a/scan.go b/scan.go index 462d42c7..0da12daf 100644 --- a/scan.go +++ b/scan.go @@ -55,15 +55,13 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re if sch == nil { values[idx] = reflectValue.Interface() } else if field := sch.LookUpField(column); field != nil && field.Readable { - fieldValue := field.NewValuePool.Get() - values[idx] = &fieldValue - defer field.NewValuePool.Put(fieldValue) + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fieldValue := field.NewValuePool.Get() - values[idx] = &fieldValue - defer field.NewValuePool.Put(fieldValue) + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) continue } } diff --git a/schema/field.go b/schema/field.go index b2f48a6a..ea065cd1 100644 --- a/schema/field.go +++ b/schema/field.go @@ -417,36 +417,43 @@ var ( stringPool = &sync.Pool{ New: func() interface{} { var v string - return &v + ptrV := &v + return &ptrV }, } intPool = &sync.Pool{ New: func() interface{} { var v int64 - return &v + ptrV := &v + return &ptrV }, } uintPool = &sync.Pool{ New: func() interface{} { var v uint64 - return &v + ptrV := &v + return &ptrV }, } floatPool = &sync.Pool{ New: func() interface{} { var v float64 - return &v + ptrV := &v + return &ptrV }, } boolPool = &sync.Pool{ New: func() interface{} { var v bool - return &v + ptrV := &v + return &ptrV }, } timePool = &sync.Pool{ New: func() interface{} { - return &time.Time{} + var v time.Time + ptrV := &v + return &ptrV }, } ) @@ -454,31 +461,34 @@ var ( // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // Setup NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - if field.NewValuePool == nil { - field.NewValuePool = fieldNewValuePool{ - getter: func() interface{} { - return reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - }, - putter: func(interface{}) {}, + if _, ok := reflect.New(field.IndirectFieldType).Interface().(sql.Scanner); !ok { + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool } } } + if field.NewValuePool == nil { + field.NewValuePool = fieldNewValuePool{ + getter: func() interface{} { + return reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + }, + putter: func(interface{}) {}, + } + } + // ValueOf returns field's value and if it is zero field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) @@ -580,14 +590,12 @@ func (field *Field) setupValuerAndSetter() { case reflect.Bool: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) + } case bool: field.ReflectValueOf(ctx, value).SetBool(data) - case **bool: - if data != nil { - field.ReflectValueOf(ctx, value).SetBool(**data) - } else { - field.ReflectValueOf(ctx, value).SetBool(false) - } case int64: field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: @@ -602,7 +610,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **int64: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) } case int64: @@ -666,7 +674,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **uint64: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) } case uint64: @@ -718,7 +726,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **float64: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) } case float64: @@ -761,10 +769,8 @@ func (field *Field) setupValuerAndSetter() { case reflect.String: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { - case *string: - field.ReflectValueOf(ctx, value).SetString(*data) case **string: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetString(**data) } case string: @@ -786,6 +792,10 @@ func (field *Field) setupValuerAndSetter() { case time.Time: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } case time.Time: field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: @@ -808,6 +818,10 @@ func (field *Field) setupValuerAndSetter() { case *time.Time: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } case time.Time: fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() {