diff --git a/finisher_api.go b/finisher_api.go index bef65ae5..b5cbfaa6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -190,6 +190,8 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if result.Error == nil && result.RowsAffected != 0 { tx.AddError(fc(result, batch)) + } else if result.Error != nil { + tx.AddError(result.Error) } if tx.Error != nil || int(result.RowsAffected) < batchSize { diff --git a/tests/query_test.go b/tests/query_test.go index ee157a13..489ac807 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,34 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithError(t *testing.T) { + var users = []User{ + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Table("wrong_table").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error == nil || result.RowsAffected > 0 { + t.Fatal("expected errors to have occurred, but nothing happened") + } + if totalBatch != 0 { + t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) + } +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user)