Add LimitPerRecord for generic version Preload

This commit is contained in:
Jinzhu 2025-05-21 18:04:20 +08:00
parent 91eb9477f2
commit 6307f69f18
4 changed files with 178 additions and 28 deletions

View File

@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
}
}

View File

@ -77,7 +77,7 @@ type PreloadBuilder interface {
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
Scopes(scopes ...func(db *Statement)) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
@ -214,76 +214,78 @@ type joinBuilder struct {
db *DB
}
func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q joinBuilder) Select(columns ...string) JoinBuilder {
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q joinBuilder) Omit(columns ...string) JoinBuilder {
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
db *DB
limitPerRecord int
db *DB
}
func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q preloadBuilder) Select(columns ...string) PreloadBuilder {
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q preloadBuilder) Omit(columns ...string) PreloadBuilder {
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q preloadBuilder) Limit(limit int) PreloadBuilder {
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q preloadBuilder) Offset(offset int) PreloadBuilder {
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q preloadBuilder) Order(value interface{}) PreloadBuilder {
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
for _, fc := range scopes {
fc(q.db.Statement)
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
@ -295,7 +297,7 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
@ -390,10 +392,69 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(q); err != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
db.AddError(fmt.Errorf("relation %s not found", association))
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)
tx.AddError(err)
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selects := q.db.Statement.Selects
selectExpr := clause.CommaExpression{}
if len(selects) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
} else {
for _, column := range selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{
{Column: clause.PrimaryColumn, Desc: false},
},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{
Expression: selectExpr,
})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})

View File

@ -209,6 +209,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
}
case interface{ getInstance() *DB }:
cv := v.getInstance()
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if cv.Statement.SQL.Len() > 0 {
var (

View File

@ -277,7 +277,7 @@ func TestGenericsScopes(t *testing.T) {
}
}
func TestGenericsJoinsAndPreload(t *testing.T) {
func TestGenericsJoins(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
@ -374,20 +374,32 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
if err == nil {
t.Fatalf("Joins should got error, but got nil")
}
}
// Preload
result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
func TestGenericsPreloads(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
names := []string{u.Name, u2.Name, u3.Name}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
}
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
db.Where("name = ?", u.Company.Name)
return nil
}).Find(ctx)
}).Where("name in ?", names).Find(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
@ -403,10 +415,80 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
return errors.New("preload error")
}).Find(ctx)
}).Where("name in ?", names).Find(ctx)
if err == nil {
t.Fatalf("Preload should failed, but got nil")
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(5)
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
}
if DB.Dialector.Name() == "sqlserver" {
// sqlserver doesn't support order by in subquery
return
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.Order("name desc").LimitPerRecord(5)
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name < result.Pets[i].Name {
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.Order("name").LimitPerRecord(5)
return nil
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
db.Order("name")
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
if len(result.Friends) != len(u.Friends) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name > result.Pets[i].Name {
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name > result.Pets[i].Name {
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
}
}
func TestGenericsDistinct(t *testing.T) {