fix: preload shouldn't overwrite the value of join (#6771)
* fix: preload shouldn't overwrite the value of join * fix lint * fix: join may automatically add nested query
This commit is contained in:
		
							parent
							
								
									e043924fe7
								
							
						
					
					
						commit
						418ee3fc19
					
				| @ -3,6 +3,7 @@ package callbacks | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| @ -82,27 +83,80 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { | |||||||
| 	return names | 	return names | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { | // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
 | ||||||
| 	if relationships == nil { | // If the current relationship is embedded or joined, current query will be ignored.
 | ||||||
| 		return nil | //
 | ||||||
|  | //nolint:cyclop
 | ||||||
|  | func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { | ||||||
|  | 	preloadMap := parsePreloadMap(db.Statement.Schema, preloads) | ||||||
|  | 
 | ||||||
|  | 	// avoid random traversal of the map
 | ||||||
|  | 	preloadNames := make([]string, 0, len(preloadMap)) | ||||||
|  | 	for key := range preloadMap { | ||||||
|  | 		preloadNames = append(preloadNames, key) | ||||||
| 	} | 	} | ||||||
| 	preloadMap := parsePreloadMap(s, preloads) | 	sort.Strings(preloadNames) | ||||||
| 	for name := range preloadMap { | 
 | ||||||
| 		if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { | 	isJoined := func(name string) (joined bool, nestedJoins []string) { | ||||||
| 			if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { | 		for _, join := range joins { | ||||||
|  | 			if _, ok := relationships.Relations[join]; ok && name == join { | ||||||
|  | 				joined = true | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			joinNames := strings.SplitN(join, ".", 2) | ||||||
|  | 			if len(joinNames) == 2 { | ||||||
|  | 				if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { | ||||||
|  | 					joined = true | ||||||
|  | 					nestedJoins = append(nestedJoins, joinNames[1]) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return joined, nestedJoins | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, name := range preloadNames { | ||||||
|  | 		if relations := relationships.EmbeddedRelations[name]; relations != nil { | ||||||
|  | 			if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} else if rel := relationships.Relations[name]; rel != nil { | 		} else if rel := relationships.Relations[name]; rel != nil { | ||||||
| 			if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { | 			if joined, nestedJoins := isJoined(name); joined { | ||||||
| 				return err | 				reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) | ||||||
|  | 				tx := preloadDB(db, reflectValue, reflectValue.Interface()) | ||||||
|  | 				if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 			} else { | ||||||
|  | 				tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) | ||||||
|  | 				tx.Statement.ReflectValue = db.Statement.ReflectValue | ||||||
|  | 				tx.Statement.Unscoped = db.Statement.Unscoped | ||||||
|  | 				if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) | 			return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { | ||||||
|  | 	tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) | ||||||
|  | 	db.Statement.Settings.Range(func(k, v interface{}) bool { | ||||||
|  | 		tx.Statement.Settings.Store(k, v) | ||||||
|  | 		return true | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if err := tx.Statement.Parse(dest); err != nil { | ||||||
|  | 		tx.AddError(err) | ||||||
|  | 		return tx | ||||||
|  | 	} | ||||||
|  | 	tx.Statement.ReflectValue = reflectValue | ||||||
|  | 	tx.Statement.Unscoped = db.Statement.Unscoped | ||||||
|  | 	return tx | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { | func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { | ||||||
| 	var ( | 	var ( | ||||||
| 		reflectValue     = tx.Statement.ReflectValue | 		reflectValue     = tx.Statement.ReflectValue | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ package callbacks | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"sort" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| @ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			db.Statement.AddClause(fromClause) | 			db.Statement.AddClause(fromClause) | ||||||
| 			db.Statement.Joins = nil |  | ||||||
| 		} else { | 		} else { | ||||||
| 			db.Statement.AddClauseIfNotExists(clause.From{}) | 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||||
| 		} | 		} | ||||||
| @ -272,38 +270,23 @@ func Preload(db *gorm.DB) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) | 		joins := make([]string, 0, len(db.Statement.Joins)) | ||||||
| 		preloadNames := make([]string, 0, len(preloadMap)) | 		for _, join := range db.Statement.Joins { | ||||||
| 		for key := range preloadMap { | 			joins = append(joins, join.Name) | ||||||
| 			preloadNames = append(preloadNames, key) |  | ||||||
| 		} | 		} | ||||||
| 		sort.Strings(preloadNames) |  | ||||||
| 
 | 
 | ||||||
| 		preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) | 		tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) | ||||||
| 		db.Statement.Settings.Range(func(k, v interface{}) bool { | 		if tx.Error != nil { | ||||||
| 			preloadDB.Statement.Settings.Store(k, v) |  | ||||||
| 			return true |  | ||||||
| 		}) |  | ||||||
| 
 |  | ||||||
| 		if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		preloadDB.Statement.ReflectValue = db.Statement.ReflectValue |  | ||||||
| 		preloadDB.Statement.Unscoped = db.Statement.Unscoped |  | ||||||
| 
 | 
 | ||||||
| 		for _, name := range preloadNames { | 		db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) | ||||||
| 			if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { |  | ||||||
| 				db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) |  | ||||||
| 			} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { |  | ||||||
| 				db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) |  | ||||||
| 			} else { |  | ||||||
| 				db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterQuery(db *gorm.DB) { | func AfterQuery(db *gorm.DB) { | ||||||
|  | 	// clear the joins after query because preload need it
 | ||||||
|  | 	db.Statement.Joins = nil | ||||||
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { | 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | ||||||
| 			if i, ok := value.(AfterFindInterface); ok { | 			if i, ok := value.(AfterFindInterface); ok { | ||||||
|  | |||||||
| @ -307,6 +307,63 @@ func TestNestedPreloadWithUnscoped(t *testing.T) { | |||||||
| 	CheckUserUnscoped(t, *user6, user) | 	CheckUserUnscoped(t, *user6, user) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestNestedPreloadWithNestedJoin(t *testing.T) { | ||||||
|  | 	type ( | ||||||
|  | 		Preload struct { | ||||||
|  | 			ID       uint | ||||||
|  | 			Value    string | ||||||
|  | 			NestedID uint | ||||||
|  | 		} | ||||||
|  | 		Join struct { | ||||||
|  | 			ID       uint | ||||||
|  | 			Value    string | ||||||
|  | 			NestedID uint | ||||||
|  | 		} | ||||||
|  | 		Nested struct { | ||||||
|  | 			ID       uint | ||||||
|  | 			Preloads []*Preload | ||||||
|  | 			Join     Join | ||||||
|  | 			ValueID  uint | ||||||
|  | 		} | ||||||
|  | 		Value struct { | ||||||
|  | 			ID     uint | ||||||
|  | 			Name   string | ||||||
|  | 			Nested Nested | ||||||
|  | 		} | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) | ||||||
|  | 	DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) | ||||||
|  | 
 | ||||||
|  | 	value := Value{ | ||||||
|  | 		Name: "value", | ||||||
|  | 		Nested: Nested{ | ||||||
|  | 			Preloads: []*Preload{ | ||||||
|  | 				{Value: "p1"}, {Value: "p2"}, | ||||||
|  | 			}, | ||||||
|  | 			Join: Join{Value: "j1"}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	if err := DB.Create(&value).Error; err != nil { | ||||||
|  | 		t.Errorf("failed to create value, got err: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var find1 Value | ||||||
|  | 	err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to find value, got err: %v", err) | ||||||
|  | 	} | ||||||
|  | 	AssertEqual(t, find1, value) | ||||||
|  | 
 | ||||||
|  | 	var find2 Value | ||||||
|  | 	// Joins will automatically add Nested queries.
 | ||||||
|  | 	err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to find value, got err: %v", err) | ||||||
|  | 	} | ||||||
|  | 	AssertEqual(t, find2, value) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestEmbedPreload(t *testing.T) { | func TestEmbedPreload(t *testing.T) { | ||||||
| 	type Country struct { | 	type Country struct { | ||||||
| 		ID   int `gorm:"primaryKey"` | 		ID   int `gorm:"primaryKey"` | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 black-06
						black-06