added FindInBatches, FindEach
This commit is contained in:
parent
a667ab8427
commit
48d0897d3e
@ -16,6 +16,8 @@ var (
|
||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||
// ErrUnaddressable unaddressable value
|
||||
ErrUnaddressable = errors.New("using unaddressable value")
|
||||
// ErrUnsupportedType unsupported type
|
||||
ErrUnsupportedType = errors.New("using unsupported type")
|
||||
)
|
||||
|
||||
type errorsInterface interface {
|
||||
|
52
main.go
52
main.go
@ -264,6 +264,58 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// FindInBatches find records in batches for large tables and allow callbacks to operate
|
||||
// on each batch
|
||||
func (s *DB) FindInBatches(out interface{}, cb func(), where ...interface{}) *DB {
|
||||
var limit interface{}
|
||||
newDB := s.clone()
|
||||
if len(newDB.search.orders) != 0 {
|
||||
//ordering by non primary col not supported - can cause infinite loop
|
||||
newDB.search.orders = []interface{}{}
|
||||
}
|
||||
value := indirect(reflect.ValueOf(out))
|
||||
if value.Kind() != reflect.Array && value.Kind() != reflect.Slice {
|
||||
newDB.AddError(ErrUnsupportedType)
|
||||
return newDB
|
||||
}
|
||||
if newDB.search.limit == -1 {
|
||||
limit = 1000
|
||||
} else {
|
||||
limit = newDB.search.limit
|
||||
}
|
||||
err := newDB.Limit(limit).Find(out, where...).Error
|
||||
records := indirect(reflect.ValueOf(out))
|
||||
for err == nil && records.Len() > 0 {
|
||||
cb()
|
||||
lastRecord := records.Index(records.Len() - 1).Interface()
|
||||
scope := s.NewScope(lastRecord)
|
||||
err = newDB.Limit(limit).
|
||||
Where(fmt.Sprintf("`%v` > '%v'", scope.PrimaryKey(), scope.PrimaryKeyValue())).
|
||||
Find(out, where...).Error
|
||||
records = indirect(reflect.ValueOf(out))
|
||||
}
|
||||
if err != nil {
|
||||
newDB.AddError(err)
|
||||
}
|
||||
return newDB
|
||||
}
|
||||
|
||||
// FindEach find records in batches for large tables and allow callbacks to operate
|
||||
// on each of them between iterations
|
||||
func (s *DB) FindEach(out interface{}, cb func(), where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
v := indirect(reflect.ValueOf(&out)).Interface()
|
||||
slice := makeSlice(reflect.TypeOf(v))
|
||||
c = c.FindInBatches(slice, func() {
|
||||
valueOfSlice := indirect(reflect.ValueOf(slice))
|
||||
for i := 0; i < valueOfSlice.Len(); i++ {
|
||||
reflect.ValueOf(out).Elem().Set(indirect(valueOfSlice.Index(i)))
|
||||
cb()
|
||||
}
|
||||
}, where...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (s *DB) Scan(dest interface{}) *DB {
|
||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||
|
86
main_test.go
86
main_test.go
@ -370,6 +370,92 @@ func TestTransaction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindInBatches(t *testing.T) {
|
||||
type SomeProduct struct {
|
||||
gorm.Model
|
||||
Code int
|
||||
Price uint
|
||||
}
|
||||
type NoTableStruct struct {
|
||||
gorm.Model
|
||||
}
|
||||
DB.DropTable(&SomeProduct{})
|
||||
DB.AutoMigrate(&SomeProduct{})
|
||||
for i := 0; i < 10; i++ {
|
||||
m := SomeProduct{Code: i}
|
||||
err := DB.Save(&m).Error
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
var arr []SomeProduct
|
||||
var products []SomeProduct
|
||||
|
||||
DB.Limit(3).FindInBatches(&arr, func() {
|
||||
for i := 0; i < len(arr); i++ {
|
||||
products = append(products, arr[i])
|
||||
}
|
||||
})
|
||||
if len(products) != 10 {
|
||||
t.Errorf("not all products were returned in find in batches")
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
if products[i].Code != i {
|
||||
t.Errorf("product %d didn't match %d", products[i].Code, i)
|
||||
}
|
||||
}
|
||||
//test errors
|
||||
err := DB.FindInBatches(SomeProduct{}, func() {}).Error
|
||||
if err == nil {
|
||||
t.Errorf("FindInBatches was supposed to fail when given a struct")
|
||||
} else if err != gorm.ErrUnsupportedType {
|
||||
t.Errorf("FindInBatches was supposed to return an ErrUnsupportedType")
|
||||
}
|
||||
err = DB.FindInBatches(&[]NoTableStruct{}, func() {}).Error
|
||||
if err == nil {
|
||||
t.Errorf("FindInBatches was supposed to return the underlying error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindEach(t *testing.T) {
|
||||
type SomeProduct struct {
|
||||
gorm.Model
|
||||
Code int
|
||||
Price uint
|
||||
}
|
||||
type NoTableStruct struct {
|
||||
gorm.Model
|
||||
}
|
||||
DB.DropTable(&SomeProduct{})
|
||||
DB.AutoMigrate(&SomeProduct{})
|
||||
for i := 0; i < 10; i++ {
|
||||
m := SomeProduct{Code: i}
|
||||
err := DB.Save(&m).Error
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
var out SomeProduct
|
||||
var products []SomeProduct
|
||||
|
||||
DB.Limit(3).FindEach(&out, func() {
|
||||
products = append(products, out)
|
||||
})
|
||||
if len(products) != 10 {
|
||||
t.Errorf("not all products were returned in find in batches")
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
if products[i].Code != i {
|
||||
t.Errorf("product %d didn't match %d", products[i].Code, i)
|
||||
}
|
||||
}
|
||||
//test errors
|
||||
err := DB.FindEach(&NoTableStruct{}, func() {}).Error
|
||||
if err == nil {
|
||||
t.Errorf("find each was supposed to return the underlying error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRow(t *testing.T) {
|
||||
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||
|
Loading…
x
Reference in New Issue
Block a user