Support slice value -> sql.Rows
This commit is contained in:
parent
b4591d4db8
commit
afe269ec7d
@ -3,7 +3,6 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -148,6 +147,6 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
|
|||||||
|
|
||||||
// Find triggers a Query
|
// Find triggers a Query
|
||||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
||||||
fmt.Printf("Expecting query: %s\n", "some query involving Find")
|
h.gorm.Find(out, where...)
|
||||||
return h.adapter.ExpectQuery("some find condition")
|
return h.adapter.ExpectQuery(regexp.QuoteMeta(h.recorder.stmt))
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,10 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,36 +23,76 @@ type SqlmockQuery struct {
|
|||||||
query *sqlmock.ExpectedQuery
|
query *sqlmock.ExpectedQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
func getRowForFields(fields []*Field) []driver.Value {
|
||||||
var (
|
var values []driver.Value
|
||||||
columns []string
|
|
||||||
rows *sqlmock.Rows
|
|
||||||
values []driver.Value
|
|
||||||
)
|
|
||||||
|
|
||||||
q.scope = &Scope{Value: out}
|
|
||||||
fields := q.scope.Fields()
|
|
||||||
|
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
var (
|
value := field.Field
|
||||||
column = field.StructField.DBName
|
|
||||||
value = field.Field.Interface()
|
|
||||||
)
|
|
||||||
|
|
||||||
if isValue := driver.IsValue(value); isValue {
|
// dereference pointers
|
||||||
columns = append(columns, column)
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
values = append(values, value)
|
value = reflect.Indirect(field.Field)
|
||||||
} else if valuer, ok := value.(driver.Valuer); ok {
|
}
|
||||||
if underlyingValue, err := valuer.Value(); err == nil {
|
|
||||||
values = append(values, underlyingValue)
|
// check if we have a zero Value
|
||||||
columns = append(columns, field.StructField.DBName)
|
if !value.IsValid() {
|
||||||
|
values = append(values, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
concreteVal := value.Interface()
|
||||||
|
|
||||||
|
// if we already have a driver.Value, just append
|
||||||
|
_, isValuer := concreteVal.(driver.Valuer)
|
||||||
|
spew.Printf("%s: %v\r\n", field.DBName, isValuer)
|
||||||
|
|
||||||
|
if driver.IsValue(concreteVal) {
|
||||||
|
values = append(values, concreteVal)
|
||||||
|
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
|
||||||
|
if convertedValue, err := valuer.Value(); err == nil {
|
||||||
|
values = append(values, convertedValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows = sqlmock.NewRows(columns).AddRow(values...)
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
||||||
|
if field.IsNormal {
|
||||||
|
columns = append(columns, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows(columns)
|
||||||
|
|
||||||
|
outVal := indirect(reflect.ValueOf(out))
|
||||||
|
|
||||||
|
if outVal.Kind() == reflect.Slice {
|
||||||
|
outSlice := []interface{}{}
|
||||||
|
for i := 0; i < outVal.Len(); i++ {
|
||||||
|
outSlice = append(outSlice, outVal.Index(i).Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, outElem := range outSlice {
|
||||||
|
scope := &Scope{Value: outElem}
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
}
|
||||||
|
} else if outVal.Kind() == reflect.Struct {
|
||||||
|
scope := &Scope{Value: out}
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
||||||
|
}
|
||||||
|
|
||||||
|
spew.Dump(columns)
|
||||||
|
spew.Dump(rows)
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
}
|
}
|
||||||
|
@ -57,22 +57,67 @@ func TestQueryReturn(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
in := &User{Id: 1}
|
in := User{Id: 1}
|
||||||
expectedOut := User{Id: 1, Name: "jinzhu"}
|
out := User{Id: 1, Name: "jinzhu"}
|
||||||
|
|
||||||
expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
|
expect.First(&in).Returns(out)
|
||||||
|
|
||||||
db.First(in)
|
db.First(&in)
|
||||||
|
|
||||||
if e := expect.AssertExpectations(); e != nil {
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
t.Error(e)
|
t.Error(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
if in.Name != "jinzhu" {
|
if in.Name != "jinzhu" {
|
||||||
t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
|
t.Errorf("Expected %s, got %s", out.Name, in.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ne := reflect.DeepEqual(*in, expectedOut); !ne {
|
if ne := reflect.DeepEqual(in, out); !ne {
|
||||||
t.Errorf("Not equal")
|
t.Errorf("Not equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindStructDest(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer func() {
|
||||||
|
db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := &User{Id: 1}
|
||||||
|
|
||||||
|
expect.Find(in)
|
||||||
|
db.Find(&User{Id: 1})
|
||||||
|
|
||||||
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
|
t.Error(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindSlice(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer func() {
|
||||||
|
db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := []User{}
|
||||||
|
out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}}
|
||||||
|
|
||||||
|
expect.Find(&in).Returns(&out)
|
||||||
|
db.Find(&in)
|
||||||
|
|
||||||
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
|
t.Error(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ne := reflect.DeepEqual(in, out); !ne {
|
||||||
|
t.Error("Expected equal slices")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -139,9 +139,9 @@ func (role Role) IsAdmin() bool {
|
|||||||
|
|
||||||
type Num int64
|
type Num int64
|
||||||
|
|
||||||
func (i *Num) Value() (driver.Value, error) {
|
func (i Num) Value() (driver.Value, error) {
|
||||||
// guaranteed ok
|
// guaranteed ok
|
||||||
return int64(*i), nil
|
return int64(i), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Num) Scan(src interface{}) error {
|
func (i *Num) Scan(src interface{}) error {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user