diff --git a/callbacks/preload.go b/callbacks/preload.go index fd8214bb..607c22bc 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) if len(values) != 0 { + tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) + for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { tx = fc(tx) @@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + if len(inlineConds) > 0 { + tx = tx.Where(inlineConds[0], inlineConds[1:]...) + } + + if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { return err } } diff --git a/generics.go b/generics.go index 7c7257f6..1fab7078 100644 --- a/generics.go +++ b/generics.go @@ -77,7 +77,7 @@ type PreloadBuilder interface { Limit(offset int) PreloadBuilder Offset(offset int) PreloadBuilder Order(value interface{}) PreloadBuilder - Scopes(scopes ...func(db *Statement)) PreloadBuilder + LimitPerRecord(num int) PreloadBuilder } type op func(*DB) *DB @@ -214,76 +214,78 @@ type joinBuilder struct { db *DB } -func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { +func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { +func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { +func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q joinBuilder) Select(columns ...string) JoinBuilder { +func (q *joinBuilder) Select(columns ...string) JoinBuilder { q.db.Select(columns) return q } -func (q joinBuilder) Omit(columns ...string) JoinBuilder { +func (q *joinBuilder) Omit(columns ...string) JoinBuilder { q.db.Omit(columns...) return q } type preloadBuilder struct { - db *DB + limitPerRecord int + db *DB } -func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { +func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } -func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { +func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } -func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { +func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } -func (q preloadBuilder) Select(columns ...string) PreloadBuilder { +func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { q.db.Select(columns) return q } -func (q preloadBuilder) Omit(columns ...string) PreloadBuilder { +func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { q.db.Omit(columns...) return q } -func (q preloadBuilder) Limit(limit int) PreloadBuilder { +func (q *preloadBuilder) Limit(limit int) PreloadBuilder { q.db.Limit(limit) return q } -func (q preloadBuilder) Offset(offset int) PreloadBuilder { + +func (q *preloadBuilder) Offset(offset int) PreloadBuilder { q.db.Offset(offset) return q } -func (q preloadBuilder) Order(value interface{}) PreloadBuilder { + +func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { q.db.Order(value) return q } -func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { - for _, fc := range scopes { - fc(q.db.Statement) - } + +func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { + q.limitPerRecord = num return q } @@ -295,7 +297,7 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} if on != nil { - if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { + if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { db.AddError(err) } } @@ -390,10 +392,69 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err return db.Preload(association, func(tx *DB) *DB { q := preloadBuilder{db: tx.getInstance()} if query != nil { - if err := query(q); err != nil { + if err := query(&q); err != nil { db.AddError(err) } } + + relation, ok := db.Statement.Schema.Relationships.Relations[association] + if !ok { + db.AddError(fmt.Errorf("relation %s not found", association)) + } + + if q.limitPerRecord > 0 { + if relation.JoinTable != nil { + err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association) + tx.AddError(err) + return tx + } + + refColumns := []clause.Column{} + for _, rel := range relation.References { + if rel.OwnPrimaryKey { + refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) + } + } + + if len(refColumns) != 0 { + selects := q.db.Statement.Selects + selectExpr := clause.CommaExpression{} + if len(selects) == 0 { + selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} + } else { + for _, column := range selects { + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) + } + } + + partitionBy := clause.CommaExpression{} + for _, column := range refColumns { + partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) + } + + rnnColumn := clause.Column{Name: "gorm_preload_rnn"} + sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" + vars := []interface{}{partitionBy} + if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { + vars = append(vars, orderBy) + } else { + vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ + Columns: []clause.OrderByColumn{ + {Column: clause.PrimaryColumn, Desc: false}, + }, + }}) + } + vars = append(vars, rnnColumn) + + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) + + q.db.Clauses(clause.Select{ + Expression: selectExpr, + }) + + return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) + } + } return q.db }) }) diff --git a/statement.go b/statement.go index 63f78006..19cdbbaf 100644 --- a/statement.go +++ b/statement.go @@ -209,6 +209,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } case interface{ getInstance() *DB }: cv := v.getInstance() + subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() if cv.Statement.SQL.Len() > 0 { var ( diff --git a/tests/generics_test.go b/tests/generics_test.go index 2e0dbc28..2f0f722b 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -277,7 +277,7 @@ func TestGenericsScopes(t *testing.T) { } } -func TestGenericsJoinsAndPreload(t *testing.T) { +func TestGenericsJoins(t *testing.T) { ctx := context.Background() db := gorm.G[User](DB) @@ -374,20 +374,32 @@ func TestGenericsJoinsAndPreload(t *testing.T) { if err == nil { t.Fatalf("Joins should got error, but got nil") } +} - // Preload - result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) +func TestGenericsPreloads(t *testing.T) { + ctx := context.Background() + db := gorm.G[User](DB) + + u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7}) + u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5}) + u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3}) + names := []string{u.Name, u2.Name, u3.Name} + + db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) + + result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx) if err != nil { t.Fatalf("Preload failed: %v", err) } - if result3.Name != u.Name || result3.Company.Name != u.Company.Name { + + if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) { t.Fatalf("Preload expected %s, got %+v", u.Name, result) } results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { db.Where("name = ?", u.Company.Name) return nil - }).Find(ctx) + }).Where("name in ?", names).Find(ctx) if err != nil { t.Fatalf("Preload failed: %v", err) } @@ -403,10 +415,80 @@ func TestGenericsJoinsAndPreload(t *testing.T) { _, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { return errors.New("preload error") - }).Find(ctx) + }).Where("name in ?", names).Find(ctx) if err == nil { t.Fatalf("Preload should failed, but got nil") } + + results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { + db.LimitPerRecord(5) + return nil + }).Where("name in ?", names).Find(ctx) + + for _, result := range results { + if result.Name == u.Name { + if len(result.Pets) != len(u.Pets) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + } else if len(result.Pets) != 5 { + t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) + } + } + + if DB.Dialector.Name() == "sqlserver" { + // sqlserver doesn't support order by in subquery + return + } + results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { + db.Order("name desc").LimitPerRecord(5) + return nil + }).Where("name in ?", names).Find(ctx) + + for _, result := range results { + if result.Name == u.Name { + if len(result.Pets) != len(u.Pets) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + } else if len(result.Pets) != 5 { + t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) + } + for i := 1; i < len(result.Pets); i++ { + if result.Pets[i-1].Name < result.Pets[i].Name { + t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) + } + } + } + + results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { + db.Order("name").LimitPerRecord(5) + return nil + }).Preload("Friends", func(db gorm.PreloadBuilder) error { + db.Order("name") + return nil + }).Where("name in ?", names).Find(ctx) + + for _, result := range results { + if result.Name == u.Name { + if len(result.Pets) != len(u.Pets) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + if len(result.Friends) != len(u.Friends) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + } else if len(result.Pets) != 5 || len(result.Friends) == 0 { + t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) + } + for i := 1; i < len(result.Pets); i++ { + if result.Pets[i-1].Name > result.Pets[i].Name { + t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) + } + } + for i := 1; i < len(result.Pets); i++ { + if result.Pets[i-1].Name > result.Pets[i].Name { + t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) + } + } + } } func TestGenericsDistinct(t *testing.T) {