fix: join may automatically add nested query
This commit is contained in:
parent
b1cdc9348f
commit
f9fab7f930
@ -87,7 +87,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
|||||||
// If the current relationship is embedded or joined, current query will be ignored.
|
// If the current relationship is embedded or joined, current query will be ignored.
|
||||||
//
|
//
|
||||||
//nolint:cyclop
|
//nolint:cyclop
|
||||||
func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||||
|
|
||||||
// avoid random traversal of the map
|
// avoid random traversal of the map
|
||||||
@ -97,26 +97,33 @@ func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships
|
|||||||
}
|
}
|
||||||
sort.Strings(preloadNames)
|
sort.Strings(preloadNames)
|
||||||
|
|
||||||
joined := func(name string) bool {
|
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||||
fullPath := strings.Join(append(preloadFields, name), ".")
|
|
||||||
for _, join := range joins {
|
for _, join := range joins {
|
||||||
if fullPath == join {
|
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||||
return true
|
joined = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
joinNames := strings.SplitN(join, ".", 2)
|
||||||
|
if len(joinNames) == 2 {
|
||||||
|
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||||
|
joined = true
|
||||||
|
nestedJoins = append(nestedJoins, joinNames[1])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return joined, nestedJoins
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
for _, name := range preloadNames {
|
||||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||||
if err := preloadEntryPoint(db, append(preloadFields, name), joins, relations, preloadMap[name], associationsConds); err != nil {
|
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if rel := relationships.Relations[name]; rel != nil {
|
} else if rel := relationships.Relations[name]; rel != nil {
|
||||||
if joined(name) {
|
if joined, nestedJoins := isJoined(name); joined {
|
||||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||||
if err := preloadEntryPoint(tx, append(preloadFields, name), joins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -280,7 +280,7 @@ func Preload(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(preloadEntryPoint(tx, nil, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,12 +348,20 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
|||||||
t.Errorf("failed to create value, got err: %v", err)
|
t.Errorf("failed to create value, got err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var find Value
|
var find1 Value
|
||||||
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find).Error
|
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to find org, got err: %v", err)
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
}
|
}
|
||||||
AssertEqual(t, find, value)
|
AssertEqual(t, find1, value)
|
||||||
|
|
||||||
|
var find2 Value
|
||||||
|
// Joins will automatically add Nested queries.
|
||||||
|
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, find2, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbedPreload(t *testing.T) {
|
func TestEmbedPreload(t *testing.T) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user