Enhance preload functionality and add custom preload tests

This commit is contained in:
Mohammad_Oveisi 2024-11-26 00:15:56 +00:00
parent 6bfccf8afa
commit c6bce1826f
5 changed files with 245 additions and 35 deletions

View File

@ -12,29 +12,6 @@ import (
"gorm.io/gorm/utils"
)
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
@ -74,7 +51,6 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
@ -84,10 +60,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, customJoin func(*gorm.DB) *gorm.DB) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
@ -116,7 +89,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds, customJoin); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
@ -138,14 +111,14 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil {
return err
}
default:
@ -155,7 +128,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name], customJoin); err != nil {
return err
}
}
@ -182,7 +155,7 @@ func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}, customJoin func(*gorm.DB) *gorm.DB) error {
var (
reflectValue = tx.Statement.ReflectValue
relForeignKeys []string
@ -193,6 +166,10 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{}
)
if customJoin != nil {
tx = customJoin(tx)
}
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
@ -268,7 +245,13 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
if customJoin != nil {
tx = tx.Preload(p, pvs, func(tx *gorm.DB) *gorm.DB {
return customJoin(tx)
})
} else {
tx = tx.Preload(p, pvs...)
}
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()

View File

@ -280,7 +280,7 @@ func Preload(db *gorm.DB) {
return
}
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations], nil))
}
}

5
go.mod
View File

@ -7,3 +7,8 @@ require (
github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
)
require (
github.com/mattn/go-sqlite3 v1.14.22 // indirect
gorm.io/driver/sqlite v1.5.6 // indirect
)

4
go.sum
View File

@ -2,5 +2,9 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=

View File

@ -0,0 +1,218 @@
package tests_test
import (
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Item struct {
ID uint
Name string
Tags []Tag `gorm:"many2many:item_tags"`
CreatedAt time.Time
}
type Tag struct {
ID uint
Name string
Status string
SubTags []SubTag `gorm:"many2many:tag_sub_tags"`
}
type SubTag struct {
ID uint
Name string
Status string
}
func setupTestDB(t *testing.T) *gorm.DB {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to connect database: %v", err)
}
db.AutoMigrate(&Item{}, &Tag{}, &SubTag{})
return db
}
func TestDefaultPreload(t *testing.T) {
db := setupTestDB(t)
tag1 := Tag{Name: "Tag1", Status: "active"}
item := Item{Name: "Item1", Tags: []Tag{tag1}}
db.Create(&item)
var items []Item
err := db.Preload("Tags").Find(&items).Error
if err != nil {
t.Fatalf("default preload failed: %v", err)
}
if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Name != "Tag1" {
t.Errorf("unexpected default preload results: %v", items)
}
}
func TestCustomJoinsWithConditions(t *testing.T) {
db := setupTestDB(t)
tag1 := Tag{Name: "Tag1", Status: "active"}
tag2 := Tag{Name: "Tag2", Status: "inactive"}
item := Item{Name: "Item1", Tags: []Tag{tag1, tag2}}
db.Create(&item)
var items []Item
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id").
Where("tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("custom join with conditions failed: %v", err)
}
if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Status != "active" {
t.Errorf("unexpected results with custom join: %v", items)
}
}
func TestNestedPreloadWithCustomJoins(t *testing.T) {
db := setupTestDB(t)
subTag := SubTag{Name: "SubTag1", Status: "active"}
tag := Tag{Name: "Tag1", Status: "active", SubTags: []SubTag{subTag}}
item := Item{Name: "Item1", Tags: []Tag{tag}}
db.Create(&item)
var items []Item
err := db.Preload("Tags.SubTags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN tag_sub_tags ON tag_sub_tags.sub_tag_id = sub_tags.id").
Where("sub_tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("nested preload with custom joins failed: %v", err)
}
if len(items) != 1 || len(items[0].Tags) != 1 || len(items[0].Tags[0].SubTags) != 1 || items[0].Tags[0].SubTags[0].Name != "SubTag1" {
t.Errorf("unexpected nested preload results: %v", items)
}
}
func TestNoMatchingRecords(t *testing.T) {
db := setupTestDB(t)
tag := Tag{Name: "Tag1", Status: "inactive"}
item := Item{Name: "Item1", Tags: []Tag{tag}}
db.Create(&item)
var items []Item
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id").
Where("tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("preload with no matching records failed: %v", err)
}
if len(items) != 1 || len(items[0].Tags) != 0 {
t.Errorf("unexpected results when no records match: %v", items)
}
}
func TestEmptyDatabase(t *testing.T) {
db := setupTestDB(t)
var items []Item
err := db.Preload("Tags").Find(&items).Error
if err != nil {
t.Fatalf("preload with empty database failed: %v", err)
}
if len(items) != 0 {
t.Errorf("unexpected results with empty database: %v", items)
}
}
func TestMultipleItemsWithDifferentTagStatuses(t *testing.T) {
db := setupTestDB(t)
tag1 := Tag{Name: "Tag1", Status: "active"}
tag2 := Tag{Name: "Tag2", Status: "inactive"}
item1 := Item{Name: "Item1", Tags: []Tag{tag1}}
item2 := Item{Name: "Item2", Tags: []Tag{tag2}}
db.Create(&item1)
db.Create(&item2)
var items []Item
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id").
Where("tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("preload with multiple items failed: %v", err)
}
if len(items) != 2 || len(items[0].Tags) != 1 || len(items[1].Tags) != 0 {
t.Errorf("unexpected results with multiple items: %v", items)
}
}
func TestNoRelationshipsDefined(t *testing.T) {
db := setupTestDB(t)
item := Item{Name: "Item1"}
db.Create(&item)
var items []Item
err := db.Preload("Tags").Find(&items).Error
if err != nil {
t.Fatalf("preload with no relationships failed: %v", err)
}
if len(items) != 1 || len(items[0].Tags) != 0 {
t.Errorf("unexpected results when no relationships are defined: %v", items)
}
}
func TestDuplicatePreloadConditions(t *testing.T) {
db := setupTestDB(t)
tag1 := Tag{Name: "Tag1", Status: "active"}
tag2 := Tag{Name: "Tag2", Status: "inactive"}
item := Item{Name: "Item1", Tags: []Tag{tag1, tag2}}
db.Create(&item)
var activeTagsItems []Item
var inactiveTagsItems []Item
// Query for active tags
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Where("status = ?", "active")
}).Find(&activeTagsItems).Error
if err != nil {
t.Fatalf("preload for active tags failed: %v", err)
}
// Query for inactive tags
err = db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Where("status = ?", "inactive")
}).Find(&inactiveTagsItems).Error
if err != nil {
t.Fatalf("preload for inactive tags failed: %v", err)
}
// Validate the results
if len(activeTagsItems) != 1 || len(activeTagsItems[0].Tags) != 1 || activeTagsItems[0].Tags[0].Status != "active" {
t.Errorf("unexpected active tag results: %v", activeTagsItems)
}
if len(inactiveTagsItems) != 1 || len(inactiveTagsItems[0].Tags) != 1 || inactiveTagsItems[0].Tags[0].Status != "inactive" {
t.Errorf("unexpected inactive tag results: %v", inactiveTagsItems)
}
}