test for nested generic version Join/Preload
This commit is contained in:
parent
304baabb12
commit
774d957089
@ -110,7 +110,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
@ -124,12 +124,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
guessNestedRelations = append(guessNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
@ -139,22 +139,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
relations = guessNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := join.Alias
|
||||
|
||||
if tableAliasName == "" {
|
||||
tableAliasName = relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
}
|
||||
|
||||
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
@ -237,19 +228,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
for idx, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
curAliasName := rel.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||
}
|
||||
|
||||
if parentTableName != clause.CurrentTable {
|
||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||
} else {
|
||||
parentTableName = rel.Name
|
||||
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||
aliasName := curAliasName
|
||||
if idx == len(relations)-1 && join.Alias != "" {
|
||||
aliasName = join.Alias
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||
specifiedRelationsName[curAliasName] = aliasName
|
||||
}
|
||||
|
||||
parentTableName = curAliasName
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
|
19
generics.go
19
generics.go
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
@ -341,6 +342,9 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable
|
||||
}
|
||||
|
||||
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||
sort.Slice(db.Statement.Joins, func(i, j int) bool {
|
||||
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
|
||||
})
|
||||
return db
|
||||
})
|
||||
}
|
||||
@ -399,7 +403,22 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err
|
||||
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||
if !ok {
|
||||
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
|
||||
relationships := db.Statement.Schema.Relationships
|
||||
for _, field := range preloadFields {
|
||||
var ok bool
|
||||
relation, ok = relationships.Relations[field]
|
||||
if ok {
|
||||
relationships = relation.FieldSchema.Relationships
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if q.limitPerRecord > 0 {
|
||||
|
4
scan.go
4
scan.go
@ -245,9 +245,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
|
||||
for _, join := range db.Statement.Joins {
|
||||
if join.Alias == names[0] {
|
||||
if join.Alias == aliasName {
|
||||
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -378,6 +379,82 @@ func TestGenericsJoins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsNestedJoins(t *testing.T) {
|
||||
users := []User{
|
||||
{
|
||||
Name: "generics-nested-joins-1",
|
||||
Manager: &User{
|
||||
Name: "generics-nested-joins-manager-1",
|
||||
Company: Company{
|
||||
Name: "generics-nested-joins-manager-company-1",
|
||||
},
|
||||
NamedPet: &Pet{
|
||||
Name: "generics-nested-joins-manager-namepet-1",
|
||||
Toy: Toy{
|
||||
Name: "generics-nested-joins-manager-namepet-toy-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
|
||||
},
|
||||
{
|
||||
Name: "generics-nested-joins-2",
|
||||
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
db.CreateInBatches(ctx, &users, 100)
|
||||
|
||||
var userIDs []uint
|
||||
for _, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
}
|
||||
|
||||
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
|
||||
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
|
||||
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
|
||||
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
|
||||
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
|
||||
Where(map[string]any{"id": userIDs}).Find(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||
} else if len(users2) != len(users) {
|
||||
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||
}
|
||||
|
||||
sort.Slice(users2, func(i, j int) bool {
|
||||
return users2[i].ID > users2[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].ID > users[j].ID
|
||||
})
|
||||
|
||||
for idx, user := range users {
|
||||
// user
|
||||
CheckUser(t, user, users2[idx])
|
||||
if users2[idx].Manager == nil {
|
||||
t.Fatalf("Failed to load Manager")
|
||||
}
|
||||
// manager
|
||||
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||
// user pet
|
||||
if users2[idx].NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||
// manager pet
|
||||
if users2[idx].Manager.NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsPreloads(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
@ -499,6 +576,35 @@ func TestGenericsPreloads(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsNestedPreloads(t *testing.T) {
|
||||
user := *GetUser("generics_nested_preload", Config{Pets: 2})
|
||||
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
|
||||
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
for idx, pet := range user.Pets {
|
||||
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
|
||||
}
|
||||
|
||||
if err := db.Create(ctx, &user); err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(3)
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 {
|
||||
t.Errorf("failed to nested preload with limit per record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsDistinct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@ -586,3 +692,40 @@ func TestGenericsSubQuery(t *testing.T) {
|
||||
t.Errorf("Three users should be found, instead found %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsUpsert(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
}
|
||||
|
||||
lang3 := Language{Code: "upsert", Name: "Upsert"}
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "code"}},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
|
||||
}).Create(ctx, &lang3); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
} else if langs[0].Name != "upsert-new" {
|
||||
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user