From a2d8f1b2c8ff1391ecf9876e2f7c27fd477932cc Mon Sep 17 00:00:00 2001 From: Leonid Bugaev Date: Thu, 12 Nov 2020 21:30:01 +0300 Subject: [PATCH] Update --- finisher_api.go | 2 +- gorm.go | 7 ++++ migrator/migrator.go | 4 ++ scan.go | 24 +++++++++-- schema/field.go | 95 +++++++++++++++++++++++++++++++++++++++++- schema/relationship.go | 4 +- schema/schema.go | 21 +++++++++- statement.go | 4 +- 8 files changed, 149 insertions(+), 12 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 857f9419..01ec16ec 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -162,7 +162,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { exprs := tx.Statement.BuildCondition(value) tx.assignInterfacesToValue(exprs) default: - if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy, tx.AutoEmbedd, tx.UseJSONTags); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Struct: diff --git a/gorm.go b/gorm.go index affa8e69..a73d4979 100644 --- a/gorm.go +++ b/gorm.go @@ -37,6 +37,13 @@ type Config struct { // AllowGlobalUpdate allow global update AllowGlobalUpdate bool + // Automatically embed structs + AutoEmbedd bool + + UseJSONTags bool + + UnknownToJSON bool + // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder // ConnPool db conn pool diff --git a/migrator/migrator.go b/migrator/migrator.go index 016ebfc7..a4ab3127 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -52,6 +52,10 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { } } + if field.DataType == "raw_json" { + return "string" + } + return m.Dialector.DataTypeOf(field) } diff --git a/scan.go b/scan.go index 8d737b17..453db2d6 100644 --- a/scan.go +++ b/scan.go @@ -49,6 +49,10 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } +type Stringer interface { + String() string +} + func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) @@ -110,7 +114,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if Schema != nil { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.AutoEmbedd, db.UseJSONTags) } for idx, column := range columns { @@ -155,7 +159,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } else { for idx, field := range fields { if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + if field.DataType == "raw_json" { + values[idx] = &sql.RawBytes{} + } else { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } } } @@ -188,13 +196,21 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } case reflect.Struct: if db.Statement.ReflectValue.Type() != Schema.ModelType { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.AutoEmbedd, db.UseJSONTags) } if initialized || rows.Next() { for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + if field.DataType == "raw_json" { + values[idx] = &sql.RawBytes{} + } else { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + // if field.DBName == "_id" { + // var str = "" + // values[idx] = &str + // } + } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { diff --git a/schema/field.go b/schema/field.go index b303fb30..86021ed9 100644 --- a/schema/field.go +++ b/schema/field.go @@ -3,6 +3,7 @@ package schema import ( "database/sql" "database/sql/driver" + "encoding/json" "fmt" "reflect" "strconv" @@ -34,6 +35,7 @@ const ( String DataType = "string" Time DataType = "time" Bytes DataType = "bytes" + JSON DataType = "json" ) type Field struct { @@ -63,6 +65,7 @@ type Field struct { StructField reflect.StructField Tag reflect.StructTag TagSettings map[string]string + JSONTagSettings map[string]string Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema @@ -85,6 +88,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Readable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), + JSONTagSettings: ParseTagSetting(fieldStruct.Tag.Get("json"), ","), Schema: schema, } @@ -126,6 +130,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.TagSettings[key] = value } } + + ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") } } } @@ -136,6 +142,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if dbName, ok := field.TagSettings["COLUMN"]; ok { field.DBName = dbName + } else { + if schema.UseJSONTags { + jsonTag := strings.Split(fieldStruct.Tag.Get("json"), ",")[0] + if jsonTag != "omniempty" && jsonTag != "-" { + field.DBName = jsonTag + } + } } if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { @@ -233,7 +246,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Array, reflect.Slice: if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { field.DataType = Bytes + } else { + field.DataType = "raw_json" } + default: + field.DataType = "raw_json" } field.GORMDataType = field.DataType @@ -321,6 +338,18 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + if schema.AutoEmbedd { + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Elem().Interface().(type) { + // Apparently time.Time also a struct...Skip. + case time.Time, *time.Time: + default: + field.TagSettings["EMBEDDED"] = "true" + } + } + } + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error @@ -330,7 +359,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.AutoEmbedd, schema.UseJSONTags); err != nil { schema.err = err } @@ -345,6 +374,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } + if _, ok := field.TagSettings["EMBEDDEDPREFIX"]; !ok { + if schema.UseJSONTags { + if _, ok := field.JSONTagSettings["INLINE"]; !ok { + field.TagSettings["EMBEDDEDPREFIX"] = field.DBName + "_" + } + } + } + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { ef.DBName = prefix + ef.DBName } @@ -386,11 +423,23 @@ func (field *Field) setupValuerAndSetter() { case len(field.StructField.Index) == 1: field.ValueOf = func(value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) + + if field.DataType == "raw_json" { + bytes, _ := json.Marshal(fieldValue.Interface()) + return bytes, false + } + return fieldValue.Interface(), fieldValue.IsZero() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: field.ValueOf = func(value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + + if field.DataType == "raw_json" { + bytes, _ := json.Marshal(fieldValue.Interface()) + return bytes, false + } + return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -414,6 +463,12 @@ func (field *Field) setupValuerAndSetter() { } } } + + if field.DataType == "raw_json" { + bytes, _ := json.Marshal(v.Interface()) + return bytes, false + } + return v.Interface(), v.IsZero() } } @@ -699,6 +754,17 @@ func (field *Field) setupValuerAndSetter() { case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } + return fallbackSetter(value, v, field.Set) } return err @@ -759,6 +825,10 @@ func (field *Field) setupValuerAndSetter() { } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // if field.DBName == "_id" { + // panic(values[idx].(**string)) + // } + // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) @@ -811,6 +881,29 @@ func (field *Field) setupValuerAndSetter() { } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { + if field.DataType == "raw_json" { + var bytes []byte + + switch t := v.(type) { + case []uint8: + bytes = []byte(t) + case *sql.RawBytes: + bytes = []byte(*t) + default: + panic(v) + } + + valueV := field.ReflectValueOf(value) + + if valueV.Kind() == reflect.Ptr { + err = json.Unmarshal(bytes, field.ReflectValueOf(value).Interface()) + } else { + err = json.Unmarshal(bytes, field.ReflectValueOf(value).Addr().Interface()) + } + + return + } + return fallbackSetter(value, v, field.Set) } } diff --git a/schema/relationship.go b/schema/relationship.go index 35af111f..275e86a9 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -71,7 +71,7 @@ func (schema *Schema) parseRelation(field *Field) { cacheStore = field.OwnerSchema.cacheStore } - if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer, schema.AutoEmbedd, schema.UseJSONTags); err != nil { schema.err = err return } @@ -263,7 +263,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer, schema.AutoEmbedd, schema.UseJSONTags); err != nil { schema.err = err } relation.JoinTable.Name = many2many diff --git a/schema/schema.go b/schema/schema.go index cffc19a7..1547ccca 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -28,6 +28,8 @@ type Schema struct { FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships + AutoEmbedd bool + UseJSONTags bool CreateClauses []clause.Interface QueryClauses []clause.Interface UpdateClauses []clause.Interface @@ -71,7 +73,7 @@ type Tabler interface { } // get data type from dialector -func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer, AutoEmbedd bool, UseJSONTags bool) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -110,6 +112,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, + AutoEmbedd: AutoEmbedd, + UseJSONTags: UseJSONTags, } defer func() { @@ -129,7 +133,19 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + // if modelType.Field(0).Name == "Enabled" { + // if _, ok := namer.(embeddedNamer); ok { + // fmt.Printf("%#v\n", schema.Fields) + // } + // } + for _, field := range schema.Fields { + // if modelType.Field(0).Name == "Enabled" { + // if _, ok := namer.(embeddedNamer); ok { + // fmt.Printf("%#v\n", field) + // } + // } + if field.DBName == "" && field.DataType != "" { field.DBName = namer.ColumnName(schema.Table, field.Name) } @@ -224,7 +240,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + schema.err = nil + // return schema, schema.err } } diff --git a/statement.go b/statement.go index 82ebdd91..dfaf71a1 100644 --- a/statement.go +++ b/statement.go @@ -317,7 +317,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) - if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.AutoEmbedd, stmt.DB.UseJSONTags); err == nil { switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { @@ -391,7 +391,7 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.AutoEmbedd, stmt.DB.UseJSONTags); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1]