From 48d0897d3e391ba6a0e3b814c9d1b546b00b4c2e Mon Sep 17 00:00:00 2001 From: ronnax Date: Thu, 20 Oct 2016 17:30:40 +0000 Subject: [PATCH] added FindInBatches, FindEach --- errors.go | 2 ++ main.go | 52 +++++++++++++++++++++++++++++++ main_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+) diff --git a/errors.go b/errors.go index ce3a25c0..716568ba 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,8 @@ var ( ErrCantStartTransaction = errors.New("can't start transaction") // ErrUnaddressable unaddressable value ErrUnaddressable = errors.New("using unaddressable value") + // ErrUnsupportedType unsupported type + ErrUnsupportedType = errors.New("using unsupported type") ) type errorsInterface interface { diff --git a/main.go b/main.go index 52a536d0..8863a636 100644 --- a/main.go +++ b/main.go @@ -264,6 +264,58 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// FindInBatches find records in batches for large tables and allow callbacks to operate +// on each batch +func (s *DB) FindInBatches(out interface{}, cb func(), where ...interface{}) *DB { + var limit interface{} + newDB := s.clone() + if len(newDB.search.orders) != 0 { + //ordering by non primary col not supported - can cause infinite loop + newDB.search.orders = []interface{}{} + } + value := indirect(reflect.ValueOf(out)) + if value.Kind() != reflect.Array && value.Kind() != reflect.Slice { + newDB.AddError(ErrUnsupportedType) + return newDB + } + if newDB.search.limit == -1 { + limit = 1000 + } else { + limit = newDB.search.limit + } + err := newDB.Limit(limit).Find(out, where...).Error + records := indirect(reflect.ValueOf(out)) + for err == nil && records.Len() > 0 { + cb() + lastRecord := records.Index(records.Len() - 1).Interface() + scope := s.NewScope(lastRecord) + err = newDB.Limit(limit). + Where(fmt.Sprintf("`%v` > '%v'", scope.PrimaryKey(), scope.PrimaryKeyValue())). + Find(out, where...).Error + records = indirect(reflect.ValueOf(out)) + } + if err != nil { + newDB.AddError(err) + } + return newDB +} + +// FindEach find records in batches for large tables and allow callbacks to operate +// on each of them between iterations +func (s *DB) FindEach(out interface{}, cb func(), where ...interface{}) *DB { + c := s.clone() + v := indirect(reflect.ValueOf(&out)).Interface() + slice := makeSlice(reflect.TypeOf(v)) + c = c.FindInBatches(slice, func() { + valueOfSlice := indirect(reflect.ValueOf(slice)) + for i := 0; i < valueOfSlice.Len(); i++ { + reflect.ValueOf(out).Elem().Set(indirect(valueOfSlice.Index(i))) + cb() + } + }, where...) + return c +} + // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db diff --git a/main_test.go b/main_test.go index 3084b733..2d089918 100644 --- a/main_test.go +++ b/main_test.go @@ -370,6 +370,92 @@ func TestTransaction(t *testing.T) { } } +func TestFindInBatches(t *testing.T) { + type SomeProduct struct { + gorm.Model + Code int + Price uint + } + type NoTableStruct struct { + gorm.Model + } + DB.DropTable(&SomeProduct{}) + DB.AutoMigrate(&SomeProduct{}) + for i := 0; i < 10; i++ { + m := SomeProduct{Code: i} + err := DB.Save(&m).Error + if err != nil { + panic(err) + } + } + var arr []SomeProduct + var products []SomeProduct + + DB.Limit(3).FindInBatches(&arr, func() { + for i := 0; i < len(arr); i++ { + products = append(products, arr[i]) + } + }) + if len(products) != 10 { + t.Errorf("not all products were returned in find in batches") + } + for i := 0; i < 10; i++ { + if products[i].Code != i { + t.Errorf("product %d didn't match %d", products[i].Code, i) + } + } + //test errors + err := DB.FindInBatches(SomeProduct{}, func() {}).Error + if err == nil { + t.Errorf("FindInBatches was supposed to fail when given a struct") + } else if err != gorm.ErrUnsupportedType { + t.Errorf("FindInBatches was supposed to return an ErrUnsupportedType") + } + err = DB.FindInBatches(&[]NoTableStruct{}, func() {}).Error + if err == nil { + t.Errorf("FindInBatches was supposed to return the underlying error") + } +} + +func TestFindEach(t *testing.T) { + type SomeProduct struct { + gorm.Model + Code int + Price uint + } + type NoTableStruct struct { + gorm.Model + } + DB.DropTable(&SomeProduct{}) + DB.AutoMigrate(&SomeProduct{}) + for i := 0; i < 10; i++ { + m := SomeProduct{Code: i} + err := DB.Save(&m).Error + if err != nil { + panic(err) + } + } + var out SomeProduct + var products []SomeProduct + + DB.Limit(3).FindEach(&out, func() { + products = append(products, out) + }) + if len(products) != 10 { + t.Errorf("not all products were returned in find in batches") + } + for i := 0; i < 10; i++ { + if products[i].Code != i { + t.Errorf("product %d didn't match %d", products[i].Code, i) + } + } + //test errors + err := DB.FindEach(&NoTableStruct{}, func() {}).Error + if err == nil { + t.Errorf("find each was supposed to return the underlying error") + } +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}