diff --git a/interfaces.go b/interfaces.go index 44b2fced..ff0ca60a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -40,14 +40,17 @@ type SavePointerDialectorInterface interface { RollbackTo(tx *DB, name string) error } +// TxBeginner tx beginner type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +// ConnPoolBeginner conn pool beginner type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxCommitter tx commiter type TxCommitter interface { Commit() error Rollback() error @@ -58,6 +61,7 @@ type Valuer interface { GormValue(context.Context, *DB) clause.Expr } +// GetDBConnector SQL db connector type GetDBConnector interface { GetDBConn() (*sql.DB, error) } diff --git a/schema/field.go b/schema/field.go index d2172adc..683e8492 100644 --- a/schema/field.go +++ b/schema/field.go @@ -84,6 +84,7 @@ type Field struct { ReflectValueOf func(context.Context, reflect.Value) reflect.Value ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface NewValuePool FieldNewValuePool } @@ -168,10 +169,16 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { field.DataType = String - } else if _, ok := field.TagSettings["SERIALIZER"]; ok { + field.Serializer = v + } else if name, ok := field.TagSettings["SERIALIZER"]; ok { field.DataType = String + if strings.ToLower(name) == "json" { + field.Serializer = &JSONSerializer{} + } else { + schema.err = fmt.Errorf("invalid serializer type %v", name) + } } if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { @@ -413,69 +420,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } -// sync pools -var ( - normalPool sync.Map - stringPool = &sync.Pool{ - New: func() interface{} { - var v string - ptrV := &v - return &ptrV - }, - } - intPool = &sync.Pool{ - New: func() interface{} { - var v int64 - ptrV := &v - return &ptrV - }, - } - uintPool = &sync.Pool{ - New: func() interface{} { - var v uint64 - ptrV := &v - return &ptrV - }, - } - floatPool = &sync.Pool{ - New: func() interface{} { - var v float64 - ptrV := &v - return &ptrV - }, - } - boolPool = &sync.Pool{ - New: func() interface{} { - var v bool - ptrV := &v - return &ptrV - }, - } - timePool = &sync.Pool{ - New: func() interface{} { - var v time.Time - ptrV := &v - return &ptrV - }, - } - poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { - if v, ok := normalPool.Load(reflectType); ok { - return v.(FieldNewValuePool) - } - - v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ - New: func() interface{} { - return reflect.New(reflectType).Interface() - }, - }) - return v.(FieldNewValuePool) - } -) - // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // Setup NewValuePool - if _, ok := reflect.New(field.IndirectFieldType).Interface().(sql.Scanner); !ok { + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &Serializer{ + Field: field, + Interface: reflect.New(reflect.ValueOf(field.Serializer).Type()).Interface().(SerializerInterface), + } + }, + } + } else if _, ok := fieldValue.(sql.Scanner); !ok { + // set default NewValuePool switch field.IndirectFieldType.Kind() { case reflect.String: field.NewValuePool = stringPool @@ -595,6 +554,19 @@ func (field *Field) setupValuerAndSetter() { } // Set + if field.Serializer != nil { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + if serializer, ok := v.(*Serializer); ok { + serializer.Interface.Scan(ctx, field, value, serializer.value) + fallbackSetter(ctx, value, serializer.Interface, field.Set) + serializer.Interface = reflect.New(reflect.ValueOf(field.Serializer).Type()).Interface().(SerializerInterface) + } + return nil + } + + return + } + switch field.FieldType.Kind() { case reflect.Bool: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { diff --git a/schema/interfaces.go b/schema/interfaces.go index 638b96cf..0a233f09 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -3,6 +3,9 @@ package schema import ( "context" "database/sql/driver" + "encoding/json" + "errors" + "fmt" "reflect" "gorm.io/gorm/clause" @@ -19,30 +22,19 @@ type FieldNewValuePool interface { Put(interface{}) } -type fieldNewValuePool struct { - getter func() interface{} - putter func(interface{}) -} - -func (fp fieldNewValuePool) Get() interface{} { - return fp.getter() -} - -func (fp fieldNewValuePool) Put(v interface{}) { - fp.putter(v) -} - // Serializer field value serializer type Serializer struct { Field *Field Interface SerializerInterface Destination reflect.Value Context context.Context + value interface{} } // Scan implements sql.Scanner interface func (s *Serializer) Scan(value interface{}) error { - return s.Interface.Scan(s.Context, s.Field, s.Destination, value) + s.value = value + return nil } // Value implements driver.Valuer interface @@ -56,6 +48,38 @@ type SerializerInterface interface { Value(ctx context.Context, field *Field, dst reflect.Value) (interface{}, error) } +// JSONSerializer json serializer +type JSONSerializer struct { +} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + } + + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value) (interface{}, error) { + fv, _ := field.ValueOf(ctx, dst) + return fv, nil +} + // CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface diff --git a/schema/pool.go b/schema/pool.go new file mode 100644 index 00000000..f5c73153 --- /dev/null +++ b/schema/pool.go @@ -0,0 +1,62 @@ +package schema + +import ( + "reflect" + "sync" + "time" +) + +// sync pools +var ( + normalPool sync.Map + stringPool = &sync.Pool{ + New: func() interface{} { + var v string + ptrV := &v + return &ptrV + }, + } + intPool = &sync.Pool{ + New: func() interface{} { + var v int64 + ptrV := &v + return &ptrV + }, + } + uintPool = &sync.Pool{ + New: func() interface{} { + var v uint64 + ptrV := &v + return &ptrV + }, + } + floatPool = &sync.Pool{ + New: func() interface{} { + var v float64 + ptrV := &v + return &ptrV + }, + } + boolPool = &sync.Pool{ + New: func() interface{} { + var v bool + ptrV := &v + return &ptrV + }, + } + timePool = &sync.Pool{ + New: func() interface{} { + var v time.Time + ptrV := &v + return &ptrV + }, + } + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +)