feat: support embedded preload (#6137)
* feat: support embedded preload * fix lint and test * fix test...
This commit is contained in:
parent
4b0da0e97a
commit
828e22b17f
@ -3,6 +3,7 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -10,6 +11,98 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// parsePreloadMap extracts nested preloads. e.g.
|
||||||
|
//
|
||||||
|
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||||
|
// parsePreloadMap(schema, map[string][]interface{}{
|
||||||
|
// clause.Associations: {"arg1"},
|
||||||
|
// "k1": {"arg2"},
|
||||||
|
// "k2.k3": {"arg3"},
|
||||||
|
// "k4.k5.k6": {"arg4"},
|
||||||
|
// })
|
||||||
|
// // preloadMap is
|
||||||
|
// map[string]map[string][]interface{}{
|
||||||
|
// "k0": {},
|
||||||
|
// "k7": {
|
||||||
|
// "k8": {},
|
||||||
|
// },
|
||||||
|
// "k1": {},
|
||||||
|
// "k2": {
|
||||||
|
// "k3": {"arg3"},
|
||||||
|
// },
|
||||||
|
// "k4": {
|
||||||
|
// "k5.k6": {"arg4"},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||||
|
preloadMap := map[string]map[string][]interface{}{}
|
||||||
|
setPreloadMap := func(name, value string, args []interface{}) {
|
||||||
|
if _, ok := preloadMap[name]; !ok {
|
||||||
|
preloadMap[name] = map[string][]interface{}{}
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
preloadMap[name][value] = args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, args := range preloads {
|
||||||
|
preloadFields := strings.Split(name, ".")
|
||||||
|
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||||
|
if preloadFields[0] == clause.Associations {
|
||||||
|
for _, relation := range s.Relationships.Relations {
|
||||||
|
if relation.Schema == s {
|
||||||
|
setPreloadMap(relation.Name, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||||
|
for _, value := range embeddedValues(embeddedRelations) {
|
||||||
|
setPreloadMap(embedded, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setPreloadMap(preloadFields[0], value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return preloadMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||||
|
if embeddedRelations == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||||
|
for _, relation := range embeddedRelations.Relations {
|
||||||
|
// skip first struct name
|
||||||
|
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||||
|
}
|
||||||
|
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||||
|
names = append(names, embeddedValues(relations)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
|
||||||
|
if relationships == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
preloadMap := parsePreloadMap(s, preloads)
|
||||||
|
for name := range preloadMap {
|
||||||
|
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
|
||||||
|
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if rel := relationships.Relations[name]; rel != nil {
|
||||||
|
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||||
var (
|
var (
|
||||||
reflectValue = tx.Statement.ReflectValue
|
reflectValue = tx.Statement.ReflectValue
|
||||||
|
@ -267,32 +267,7 @@ func Preload(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadMap := map[string]map[string][]interface{}{}
|
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
|
||||||
for name := range db.Statement.Preloads {
|
|
||||||
preloadFields := strings.Split(name, ".")
|
|
||||||
if preloadFields[0] == clause.Associations {
|
|
||||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
|
||||||
if rel.Schema == db.Statement.Schema {
|
|
||||||
if _, ok := preloadMap[rel.Name]; !ok {
|
|
||||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
|
||||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
|
||||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
|
||||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadNames := make([]string, 0, len(preloadMap))
|
preloadNames := make([]string, 0, len(preloadMap))
|
||||||
for key := range preloadMap {
|
for key := range preloadMap {
|
||||||
preloadNames = append(preloadNames, key)
|
preloadNames = append(preloadNames, key)
|
||||||
@ -312,7 +287,9 @@ func Preload(db *gorm.DB) {
|
|||||||
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
for _, name := range preloadNames {
|
||||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
|
||||||
|
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
|
||||||
|
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||||
} else {
|
} else {
|
||||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||||
|
@ -89,6 +89,10 @@ type Field struct {
|
|||||||
NewValuePool FieldNewValuePool
|
NewValuePool FieldNewValuePool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (field *Field) BindName() string {
|
||||||
|
return strings.Join(field.BindNames, ".")
|
||||||
|
}
|
||||||
|
|
||||||
// ParseField parses reflect.StructField to Field
|
// ParseField parses reflect.StructField to Field
|
||||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
var (
|
var (
|
||||||
|
@ -27,6 +27,8 @@ type Relationships struct {
|
|||||||
HasMany []*Relationship
|
HasMany []*Relationship
|
||||||
Many2Many []*Relationship
|
Many2Many []*Relationship
|
||||||
Relations map[string]*Relationship
|
Relations map[string]*Relationship
|
||||||
|
|
||||||
|
EmbeddedRelations map[string]*Relationships
|
||||||
}
|
}
|
||||||
|
|
||||||
type Relationship struct {
|
type Relationship struct {
|
||||||
@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if schema.err == nil {
|
if schema.err == nil {
|
||||||
schema.Relationships.Relations[relation.Name] = relation
|
schema.setRelation(relation)
|
||||||
switch relation.Type {
|
switch relation.Type {
|
||||||
case HasOne:
|
case HasOne:
|
||||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||||
@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
return relation
|
return relation
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (schema *Schema) setRelation(relation *Relationship) {
|
||||||
|
// set non-embedded relation
|
||||||
|
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||||
|
if len(rel.Field.BindNames) > 1 {
|
||||||
|
schema.Relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
schema.Relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
|
||||||
|
// set embedded relation
|
||||||
|
if len(relation.Field.BindNames) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
relationships := &schema.Relationships
|
||||||
|
for i, name := range relation.Field.BindNames {
|
||||||
|
if i < len(relation.Field.BindNames)-1 {
|
||||||
|
if relationships.EmbeddedRelations == nil {
|
||||||
|
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||||
|
}
|
||||||
|
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||||
|
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||||
|
}
|
||||||
|
relationships = relationships.EmbeddedRelations[name]
|
||||||
|
} else {
|
||||||
|
if relationships.Relations == nil {
|
||||||
|
relationships.Relations = map[string]*Relationship{}
|
||||||
|
}
|
||||||
|
relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||||
//
|
//
|
||||||
// type User struct {
|
// type User struct {
|
||||||
@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if primaryKeyField == nil {
|
||||||
|
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// use same data type for foreign keys
|
// use same data type for foreign keys
|
||||||
if copyableDataType(primaryKeyField.DataType) {
|
if copyableDataType(primaryKeyField.DataType) {
|
||||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||||
@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
primaryFields = primarySchema.PrimaryFields
|
primaryFields = primarySchema.PrimaryFields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
primaryFieldLoop:
|
||||||
for _, primaryField := range primaryFields {
|
for _, primaryField := range primaryFields {
|
||||||
lookUpName := primarySchemaName + primaryField.Name
|
lookUpName := primarySchemaName + primaryField.Name
|
||||||
if gl == guessBelongs {
|
if gl == guessBelongs {
|
||||||
@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, name := range lookUpNames {
|
||||||
|
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||||
|
foreignFields = append(foreignFields, f)
|
||||||
|
primaryFields = append(primaryFields, primaryField)
|
||||||
|
continue primaryFieldLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, name := range lookUpNames {
|
for _, name := range lookUpNames {
|
||||||
if f := foreignSchema.LookUpField(name); f != nil {
|
if f := foreignSchema.LookUpField(name); f != nil {
|
||||||
foreignFields = append(foreignFields, f)
|
foreignFields = append(foreignFields, f)
|
||||||
primaryFields = append(primaryFields, primaryField)
|
primaryFields = append(primaryFields, primaryField)
|
||||||
break
|
continue primaryFieldLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedHas(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
OwnerID int
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
Cat struct {
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
} `gorm:"embedded;embeddedPrefix:cat_"`
|
||||||
|
Dog struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
UserID int
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toy": {
|
||||||
|
Name: "Toy",
|
||||||
|
Type: schema.HasOne,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Toys": {
|
||||||
|
Name: "Toys",
|
||||||
|
Type: schema.HasMany,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||||
|
type Country struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type Address struct {
|
||||||
|
CountryID int
|
||||||
|
Country Country
|
||||||
|
}
|
||||||
|
type NestedAddress struct {
|
||||||
|
Address
|
||||||
|
}
|
||||||
|
type Org struct {
|
||||||
|
ID int
|
||||||
|
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||||
|
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||||
|
AddressID int
|
||||||
|
Address struct {
|
||||||
|
ID int
|
||||||
|
Address
|
||||||
|
}
|
||||||
|
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"PostalAddress": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"VisitingAddress": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"NestedAddress": {
|
||||||
|
EmbeddedRelations: map[string]EmbeddedRelations{
|
||||||
|
"Address": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestVariableRelation(t *testing.T) {
|
func TestVariableRelation(t *testing.T) {
|
||||||
var result struct {
|
var result struct {
|
||||||
User
|
User
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -25,6 +26,7 @@ type Schema struct {
|
|||||||
PrimaryFieldDBNames []string
|
PrimaryFieldDBNames []string
|
||||||
Fields []*Field
|
Fields []*Field
|
||||||
FieldsByName map[string]*Field
|
FieldsByName map[string]*Field
|
||||||
|
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||||
FieldsByDBName map[string]*Field
|
FieldsByDBName map[string]*Field
|
||||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||||
Relationships Relationships
|
Relationships Relationships
|
||||||
@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||||
|
//
|
||||||
|
// type Struct struct {
|
||||||
|
// Embedded struct {
|
||||||
|
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||||
|
// }
|
||||||
|
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||||
|
// }
|
||||||
|
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||||
|
if len(bindNames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||||
|
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||||
|
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Tabler interface {
|
type Tabler interface {
|
||||||
TableName() string
|
TableName() string
|
||||||
}
|
}
|
||||||
@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
Name: modelType.Name(),
|
Name: modelType.Name(),
|
||||||
ModelType: modelType,
|
ModelType: modelType,
|
||||||
Table: tableName,
|
Table: tableName,
|
||||||
FieldsByName: map[string]*Field{},
|
FieldsByName: map[string]*Field{},
|
||||||
FieldsByDBName: map[string]*Field{},
|
FieldsByBindName: map[string]*Field{},
|
||||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
FieldsByDBName: map[string]*Field{},
|
||||||
cacheStore: cacheStore,
|
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||||
namer: namer,
|
cacheStore: cacheStore,
|
||||||
initialized: make(chan struct{}),
|
namer: namer,
|
||||||
|
initialized: make(chan struct{}),
|
||||||
}
|
}
|
||||||
// When the schema initialization is completed, the channel will be closed
|
// When the schema initialization is completed, the channel will be closed
|
||||||
defer close(schema.initialized)
|
defer close(schema.initialized)
|
||||||
@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bindName := field.BindName()
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
// nonexistence or shortest path or first appear prioritized if has permission
|
// nonexistence or shortest path or first appear prioritized if has permission
|
||||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||||
@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
schema.FieldsByDBName[field.DBName] = field
|
schema.FieldsByDBName[field.DBName] = field
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
|
schema.FieldsByBindName[bindName] = field
|
||||||
|
|
||||||
if v != nil && v.PrimaryKey {
|
if v != nil && v.PrimaryKey {
|
||||||
for idx, f := range schema.PrimaryFields {
|
for idx, f := range schema.PrimaryFields {
|
||||||
@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
}
|
}
|
||||||
|
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||||
|
schema.FieldsByBindName[bindName] = field
|
||||||
|
}
|
||||||
|
|
||||||
field.setupValuerAndSetter()
|
field.setupValuerAndSetter()
|
||||||
}
|
}
|
||||||
@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
} else {
|
} else {
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
|
schema.FieldsByBindName[field.BindName()] = field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddedRelations struct {
|
||||||
|
Relations map[string]Relation
|
||||||
|
EmbeddedRelations map[string]EmbeddedRelations
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
|
||||||
|
for name, relations := range actual {
|
||||||
|
rs := expected[name]
|
||||||
|
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
|
||||||
|
if len(relations.Relations) != len(rs.Relations) {
|
||||||
|
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
|
||||||
|
}
|
||||||
|
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
|
||||||
|
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
|
||||||
|
}
|
||||||
|
for n, rel := range relations.Relations {
|
||||||
|
if r, ok := rs.Relations[n]; !ok {
|
||||||
|
t.Errorf("failed to find relation by name %s", n)
|
||||||
|
} else {
|
||||||
|
checkSchemaRelation(t, &schema.Schema{
|
||||||
|
Relationships: schema.Relationships{
|
||||||
|
Relations: map[string]*schema.Relationship{n: rel},
|
||||||
|
},
|
||||||
|
}, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||||
for k, v := range values {
|
for k, v := range values {
|
||||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||||
|
@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
|
|||||||
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
||||||
CheckUserUnscoped(t, *user6, user)
|
CheckUserUnscoped(t, *user6, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbedPreload(t *testing.T) {
|
||||||
|
type Country struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type EmbeddedAddress struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
CountryID *int
|
||||||
|
Country *Country
|
||||||
|
}
|
||||||
|
type NestedAddress struct {
|
||||||
|
EmbeddedAddress
|
||||||
|
}
|
||||||
|
type Org struct {
|
||||||
|
ID int
|
||||||
|
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||||
|
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||||
|
AddressID *int
|
||||||
|
Address *EmbeddedAddress
|
||||||
|
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||||
|
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||||
|
|
||||||
|
org := Org{
|
||||||
|
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
|
||||||
|
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
|
||||||
|
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
|
||||||
|
NestedAddress: NestedAddress{
|
||||||
|
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := DB.Create(&org).Error; err != nil {
|
||||||
|
t.Errorf("failed to create org, got err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
preloads map[string][]interface{}
|
||||||
|
expect Org
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "address country",
|
||||||
|
preloads: map[string][]interface{}{"Address.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: EmbeddedAddress{
|
||||||
|
ID: org.PostalAddress.ID,
|
||||||
|
Name: org.PostalAddress.Name,
|
||||||
|
CountryID: org.PostalAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: org.Address,
|
||||||
|
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||||
|
ID: org.NestedAddress.ID,
|
||||||
|
Name: org.NestedAddress.Name,
|
||||||
|
CountryID: org.NestedAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "postal address country",
|
||||||
|
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: org.PostalAddress,
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: nil,
|
||||||
|
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||||
|
ID: org.NestedAddress.ID,
|
||||||
|
Name: org.NestedAddress.Name,
|
||||||
|
CountryID: org.NestedAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "nested address country",
|
||||||
|
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: EmbeddedAddress{
|
||||||
|
ID: org.PostalAddress.ID,
|
||||||
|
Name: org.PostalAddress.Name,
|
||||||
|
CountryID: org.PostalAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: nil,
|
||||||
|
NestedAddress: org.NestedAddress,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "associations",
|
||||||
|
preloads: map[string][]interface{}{
|
||||||
|
clause.Associations: {},
|
||||||
|
// clause.Associations won’t preload nested associations
|
||||||
|
"Address.Country": {},
|
||||||
|
},
|
||||||
|
expect: org,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB = DB.Debug()
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
actual := Org{}
|
||||||
|
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
|
||||||
|
for name, args := range test.preloads {
|
||||||
|
tx = tx.Preload(name, args...)
|
||||||
|
}
|
||||||
|
if err := tx.Find(&actual).Error; err != nil {
|
||||||
|
t.Errorf("failed to find org, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, actual, test.expect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user