fix: FindInBatches with offset limit (#5255)
* fix: FindInBatches with offset limit * fix: break first * fix: FindInBatches Limit zero
This commit is contained in:
		
							parent
							
								
									e0ed3ce400
								
							
						
					
					
						commit
						b49ae84780
					
				| @ -181,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | |||||||
| 		batch        int | 		batch        int | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
|  | 	// user specified offset or limit
 | ||||||
|  | 	var totalSize int | ||||||
|  | 	if c, ok := tx.Statement.Clauses["LIMIT"]; ok { | ||||||
|  | 		if limit, ok := c.Expression.(clause.Limit); ok { | ||||||
|  | 			totalSize = limit.Limit | ||||||
|  | 
 | ||||||
|  | 			if totalSize > 0 && batchSize > totalSize { | ||||||
|  | 				batchSize = totalSize | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// reset to offset to 0 in next batch
 | ||||||
|  | 			tx = tx.Offset(-1).Session(&Session{}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for { | 	for { | ||||||
| 		result := queryDB.Limit(batchSize).Find(dest) | 		result := queryDB.Limit(batchSize).Find(dest) | ||||||
| 		rowsAffected += result.RowsAffected | 		rowsAffected += result.RowsAffected | ||||||
| @ -196,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | |||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if totalSize > 0 { | ||||||
|  | 			if totalSize <= int(rowsAffected) { | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			if totalSize/batchSize == batch { | ||||||
|  | 				batchSize = totalSize % batchSize | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		// Optimize for-break
 | 		// Optimize for-break
 | ||||||
| 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | ||||||
| 		if result.Statement.Schema.PrioritizedPrimaryField == nil { | 		if result.Statement.Schema.PrioritizedPrimaryField == nil { | ||||||
|  | |||||||
| @ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestFindInBatchesWithOffsetLimit(t *testing.T) { | ||||||
|  | 	users := []User{ | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 		*GetUser("find_in_batches_with_offset_limit", Config{}), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Create(&users) | ||||||
|  | 
 | ||||||
|  | 	var ( | ||||||
|  | 		sub, results []User | ||||||
|  | 		lastBatch    int | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// offset limit
 | ||||||
|  | 	if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { | ||||||
|  | 		results = append(results, sub...) | ||||||
|  | 		lastBatch = batch | ||||||
|  | 		return nil | ||||||
|  | 	}); result.Error != nil || result.RowsAffected != 5 { | ||||||
|  | 		t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) | ||||||
|  | 	} | ||||||
|  | 	if lastBatch != 3 { | ||||||
|  | 		t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	targetUsers := users[3:8] | ||||||
|  | 	for i := 0; i < len(targetUsers); i++ { | ||||||
|  | 		AssertEqual(t, results[i], targetUsers[i]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var sub1 []User | ||||||
|  | 	// limit < batchSize
 | ||||||
|  | 	if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { | ||||||
|  | 		return nil | ||||||
|  | 	}); result.Error != nil || result.RowsAffected != 5 { | ||||||
|  | 		t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var sub2 []User | ||||||
|  | 	// only offset
 | ||||||
|  | 	if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { | ||||||
|  | 		return nil | ||||||
|  | 	}); result.Error != nil || result.RowsAffected != 7 { | ||||||
|  | 		t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var sub3 []User | ||||||
|  | 	if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { | ||||||
|  | 		return nil | ||||||
|  | 	}); result.Error != nil || result.RowsAffected != 4 { | ||||||
|  | 		t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestFindInBatchesWithError(t *testing.T) { | func TestFindInBatchesWithError(t *testing.T) { | ||||||
| 	if name := DB.Dialector.Name(); name == "sqlserver" { | 	if name := DB.Dialector.Name(); name == "sqlserver" { | ||||||
| 		t.Skip("skip sqlserver due to it will raise data race for invalid sql") | 		t.Skip("skip sqlserver due to it will raise data race for invalid sql") | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Cr
						Cr