fixed: queryCallback error if params to Find is created by reflection

This commit is contained in:
double 2019-06-16 15:15:41 +08:00
parent 01b6601142
commit b8890dd52c
2 changed files with 37 additions and 11 deletions

View File

@ -26,8 +26,13 @@ func queryCallback(scope *Scope) {
defer scope.trace(scope.db.nowFunc()) defer scope.trace(scope.db.nowFunc())
if reflect.ValueOf(scope.Value).Kind() != reflect.Ptr {
panic("results argument must be a slice address")
}
var ( var (
isSlice, isPtr bool isSlice, isPtr bool
slicev reflect.Value
resultType reflect.Type resultType reflect.Type
results = scope.IndirectValue() results = scope.IndirectValue()
) )
@ -42,11 +47,18 @@ func queryCallback(scope *Scope) {
results = indirect(reflect.ValueOf(value)) results = indirect(reflect.ValueOf(value))
} }
if kind := results.Kind(); kind == reflect.Slice { if results.Kind() == reflect.Interface {
isSlice = true // for reflect params
resultType = results.Type().Elem() slicev = results.Elem()
results.Set(reflect.MakeSlice(results.Type(), 0, 0)) } else {
// for struct params
slicev = results
}
if kind := slicev.Kind(); kind == reflect.Slice {
isSlice = true
resultType = slicev.Type().Elem()
slicev = slicev.Slice(0, slicev.Cap())
if resultType.Kind() == reflect.Ptr { if resultType.Kind() == reflect.Ptr {
isPtr = true isPtr = true
resultType = resultType.Elem() resultType = resultType.Elem()
@ -63,11 +75,11 @@ func queryCallback(scope *Scope) {
if str, ok := scope.Get("gorm:query_option"); ok { if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
} }
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close() defer rows.Close()
columns, _ := rows.Columns() columns, _ := rows.Columns()
i := 0
for rows.Next() { for rows.Next() {
scope.db.RowsAffected++ scope.db.RowsAffected++
@ -75,18 +87,20 @@ func queryCallback(scope *Scope) {
if isSlice { if isSlice {
elem = reflect.New(resultType).Elem() elem = reflect.New(resultType).Elem()
} }
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
if isSlice { if isSlice {
if isPtr { if isPtr {
results.Set(reflect.Append(results, elem.Addr())) slicev = reflect.Append(slicev, elem.Addr())
} else { } else {
results.Set(reflect.Append(results, elem)) slicev = reflect.Append(slicev, elem)
} }
slicev = slicev.Slice(0, slicev.Cap())
i++
} }
} }
if isSlice {
results.Set(slicev.Slice(0, i))
}
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
scope.Err(err) scope.Err(err)
} else if scope.db.RowsAffected == 0 && !isSlice { } else if scope.db.RowsAffected == 0 && !isSlice {

View File

@ -149,12 +149,24 @@ func getForeignField(column string, fields []*StructField) *StructField {
// GetModelStruct get value's model struct, relationships based on struct and tag definition // GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct { func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct var modelStruct ModelStruct
var reflectType reflect.Type
// Scope value can't be nil // Scope value can't be nil
if scope.Value == nil { if scope.Value == nil {
return &modelStruct return &modelStruct
} }
reflectType := reflect.ValueOf(scope.Value).Type() if reflect.ValueOf(scope.Value).Kind() != reflect.Ptr {
panic("results argument must be a ptr or slice")
}
if reflect.ValueOf(scope.Value).Elem().Kind() == reflect.Interface {
// for reflect params
reflectType = reflect.ValueOf(scope.Value).Elem().Elem().Type()
} else {
// for struct params
reflectType = reflect.ValueOf(scope.Value).Type()
}
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem() reflectType = reflectType.Elem()
} }