From 84e93363bb8e45618835556872fa5b2542454b5f Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Fri, 8 Jul 2022 17:44:35 +0800 Subject: [PATCH] fix(FindInBatches): support composite primarykey --- finisher_api.go | 52 +++++++++++++++++++++++++++++++++++++----- tests/query_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 7a3f27ba..28ff0eae 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -173,7 +173,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { // FindInBatches find records in batches func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( - tx = db.Order(clause.OrderByColumn{ + tx = db.Session(&Session{}).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }).Session(&Session{}) queryDB = tx @@ -183,9 +183,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat // user specified offset or limit var totalSize int + var totalOffset int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { totalSize = limit.Limit + totalOffset = limit.Offset if totalSize > 0 && batchSize > totalSize { batchSize = totalSize @@ -222,19 +224,59 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - if result.Statement.Schema.PrioritizedPrimaryField == nil { + if result.Statement.Schema.PrioritizedPrimaryField != nil { + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + } else if len(result.Statement.Schema.PrimaryFields) > 0 { + offset := totalOffset + int(rowsAffected) + queryDB = getCPKBatchesQuery(db.Session(&Session{}), result.Statement.Schema, offset, batchSize) + } else { tx.AddError(ErrPrimaryKeyRequired) break } - - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected return tx } +// use subquery and offset limit to query composite primarykey +// so it not support Save in FindInBatches callback func like this case: +// https://github.com/go-gorm/gorm/blob/master/tests/query_test.go#L275 +// it will generate sql like this: +// SELECT table.* FROM table INNER JOIN (&{subquery}) sub ON sub.&{primarykey} = table.&{primarykey} +func getCPKBatchesQuery(subTx *DB, sch *schema.Schema, offset, limit int) *DB { + queryTx := subTx.Session(&Session{NewDB: true}) + // subquery conditions like DeletedAt + for _, c := range sch.QueryClauses { + subTx.Statement.AddClause(c) + } + + subqueryAlias := "sub" + // order conditions + subOrderConds := make([]clause.OrderByColumn, len(sch.PrimaryFieldDBNames)) + // on conditions + onConds := make([]clause.Expression, len(sch.PrimaryFieldDBNames)) + for i, pkname := range sch.PrimaryFieldDBNames { + onConds[i] = clause.Eq{ + Column: clause.Column{Table: sch.Table, Name: pkname}, + Value: clause.Column{Table: subqueryAlias, Name: pkname}, + } + subOrderConds[i] = clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: pkname}, + } + } + + return queryTx.Table(sch.Table).Select("?.*", clause.Expr{SQL: sch.Table}).Joins("INNER JOIN (?) ? ON ?", + // all conditions are judged in subquery + subTx.Table(sch.Table).Select(sch.PrimaryFieldDBNames).Offset(offset).Limit(limit).Clauses(clause.OrderBy{ + Columns: subOrderConds, + }), + clause.Expr{SQL: subqueryAlias}, + clause.And(onConds...), + ) +} + func (db *DB) assignInterfacesToValue(values ...interface{}) { for _, value := range values { switch v := value.(type) { diff --git a/tests/query_test.go b/tests/query_test.go index 253d8409..5996a6b0 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,61 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithCompositePrimarykey(t *testing.T) { + type CompositeModel struct { + Name string `gorm:"primaryKey"` + Gender int `gorm:"primaryKey"` + Cond string + DeletedAt gorm.DeletedAt `gorm:"index"` + } + DB.Migrator().DropTable(&CompositeModel{}) + DB.AutoMigrate(&CompositeModel{}) + + models := []CompositeModel{ + {Name: "Composite_0", Gender: 0, Cond: "test"}, + {Name: "Composite_0", Gender: 1, Cond: "test"}, + {Name: "Composite_0", Gender: 2, Cond: "test"}, + {Name: "Composite_1", Gender: 1, Cond: "test"}, + {Name: "Composite_1", Gender: 2, Cond: "test"}, + {Name: "Composite_1", Gender: 3, Cond: "test"}, + {Name: "Composite_2", Gender: 1, Cond: "test"}, + } + + DB.Create(&models) + DB.Delete(&models[0]) + + var ( + sub, results []CompositeModel + totalBatch int + ) + + if result := DB.Where("cond = ?", models[0].Cond).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(sub) != 2 { + t.Errorf("Incorrect users length, expects: 2, got %v", len(sub)) + } + + results = append(results, sub...) + return nil + }); result.Error != nil || result.RowsAffected != 6 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if totalBatch != 6 { + t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) + } + + targets := models[1:] + for i := 0; i < len(targets); i++ { + AssertEqual(t, results[i], targets[i]) + } +} + func TestFindInBatchesWithOffsetLimit(t *testing.T) { users := []User{ *GetUser("find_in_batches_with_offset_limit", Config{}),