fix: join may automatically add nested query

This commit is contained in:
black 2024-01-03 15:25:39 +08:00
parent b1cdc9348f
commit f9fab7f930
3 changed files with 29 additions and 14 deletions

View File

@ -87,7 +87,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
// If the current relationship is embedded or joined, current query will be ignored. // If the current relationship is embedded or joined, current query will be ignored.
// //
//nolint:cyclop //nolint:cyclop
func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads) preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map // avoid random traversal of the map
@ -97,26 +97,33 @@ func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships
} }
sort.Strings(preloadNames) sort.Strings(preloadNames)
joined := func(name string) bool { isJoined := func(name string) (joined bool, nestedJoins []string) {
fullPath := strings.Join(append(preloadFields, name), ".")
for _, join := range joins { for _, join := range joins {
if fullPath == join { if _, ok := relationships.Relations[join]; ok && name == join {
return true joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
} }
} }
return false }
return joined, nestedJoins
} }
for _, name := range preloadNames { for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil { if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, append(preloadFields, name), joins, relations, preloadMap[name], associationsConds); err != nil { if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err return err
} }
} else if rel := relationships.Relations[name]; rel != nil { } else if rel := relationships.Relations[name]; rel != nil {
if joined(name) { if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface()) tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, append(preloadFields, name), joins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err return err
} }
} else { } else {

View File

@ -280,7 +280,7 @@ func Preload(db *gorm.DB) {
return return
} }
db.AddError(preloadEntryPoint(tx, nil, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
} }
} }

View File

@ -348,12 +348,20 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
t.Errorf("failed to create value, got err: %v", err) t.Errorf("failed to create value, got err: %v", err)
} }
var find Value var find1 Value
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find).Error err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
if err != nil { if err != nil {
t.Errorf("failed to find org, got err: %v", err) t.Errorf("failed to find value, got err: %v", err)
} }
AssertEqual(t, find, value) AssertEqual(t, find1, value)
var find2 Value
// Joins will automatically add Nested queries.
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find2, value)
} }
func TestEmbedPreload(t *testing.T) { func TestEmbedPreload(t *testing.T) {