test for nested generic version Join/Preload

This commit is contained in:
Jinzhu 2025-05-22 19:45:34 +08:00
parent 304baabb12
commit 774d957089
4 changed files with 185 additions and 25 deletions

View File

@ -110,7 +110,7 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
specifiedRelationsName := make(map[string]interface{})
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
@ -124,12 +124,12 @@ func BuildQuerySQL(db *gorm.DB) {
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
@ -139,22 +139,13 @@ func BuildQuerySQL(db *gorm.DB) {
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
relations = guessNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := join.Alias
if tableAliasName == "" {
tableAliasName = relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
}
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
@ -237,19 +228,24 @@ func BuildQuerySQL(db *gorm.DB) {
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{

View File

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
@ -341,6 +342,9 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
@ -399,7 +403,22 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
db.AddError(fmt.Errorf("relation %s not found", association))
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {

View File

@ -245,9 +245,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == names[0] {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"testing"
@ -378,6 +379,82 @@ func TestGenericsJoins(t *testing.T) {
}
}
func TestGenericsNestedJoins(t *testing.T) {
users := []User{
{
Name: "generics-nested-joins-1",
Manager: &User{
Name: "generics-nested-joins-manager-1",
Company: Company{
Name: "generics-nested-joins-manager-company-1",
},
NamedPet: &Pet{
Name: "generics-nested-joins-manager-namepet-1",
Toy: Toy{
Name: "generics-nested-joins-manager-namepet-toy-1",
},
},
},
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
},
{
Name: "generics-nested-joins-2",
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
},
}
ctx := context.Background()
db := gorm.G[User](DB)
db.CreateInBatches(ctx, &users, 100)
var userIDs []uint
for _, user := range users {
userIDs = append(userIDs, user.ID)
}
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
Where(map[string]any{"id": userIDs}).Find(ctx)
if err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err)
} else if len(users2) != len(users) {
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
}
sort.Slice(users2, func(i, j int) bool {
return users2[i].ID > users2[j].ID
})
sort.Slice(users, func(i, j int) bool {
return users[i].ID > users[j].ID
})
for idx, user := range users {
// user
CheckUser(t, user, users2[idx])
if users2[idx].Manager == nil {
t.Fatalf("Failed to load Manager")
}
// manager
CheckUser(t, *user.Manager, *users2[idx].Manager)
// user pet
if users2[idx].NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
// manager pet
if users2[idx].Manager.NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
}
}
func TestGenericsPreloads(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
@ -499,6 +576,35 @@ func TestGenericsPreloads(t *testing.T) {
}
}
func TestGenericsNestedPreloads(t *testing.T) {
user := *GetUser("generics_nested_preload", Config{Pets: 2})
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
ctx := context.Background()
db := gorm.G[User](DB)
for idx, pet := range user.Pets {
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
}
if err := db.Create(ctx, &user); err != nil {
t.Fatalf("errors happened when create: %v", err)
}
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(3)
return nil
}).Where(user.ID).Take(ctx)
if err != nil {
t.Errorf("failed to nested preload user")
}
CheckUser(t, user2, user)
if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 {
t.Errorf("failed to nested preload with limit per record")
}
}
func TestGenericsDistinct(t *testing.T) {
ctx := context.Background()
@ -586,3 +692,40 @@ func TestGenericsSubQuery(t *testing.T) {
t.Errorf("Three users should be found, instead found %d", len(results))
}
}
func TestGenericsUpsert(t *testing.T) {
ctx := context.Background()
lang := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
lang2 := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
if err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
}
lang3 := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{
Columns: []clause.Column{{Name: "code"}},
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
}).Create(ctx, &lang3); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
} else if langs[0].Name != "upsert-new" {
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
}
}