From f9fab7f9304d6350840fca396be39e41f3c0b4d7 Mon Sep 17 00:00:00 2001 From: black Date: Wed, 3 Jan 2024 15:25:39 +0800 Subject: [PATCH] fix: join may automatically add nested query --- callbacks/preload.go | 25 ++++++++++++++++--------- callbacks/query.go | 2 +- tests/preload_test.go | 16 ++++++++++++---- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 4133c2c3..25ecfe76 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -87,7 +87,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { // If the current relationship is embedded or joined, current query will be ignored. // //nolint:cyclop -func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { preloadMap := parsePreloadMap(db.Statement.Schema, preloads) // avoid random traversal of the map @@ -97,26 +97,33 @@ func preloadEntryPoint(db *gorm.DB, preloadFields, joins []string, relationships } sort.Strings(preloadNames) - joined := func(name string) bool { - fullPath := strings.Join(append(preloadFields, name), ".") + isJoined := func(name string) (joined bool, nestedJoins []string) { for _, join := range joins { - if fullPath == join { - return true + if _, ok := relationships.Relations[join]; ok && name == join { + joined = true + continue + } + joinNames := strings.SplitN(join, ".", 2) + if len(joinNames) == 2 { + if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { + joined = true + nestedJoins = append(nestedJoins, joinNames[1]) + } } } - return false + return joined, nestedJoins } for _, name := range preloadNames { if relations := relationships.EmbeddedRelations[name]; relations != nil { - if err := preloadEntryPoint(db, append(preloadFields, name), joins, relations, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { - if joined(name) { + if joined, nestedJoins := isJoined(name); joined { reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, append(preloadFields, name), joins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } } else { diff --git a/callbacks/query.go b/callbacks/query.go index 8a8f3be1..2a82eaba 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -280,7 +280,7 @@ func Preload(db *gorm.DB) { return } - db.AddError(preloadEntryPoint(tx, nil, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } diff --git a/tests/preload_test.go b/tests/preload_test.go index 971c4c6c..26b08d7d 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -348,12 +348,20 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { t.Errorf("failed to create value, got err: %v", err) } - var find Value - err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find).Error + var find1 Value + err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error if err != nil { - t.Errorf("failed to find org, got err: %v", err) + t.Errorf("failed to find value, got err: %v", err) } - AssertEqual(t, find, value) + AssertEqual(t, find1, value) + + var find2 Value + // Joins will automatically add Nested queries. + err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find2, value) } func TestEmbedPreload(t *testing.T) {