Enhance preload functionality and add custom preload tests
This commit is contained in:
parent
6bfccf8afa
commit
c6bce1826f
@ -12,29 +12,6 @@ import (
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
@ -74,7 +51,6 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
@ -84,10 +60,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, customJoin func(*gorm.DB) *gorm.DB) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
@ -116,7 +89,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds, customJoin); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
@ -138,14 +111,14 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
||||
}
|
||||
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Pointer:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
@ -155,7 +128,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name], customJoin); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -182,7 +155,7 @@ func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}, customJoin func(*gorm.DB) *gorm.DB) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
relForeignKeys []string
|
||||
@ -193,6 +166,10 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
inlineConds []interface{}
|
||||
)
|
||||
|
||||
if customJoin != nil {
|
||||
tx = customJoin(tx)
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var (
|
||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
@ -268,8 +245,14 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
|
||||
// nested preload
|
||||
for p, pvs := range preloads {
|
||||
if customJoin != nil {
|
||||
tx = tx.Preload(p, pvs, func(tx *gorm.DB) *gorm.DB {
|
||||
return customJoin(tx)
|
||||
})
|
||||
} else {
|
||||
tx = tx.Preload(p, pvs...)
|
||||
}
|
||||
}
|
||||
|
||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
@ -280,7 +280,7 @@ func Preload(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations], nil))
|
||||
}
|
||||
}
|
||||
|
||||
|
5
go.mod
5
go.mod
@ -7,3 +7,8 @@ require (
|
||||
github.com/jinzhu/now v1.1.5
|
||||
golang.org/x/text v0.20.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
gorm.io/driver/sqlite v1.5.6 // indirect
|
||||
)
|
||||
|
4
go.sum
4
go.sum
@ -2,5 +2,9 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
|
||||
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
|
||||
|
218
tests/preload_custom_test.go
Normal file
218
tests/preload_custom_test.go
Normal file
@ -0,0 +1,218 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Item struct {
|
||||
ID uint
|
||||
Name string
|
||||
Tags []Tag `gorm:"many2many:item_tags"`
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID uint
|
||||
Name string
|
||||
Status string
|
||||
SubTags []SubTag `gorm:"many2many:tag_sub_tags"`
|
||||
}
|
||||
|
||||
type SubTag struct {
|
||||
ID uint
|
||||
Name string
|
||||
Status string
|
||||
}
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database: %v", err)
|
||||
}
|
||||
db.AutoMigrate(&Item{}, &Tag{}, &SubTag{})
|
||||
return db
|
||||
}
|
||||
|
||||
func TestDefaultPreload(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
tag1 := Tag{Name: "Tag1", Status: "active"}
|
||||
item := Item{Name: "Item1", Tags: []Tag{tag1}}
|
||||
db.Create(&item)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags").Find(&items).Error
|
||||
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("default preload failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Name != "Tag1" {
|
||||
t.Errorf("unexpected default preload results: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomJoinsWithConditions(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
tag1 := Tag{Name: "Tag1", Status: "active"}
|
||||
tag2 := Tag{Name: "Tag2", Status: "inactive"}
|
||||
item := Item{Name: "Item1", Tags: []Tag{tag1, tag2}}
|
||||
db.Create(&item)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id").
|
||||
Where("tags.status = ?", "active")
|
||||
}).Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("custom join with conditions failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Status != "active" {
|
||||
t.Errorf("unexpected results with custom join: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreloadWithCustomJoins(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
subTag := SubTag{Name: "SubTag1", Status: "active"}
|
||||
tag := Tag{Name: "Tag1", Status: "active", SubTags: []SubTag{subTag}}
|
||||
item := Item{Name: "Item1", Tags: []Tag{tag}}
|
||||
db.Create(&item)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags.SubTags", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Joins("JOIN tag_sub_tags ON tag_sub_tags.sub_tag_id = sub_tags.id").
|
||||
Where("sub_tags.status = ?", "active")
|
||||
}).Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("nested preload with custom joins failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 1 || len(items[0].Tags) != 1 || len(items[0].Tags[0].SubTags) != 1 || items[0].Tags[0].SubTags[0].Name != "SubTag1" {
|
||||
t.Errorf("unexpected nested preload results: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoMatchingRecords(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
tag := Tag{Name: "Tag1", Status: "inactive"}
|
||||
item := Item{Name: "Item1", Tags: []Tag{tag}}
|
||||
db.Create(&item)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id").
|
||||
Where("tags.status = ?", "active")
|
||||
}).Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("preload with no matching records failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 1 || len(items[0].Tags) != 0 {
|
||||
t.Errorf("unexpected results when no records match: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyDatabase(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags").Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("preload with empty database failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 0 {
|
||||
t.Errorf("unexpected results with empty database: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleItemsWithDifferentTagStatuses(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
tag1 := Tag{Name: "Tag1", Status: "active"}
|
||||
tag2 := Tag{Name: "Tag2", Status: "inactive"}
|
||||
item1 := Item{Name: "Item1", Tags: []Tag{tag1}}
|
||||
item2 := Item{Name: "Item2", Tags: []Tag{tag2}}
|
||||
db.Create(&item1)
|
||||
db.Create(&item2)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Joins("JOIN item_tags ON item_tags.tag_id = tags.id").
|
||||
Where("tags.status = ?", "active")
|
||||
}).Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("preload with multiple items failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 2 || len(items[0].Tags) != 1 || len(items[1].Tags) != 0 {
|
||||
t.Errorf("unexpected results with multiple items: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoRelationshipsDefined(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
item := Item{Name: "Item1"}
|
||||
db.Create(&item)
|
||||
|
||||
var items []Item
|
||||
err := db.Preload("Tags").Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("preload with no relationships failed: %v", err)
|
||||
}
|
||||
|
||||
if len(items) != 1 || len(items[0].Tags) != 0 {
|
||||
t.Errorf("unexpected results when no relationships are defined: %v", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicatePreloadConditions(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
tag1 := Tag{Name: "Tag1", Status: "active"}
|
||||
tag2 := Tag{Name: "Tag2", Status: "inactive"}
|
||||
item := Item{Name: "Item1", Tags: []Tag{tag1, tag2}}
|
||||
db.Create(&item)
|
||||
|
||||
var activeTagsItems []Item
|
||||
var inactiveTagsItems []Item
|
||||
|
||||
// Query for active tags
|
||||
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Where("status = ?", "active")
|
||||
}).Find(&activeTagsItems).Error
|
||||
if err != nil {
|
||||
t.Fatalf("preload for active tags failed: %v", err)
|
||||
}
|
||||
|
||||
// Query for inactive tags
|
||||
err = db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Where("status = ?", "inactive")
|
||||
}).Find(&inactiveTagsItems).Error
|
||||
if err != nil {
|
||||
t.Fatalf("preload for inactive tags failed: %v", err)
|
||||
}
|
||||
|
||||
// Validate the results
|
||||
if len(activeTagsItems) != 1 || len(activeTagsItems[0].Tags) != 1 || activeTagsItems[0].Tags[0].Status != "active" {
|
||||
t.Errorf("unexpected active tag results: %v", activeTagsItems)
|
||||
}
|
||||
if len(inactiveTagsItems) != 1 || len(inactiveTagsItems[0].Tags) != 1 || inactiveTagsItems[0].Tags[0].Status != "inactive" {
|
||||
t.Errorf("unexpected inactive tag results: %v", inactiveTagsItems)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user