fix: join may automatically add nested query

This commit is contained in:
black 2024-01-03 15:25:39 +08:00
parent b1cdc9348f
commit f9fab7f930
3 changed files with 29 additions and 14 deletions

View File

@ -87,7 +87,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
@ -97,26 +97,33 @@ func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships
}
sort.Strings(preloadNames)
joined := func(name string) bool {
fullPath := strings.Join(append(preloadFields, name), ".")
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if fullPath == join {
return true
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
}
}
return false
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, append(preloadFields, name), joins, relations, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined(name) {
if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, append(preloadFields, name), joins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
} else {

View File

@ -280,7 +280,7 @@ func Preload(db *gorm.DB) {
return
}
db.AddError(preloadEntryPoint(tx, nil, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
}
}

View File

@ -348,12 +348,20 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
t.Errorf("failed to create value, got err: %v", err)
}
var find Value
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find).Error
var find1 Value
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
if err != nil {
t.Errorf("failed to find org, got err: %v", err)
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find, value)
AssertEqual(t, find1, value)
var find2 Value
// Joins will automatically add Nested queries.
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find2, value)
}
func TestEmbedPreload(t *testing.T) {