test for nested generic version Join/Preload
This commit is contained in:
		
							parent
							
								
									304baabb12
								
							
						
					
					
						commit
						774d957089
					
				| @ -110,7 +110,7 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			specifiedRelationsName := make(map[string]interface{}) | ||||
| 			specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} | ||||
| 			for _, join := range db.Statement.Joins { | ||||
| 				if db.Statement.Schema != nil { | ||||
| 					var isRelations bool // is relations or raw sql
 | ||||
| @ -124,12 +124,12 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 						nestedJoinNames := strings.Split(join.Name, ".") | ||||
| 						if len(nestedJoinNames) > 1 { | ||||
| 							isNestedJoin := true | ||||
| 							gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 							guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 							currentRelations := db.Statement.Schema.Relationships.Relations | ||||
| 							for _, relname := range nestedJoinNames { | ||||
| 								// incomplete match, only treated as raw sql
 | ||||
| 								if relation, ok = currentRelations[relname]; ok { | ||||
| 									gussNestedRelations = append(gussNestedRelations, relation) | ||||
| 									guessNestedRelations = append(guessNestedRelations, relation) | ||||
| 									currentRelations = relation.FieldSchema.Relationships.Relations | ||||
| 								} else { | ||||
| 									isNestedJoin = false | ||||
| @ -139,22 +139,13 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 
 | ||||
| 							if isNestedJoin { | ||||
| 								isRelations = true | ||||
| 								relations = gussNestedRelations | ||||
| 								relations = guessNestedRelations | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					if isRelations { | ||||
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							tableAliasName := join.Alias | ||||
| 
 | ||||
| 							if tableAliasName == "" { | ||||
| 								tableAliasName = relation.Name | ||||
| 								if parentTableName != clause.CurrentTable { | ||||
| 									tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 						genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							columnStmt := gorm.Statement{ | ||||
| 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | ||||
| 								Selects: join.Selects, Omits: join.Omits, | ||||
| @ -237,19 +228,24 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 						} | ||||
| 
 | ||||
| 						parentTableName := clause.CurrentTable | ||||
| 						for _, rel := range relations { | ||||
| 						for idx, rel := range relations { | ||||
| 							// joins table alias like "Manager, Company, Manager__Company"
 | ||||
| 							nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							if _, ok := specifiedRelationsName[nestedAlias]; !ok { | ||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) | ||||
| 								specifiedRelationsName[nestedAlias] = nil | ||||
| 							curAliasName := rel.Name | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								curAliasName = utils.NestedRelationName(parentTableName, curAliasName) | ||||
| 							} | ||||
| 
 | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								parentTableName = utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							} else { | ||||
| 								parentTableName = rel.Name | ||||
| 							if _, ok := specifiedRelationsName[curAliasName]; !ok { | ||||
| 								aliasName := curAliasName | ||||
| 								if idx == len(relations)-1 && join.Alias != "" { | ||||
| 									aliasName = join.Alias | ||||
| 								} | ||||
| 
 | ||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) | ||||
| 								specifiedRelationsName[curAliasName] = aliasName | ||||
| 							} | ||||
| 
 | ||||
| 							parentTableName = curAliasName | ||||
| 						} | ||||
| 					} else { | ||||
| 						fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
|  | ||||
							
								
								
									
										21
									
								
								generics.go
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								generics.go
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| @ -341,6 +342,9 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.Joins = append(db.Statement.Joins, j) | ||||
| 		sort.Slice(db.Statement.Joins, func(i, j int) bool { | ||||
| 			return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name | ||||
| 		}) | ||||
| 		return db | ||||
| 	}) | ||||
| } | ||||
| @ -399,7 +403,22 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err | ||||
| 
 | ||||
| 			relation, ok := db.Statement.Schema.Relationships.Relations[association] | ||||
| 			if !ok { | ||||
| 				db.AddError(fmt.Errorf("relation %s not found", association)) | ||||
| 				if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { | ||||
| 					relationships := db.Statement.Schema.Relationships | ||||
| 					for _, field := range preloadFields { | ||||
| 						var ok bool | ||||
| 						relation, ok = relationships.Relations[field] | ||||
| 						if ok { | ||||
| 							relationships = relation.FieldSchema.Relationships | ||||
| 						} else { | ||||
| 							db.AddError(fmt.Errorf("relation %s not found", association)) | ||||
| 							return nil | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					db.AddError(fmt.Errorf("relation %s not found", association)) | ||||
| 					return nil | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if q.limitPerRecord > 0 { | ||||
|  | ||||
							
								
								
									
										4
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								scan.go
									
									
									
									
									
								
							| @ -245,9 +245,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 							matchedFieldCount[column] = 1 | ||||
| 						} | ||||
| 					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | ||||
| 						aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) | ||||
| 						for _, join := range db.Statement.Joins { | ||||
| 							if join.Alias == names[0] { | ||||
| 							if join.Alias == aliasName { | ||||
| 								names = append(strings.Split(join.Name, "."), names[len(names)-1]) | ||||
| 								break | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
|  | ||||
| @ -6,6 +6,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| @ -378,6 +379,82 @@ func TestGenericsJoins(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsNestedJoins(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{ | ||||
| 			Name: "generics-nested-joins-1", | ||||
| 			Manager: &User{ | ||||
| 				Name: "generics-nested-joins-manager-1", | ||||
| 				Company: Company{ | ||||
| 					Name: "generics-nested-joins-manager-company-1", | ||||
| 				}, | ||||
| 				NamedPet: &Pet{ | ||||
| 					Name: "generics-nested-joins-manager-namepet-1", | ||||
| 					Toy: Toy{ | ||||
| 						Name: "generics-nested-joins-manager-namepet-toy-1", | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name:     "generics-nested-joins-2", | ||||
| 			Manager:  GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}), | ||||
| 			NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 	db.CreateInBatches(ctx, &users, 100) | ||||
| 
 | ||||
| 	var userIDs []uint | ||||
| 	for _, user := range users { | ||||
| 		userIDs = append(userIDs, user.ID) | ||||
| 	} | ||||
| 
 | ||||
| 	users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("Manager.Company"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil). | ||||
| 		Where(map[string]any{"id": userIDs}).Find(ctx) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to load with joins, got error: %v", err) | ||||
| 	} else if len(users2) != len(users) { | ||||
| 		t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Slice(users2, func(i, j int) bool { | ||||
| 		return users2[i].ID > users2[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(users, func(i, j int) bool { | ||||
| 		return users[i].ID > users[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	for idx, user := range users { | ||||
| 		// user
 | ||||
| 		CheckUser(t, user, users2[idx]) | ||||
| 		if users2[idx].Manager == nil { | ||||
| 			t.Fatalf("Failed to load Manager") | ||||
| 		} | ||||
| 		// manager
 | ||||
| 		CheckUser(t, *user.Manager, *users2[idx].Manager) | ||||
| 		// user pet
 | ||||
| 		if users2[idx].NamedPet == nil { | ||||
| 			t.Fatalf("Failed to load NamedPet") | ||||
| 		} | ||||
| 		CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) | ||||
| 		// manager pet
 | ||||
| 		if users2[idx].Manager.NamedPet == nil { | ||||
| 			t.Fatalf("Failed to load NamedPet") | ||||
| 		} | ||||
| 		CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsPreloads(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| @ -499,6 +576,35 @@ func TestGenericsPreloads(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsNestedPreloads(t *testing.T) { | ||||
| 	user := *GetUser("generics_nested_preload", Config{Pets: 2}) | ||||
| 	user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})} | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| 	for idx, pet := range user.Pets { | ||||
| 		pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} | ||||
| 	} | ||||
| 
 | ||||
| 	if err := db.Create(ctx, &user); err != nil { | ||||
| 		t.Fatalf("errors happened when create: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.LimitPerRecord(3) | ||||
| 		return nil | ||||
| 	}).Where(user.ID).Take(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("failed to nested preload user") | ||||
| 	} | ||||
| 	CheckUser(t, user2, user) | ||||
| 
 | ||||
| 	if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { | ||||
| 		t.Errorf("failed to nested preload with limit per record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsDistinct(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| @ -586,3 +692,40 @@ func TestGenericsSubQuery(t *testing.T) { | ||||
| 		t.Errorf("Three users should be found, instead found %d", len(results)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsUpsert(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	lang := Language{Code: "upsert", Name: "Upsert"} | ||||
| 
 | ||||
| 	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil { | ||||
| 		t.Fatalf("failed to upsert, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	lang2 := Language{Code: "upsert", Name: "Upsert"} | ||||
| 	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil { | ||||
| 		t.Fatalf("failed to upsert, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("no error should happen when find languages with code, but got %v", err) | ||||
| 	} else if len(langs) != 1 { | ||||
| 		t.Errorf("should only find only 1 languages, but got %+v", langs) | ||||
| 	} | ||||
| 
 | ||||
| 	lang3 := Language{Code: "upsert", Name: "Upsert"} | ||||
| 	if err := gorm.G[Language](DB, clause.OnConflict{ | ||||
| 		Columns:   []clause.Column{{Name: "code"}}, | ||||
| 		DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), | ||||
| 	}).Create(ctx, &lang3); err != nil { | ||||
| 		t.Fatalf("failed to upsert, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil { | ||||
| 		t.Errorf("no error should happen when find languages with code, but got %v", err) | ||||
| 	} else if len(langs) != 1 { | ||||
| 		t.Errorf("should only find only 1 languages, but got %+v", langs) | ||||
| 	} else if langs[0].Name != "upsert-new" { | ||||
| 		t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu