Fix FindInBatches to modify the query conditions, close #3734
This commit is contained in:
		
							parent
							
								
									a8db54afd6
								
							
						
					
					
						commit
						320f33061c
					
				| @ -140,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| // FindInBatches find records in batches
 | ||||
| func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { | ||||
| 	tx = db.Session(&Session{WithConditions: true}) | ||||
| 	rowsAffected := int64(0) | ||||
| 	batch := 0 | ||||
| func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { | ||||
| 	var ( | ||||
| 		tx = db.Order(clause.OrderByColumn{ | ||||
| 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		}).Session(&Session{WithConditions: true}) | ||||
| 		queryDB      = tx | ||||
| 		rowsAffected int64 | ||||
| 		batch        int | ||||
| 	) | ||||
| 
 | ||||
| 	for { | ||||
| 		result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) | ||||
| 		result := queryDB.Limit(batchSize).Find(dest) | ||||
| 		rowsAffected += result.RowsAffected | ||||
| 		batch++ | ||||
| 
 | ||||
| @ -156,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | ||||
| 
 | ||||
| 		if tx.Error != nil || int(result.RowsAffected) < batchSize { | ||||
| 			break | ||||
| 		} else { | ||||
| 			resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | ||||
| 			primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) | ||||
| 			queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	tx.RowsAffected = rowsAffected | ||||
| 	return | ||||
| 	return tx | ||||
| } | ||||
| 
 | ||||
| func (tx *DB) assignInterfacesToValue(values ...interface{}) { | ||||
|  | ||||
| @ -260,6 +260,13 @@ func TestFindInBatches(t *testing.T) { | ||||
| 	if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { | ||||
| 		totalBatch += batch | ||||
| 
 | ||||
| 		for idx := range results { | ||||
| 			results[idx].Name = results[idx].Name + "_new" | ||||
| 		} | ||||
| 		if err := tx.Save(results).Error; err != nil { | ||||
| 			t.Errorf("failed to save users, got error %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if tx.RowsAffected != 2 { | ||||
| 			t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) | ||||
| 		} | ||||
| @ -276,6 +283,12 @@ func TestFindInBatches(t *testing.T) { | ||||
| 	if totalBatch != 6 { | ||||
| 		t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) | ||||
| 	} | ||||
| 
 | ||||
| 	var count int64 | ||||
| 	DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) | ||||
| 	if count != 6 { | ||||
| 		t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFillSmallerStruct(t *testing.T) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu