diff --git a/generics.go b/generics.go index 1fab7078..0b4d48b8 100644 --- a/generics.go +++ b/generics.go @@ -404,8 +404,7 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err if q.limitPerRecord > 0 { if relation.JoinTable != nil { - err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association) - tx.AddError(err) + tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)) return tx } @@ -417,14 +416,13 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err } if len(refColumns) != 0 { - selects := q.db.Statement.Selects selectExpr := clause.CommaExpression{} - if len(selects) == 0 { + for _, column := range q.db.Statement.Selects { + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) + } + + if len(selectExpr.Exprs) == 0 { selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} - } else { - for _, column := range selects { - selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) - } } partitionBy := clause.CommaExpression{} @@ -439,22 +437,19 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err vars = append(vars, orderBy) } else { vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ - Columns: []clause.OrderByColumn{ - {Column: clause.PrimaryColumn, Desc: false}, - }, + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, }}) } vars = append(vars, rnnColumn) selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) - q.db.Clauses(clause.Select{ - Expression: selectExpr, - }) + q.db.Clauses(clause.Select{Expression: selectExpr}) return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) } } + return q.db }) }) diff --git a/tests/generics_test.go b/tests/generics_test.go index 2f0f722b..32881ce5 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -6,8 +6,10 @@ import ( "fmt" "reflect" "sort" + "strings" "testing" + "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -420,6 +422,12 @@ func TestGenericsPreloads(t *testing.T) { t.Fatalf("Preload should failed, but got nil") } + if DB.Dialector.Name() == "mysql" { + // mysql 5.7 doesn't support row_number() + if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { + return + } + } results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { db.LimitPerRecord(5) return nil