Enhance preload functionality and add custom preload tests
This commit is contained in:
		
							parent
							
								
									6bfccf8afa
								
							
						
					
					
						commit
						c6bce1826f
					
				| @ -12,29 +12,6 @@ import ( | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| // parsePreloadMap extracts nested preloads. e.g.
 | ||||
| //
 | ||||
| //	// schema has a "k0" relation and a "k7.k8" embedded relation
 | ||||
| //	parsePreloadMap(schema, map[string][]interface{}{
 | ||||
| //		clause.Associations: {"arg1"},
 | ||||
| //		"k1":                {"arg2"},
 | ||||
| //		"k2.k3":             {"arg3"},
 | ||||
| //		"k4.k5.k6":          {"arg4"},
 | ||||
| //	})
 | ||||
| //	// preloadMap is
 | ||||
| //	map[string]map[string][]interface{}{
 | ||||
| //		"k0": {},
 | ||||
| //		"k7": {
 | ||||
| //			"k8": {},
 | ||||
| //		},
 | ||||
| //		"k1": {},
 | ||||
| //		"k2": {
 | ||||
| //			"k3": {"arg3"},
 | ||||
| //		},
 | ||||
| //		"k4": {
 | ||||
| //			"k5.k6": {"arg4"},
 | ||||
| //		},
 | ||||
| //	}
 | ||||
| func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { | ||||
| 	preloadMap := map[string]map[string][]interface{}{} | ||||
| 	setPreloadMap := func(name, value string, args []interface{}) { | ||||
| @ -74,7 +51,6 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { | ||||
| 	} | ||||
| 	names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) | ||||
| 	for _, relation := range embeddedRelations.Relations { | ||||
| 		// skip first struct name
 | ||||
| 		names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) | ||||
| 	} | ||||
| 	for _, relations := range embeddedRelations.EmbeddedRelations { | ||||
| @ -84,10 +60,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { | ||||
| } | ||||
| 
 | ||||
| // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
 | ||||
| // If the current relationship is embedded or joined, current query will be ignored.
 | ||||
| //
 | ||||
| //nolint:cyclop
 | ||||
| func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { | ||||
| func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, customJoin func(*gorm.DB) *gorm.DB) error { | ||||
| 	preloadMap := parsePreloadMap(db.Statement.Schema, preloads) | ||||
| 
 | ||||
| 	// avoid random traversal of the map
 | ||||
| @ -116,7 +89,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati | ||||
| 
 | ||||
| 	for _, name := range preloadNames { | ||||
| 		if relations := relationships.EmbeddedRelations[name]; relations != nil { | ||||
| 			if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { | ||||
| 			if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds, customJoin); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} else if rel := relationships.Relations[name]; rel != nil { | ||||
| @ -138,14 +111,14 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati | ||||
| 						} | ||||
| 
 | ||||
| 						tx := preloadDB(db, reflectValue, reflectValue.Interface()) | ||||
| 						if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { | ||||
| 						if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil { | ||||
| 							return err | ||||
| 						} | ||||
| 					} | ||||
| 				case reflect.Struct, reflect.Pointer: | ||||
| 					reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) | ||||
| 					tx := preloadDB(db, reflectValue, reflectValue.Interface()) | ||||
| 					if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { | ||||
| 					if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil { | ||||
| 						return err | ||||
| 					} | ||||
| 				default: | ||||
| @ -155,7 +128,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati | ||||
| 				tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) | ||||
| 				tx.Statement.ReflectValue = db.Statement.ReflectValue | ||||
| 				tx.Statement.Unscoped = db.Statement.Unscoped | ||||
| 				if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { | ||||
| 				if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name], customJoin); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 			} | ||||
| @ -182,7 +155,7 @@ func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm. | ||||
| 	return tx | ||||
| } | ||||
| 
 | ||||
| func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { | ||||
| func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}, customJoin func(*gorm.DB) *gorm.DB) error { | ||||
| 	var ( | ||||
| 		reflectValue     = tx.Statement.ReflectValue | ||||
| 		relForeignKeys   []string | ||||
| @ -193,6 +166,10 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 		inlineConds      []interface{} | ||||
| 	) | ||||
| 
 | ||||
| 	if customJoin != nil { | ||||
| 		tx = customJoin(tx) | ||||
| 	} | ||||
| 
 | ||||
| 	if rel.JoinTable != nil { | ||||
| 		var ( | ||||
| 			joinForeignFields    = make([]*schema.Field, 0, len(rel.References)) | ||||
| @ -268,7 +245,13 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 
 | ||||
| 	// nested preload
 | ||||
| 	for p, pvs := range preloads { | ||||
| 		tx = tx.Preload(p, pvs...) | ||||
| 		if customJoin != nil { | ||||
| 			tx = tx.Preload(p, pvs, func(tx *gorm.DB) *gorm.DB { | ||||
| 				return customJoin(tx) | ||||
| 			}) | ||||
| 		} else { | ||||
| 			tx = tx.Preload(p, pvs...) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	reflectResults := rel.FieldSchema.MakeSlice().Elem() | ||||
|  | ||||
| @ -280,7 +280,7 @@ func Preload(db *gorm.DB) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) | ||||
| 		db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations], nil)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @ -7,3 +7,8 @@ require ( | ||||
| 	github.com/jinzhu/now v1.1.5 | ||||
| 	golang.org/x/text v0.20.0 | ||||
| ) | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/mattn/go-sqlite3 v1.14.22 // indirect | ||||
| 	gorm.io/driver/sqlite v1.5.6 // indirect | ||||
| ) | ||||
|  | ||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							| @ -2,5 +2,9 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= | ||||
| github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= | ||||
| golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= | ||||
| golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= | ||||
| gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= | ||||
| gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= | ||||
|  | ||||
							
								
								
									
										218
									
								
								tests/preload_custom_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								tests/preload_custom_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,218 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 	"gorm.io/driver/sqlite" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| 
 | ||||
| type Item struct { | ||||
| 	ID        uint | ||||
| 	Name      string | ||||
| 	Tags      []Tag `gorm:"many2many:item_tags"` | ||||
| 	CreatedAt time.Time | ||||
| } | ||||
| 
 | ||||
| type Tag struct { | ||||
| 	ID        uint | ||||
| 	Name      string | ||||
| 	Status    string | ||||
| 	SubTags   []SubTag `gorm:"many2many:tag_sub_tags"` | ||||
| } | ||||
| 
 | ||||
| type SubTag struct { | ||||
| 	ID     uint | ||||
| 	Name   string | ||||
| 	Status string | ||||
| } | ||||
| 
 | ||||
| func setupTestDB(t *testing.T) *gorm.DB { | ||||
| 	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to connect database: %v", err) | ||||
| 	} | ||||
| 	db.AutoMigrate(&Item{}, &Tag{}, &SubTag{}) | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| func TestDefaultPreload(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 
 | ||||
| 	tag1 := Tag{Name: "Tag1", Status: "active"} | ||||
| 	item := Item{Name: "Item1", Tags: []Tag{tag1}} | ||||
| 	db.Create(&item) | ||||
| 
 | ||||
| 	var items []Item | ||||
| 	err := db.Preload("Tags").Find(&items).Error | ||||
| 
 | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("default preload failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Name != "Tag1" { | ||||
| 		t.Errorf("unexpected default preload results: %v", items) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCustomJoinsWithConditions(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 
 | ||||
| 	tag1 := Tag{Name: "Tag1", Status: "active"} | ||||
| 	tag2 := Tag{Name: "Tag2", Status: "inactive"} | ||||
| 	item := Item{Name: "Item1", Tags: []Tag{tag1, tag2}} | ||||
| 	db.Create(&item) | ||||
| 
 | ||||
| 	var items []Item | ||||
| 	err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { | ||||
| 		return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id"). | ||||
| 			Where("tags.status = ?", "active") | ||||
| 	}).Find(&items).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("custom join with conditions failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Status != "active" { | ||||
| 		t.Errorf("unexpected results with custom join: %v", items) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNestedPreloadWithCustomJoins(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 
 | ||||
| 	subTag := SubTag{Name: "SubTag1", Status: "active"} | ||||
| 	tag := Tag{Name: "Tag1", Status: "active", SubTags: []SubTag{subTag}} | ||||
| 	item := Item{Name: "Item1", Tags: []Tag{tag}} | ||||
| 	db.Create(&item) | ||||
| 
 | ||||
| 	var items []Item | ||||
| 	err := db.Preload("Tags.SubTags", func(tx *gorm.DB) *gorm.DB { | ||||
| 		return tx.Joins("JOIN tag_sub_tags ON tag_sub_tags.sub_tag_id = sub_tags.id"). | ||||
| 			Where("sub_tags.status = ?", "active") | ||||
| 	}).Find(&items).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("nested preload with custom joins failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(items) != 1 || len(items[0].Tags) != 1 || len(items[0].Tags[0].SubTags) != 1 || items[0].Tags[0].SubTags[0].Name != "SubTag1" { | ||||
| 		t.Errorf("unexpected nested preload results: %v", items) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNoMatchingRecords(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 
 | ||||
| 	tag := Tag{Name: "Tag1", Status: "inactive"} | ||||
| 	item := Item{Name: "Item1", Tags: []Tag{tag}} | ||||
| 	db.Create(&item) | ||||
| 
 | ||||
| 	var items []Item | ||||
| 	err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { | ||||
| 		return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id"). | ||||
| 			Where("tags.status = ?", "active") | ||||
| 	}).Find(&items).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("preload with no matching records failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(items) != 1 || len(items[0].Tags) != 0 { | ||||
| 		t.Errorf("unexpected results when no records match: %v", items) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEmptyDatabase(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 
 | ||||
| 	var items []Item | ||||
| 	err := db.Preload("Tags").Find(&items).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("preload with empty database failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(items) != 0 { | ||||
| 		t.Errorf("unexpected results with empty database: %v", items) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestMultipleItemsWithDifferentTagStatuses(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 
 | ||||
| 	tag1 := Tag{Name: "Tag1", Status: "active"} | ||||
| 	tag2 := Tag{Name: "Tag2", Status: "inactive"} | ||||
| 	item1 := Item{Name: "Item1", Tags: []Tag{tag1}} | ||||
| 	item2 := Item{Name: "Item2", Tags: []Tag{tag2}} | ||||
| 	db.Create(&item1) | ||||
| 	db.Create(&item2) | ||||
| 
 | ||||
| 	var items []Item | ||||
| 	err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { | ||||
| 		return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id"). | ||||
| 			Where("tags.status = ?", "active") | ||||
| 	}).Find(&items).Error | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("preload with multiple items failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(items) != 2 || len(items[0].Tags) != 1 || len(items[1].Tags) != 0 { | ||||
| 		t.Errorf("unexpected results with multiple items: %v", items) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNoRelationshipsDefined(t *testing.T) { | ||||
|     db := setupTestDB(t) | ||||
|     item := Item{Name: "Item1"} | ||||
|     db.Create(&item) | ||||
| 
 | ||||
|     var items []Item | ||||
|     err := db.Preload("Tags").Find(&items).Error | ||||
| 
 | ||||
|     if err != nil { | ||||
|         t.Fatalf("preload with no relationships failed: %v", err) | ||||
|     } | ||||
| 
 | ||||
|     if len(items) != 1 || len(items[0].Tags) != 0 { | ||||
|         t.Errorf("unexpected results when no relationships are defined: %v", items) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| func TestDuplicatePreloadConditions(t *testing.T) { | ||||
|     db := setupTestDB(t) | ||||
| 
 | ||||
|     tag1 := Tag{Name: "Tag1", Status: "active"} | ||||
|     tag2 := Tag{Name: "Tag2", Status: "inactive"} | ||||
|     item := Item{Name: "Item1", Tags: []Tag{tag1, tag2}} | ||||
|     db.Create(&item) | ||||
| 
 | ||||
|     var activeTagsItems []Item | ||||
|     var inactiveTagsItems []Item | ||||
| 
 | ||||
|     // Query for active tags
 | ||||
|     err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { | ||||
|         return tx.Where("status = ?", "active") | ||||
|     }).Find(&activeTagsItems).Error | ||||
|     if err != nil { | ||||
|         t.Fatalf("preload for active tags failed: %v", err) | ||||
|     } | ||||
| 
 | ||||
|     // Query for inactive tags
 | ||||
|     err = db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { | ||||
|         return tx.Where("status = ?", "inactive") | ||||
|     }).Find(&inactiveTagsItems).Error | ||||
|     if err != nil { | ||||
|         t.Fatalf("preload for inactive tags failed: %v", err) | ||||
|     } | ||||
| 
 | ||||
|     // Validate the results
 | ||||
|     if len(activeTagsItems) != 1 || len(activeTagsItems[0].Tags) != 1 || activeTagsItems[0].Tags[0].Status != "active" { | ||||
|         t.Errorf("unexpected active tag results: %v", activeTagsItems) | ||||
|     } | ||||
|     if len(inactiveTagsItems) != 1 || len(inactiveTagsItems[0].Tags) != 1 || inactiveTagsItems[0].Tags[0].Status != "inactive" { | ||||
|         t.Errorf("unexpected inactive tag results: %v", inactiveTagsItems) | ||||
|     } | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Mohammad_Oveisi
						Mohammad_Oveisi