Add Serializer Interface

This commit is contained in:
Jinzhu 2022-02-10 20:11:37 +08:00
parent 1b6cc25e19
commit 3c77eb0bb0
5 changed files with 62 additions and 52 deletions

View File

@ -15,12 +15,17 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
type DataType string type (
// DataType GORM data type
type TimeType int64 DataType string
// TimeType GORM time type
TimeType int64
)
// TimeReflectType time's reflect type
var TimeReflectType = reflect.TypeOf(time.Time{}) var TimeReflectType = reflect.TypeOf(time.Time{})
// GORM time types
const ( const (
UnixTime TimeType = 1 UnixTime TimeType = 1
UnixSecond TimeType = 2 UnixSecond TimeType = 2
@ -28,6 +33,7 @@ const (
UnixNanosecond TimeType = 4 UnixNanosecond TimeType = 4
) )
// GORM fields types
const ( const (
Bool DataType = "bool" Bool DataType = "bool"
Int DataType = "int" Int DataType = "int"
@ -38,6 +44,7 @@ const (
Bytes DataType = "bytes" Bytes DataType = "bytes"
) )
// Field is the representation of model schema's field
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
@ -50,9 +57,9 @@ type Field struct {
Creatable bool Creatable bool
Updatable bool Updatable bool
Readable bool Readable bool
HasDefaultValue bool
AutoCreateTime TimeType AutoCreateTime TimeType
AutoUpdateTime TimeType AutoUpdateTime TimeType
HasDefaultValue bool
DefaultValue string DefaultValue string
DefaultValueInterface interface{} DefaultValueInterface interface{}
NotNull bool NotNull bool
@ -61,6 +68,7 @@ type Field struct {
Size int Size int
Precision int Precision int
Scale int Scale int
IgnoreMigration bool
FieldType reflect.Type FieldType reflect.Type
IndirectFieldType reflect.Type IndirectFieldType reflect.Type
StructField reflect.StructField StructField reflect.StructField
@ -72,24 +80,32 @@ type Field struct {
ReflectValueOf func(context.Context, reflect.Value) reflect.Value ReflectValueOf func(context.Context, reflect.Value) reflect.Value
ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
Set func(context.Context, reflect.Value, interface{}) error Set func(context.Context, reflect.Value, interface{}) error
IgnoreMigration bool
} }
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error var (
err error
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{ field := &Field{
Name: fieldStruct.Name, Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type, FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct, StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true, Creatable: true,
Updatable: true, Updatable: true,
Readable: true, Readable: true,
Tag: fieldStruct.Tag, PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Schema: schema, Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1, AutoIncrementIncrement: 1,
} }
@ -139,16 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if dbName, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = dbName
}
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
field.PrimaryKey = true
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
field.PrimaryKey = true
}
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
field.AutoIncrement = true field.AutoIncrement = true
field.HasDefaultValue = true field.HasDefaultValue = true
@ -177,20 +183,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Scale, _ = strconv.Atoi(s) field.Scale, _ = strconv.Atoi(s)
} }
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
} else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
}
if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
field.Unique = true
}
if val, ok := field.TagSettings["COMMENT"]; ok {
field.Comment = val
}
// default value is function or null or blank (primary keys) // default value is function or null or blank (primary keys)
field.DefaultValue = strings.TrimSpace(field.DefaultValue) field.DefaultValue = strings.TrimSpace(field.DefaultValue)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"sync" "sync"
@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }

View File

@ -2,6 +2,7 @@ package schema
import ( import (
"context" "context"
"database/sql/driver"
"reflect" "reflect"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -12,8 +13,26 @@ type GormDataTypeInterface interface {
GormDataType() string GormDataType() string
} }
// Serializer serializer interface // Serializer field value serializer
type Serializer interface { type Serializer struct {
Field *Field
Interface SerializerInterface
Destination reflect.Value
Context context.Context
}
// Scan implements sql.Scanner interface
func (s *Serializer) Scan(value interface{}) error {
return s.Interface.Scan(s.Context, s.Field, s.Destination, value)
}
// Value implements driver.Valuer interface
func (s Serializer) Value() (driver.Value, error) {
return s.Interface.Value(s.Context, s.Field, s.Destination)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
Value(ctx context.Context, field *Field, dst reflect.Value) (interface{}, error) Value(ctx context.Context, field *Field, dst reflect.Value) (interface{}, error)
} }

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values { for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) { t.Run("CheckField/"+k, func(t *testing.T) {
fv, _ := s.FieldsByDBName[k].ValueOf(value) fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value)
tests.AssertEqual(t, v, fv) tests.AssertEqual(t, v, fv)
}) })
} }

View File

@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
} }
func CheckTruth(val interface{}) bool { // CheckTruth check string true or not
if v, ok := val.(bool); ok { func CheckTruth(vals ...string) bool {
return v for _, val := range vals {
if !strings.EqualFold(val, "false") && val != "" {
return true
} }
if v, ok := val.(string); ok {
v = strings.ToLower(v)
return v != "false"
} }
return false
return !reflect.ValueOf(val).IsZero()
} }
func ToStringKey(values ...interface{}) string { func ToStringKey(values ...interface{}) string {