Add LimitPerRecord for generic version Preload
This commit is contained in:
parent
91eb9477f2
commit
6307f69f18
@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
if len(values) != 0 {
|
||||
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
||||
if len(inlineConds) > 0 {
|
||||
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||
}
|
||||
|
||||
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
103
generics.go
103
generics.go
@ -77,7 +77,7 @@ type PreloadBuilder interface {
|
||||
Limit(offset int) PreloadBuilder
|
||||
Offset(offset int) PreloadBuilder
|
||||
Order(value interface{}) PreloadBuilder
|
||||
Scopes(scopes ...func(db *Statement)) PreloadBuilder
|
||||
LimitPerRecord(num int) PreloadBuilder
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
@ -214,76 +214,78 @@ type joinBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q joinBuilder) Select(columns ...string) JoinBuilder {
|
||||
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
type preloadBuilder struct {
|
||||
db *DB
|
||||
limitPerRecord int
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||
q.db.Limit(limit)
|
||||
return q
|
||||
}
|
||||
func (q preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||
|
||||
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||
q.db.Offset(offset)
|
||||
return q
|
||||
}
|
||||
func (q preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||
|
||||
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||
q.db.Order(value)
|
||||
return q
|
||||
}
|
||||
func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
|
||||
for _, fc := range scopes {
|
||||
fc(q.db.Statement)
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
|
||||
q.limitPerRecord = num
|
||||
return q
|
||||
}
|
||||
|
||||
@ -295,7 +297,7 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable
|
||||
|
||||
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
||||
if on != nil {
|
||||
if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
@ -390,10 +392,69 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err
|
||||
return db.Preload(association, func(tx *DB) *DB {
|
||||
q := preloadBuilder{db: tx.getInstance()}
|
||||
if query != nil {
|
||||
if err := query(q); err != nil {
|
||||
if err := query(&q); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||
if !ok {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
}
|
||||
|
||||
if q.limitPerRecord > 0 {
|
||||
if relation.JoinTable != nil {
|
||||
err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
|
||||
refColumns := []clause.Column{}
|
||||
for _, rel := range relation.References {
|
||||
if rel.OwnPrimaryKey {
|
||||
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
|
||||
}
|
||||
}
|
||||
|
||||
if len(refColumns) != 0 {
|
||||
selects := q.db.Statement.Selects
|
||||
selectExpr := clause.CommaExpression{}
|
||||
if len(selects) == 0 {
|
||||
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
|
||||
} else {
|
||||
for _, column := range selects {
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
|
||||
}
|
||||
}
|
||||
|
||||
partitionBy := clause.CommaExpression{}
|
||||
for _, column := range refColumns {
|
||||
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
|
||||
}
|
||||
|
||||
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
|
||||
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
|
||||
vars := []interface{}{partitionBy}
|
||||
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
|
||||
vars = append(vars, orderBy)
|
||||
} else {
|
||||
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{
|
||||
{Column: clause.PrimaryColumn, Desc: false},
|
||||
},
|
||||
}})
|
||||
}
|
||||
vars = append(vars, rnnColumn)
|
||||
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
|
||||
|
||||
q.db.Clauses(clause.Select{
|
||||
Expression: selectExpr,
|
||||
})
|
||||
|
||||
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
|
||||
}
|
||||
}
|
||||
return q.db
|
||||
})
|
||||
})
|
||||
|
@ -209,6 +209,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
}
|
||||
case interface{ getInstance() *DB }:
|
||||
cv := v.getInstance()
|
||||
|
||||
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if cv.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
|
@ -277,7 +277,7 @@ func TestGenericsScopes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
func TestGenericsJoins(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
@ -374,20 +374,32 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatalf("Joins should got error, but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Preload
|
||||
result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
|
||||
func TestGenericsPreloads(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
|
||||
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
|
||||
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
|
||||
names := []string{u.Name, u2.Name, u3.Name}
|
||||
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
|
||||
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
db.Where("name = ?", u.Company.Name)
|
||||
return nil
|
||||
}).Find(ctx)
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
@ -403,10 +415,80 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
|
||||
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
return errors.New("preload error")
|
||||
}).Find(ctx)
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("Preload should failed, but got nil")
|
||||
}
|
||||
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(5)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver doesn't support order by in subquery
|
||||
return
|
||||
}
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name desc").LimitPerRecord(5)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name < result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name").LimitPerRecord(5)
|
||||
return nil
|
||||
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name")
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
if len(result.Friends) != len(u.Friends) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsDistinct(t *testing.T) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user