Merge 48d0897d3e391ba6a0e3b814c9d1b546b00b4c2e into 9acaa33324bbcc78239a1c913d4f1292c12177b9

This commit is contained in:
Ronna Steinberg 2017-05-06 20:53:53 +00:00 committed by GitHub
commit 6cbd6648dc
3 changed files with 140 additions and 0 deletions

View File

@ -16,6 +16,8 @@ var (
ErrCantStartTransaction = errors.New("can't start transaction") ErrCantStartTransaction = errors.New("can't start transaction")
// ErrUnaddressable unaddressable value // ErrUnaddressable unaddressable value
ErrUnaddressable = errors.New("using unaddressable value") ErrUnaddressable = errors.New("using unaddressable value")
// ErrUnsupportedType unsupported type
ErrUnsupportedType = errors.New("using unsupported type")
) )
// Errors contains all happened errors // Errors contains all happened errors

52
main.go
View File

@ -284,6 +284,58 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
} }
// FindInBatches find records in batches for large tables and allow callbacks to operate
// on each batch
func (s *DB) FindInBatches(out interface{}, cb func(), where ...interface{}) *DB {
var limit interface{}
newDB := s.clone()
if len(newDB.search.orders) != 0 {
//ordering by non primary col not supported - can cause infinite loop
newDB.search.orders = []interface{}{}
}
value := indirect(reflect.ValueOf(out))
if value.Kind() != reflect.Array && value.Kind() != reflect.Slice {
newDB.AddError(ErrUnsupportedType)
return newDB
}
if newDB.search.limit == -1 {
limit = 1000
} else {
limit = newDB.search.limit
}
err := newDB.Limit(limit).Find(out, where...).Error
records := indirect(reflect.ValueOf(out))
for err == nil && records.Len() > 0 {
cb()
lastRecord := records.Index(records.Len() - 1).Interface()
scope := s.NewScope(lastRecord)
err = newDB.Limit(limit).
Where(fmt.Sprintf("`%v` > '%v'", scope.PrimaryKey(), scope.PrimaryKeyValue())).
Find(out, where...).Error
records = indirect(reflect.ValueOf(out))
}
if err != nil {
newDB.AddError(err)
}
return newDB
}
// FindEach find records in batches for large tables and allow callbacks to operate
// on each of them between iterations
func (s *DB) FindEach(out interface{}, cb func(), where ...interface{}) *DB {
c := s.clone()
v := indirect(reflect.ValueOf(&out)).Interface()
slice := makeSlice(reflect.TypeOf(v))
c = c.FindInBatches(slice, func() {
valueOfSlice := indirect(reflect.ValueOf(slice))
for i := 0; i < valueOfSlice.Len(); i++ {
reflect.ValueOf(out).Elem().Set(indirect(valueOfSlice.Index(i)))
cb()
}
}, where...)
return c
}
// Scan scan value to a struct // Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB { func (s *DB) Scan(dest interface{}) *DB {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db

View File

@ -381,6 +381,92 @@ func TestTransaction(t *testing.T) {
} }
} }
func TestFindInBatches(t *testing.T) {
type SomeProduct struct {
gorm.Model
Code int
Price uint
}
type NoTableStruct struct {
gorm.Model
}
DB.DropTable(&SomeProduct{})
DB.AutoMigrate(&SomeProduct{})
for i := 0; i < 10; i++ {
m := SomeProduct{Code: i}
err := DB.Save(&m).Error
if err != nil {
panic(err)
}
}
var arr []SomeProduct
var products []SomeProduct
DB.Limit(3).FindInBatches(&arr, func() {
for i := 0; i < len(arr); i++ {
products = append(products, arr[i])
}
})
if len(products) != 10 {
t.Errorf("not all products were returned in find in batches")
}
for i := 0; i < 10; i++ {
if products[i].Code != i {
t.Errorf("product %d didn't match %d", products[i].Code, i)
}
}
//test errors
err := DB.FindInBatches(SomeProduct{}, func() {}).Error
if err == nil {
t.Errorf("FindInBatches was supposed to fail when given a struct")
} else if err != gorm.ErrUnsupportedType {
t.Errorf("FindInBatches was supposed to return an ErrUnsupportedType")
}
err = DB.FindInBatches(&[]NoTableStruct{}, func() {}).Error
if err == nil {
t.Errorf("FindInBatches was supposed to return the underlying error")
}
}
func TestFindEach(t *testing.T) {
type SomeProduct struct {
gorm.Model
Code int
Price uint
}
type NoTableStruct struct {
gorm.Model
}
DB.DropTable(&SomeProduct{})
DB.AutoMigrate(&SomeProduct{})
for i := 0; i < 10; i++ {
m := SomeProduct{Code: i}
err := DB.Save(&m).Error
if err != nil {
panic(err)
}
}
var out SomeProduct
var products []SomeProduct
DB.Limit(3).FindEach(&out, func() {
products = append(products, out)
})
if len(products) != 10 {
t.Errorf("not all products were returned in find in batches")
}
for i := 0; i < 10; i++ {
if products[i].Code != i {
t.Errorf("product %d didn't match %d", products[i].Code, i)
}
}
//test errors
err := DB.FindEach(&NoTableStruct{}, func() {}).Error
if err == nil {
t.Errorf("find each was supposed to return the underlying error")
}
}
func TestRow(t *testing.T) { func TestRow(t *testing.T) {
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}