diff --git a/finisher_api.go b/finisher_api.go index 552f5990..bc44dbd4 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -212,13 +212,12 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat } if totalSize > 0 { - if totalSize/batchSize == batch { - batchSize = totalSize % batchSize - } - if totalSize <= int(rowsAffected) { break } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } } // Optimize for-break diff --git a/tests/query_test.go b/tests/query_test.go index 24c53c1f..9dd61bab 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -336,6 +336,12 @@ func TestFindInBatchesWithOffsetLimit(t *testing.T) { }); result.Error != nil || result.RowsAffected != 5 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) } + + if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 4 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } } func TestFindInBatchesWithError(t *testing.T) {