Add LimitPerRecord for generic version Preload
This commit is contained in:
parent
91eb9477f2
commit
6307f69f18
@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||||
|
|
||||||
if len(values) != 0 {
|
if len(values) != 0 {
|
||||||
|
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||||
|
|
||||||
for _, cond := range conds {
|
for _, cond := range conds {
|
||||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||||
tx = fc(tx)
|
tx = fc(tx)
|
||||||
@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
if len(inlineConds) > 0 {
|
||||||
|
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
101
generics.go
101
generics.go
@ -77,7 +77,7 @@ type PreloadBuilder interface {
|
|||||||
Limit(offset int) PreloadBuilder
|
Limit(offset int) PreloadBuilder
|
||||||
Offset(offset int) PreloadBuilder
|
Offset(offset int) PreloadBuilder
|
||||||
Order(value interface{}) PreloadBuilder
|
Order(value interface{}) PreloadBuilder
|
||||||
Scopes(scopes ...func(db *Statement)) PreloadBuilder
|
LimitPerRecord(num int) PreloadBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
type op func(*DB) *DB
|
type op func(*DB) *DB
|
||||||
@ -214,76 +214,78 @@ type joinBuilder struct {
|
|||||||
db *DB
|
db *DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q joinBuilder) Select(columns ...string) JoinBuilder {
|
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
|
||||||
q.db.Select(columns)
|
q.db.Select(columns)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q joinBuilder) Omit(columns ...string) JoinBuilder {
|
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||||
q.db.Omit(columns...)
|
q.db.Omit(columns...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
type preloadBuilder struct {
|
type preloadBuilder struct {
|
||||||
|
limitPerRecord int
|
||||||
db *DB
|
db *DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q preloadBuilder) Select(columns ...string) PreloadBuilder {
|
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||||
q.db.Select(columns)
|
q.db.Select(columns)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||||
q.db.Omit(columns...)
|
q.db.Omit(columns...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q preloadBuilder) Limit(limit int) PreloadBuilder {
|
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||||
q.db.Limit(limit)
|
q.db.Limit(limit)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
func (q preloadBuilder) Offset(offset int) PreloadBuilder {
|
|
||||||
|
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||||
q.db.Offset(offset)
|
q.db.Offset(offset)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
func (q preloadBuilder) Order(value interface{}) PreloadBuilder {
|
|
||||||
|
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||||
q.db.Order(value)
|
q.db.Order(value)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
|
|
||||||
for _, fc := range scopes {
|
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
|
||||||
fc(q.db.Statement)
|
q.limitPerRecord = num
|
||||||
}
|
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,7 +297,7 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable
|
|||||||
|
|
||||||
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
||||||
if on != nil {
|
if on != nil {
|
||||||
if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -390,10 +392,69 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err
|
|||||||
return db.Preload(association, func(tx *DB) *DB {
|
return db.Preload(association, func(tx *DB) *DB {
|
||||||
q := preloadBuilder{db: tx.getInstance()}
|
q := preloadBuilder{db: tx.getInstance()}
|
||||||
if query != nil {
|
if query != nil {
|
||||||
if err := query(q); err != nil {
|
if err := query(&q); err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||||
|
if !ok {
|
||||||
|
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.limitPerRecord > 0 {
|
||||||
|
if relation.JoinTable != nil {
|
||||||
|
err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)
|
||||||
|
tx.AddError(err)
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
refColumns := []clause.Column{}
|
||||||
|
for _, rel := range relation.References {
|
||||||
|
if rel.OwnPrimaryKey {
|
||||||
|
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(refColumns) != 0 {
|
||||||
|
selects := q.db.Statement.Selects
|
||||||
|
selectExpr := clause.CommaExpression{}
|
||||||
|
if len(selects) == 0 {
|
||||||
|
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
|
||||||
|
} else {
|
||||||
|
for _, column := range selects {
|
||||||
|
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
partitionBy := clause.CommaExpression{}
|
||||||
|
for _, column := range refColumns {
|
||||||
|
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
|
||||||
|
}
|
||||||
|
|
||||||
|
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
|
||||||
|
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
|
||||||
|
vars := []interface{}{partitionBy}
|
||||||
|
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
|
||||||
|
vars = append(vars, orderBy)
|
||||||
|
} else {
|
||||||
|
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
|
||||||
|
Columns: []clause.OrderByColumn{
|
||||||
|
{Column: clause.PrimaryColumn, Desc: false},
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
vars = append(vars, rnnColumn)
|
||||||
|
|
||||||
|
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
|
||||||
|
|
||||||
|
q.db.Clauses(clause.Select{
|
||||||
|
Expression: selectExpr,
|
||||||
|
})
|
||||||
|
|
||||||
|
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
|
||||||
|
}
|
||||||
|
}
|
||||||
return q.db
|
return q.db
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -209,6 +209,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||||||
}
|
}
|
||||||
case interface{ getInstance() *DB }:
|
case interface{ getInstance() *DB }:
|
||||||
cv := v.getInstance()
|
cv := v.getInstance()
|
||||||
|
|
||||||
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||||
if cv.Statement.SQL.Len() > 0 {
|
if cv.Statement.SQL.Len() > 0 {
|
||||||
var (
|
var (
|
||||||
|
@ -277,7 +277,7 @@ func TestGenericsScopes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenericsJoinsAndPreload(t *testing.T) {
|
func TestGenericsJoins(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db := gorm.G[User](DB)
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
@ -374,20 +374,32 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Joins should got error, but got nil")
|
t.Fatalf("Joins should got error, but got nil")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Preload
|
func TestGenericsPreloads(t *testing.T) {
|
||||||
result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
|
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
|
||||||
|
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
|
||||||
|
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
|
||||||
|
names := []string{u.Name, u2.Name, u3.Name}
|
||||||
|
|
||||||
|
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||||
|
|
||||||
|
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Preload failed: %v", err)
|
t.Fatalf("Preload failed: %v", err)
|
||||||
}
|
}
|
||||||
if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
|
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
|
||||||
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||||
db.Where("name = ?", u.Company.Name)
|
db.Where("name = ?", u.Company.Name)
|
||||||
return nil
|
return nil
|
||||||
}).Find(ctx)
|
}).Where("name in ?", names).Find(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Preload failed: %v", err)
|
t.Fatalf("Preload failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -403,10 +415,80 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
|
|
||||||
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||||
return errors.New("preload error")
|
return errors.New("preload error")
|
||||||
}).Find(ctx)
|
}).Where("name in ?", names).Find(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Preload should failed, but got nil")
|
t.Fatalf("Preload should failed, but got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.LimitPerRecord(5)
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
} else if len(result.Pets) != 5 {
|
||||||
|
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Dialector.Name() == "sqlserver" {
|
||||||
|
// sqlserver doesn't support order by in subquery
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Order("name desc").LimitPerRecord(5)
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
} else if len(result.Pets) != 5 {
|
||||||
|
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||||
|
}
|
||||||
|
for i := 1; i < len(result.Pets); i++ {
|
||||||
|
if result.Pets[i-1].Name < result.Pets[i].Name {
|
||||||
|
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Order("name").LimitPerRecord(5)
|
||||||
|
return nil
|
||||||
|
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Order("name")
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
if len(result.Friends) != len(u.Friends) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
|
||||||
|
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||||
|
}
|
||||||
|
for i := 1; i < len(result.Pets); i++ {
|
||||||
|
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||||
|
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 1; i < len(result.Pets); i++ {
|
||||||
|
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||||
|
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenericsDistinct(t *testing.T) {
|
func TestGenericsDistinct(t *testing.T) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user