From 3c77eb0bb08ac48394d2b99752a77e9fe607a675 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Feb 2022 20:11:37 +0800 Subject: [PATCH] Add Serializer Interface --- schema/field.go | 58 ++++++++++++++++-------------------- schema/field_test.go | 13 ++++---- schema/interfaces.go | 23 ++++++++++++-- schema/schema_helper_test.go | 3 +- utils/utils.go | 17 +++++------ 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/schema/field.go b/schema/field.go index f060bc46..23de0405 100644 --- a/schema/field.go +++ b/schema/field.go @@ -15,12 +15,17 @@ import ( "gorm.io/gorm/utils" ) -type DataType string - -type TimeType int64 +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) +// TimeReflectType time's reflect type var TimeReflectType = reflect.TypeOf(time.Time{}) +// GORM time types const ( UnixTime TimeType = 1 UnixSecond TimeType = 2 @@ -28,6 +33,7 @@ const ( UnixNanosecond TimeType = 4 ) +// GORM fields types const ( Bool DataType = "bool" Int DataType = "int" @@ -38,6 +44,7 @@ const ( Bytes DataType = "bytes" ) +// Field is the representation of model schema's field type Field struct { Name string DBName string @@ -50,9 +57,9 @@ type Field struct { Creatable bool Updatable bool Readable bool - HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType + HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool @@ -61,6 +68,7 @@ type Field struct { Size int Precision int Scale int + IgnoreMigration bool FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -72,24 +80,32 @@ type Field struct { ReflectValueOf func(context.Context, reflect.Value) reflect.Value ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) Set func(context.Context, reflect.Value, interface{}) error - IgnoreMigration bool } +// ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { - var err error + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") + ) field := &Field{ Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, Creatable: true, Updatable: true, Readable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), - Schema: schema, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], AutoIncrementIncrement: 1, } @@ -139,16 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if dbName, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = dbName - } - - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { field.AutoIncrement = true field.HasDefaultValue = true @@ -177,20 +183,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Scale, _ = strconv.Atoi(s) } - if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } - - if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { - field.Unique = true - } - - if val, ok := field.TagSettings["COMMENT"]; ok { - field.Comment = val - } - // default value is function or null or blank (primary keys) field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && diff --git a/schema/field_test.go b/schema/field_test.go index 8fa46b87..300e375b 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "database/sql" "reflect" "sync" @@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } diff --git a/schema/interfaces.go b/schema/interfaces.go index 456867e8..ff9b0842 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -2,6 +2,7 @@ package schema import ( "context" + "database/sql/driver" "reflect" "gorm.io/gorm/clause" @@ -12,8 +13,26 @@ type GormDataTypeInterface interface { GormDataType() string } -// Serializer serializer interface -type Serializer interface { +// Serializer field value serializer +type Serializer struct { + Field *Field + Interface SerializerInterface + Destination reflect.Value + Context context.Context +} + +// Scan implements sql.Scanner interface +func (s *Serializer) Scan(value interface{}) error { + return s.Interface.Scan(s.Context, s.Field, s.Destination, value) +} + +// Value implements driver.Valuer interface +func (s Serializer) Value() (driver.Value, error) { + return s.Interface.Value(s.Context, s.Field, s.Destination) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error Value(ctx context.Context, field *Field, dst reflect.Value) (interface{}, error) } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 6d2bc664..9abaecba 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "fmt" "reflect" "strings" @@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - fv, _ := s.FieldsByDBName[k].ValueOf(value) + fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) tests.AssertEqual(t, v, fv) }) } diff --git a/utils/utils.go b/utils/utils.go index f00f92ba..28ca0daf 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } -func CheckTruth(val interface{}) bool { - if v, ok := val.(bool); ok { - return v +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if !strings.EqualFold(val, "false") && val != "" { + return true + } } - - if v, ok := val.(string); ok { - v = strings.ToLower(v) - return v != "false" - } - - return !reflect.ValueOf(val).IsZero() + return false } func ToStringKey(values ...interface{}) string {