Pass Stmt directly to SqlmockQuery

This commit is contained in:
Ian Tan 2017-11-23 17:27:59 +08:00
parent e89019d178
commit 505ecd17d3
2 changed files with 38 additions and 52 deletions

View File

@ -24,7 +24,7 @@ func init() {
// Adapter provides an abstract interface over concrete mock database
// implementations (e.g. go-sqlmock or go-testdb)
type Adapter interface {
ExpectQuery(stmts ...string) ExpectedQuery
ExpectQuery(stmts ...Stmt) ExpectedQuery
ExpectExec(stmt string) ExpectedExec
AssertExpectations() error
}
@ -50,14 +50,8 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
// ExpectQuery wraps the underlying mock method for setting a query
// expectation. It accepts multiple statements in the event of preloading
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery {
var queries []*sqlmock.ExpectedQuery
for _, stmt := range stmts {
queries = append(queries, a.mocker.ExpectQuery(stmt))
}
return &SqlmockQuery{queries: queries}
func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
return &SqlmockQuery{mock: a.mocker, queries: queries}
}
// ExpectExec wraps the underlying mock method for setting a exec

View File

@ -4,8 +4,8 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"github.com/davecgh/go-spew/spew"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
@ -25,7 +25,8 @@ type ExpectedExec interface {
// SqlmockQuery implements Query for go-sqlmock
type SqlmockQuery struct {
queries []*sqlmock.ExpectedQuery
mock sqlmock.Sqlmock
queries []Stmt
}
func getRowForFields(fields []*Field) []driver.Value {
@ -64,7 +65,7 @@ func getRowForFields(fields []*Field) []driver.Value {
return values
}
func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
var (
rows *sqlmock.Rows
columns []string
@ -72,7 +73,7 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
// we need to check for zero values
if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
// spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
return nil, false
}
@ -102,8 +103,8 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
columns = append(columns, field.DBName)
}
}
spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
// spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
// spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
columns = append(columns, "user_id", "language_id")
rows = sqlmock.NewRows(columns)
@ -126,36 +127,21 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
}
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
var (
columns []string
relations = make(map[string]*Relationship)
rowsSet []*sqlmock.Rows
)
func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
var columns []string
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
// we get the primary model's columns here
if field.IsNormal {
columns = append(columns, field.DBName)
}
// check relations
if !field.IsNormal {
relationVal := reflect.ValueOf(field.Relationship)
isNil := relationVal.IsNil()
if !isNil {
relations[field.Name] = field.Relationship
}
}
}
rows := sqlmock.NewRows(columns)
outVal := indirect(reflect.ValueOf(out))
// SELECT multiple columns
if outVal.Kind() == reflect.Slice {
outSlice := []interface{}{}
for i := 0; i < outVal.Len(); i++ {
outSlice = append(outSlice, outVal.Index(i).Interface())
}
@ -164,39 +150,45 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
scope := &Scope{Value: outElem}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
}
} else if outVal.Kind() == reflect.Struct {
} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
scope := &Scope{Value: out}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
for name, relation := range relations {
rVal := outVal.FieldByName(name)
relationRows, hasRows := getRelationRows(rVal, name, relation)
if hasRows {
spew.Printf("___GENERATED ROWS FOR %s___\r\n: %s\r\n", name, spew.Sdump(relationRows))
rowsSet = append(rowsSet, relationRows)
}
}
} else {
panic(fmt.Errorf("Can only get rows for slice or struct"))
}
return rowsSet
return rows
}
// Returns accepts an out type which should either be a struct or slice. Under
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
// the underlying mock db
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
rows := q.getRowsForOutType(out)
scope := (&Scope{}).New(out)
outVal := indirect(reflect.ValueOf(out))
for i, query := range q.queries {
query.WillReturnRows(rows[i])
spew.Printf("___SET RETURN ROW___: %s", spew.Sdump(rows[i]))
// rows := q.getRowsForOutType(out)
destQuery := q.queries[0]
subQueries := q.queries[1:]
// main query always at the head of the slice
q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)).
WillReturnRows(q.getDestRows(out))
// subqueries are preload
for _, subQuery := range subQueries {
if subQuery.preload != "" {
if field, ok := scope.FieldByName(subQuery.preload); ok {
expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql))
rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship)
if hasRows {
expectation.WillReturnRows(rows)
}
}
}
}
return q