Support slice value -> sql.Rows

This commit is contained in:
Ian Tan 2017-11-20 18:38:57 +08:00
parent b4591d4db8
commit afe269ec7d
4 changed files with 120 additions and 33 deletions

View File

@ -3,7 +3,6 @@ package gorm
import (
"database/sql"
"errors"
"fmt"
"regexp"
)
@ -148,6 +147,6 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
// Find triggers a Query
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
fmt.Printf("Expecting query: %s\n", "some query involving Find")
return h.adapter.ExpectQuery("some find condition")
h.gorm.Find(out, where...)
return h.adapter.ExpectQuery(regexp.QuoteMeta(h.recorder.stmt))
}

View File

@ -2,7 +2,10 @@ package gorm
import (
"database/sql/driver"
"fmt"
"reflect"
"github.com/davecgh/go-spew/spew"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
@ -20,36 +23,76 @@ type SqlmockQuery struct {
query *sqlmock.ExpectedQuery
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
var (
columns []string
rows *sqlmock.Rows
values []driver.Value
)
q.scope = &Scope{Value: out}
fields := q.scope.Fields()
func getRowForFields(fields []*Field) []driver.Value {
var values []driver.Value
for _, field := range fields {
if field.IsNormal {
var (
column = field.StructField.DBName
value = field.Field.Interface()
)
value := field.Field
if isValue := driver.IsValue(value); isValue {
columns = append(columns, column)
values = append(values, value)
} else if valuer, ok := value.(driver.Valuer); ok {
if underlyingValue, err := valuer.Value(); err == nil {
values = append(values, underlyingValue)
columns = append(columns, field.StructField.DBName)
// dereference pointers
if field.Field.Kind() == reflect.Ptr {
value = reflect.Indirect(field.Field)
}
// check if we have a zero Value
if !value.IsValid() {
values = append(values, nil)
continue
}
concreteVal := value.Interface()
// if we already have a driver.Value, just append
_, isValuer := concreteVal.(driver.Valuer)
spew.Printf("%s: %v\r\n", field.DBName, isValuer)
if driver.IsValue(concreteVal) {
values = append(values, concreteVal)
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
if convertedValue, err := valuer.Value(); err == nil {
values = append(values, convertedValue)
}
}
}
}
rows = sqlmock.NewRows(columns).AddRow(values...)
return values
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
var columns []string
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
if field.IsNormal {
columns = append(columns, field.DBName)
}
}
rows := sqlmock.NewRows(columns)
outVal := indirect(reflect.ValueOf(out))
if outVal.Kind() == reflect.Slice {
outSlice := []interface{}{}
for i := 0; i < outVal.Len(); i++ {
outSlice = append(outSlice, outVal.Index(i).Interface())
}
for _, outElem := range outSlice {
scope := &Scope{Value: outElem}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
}
} else if outVal.Kind() == reflect.Struct {
scope := &Scope{Value: out}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
} else {
panic(fmt.Errorf("Can only get rows for slice or struct"))
}
spew.Dump(columns)
spew.Dump(rows)
return rows
}

View File

@ -57,22 +57,67 @@ func TestQueryReturn(t *testing.T) {
t.Fatal(err)
}
in := &User{Id: 1}
expectedOut := User{Id: 1, Name: "jinzhu"}
in := User{Id: 1}
out := User{Id: 1, Name: "jinzhu"}
expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
expect.First(&in).Returns(out)
db.First(in)
db.First(&in)
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
if in.Name != "jinzhu" {
t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
t.Errorf("Expected %s, got %s", out.Name, in.Name)
}
if ne := reflect.DeepEqual(*in, expectedOut); !ne {
if ne := reflect.DeepEqual(in, out); !ne {
t.Errorf("Not equal")
}
}
func TestFindStructDest(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
in := &User{Id: 1}
expect.Find(in)
db.Find(&User{Id: 1})
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
}
func TestFindSlice(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
in := []User{}
out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}}
expect.Find(&in).Returns(&out)
db.Find(&in)
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
if ne := reflect.DeepEqual(in, out); !ne {
t.Error("Expected equal slices")
}
}

View File

@ -139,9 +139,9 @@ func (role Role) IsAdmin() bool {
type Num int64
func (i *Num) Value() (driver.Value, error) {
func (i Num) Value() (driver.Value, error) {
// guaranteed ok
return int64(*i), nil
return int64(i), nil
}
func (i *Num) Scan(src interface{}) error {