diff --git a/schema/field.go b/schema/field.go index 65d54ab9..319f3693 100644 --- a/schema/field.go +++ b/schema/field.go @@ -172,16 +172,20 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { field.DataType = String field.Serializer = v - } else if name, ok := field.TagSettings["SERIALIZER"]; ok { - field.DataType = String - if strings.ToLower(name) == "json" { - field.Serializer = JSONSerializer{} - } else { - schema.err = fmt.Errorf("invalid serializer type %v", name) + } else { + var serializerName = field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } + if serializerName != "" { + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String + field.Serializer = serializer + } else { + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } } - } else if _, ok := field.TagSettings["JSON"]; ok { - field.DataType = String - field.Serializer = JSONSerializer{} } if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { @@ -430,9 +434,9 @@ func (field *Field) setupValuerAndSetter() { if field.Serializer != nil { field.NewValuePool = &sync.Pool{ New: func() interface{} { - return &Serializer{ - Field: field, - Interface: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), } }, } @@ -489,17 +493,17 @@ func (field *Field) setupValuerAndSetter() { return value, zero } - serializer, ok := value.(SerializerValuerInterface) + s, ok := value.(SerializerValuerInterface) if !ok { - serializer = field.Serializer + s = field.Serializer } - return Serializer{ - Field: field, - Valuer: serializer, - Destination: v, - Context: ctx, - fieldValue: value, + return serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, }, false } } @@ -583,19 +587,6 @@ func (field *Field) setupValuerAndSetter() { } // Set - if field.Serializer != nil { - field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { - if serializer, ok := v.(*Serializer); ok { - serializer.Interface.Scan(ctx, field, value, serializer.value) - fallbackSetter(ctx, value, serializer.Interface, field.Set) - serializer.Interface = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) - } - return nil - } - - return - } - switch field.FieldType.Kind() { case reflect.Bool: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { @@ -917,4 +908,33 @@ func (field *Field) setupValuerAndSetter() { } } } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } + } + } else { + err = oldFieldSetter(ctx, value, v) + } + return + } + } } diff --git a/schema/interfaces.go b/schema/interfaces.go index 2771660a..a75a33c0 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -1,13 +1,6 @@ package schema import ( - "context" - "database/sql/driver" - "encoding/json" - "errors" - "fmt" - "reflect" - "gorm.io/gorm/clause" ) @@ -22,71 +15,6 @@ type FieldNewValuePool interface { Put(interface{}) } -// Serializer field value serializer -type Serializer struct { - Field *Field - Interface SerializerInterface - Valuer SerializerValuerInterface - Destination reflect.Value - Context context.Context - value interface{} - fieldValue interface{} -} - -// Scan implements sql.Scanner interface -func (s *Serializer) Scan(value interface{}) error { - s.value = value - return nil -} - -// Value implements driver.Valuer interface -func (s Serializer) Value() (driver.Value, error) { - return s.Valuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) -} - -// SerializerInterface serializer interface -type SerializerInterface interface { - Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error - SerializerValuerInterface -} - -// SerializerValuerInterface serializer valuer interface -type SerializerValuerInterface interface { - Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) -} - -// JSONSerializer json serializer -type JSONSerializer struct { -} - -// Scan implements serializer interface -func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { - fieldValue := reflect.New(field.FieldType) - - if dbValue != nil { - var bytes []byte - switch v := dbValue.(type) { - case []byte: - bytes = v - case string: - bytes = []byte(v) - default: - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) - } - - err = json.Unmarshal(bytes, fieldValue.Interface()) - } - - field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) - return -} - -// Value implements serializer interface -func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { - result, err := json.Marshal(fieldValue) - return string(result), err -} - // CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface diff --git a/schema/serializer.go b/schema/serializer.go new file mode 100644 index 00000000..68597538 --- /dev/null +++ b/schema/serializer.go @@ -0,0 +1,125 @@ +package schema + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct { +} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + } + + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct { +} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil { + err = field.Set(ctx, dst, t.Time) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.ValueOf(v).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 4d465b8c..3ed733d9 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "testing" + "time" "gorm.io/gorm" "gorm.io/gorm/schema" @@ -15,8 +16,10 @@ import ( type SerializerStruct struct { gorm.Model - Name []byte `gorm:"json"` - Roles Roles `gorm:"serializer:json"` + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type EncryptedString EncryptedString } @@ -45,10 +48,14 @@ func TestSerializer(t *testing.T) { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + data := SerializerStruct{ Name: []byte("jinzhu"), Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), } if err := DB.Create(&data).Error; err != nil {