diff --git a/callback_query.go b/callback_query.go index e3b3d534..db538540 100644 --- a/callback_query.go +++ b/callback_query.go @@ -26,8 +26,13 @@ func queryCallback(scope *Scope) { defer scope.trace(scope.db.nowFunc()) + if reflect.ValueOf(scope.Value).Kind() != reflect.Ptr { + panic("results argument must be a slice address") + } + var ( isSlice, isPtr bool + slicev reflect.Value resultType reflect.Type results = scope.IndirectValue() ) @@ -42,11 +47,18 @@ func queryCallback(scope *Scope) { results = indirect(reflect.ValueOf(value)) } - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) + if results.Kind() == reflect.Interface { + // for reflect params + slicev = results.Elem() + } else { + // for struct params + slicev = results + } + if kind := slicev.Kind(); kind == reflect.Slice { + isSlice = true + resultType = slicev.Type().Elem() + slicev = slicev.Slice(0, slicev.Cap()) if resultType.Kind() == reflect.Ptr { isPtr = true resultType = resultType.Elem() @@ -63,11 +75,11 @@ func queryCallback(scope *Scope) { if str, ok := scope.Get("gorm:query_option"); ok { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { defer rows.Close() columns, _ := rows.Columns() + i := 0 for rows.Next() { scope.db.RowsAffected++ @@ -75,18 +87,20 @@ func queryCallback(scope *Scope) { if isSlice { elem = reflect.New(resultType).Elem() } - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - if isSlice { if isPtr { - results.Set(reflect.Append(results, elem.Addr())) + slicev = reflect.Append(slicev, elem.Addr()) } else { - results.Set(reflect.Append(results, elem)) + slicev = reflect.Append(slicev, elem) } + slicev = slicev.Slice(0, slicev.Cap()) + i++ } } - + if isSlice { + results.Set(slicev.Slice(0, i)) + } if err := rows.Err(); err != nil { scope.Err(err) } else if scope.db.RowsAffected == 0 && !isSlice { diff --git a/model_struct.go b/model_struct.go index 5234b287..dcdf60ed 100644 --- a/model_struct.go +++ b/model_struct.go @@ -149,12 +149,24 @@ func getForeignField(column string, fields []*StructField) *StructField { // GetModelStruct get value's model struct, relationships based on struct and tag definition func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct + var reflectType reflect.Type // Scope value can't be nil if scope.Value == nil { return &modelStruct } - reflectType := reflect.ValueOf(scope.Value).Type() + if reflect.ValueOf(scope.Value).Kind() != reflect.Ptr { + panic("results argument must be a ptr or slice") + } + + if reflect.ValueOf(scope.Value).Elem().Kind() == reflect.Interface { + // for reflect params + reflectType = reflect.ValueOf(scope.Value).Elem().Elem().Type() + } else { + // for struct params + reflectType = reflect.ValueOf(scope.Value).Type() + } + for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { reflectType = reflectType.Elem() }