This commit is contained in:
Leonid Bugaev 2020-11-12 21:30:01 +03:00
parent 85e9f66d26
commit a2d8f1b2c8
8 changed files with 149 additions and 12 deletions

View File

@ -162,7 +162,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
exprs := tx.Statement.BuildCondition(value)
tx.assignInterfacesToValue(exprs)
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy, tx.AutoEmbedd, tx.UseJSONTags); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:

View File

@ -37,6 +37,13 @@ type Config struct {
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// Automatically embed structs
AutoEmbedd bool
UseJSONTags bool
UnknownToJSON bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool

View File

@ -52,6 +52,10 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
}
}
if field.DataType == "raw_json" {
return "string"
}
return m.Dialector.DataTypeOf(field)
}

24
scan.go
View File

@ -49,6 +49,10 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
}
}
type Stringer interface {
String() string
}
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
@ -110,7 +114,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.AutoEmbedd, db.UseJSONTags)
}
for idx, column := range columns {
@ -155,7 +159,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
if field.DataType == "raw_json" {
values[idx] = &sql.RawBytes{}
} else {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
}
}
@ -188,13 +196,21 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
case reflect.Struct:
if db.Statement.ReflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.AutoEmbedd, db.UseJSONTags)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
if field.DataType == "raw_json" {
values[idx] = &sql.RawBytes{}
} else {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
// if field.DBName == "_id" {
// var str = ""
// values[idx] = &str
// }
}
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {

View File

@ -3,6 +3,7 @@ package schema
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strconv"
@ -34,6 +35,7 @@ const (
String DataType = "string"
Time DataType = "time"
Bytes DataType = "bytes"
JSON DataType = "json"
)
type Field struct {
@ -63,6 +65,7 @@ type Field struct {
StructField reflect.StructField
Tag reflect.StructTag
TagSettings map[string]string
JSONTagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
@ -85,6 +88,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Readable: true,
Tag: fieldStruct.Tag,
TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"),
JSONTagSettings: ParseTagSetting(fieldStruct.Tag.Get("json"), ","),
Schema: schema,
}
@ -126,6 +130,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.TagSettings[key] = value
}
}
ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";")
}
}
}
@ -136,6 +142,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
if dbName, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = dbName
} else {
if schema.UseJSONTags {
jsonTag := strings.Split(fieldStruct.Tag.Get("json"), ",")[0]
if jsonTag != "omniempty" && jsonTag != "-" {
field.DBName = jsonTag
}
}
}
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
@ -233,7 +246,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) {
field.DataType = Bytes
} else {
field.DataType = "raw_json"
}
default:
field.DataType = "raw_json"
}
field.GORMDataType = field.DataType
@ -321,6 +338,18 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if reflect.Indirect(fieldValue).Kind() == reflect.Struct {
if schema.AutoEmbedd {
fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) {
// Apparently time.Time also a struct...Skip.
case time.Time, *time.Time:
default:
field.TagSettings["EMBEDDED"] = "true"
}
}
}
if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) {
if reflect.Indirect(fieldValue).Kind() == reflect.Struct {
var err error
@ -330,7 +359,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.AutoEmbedd, schema.UseJSONTags); err != nil {
schema.err = err
}
@ -345,6 +374,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
}
if _, ok := field.TagSettings["EMBEDDEDPREFIX"]; !ok {
if schema.UseJSONTags {
if _, ok := field.JSONTagSettings["INLINE"]; !ok {
field.TagSettings["EMBEDDEDPREFIX"] = field.DBName + "_"
}
}
}
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
ef.DBName = prefix + ef.DBName
}
@ -386,11 +423,23 @@ func (field *Field) setupValuerAndSetter() {
case len(field.StructField.Index) == 1:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if field.DataType == "raw_json" {
bytes, _ := json.Marshal(fieldValue.Interface())
return bytes, false
}
return fieldValue.Interface(), fieldValue.IsZero()
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
if field.DataType == "raw_json" {
bytes, _ := json.Marshal(fieldValue.Interface())
return bytes, false
}
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
@ -414,6 +463,12 @@ func (field *Field) setupValuerAndSetter() {
}
}
}
if field.DataType == "raw_json" {
bytes, _ := json.Marshal(v.Interface())
return bytes, false
}
return v.Interface(), v.IsZero()
}
}
@ -699,6 +754,17 @@ func (field *Field) setupValuerAndSetter() {
case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
reflectV := reflect.ValueOf(v)
if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
}
return fallbackSetter(value, v, field.Set)
}
return err
@ -759,6 +825,10 @@ func (field *Field) setupValuerAndSetter() {
}
default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// if field.DBName == "_id" {
// panic(values[idx].(**string))
// }
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
@ -811,6 +881,29 @@ func (field *Field) setupValuerAndSetter() {
}
} else {
field.Set = func(value reflect.Value, v interface{}) (err error) {
if field.DataType == "raw_json" {
var bytes []byte
switch t := v.(type) {
case []uint8:
bytes = []byte(t)
case *sql.RawBytes:
bytes = []byte(*t)
default:
panic(v)
}
valueV := field.ReflectValueOf(value)
if valueV.Kind() == reflect.Ptr {
err = json.Unmarshal(bytes, field.ReflectValueOf(value).Interface())
} else {
err = json.Unmarshal(bytes, field.ReflectValueOf(value).Addr().Interface())
}
return
}
return fallbackSetter(value, v, field.Set)
}
}

View File

@ -71,7 +71,7 @@ func (schema *Schema) parseRelation(field *Field) {
cacheStore = field.OwnerSchema.cacheStore
}
if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil {
if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer, schema.AutoEmbedd, schema.UseJSONTags); err != nil {
schema.err = err
return
}
@ -263,7 +263,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer, schema.AutoEmbedd, schema.UseJSONTags); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many

View File

@ -28,6 +28,8 @@ type Schema struct {
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
AutoEmbedd bool
UseJSONTags bool
CreateClauses []clause.Interface
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
@ -71,7 +73,7 @@ type Tabler interface {
}
// get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer, AutoEmbedd bool, UseJSONTags bool) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
@ -110,6 +112,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
AutoEmbedd: AutoEmbedd,
UseJSONTags: UseJSONTags,
}
defer func() {
@ -129,7 +133,19 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
}
// if modelType.Field(0).Name == "Enabled" {
// if _, ok := namer.(embeddedNamer); ok {
// fmt.Printf("%#v\n", schema.Fields)
// }
// }
for _, field := range schema.Fields {
// if modelType.Field(0).Name == "Enabled" {
// if _, ok := namer.(embeddedNamer); ok {
// fmt.Printf("%#v\n", field)
// }
// }
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
@ -224,7 +240,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
schema.err = nil
// return schema, schema.err
}
}

View File

@ -317,7 +317,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
}
default:
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.AutoEmbedd, stmt.DB.UseJSONTags); err == nil {
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
@ -391,7 +391,7 @@ func (stmt *Statement) Build(clauses ...string) {
}
func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.AutoEmbedd, stmt.DB.UseJSONTags); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]