Pass Stmt directly to SqlmockQuery

This commit is contained in:
Ian Tan 2017-11-23 17:27:59 +08:00
parent e89019d178
commit 505ecd17d3
2 changed files with 38 additions and 52 deletions

View File

@ -24,7 +24,7 @@ func init() {
// Adapter provides an abstract interface over concrete mock database // Adapter provides an abstract interface over concrete mock database
// implementations (e.g. go-sqlmock or go-testdb) // implementations (e.g. go-sqlmock or go-testdb)
type Adapter interface { type Adapter interface {
ExpectQuery(stmts ...string) ExpectedQuery ExpectQuery(stmts ...Stmt) ExpectedQuery
ExpectExec(stmt string) ExpectedExec ExpectExec(stmt string) ExpectedExec
AssertExpectations() error AssertExpectations() error
} }
@ -50,14 +50,8 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
// ExpectQuery wraps the underlying mock method for setting a query // ExpectQuery wraps the underlying mock method for setting a query
// expectation. It accepts multiple statements in the event of preloading // expectation. It accepts multiple statements in the event of preloading
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery { func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
var queries []*sqlmock.ExpectedQuery return &SqlmockQuery{mock: a.mocker, queries: queries}
for _, stmt := range stmts {
queries = append(queries, a.mocker.ExpectQuery(stmt))
}
return &SqlmockQuery{queries: queries}
} }
// ExpectExec wraps the underlying mock method for setting a exec // ExpectExec wraps the underlying mock method for setting a exec

View File

@ -4,8 +4,8 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"github.com/davecgh/go-spew/spew"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
) )
@ -25,7 +25,8 @@ type ExpectedExec interface {
// SqlmockQuery implements Query for go-sqlmock // SqlmockQuery implements Query for go-sqlmock
type SqlmockQuery struct { type SqlmockQuery struct {
queries []*sqlmock.ExpectedQuery mock sqlmock.Sqlmock
queries []Stmt
} }
func getRowForFields(fields []*Field) []driver.Value { func getRowForFields(fields []*Field) []driver.Value {
@ -64,7 +65,7 @@ func getRowForFields(fields []*Field) []driver.Value {
return values return values
} }
func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) { func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
var ( var (
rows *sqlmock.Rows rows *sqlmock.Rows
columns []string columns []string
@ -72,7 +73,7 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
// we need to check for zero values // we need to check for zero values
if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) { if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName) // spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
return nil, false return nil, false
} }
@ -102,8 +103,8 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
columns = append(columns, field.DBName) columns = append(columns, field.DBName)
} }
} }
spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName) // spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns)) // spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
columns = append(columns, "user_id", "language_id") columns = append(columns, "user_id", "language_id")
rows = sqlmock.NewRows(columns) rows = sqlmock.NewRows(columns)
@ -126,36 +127,21 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
} }
} }
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
var ( var columns []string
columns []string
relations = make(map[string]*Relationship)
rowsSet []*sqlmock.Rows
)
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields { for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
// we get the primary model's columns here
if field.IsNormal { if field.IsNormal {
columns = append(columns, field.DBName) columns = append(columns, field.DBName)
} }
// check relations
if !field.IsNormal {
relationVal := reflect.ValueOf(field.Relationship)
isNil := relationVal.IsNil()
if !isNil {
relations[field.Name] = field.Relationship
}
}
} }
rows := sqlmock.NewRows(columns) rows := sqlmock.NewRows(columns)
outVal := indirect(reflect.ValueOf(out)) outVal := indirect(reflect.ValueOf(out))
// SELECT multiple columns
if outVal.Kind() == reflect.Slice { if outVal.Kind() == reflect.Slice {
outSlice := []interface{}{} outSlice := []interface{}{}
for i := 0; i < outVal.Len(); i++ { for i := 0; i < outVal.Len(); i++ {
outSlice = append(outSlice, outVal.Index(i).Interface()) outSlice = append(outSlice, outVal.Index(i).Interface())
} }
@ -164,39 +150,45 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
scope := &Scope{Value: outElem} scope := &Scope{Value: outElem}
row := getRowForFields(scope.Fields()) row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...) rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
} }
} else if outVal.Kind() == reflect.Struct { } else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
scope := &Scope{Value: out} scope := &Scope{Value: out}
row := getRowForFields(scope.Fields()) row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...) rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
for name, relation := range relations {
rVal := outVal.FieldByName(name)
relationRows, hasRows := getRelationRows(rVal, name, relation)
if hasRows {
spew.Printf("___GENERATED ROWS FOR %s___\r\n: %s\r\n", name, spew.Sdump(relationRows))
rowsSet = append(rowsSet, relationRows)
}
}
} else { } else {
panic(fmt.Errorf("Can only get rows for slice or struct")) panic(fmt.Errorf("Can only get rows for slice or struct"))
} }
return rowsSet return rows
} }
// Returns accepts an out type which should either be a struct or slice. Under // Returns accepts an out type which should either be a struct or slice. Under
// the hood, it converts a gorm model struct to sql.Rows that can be passed to // the hood, it converts a gorm model struct to sql.Rows that can be passed to
// the underlying mock db // the underlying mock db
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
rows := q.getRowsForOutType(out) scope := (&Scope{}).New(out)
outVal := indirect(reflect.ValueOf(out))
for i, query := range q.queries { // rows := q.getRowsForOutType(out)
query.WillReturnRows(rows[i]) destQuery := q.queries[0]
spew.Printf("___SET RETURN ROW___: %s", spew.Sdump(rows[i])) subQueries := q.queries[1:]
// main query always at the head of the slice
q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)).
WillReturnRows(q.getDestRows(out))
// subqueries are preload
for _, subQuery := range subQueries {
if subQuery.preload != "" {
if field, ok := scope.FieldByName(subQuery.preload); ok {
expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql))
rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship)
if hasRows {
expectation.WillReturnRows(rows)
}
}
}
} }
return q return q