Refactor gorm field

This commit is contained in:
Jinzhu 2022-02-14 15:51:32 +08:00
parent 3c77eb0bb0
commit 69b851ed95
2 changed files with 36 additions and 29 deletions

View File

@ -15,6 +15,13 @@ import (
"gorm.io/gorm/utils"
)
// special types' reflect type
var (
TimeReflectType = reflect.TypeOf(time.Time{})
TimePtrReflectType = reflect.TypeOf(&time.Time{})
ByteReflectType = reflect.TypeOf(uint8(0))
)
type (
// DataType GORM data type
DataType string
@ -22,9 +29,6 @@ type (
TimeType int64
)
// TimeReflectType time's reflect type
var TimeReflectType = reflect.TypeOf(time.Time{})
// GORM time types
const (
UnixTime TimeType = 1
@ -103,6 +107,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Updatable: true,
Readable: true,
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
@ -114,7 +120,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type
// if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
@ -122,31 +128,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue = reflect.ValueOf(v)
}
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v)
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type
var (
rv = reflect.Indirect(v)
rvType = rv.Type()
)
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem()
}
fieldValue = reflect.New(newFieldType)
if rv.Type() != reflect.Indirect(fieldValue).Type() {
if rvType != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
}
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
}
}
@ -155,11 +167,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
field.AutoIncrement = true
field.HasDefaultValue = true
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
}
@ -218,7 +225,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
case reflect.String:
field.DataType = String
if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
@ -229,17 +235,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
} else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
field.DataType = Time
}
case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) {
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType {
field.DataType = Bytes
}
}
field.GORMDataType = field.DataType
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType())
}
@ -339,8 +343,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes &&
(ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) {
// Normal anonymous field or having `EMBEDDED` tag
if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
kind := reflect.Indirect(fieldValue).Kind()
switch kind {
case reflect.Struct:

View File

@ -9,7 +9,7 @@ require (
github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.11 // indirect
golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect
golang.org/x/crypto v0.0.0-20220213190939-1e6e3497d506 // indirect
gorm.io/driver/mysql v1.2.3
gorm.io/driver/postgres v1.2.3
gorm.io/driver/sqlite v1.2.6
@ -18,3 +18,5 @@ require (
)
replace gorm.io/gorm => ../
replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/gorm/sqlserver