From f47c2578fa3fa972ae8180629b468184ba529d46 Mon Sep 17 00:00:00 2001 From: yetone Date: Tue, 9 Apr 2024 00:06:19 +0800 Subject: [PATCH] fix: duplicated preload --- callbacks/preload.go | 12 ++++++++---- callbacks/query.go | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index cf7a0d2b..ecb8fee3 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -87,7 +87,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { // If the current relationship is embedded or joined, current query will be ignored. // //nolint:cyclop -func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, preloaded map[string]struct{}) error { preloadMap := parsePreloadMap(db.Statement.Schema, preloads) // avoid random traversal of the map @@ -116,7 +116,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati for _, name := range preloadNames { if relations := relationships.EmbeddedRelations[name]; relations != nil { - if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds, preloaded); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { @@ -126,20 +126,24 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati for i := 0; i < rv.Len(); i++ { reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, preloaded); err != nil { return err } } case reflect.Struct: reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, preloaded); err != nil { return err } default: return gorm.ErrInvalidData } } else { + if _, ok := preloaded[rel.Name]; ok { + continue + } + preloaded[rel.Name] = struct{}{} tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped diff --git a/callbacks/query.go b/callbacks/query.go index 2a82eaba..47665bb3 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -280,7 +280,7 @@ func Preload(db *gorm.DB) { return } - db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations], map[string]struct{}{})) } }