diff --git a/callbacks/query.go b/callbacks/query.go index 56a5944a..c8632cc5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -110,7 +110,7 @@ func BuildQuerySQL(db *gorm.DB) { } } - specifiedRelationsName := make(map[string]interface{}) + specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} for _, join := range db.Statement.Joins { if db.Statement.Schema != nil { var isRelations bool // is relations or raw sql @@ -124,12 +124,12 @@ func BuildQuerySQL(db *gorm.DB) { nestedJoinNames := strings.Split(join.Name, ".") if len(nestedJoinNames) > 1 { isNestedJoin := true - gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) currentRelations := db.Statement.Schema.Relationships.Relations for _, relname := range nestedJoinNames { // incomplete match, only treated as raw sql if relation, ok = currentRelations[relname]; ok { - gussNestedRelations = append(gussNestedRelations, relation) + guessNestedRelations = append(guessNestedRelations, relation) currentRelations = relation.FieldSchema.Relationships.Relations } else { isNestedJoin = false @@ -139,22 +139,13 @@ func BuildQuerySQL(db *gorm.DB) { if isNestedJoin { isRelations = true - relations = gussNestedRelations + relations = guessNestedRelations } } } if isRelations { - genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := join.Alias - - if tableAliasName == "" { - tableAliasName = relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) - } - } - + genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { columnStmt := gorm.Statement{ Table: tableAliasName, DB: db, Schema: relation.FieldSchema, Selects: join.Selects, Omits: join.Omits, @@ -237,19 +228,24 @@ func BuildQuerySQL(db *gorm.DB) { } parentTableName := clause.CurrentTable - for _, rel := range relations { + for idx, rel := range relations { // joins table alias like "Manager, Company, Manager__Company" - nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) - if _, ok := specifiedRelationsName[nestedAlias]; !ok { - fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) - specifiedRelationsName[nestedAlias] = nil + curAliasName := rel.Name + if parentTableName != clause.CurrentTable { + curAliasName = utils.NestedRelationName(parentTableName, curAliasName) } - if parentTableName != clause.CurrentTable { - parentTableName = utils.NestedRelationName(parentTableName, rel.Name) - } else { - parentTableName = rel.Name + if _, ok := specifiedRelationsName[curAliasName]; !ok { + aliasName := curAliasName + if idx == len(relations)-1 && join.Alias != "" { + aliasName = join.Alias + } + + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) + specifiedRelationsName[curAliasName] = aliasName } + + parentTableName = curAliasName } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ diff --git a/generics.go b/generics.go index 0b4d48b8..f2863dac 100644 --- a/generics.go +++ b/generics.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "sort" "strings" "gorm.io/gorm/clause" @@ -341,6 +342,9 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable } db.Statement.Joins = append(db.Statement.Joins, j) + sort.Slice(db.Statement.Joins, func(i, j int) bool { + return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name + }) return db }) } @@ -399,7 +403,22 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { - db.AddError(fmt.Errorf("relation %s not found", association)) + if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { + relationships := db.Statement.Schema.Relationships + for _, field := range preloadFields { + var ok bool + relation, ok = relationships.Relations[field] + if ok { + relationships = relation.FieldSchema.Relationships + } else { + db.AddError(fmt.Errorf("relation %s not found", association)) + return nil + } + } + } else { + db.AddError(fmt.Errorf("relation %s not found", association)) + return nil + } } if q.limitPerRecord > 0 { diff --git a/scan.go b/scan.go index 624f822f..9a99d024 100644 --- a/scan.go +++ b/scan.go @@ -245,9 +245,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { matchedFieldCount[column] = 1 } } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) for _, join := range db.Statement.Joins { - if join.Alias == names[0] { + if join.Alias == aliasName { names = append(strings.Split(join.Name, "."), names[len(names)-1]) + break } } diff --git a/tests/generics_test.go b/tests/generics_test.go index 32881ce5..876c7409 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "sort" + "strconv" "strings" "testing" @@ -378,6 +379,82 @@ func TestGenericsJoins(t *testing.T) { } } +func TestGenericsNestedJoins(t *testing.T) { + users := []User{ + { + Name: "generics-nested-joins-1", + Manager: &User{ + Name: "generics-nested-joins-manager-1", + Company: Company{ + Name: "generics-nested-joins-manager-company-1", + }, + NamedPet: &Pet{ + Name: "generics-nested-joins-manager-namepet-1", + Toy: Toy{ + Name: "generics-nested-joins-manager-namepet-toy-1", + }, + }, + }, + NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}}, + }, + { + Name: "generics-nested-joins-2", + Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}}, + }, + } + + ctx := context.Background() + db := gorm.G[User](DB) + db.CreateInBatches(ctx, &users, 100) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil). + Joins(clause.LeftJoin.Association("Manager.Company"), nil). + Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil). + Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil). + Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil). + Where(map[string]any{"id": userIDs}).Find(ctx) + + if err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + // user + CheckUser(t, user, users2[idx]) + if users2[idx].Manager == nil { + t.Fatalf("Failed to load Manager") + } + // manager + CheckUser(t, *user.Manager, *users2[idx].Manager) + // user pet + if users2[idx].NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) + // manager pet + if users2[idx].Manager.NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) + } +} + func TestGenericsPreloads(t *testing.T) { ctx := context.Background() db := gorm.G[User](DB) @@ -499,6 +576,35 @@ func TestGenericsPreloads(t *testing.T) { } } +func TestGenericsNestedPreloads(t *testing.T) { + user := *GetUser("generics_nested_preload", Config{Pets: 2}) + user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})} + + ctx := context.Background() + db := gorm.G[User](DB) + + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} + } + + if err := db.Create(ctx, &user); err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { + db.LimitPerRecord(3) + return nil + }).Where(user.ID).Take(ctx) + if err != nil { + t.Errorf("failed to nested preload user") + } + CheckUser(t, user2, user) + + if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { + t.Errorf("failed to nested preload with limit per record") + } +} + func TestGenericsDistinct(t *testing.T) { ctx := context.Background() @@ -586,3 +692,40 @@ func TestGenericsSubQuery(t *testing.T) { t.Errorf("Three users should be found, instead found %d", len(results)) } } + +func TestGenericsUpsert(t *testing.T) { + ctx := context.Background() + lang := Language{Code: "upsert", Name: "Upsert"} + + if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + lang2 := Language{Code: "upsert", Name: "Upsert"} + if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx) + if err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } + + lang3 := Language{Code: "upsert", Name: "Upsert"} + if err := gorm.G[Language](DB, clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), + }).Create(ctx, &lang3); err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if langs[0].Name != "upsert-new" { + t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) + } +}