diff --git a/schema/field.go b/schema/field.go index 23de0405..773e5414 100644 --- a/schema/field.go +++ b/schema/field.go @@ -15,6 +15,13 @@ import ( "gorm.io/gorm/utils" ) +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) + type ( // DataType GORM data type DataType string @@ -22,9 +29,6 @@ type ( TimeType int64 ) -// TimeReflectType time's reflect type -var TimeReflectType = reflect.TypeOf(time.Time{}) - // GORM time types const ( UnixTime TimeType = 1 @@ -103,6 +107,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Updatable: true, Readable: true, PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Comment: tagSetting["COMMENT"], @@ -114,7 +120,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - // if field is valuer, used its value or first fields as data type + // if field is valuer, used its value or first field as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { @@ -122,31 +128,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { - rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { - for i := 0; i < rv.Type().NumField(); i++ { - newFieldType := rv.Type().Field(i).Type + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - - if rv.Type() != reflect.Indirect(fieldValue).Type() { + if rvType != reflect.Indirect(fieldValue).Type() { getRealFieldValue(fieldValue) } if fieldValue.IsValid() { return } - - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value - } - } } } } @@ -155,11 +167,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { - field.AutoIncrement = true - field.HasDefaultValue = true - } - if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) } @@ -218,7 +225,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } case reflect.String: field.DataType = String - if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, `"`) @@ -229,17 +235,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } case reflect.Array, reflect.Slice: - if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType { field.DataType = Bytes } } - field.GORMDataType = field.DataType - if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } @@ -339,8 +343,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && - (ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { + // Normal anonymous field or having `EMBEDDED` tag + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { kind := reflect.Indirect(fieldValue).Kind() switch kind { case reflect.Struct: diff --git a/tests/go.mod b/tests/go.mod index 3453f77b..f414a2dd 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect + golang.org/x/crypto v0.0.0-20220213190939-1e6e3497d506 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 @@ -18,3 +18,5 @@ require ( ) replace gorm.io/gorm => ../ + +replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/gorm/sqlserver