test for nested generic version Join/Preload
This commit is contained in:
parent
304baabb12
commit
774d957089
@ -110,7 +110,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
specifiedRelationsName := make(map[string]interface{})
|
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil {
|
||||||
var isRelations bool // is relations or raw sql
|
var isRelations bool // is relations or raw sql
|
||||||
@ -124,12 +124,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
nestedJoinNames := strings.Split(join.Name, ".")
|
nestedJoinNames := strings.Split(join.Name, ".")
|
||||||
if len(nestedJoinNames) > 1 {
|
if len(nestedJoinNames) > 1 {
|
||||||
isNestedJoin := true
|
isNestedJoin := true
|
||||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||||
for _, relname := range nestedJoinNames {
|
for _, relname := range nestedJoinNames {
|
||||||
// incomplete match, only treated as raw sql
|
// incomplete match, only treated as raw sql
|
||||||
if relation, ok = currentRelations[relname]; ok {
|
if relation, ok = currentRelations[relname]; ok {
|
||||||
gussNestedRelations = append(gussNestedRelations, relation)
|
guessNestedRelations = append(guessNestedRelations, relation)
|
||||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||||
} else {
|
} else {
|
||||||
isNestedJoin = false
|
isNestedJoin = false
|
||||||
@ -139,22 +139,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
if isNestedJoin {
|
if isNestedJoin {
|
||||||
isRelations = true
|
isRelations = true
|
||||||
relations = gussNestedRelations
|
relations = guessNestedRelations
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRelations {
|
if isRelations {
|
||||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||||
tableAliasName := join.Alias
|
|
||||||
|
|
||||||
if tableAliasName == "" {
|
|
||||||
tableAliasName = relation.Name
|
|
||||||
if parentTableName != clause.CurrentTable {
|
|
||||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
columnStmt := gorm.Statement{
|
columnStmt := gorm.Statement{
|
||||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||||
Selects: join.Selects, Omits: join.Omits,
|
Selects: join.Selects, Omits: join.Omits,
|
||||||
@ -237,19 +228,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
parentTableName := clause.CurrentTable
|
parentTableName := clause.CurrentTable
|
||||||
for _, rel := range relations {
|
for idx, rel := range relations {
|
||||||
// joins table alias like "Manager, Company, Manager__Company"
|
// joins table alias like "Manager, Company, Manager__Company"
|
||||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
curAliasName := rel.Name
|
||||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
if parentTableName != clause.CurrentTable {
|
||||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||||
specifiedRelationsName[nestedAlias] = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if parentTableName != clause.CurrentTable {
|
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
aliasName := curAliasName
|
||||||
} else {
|
if idx == len(relations)-1 && join.Alias != "" {
|
||||||
parentTableName = rel.Name
|
aliasName = join.Alias
|
||||||
|
}
|
||||||
|
|
||||||
|
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||||
|
specifiedRelationsName[curAliasName] = aliasName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parentTableName = curAliasName
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
|
21
generics.go
21
generics.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -341,6 +342,9 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.Joins = append(db.Statement.Joins, j)
|
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||||
|
sort.Slice(db.Statement.Joins, func(i, j int) bool {
|
||||||
|
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
|
||||||
|
})
|
||||||
return db
|
return db
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -399,7 +403,22 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err
|
|||||||
|
|
||||||
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||||
if !ok {
|
if !ok {
|
||||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
|
||||||
|
relationships := db.Statement.Schema.Relationships
|
||||||
|
for _, field := range preloadFields {
|
||||||
|
var ok bool
|
||||||
|
relation, ok = relationships.Relations[field]
|
||||||
|
if ok {
|
||||||
|
relationships = relation.FieldSchema.Relationships
|
||||||
|
} else {
|
||||||
|
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if q.limitPerRecord > 0 {
|
if q.limitPerRecord > 0 {
|
||||||
|
4
scan.go
4
scan.go
@ -245,9 +245,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
matchedFieldCount[column] = 1
|
matchedFieldCount[column] = 1
|
||||||
}
|
}
|
||||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||||
|
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if join.Alias == names[0] {
|
if join.Alias == aliasName {
|
||||||
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -378,6 +379,82 @@ func TestGenericsJoins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenericsNestedJoins(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{
|
||||||
|
Name: "generics-nested-joins-1",
|
||||||
|
Manager: &User{
|
||||||
|
Name: "generics-nested-joins-manager-1",
|
||||||
|
Company: Company{
|
||||||
|
Name: "generics-nested-joins-manager-company-1",
|
||||||
|
},
|
||||||
|
NamedPet: &Pet{
|
||||||
|
Name: "generics-nested-joins-manager-namepet-1",
|
||||||
|
Toy: Toy{
|
||||||
|
Name: "generics-nested-joins-manager-namepet-toy-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "generics-nested-joins-2",
|
||||||
|
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||||
|
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
db.CreateInBatches(ctx, &users, 100)
|
||||||
|
|
||||||
|
var userIDs []uint
|
||||||
|
for _, user := range users {
|
||||||
|
userIDs = append(userIDs, user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
|
||||||
|
Where(map[string]any{"id": userIDs}).Find(ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
|
} else if len(users2) != len(users) {
|
||||||
|
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(users2, func(i, j int) bool {
|
||||||
|
return users2[i].ID > users2[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(users, func(i, j int) bool {
|
||||||
|
return users[i].ID > users[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
for idx, user := range users {
|
||||||
|
// user
|
||||||
|
CheckUser(t, user, users2[idx])
|
||||||
|
if users2[idx].Manager == nil {
|
||||||
|
t.Fatalf("Failed to load Manager")
|
||||||
|
}
|
||||||
|
// manager
|
||||||
|
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||||
|
// user pet
|
||||||
|
if users2[idx].NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||||
|
// manager pet
|
||||||
|
if users2[idx].Manager.NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenericsPreloads(t *testing.T) {
|
func TestGenericsPreloads(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db := gorm.G[User](DB)
|
db := gorm.G[User](DB)
|
||||||
@ -499,6 +576,35 @@ func TestGenericsPreloads(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenericsNestedPreloads(t *testing.T) {
|
||||||
|
user := *GetUser("generics_nested_preload", Config{Pets: 2})
|
||||||
|
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
|
for idx, pet := range user.Pets {
|
||||||
|
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(ctx, &user); err != nil {
|
||||||
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.LimitPerRecord(3)
|
||||||
|
return nil
|
||||||
|
}).Where(user.ID).Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to nested preload user")
|
||||||
|
}
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 {
|
||||||
|
t.Errorf("failed to nested preload with limit per record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenericsDistinct(t *testing.T) {
|
func TestGenericsDistinct(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@ -586,3 +692,40 @@ func TestGenericsSubQuery(t *testing.T) {
|
|||||||
t.Errorf("Three users should be found, instead found %d", len(results))
|
t.Errorf("Three users should be found, instead found %d", len(results))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenericsUpsert(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
|
||||||
|
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs) != 1 {
|
||||||
|
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||||
|
}
|
||||||
|
|
||||||
|
lang3 := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
if err := gorm.G[Language](DB, clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "code"}},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
|
||||||
|
}).Create(ctx, &lang3); err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs) != 1 {
|
||||||
|
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||||
|
} else if langs[0].Name != "upsert-new" {
|
||||||
|
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user