Refactor gorm field

This commit is contained in:
Jinzhu 2022-02-14 15:51:32 +08:00
parent 3c77eb0bb0
commit 69b851ed95
2 changed files with 36 additions and 29 deletions

View File

@ -15,6 +15,13 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// special types' reflect type
var (
TimeReflectType = reflect.TypeOf(time.Time{})
TimePtrReflectType = reflect.TypeOf(&time.Time{})
ByteReflectType = reflect.TypeOf(uint8(0))
)
type ( type (
// DataType GORM data type // DataType GORM data type
DataType string DataType string
@ -22,9 +29,6 @@ type (
TimeType int64 TimeType int64
) )
// TimeReflectType time's reflect type
var TimeReflectType = reflect.TypeOf(time.Time{})
// GORM time types // GORM time types
const ( const (
UnixTime TimeType = 1 UnixTime TimeType = 1
@ -103,6 +107,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Updatable: true, Updatable: true,
Readable: true, Readable: true,
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"], Comment: tagSetting["COMMENT"],
@ -114,7 +120,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer) valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer { if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
@ -122,31 +128,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue = reflect.ValueOf(v) fieldValue = reflect.ValueOf(v)
} }
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value) var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) { getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v) var (
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { rv = reflect.Indirect(v)
for i := 0; i < rv.Type().NumField(); i++ { rvType = rv.Type()
newFieldType := rv.Type().Field(i).Type )
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem() newFieldType = newFieldType.Elem()
} }
fieldValue = reflect.New(newFieldType) fieldValue = reflect.New(newFieldType)
if rvType != reflect.Indirect(fieldValue).Type() {
if rv.Type() != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue) getRealFieldValue(fieldValue)
} }
if fieldValue.IsValid() { if fieldValue.IsValid() {
return return
} }
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
} }
} }
} }
@ -155,11 +167,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
field.AutoIncrement = true
field.HasDefaultValue = true
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
} }
@ -218,7 +225,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
case reflect.String: case reflect.String:
field.DataType = String field.DataType = String
if field.HasDefaultValue && !skipParseDefaultValue { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, `"`) field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
@ -229,17 +235,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) { } else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
field.DataType = Time field.DataType = Time
} }
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType {
field.DataType = Bytes field.DataType = Bytes
} }
} }
field.GORMDataType = field.DataType
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType()) field.DataType = DataType(dataTyper.GormDataType())
} }
@ -339,8 +343,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && // Normal anonymous field or having `EMBEDDED` tag
(ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
kind := reflect.Indirect(fieldValue).Kind() kind := reflect.Indirect(fieldValue).Kind()
switch kind { switch kind {
case reflect.Struct: case reflect.Struct:

View File

@ -9,7 +9,7 @@ require (
github.com/jinzhu/now v1.1.4 github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4 github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.11 // indirect github.com/mattn/go-sqlite3 v1.14.11 // indirect
golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect golang.org/x/crypto v0.0.0-20220213190939-1e6e3497d506 // indirect
gorm.io/driver/mysql v1.2.3 gorm.io/driver/mysql v1.2.3
gorm.io/driver/postgres v1.2.3 gorm.io/driver/postgres v1.2.3
gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlite v1.2.6
@ -18,3 +18,5 @@ require (
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../
replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/gorm/sqlserver