Pass Stmt directly to SqlmockQuery
This commit is contained in:
parent
e89019d178
commit
505ecd17d3
@ -24,7 +24,7 @@ func init() {
|
|||||||
// Adapter provides an abstract interface over concrete mock database
|
// Adapter provides an abstract interface over concrete mock database
|
||||||
// implementations (e.g. go-sqlmock or go-testdb)
|
// implementations (e.g. go-sqlmock or go-testdb)
|
||||||
type Adapter interface {
|
type Adapter interface {
|
||||||
ExpectQuery(stmts ...string) ExpectedQuery
|
ExpectQuery(stmts ...Stmt) ExpectedQuery
|
||||||
ExpectExec(stmt string) ExpectedExec
|
ExpectExec(stmt string) ExpectedExec
|
||||||
AssertExpectations() error
|
AssertExpectations() error
|
||||||
}
|
}
|
||||||
@ -50,14 +50,8 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
|
|||||||
|
|
||||||
// ExpectQuery wraps the underlying mock method for setting a query
|
// ExpectQuery wraps the underlying mock method for setting a query
|
||||||
// expectation. It accepts multiple statements in the event of preloading
|
// expectation. It accepts multiple statements in the event of preloading
|
||||||
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery {
|
func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
|
||||||
var queries []*sqlmock.ExpectedQuery
|
return &SqlmockQuery{mock: a.mocker, queries: queries}
|
||||||
|
|
||||||
for _, stmt := range stmts {
|
|
||||||
queries = append(queries, a.mocker.ExpectQuery(stmt))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &SqlmockQuery{queries: queries}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpectExec wraps the underlying mock method for setting a exec
|
// ExpectExec wraps the underlying mock method for setting a exec
|
||||||
|
@ -4,8 +4,8 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,7 +25,8 @@ type ExpectedExec interface {
|
|||||||
|
|
||||||
// SqlmockQuery implements Query for go-sqlmock
|
// SqlmockQuery implements Query for go-sqlmock
|
||||||
type SqlmockQuery struct {
|
type SqlmockQuery struct {
|
||||||
queries []*sqlmock.ExpectedQuery
|
mock sqlmock.Sqlmock
|
||||||
|
queries []Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRowForFields(fields []*Field) []driver.Value {
|
func getRowForFields(fields []*Field) []driver.Value {
|
||||||
@ -64,7 +65,7 @@ func getRowForFields(fields []*Field) []driver.Value {
|
|||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
|
func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
|
||||||
var (
|
var (
|
||||||
rows *sqlmock.Rows
|
rows *sqlmock.Rows
|
||||||
columns []string
|
columns []string
|
||||||
@ -72,7 +73,7 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
|
|||||||
|
|
||||||
// we need to check for zero values
|
// we need to check for zero values
|
||||||
if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
|
if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
|
||||||
spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
|
// spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,8 +103,8 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
|
|||||||
columns = append(columns, field.DBName)
|
columns = append(columns, field.DBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
|
// spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
|
||||||
spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
|
// spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
|
||||||
columns = append(columns, "user_id", "language_id")
|
columns = append(columns, "user_id", "language_id")
|
||||||
|
|
||||||
rows = sqlmock.NewRows(columns)
|
rows = sqlmock.NewRows(columns)
|
||||||
@ -126,36 +127,21 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
|
func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
|
||||||
var (
|
var columns []string
|
||||||
columns []string
|
|
||||||
relations = make(map[string]*Relationship)
|
|
||||||
rowsSet []*sqlmock.Rows
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
||||||
// we get the primary model's columns here
|
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
columns = append(columns, field.DBName)
|
columns = append(columns, field.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check relations
|
|
||||||
if !field.IsNormal {
|
|
||||||
relationVal := reflect.ValueOf(field.Relationship)
|
|
||||||
isNil := relationVal.IsNil()
|
|
||||||
|
|
||||||
if !isNil {
|
|
||||||
relations[field.Name] = field.Relationship
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := sqlmock.NewRows(columns)
|
rows := sqlmock.NewRows(columns)
|
||||||
|
|
||||||
outVal := indirect(reflect.ValueOf(out))
|
outVal := indirect(reflect.ValueOf(out))
|
||||||
|
|
||||||
|
// SELECT multiple columns
|
||||||
if outVal.Kind() == reflect.Slice {
|
if outVal.Kind() == reflect.Slice {
|
||||||
outSlice := []interface{}{}
|
outSlice := []interface{}{}
|
||||||
|
|
||||||
for i := 0; i < outVal.Len(); i++ {
|
for i := 0; i < outVal.Len(); i++ {
|
||||||
outSlice = append(outSlice, outVal.Index(i).Interface())
|
outSlice = append(outSlice, outVal.Index(i).Interface())
|
||||||
}
|
}
|
||||||
@ -164,39 +150,45 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
|
|||||||
scope := &Scope{Value: outElem}
|
scope := &Scope{Value: outElem}
|
||||||
row := getRowForFields(scope.Fields())
|
row := getRowForFields(scope.Fields())
|
||||||
rows = rows.AddRow(row...)
|
rows = rows.AddRow(row...)
|
||||||
rowsSet = append(rowsSet, rows)
|
|
||||||
}
|
}
|
||||||
} else if outVal.Kind() == reflect.Struct {
|
} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
|
||||||
scope := &Scope{Value: out}
|
scope := &Scope{Value: out}
|
||||||
row := getRowForFields(scope.Fields())
|
row := getRowForFields(scope.Fields())
|
||||||
rows = rows.AddRow(row...)
|
rows = rows.AddRow(row...)
|
||||||
rowsSet = append(rowsSet, rows)
|
|
||||||
|
|
||||||
for name, relation := range relations {
|
|
||||||
rVal := outVal.FieldByName(name)
|
|
||||||
relationRows, hasRows := getRelationRows(rVal, name, relation)
|
|
||||||
|
|
||||||
if hasRows {
|
|
||||||
spew.Printf("___GENERATED ROWS FOR %s___\r\n: %s\r\n", name, spew.Sdump(relationRows))
|
|
||||||
rowsSet = append(rowsSet, relationRows)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return rowsSet
|
return rows
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns accepts an out type which should either be a struct or slice. Under
|
// Returns accepts an out type which should either be a struct or slice. Under
|
||||||
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
|
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
|
||||||
// the underlying mock db
|
// the underlying mock db
|
||||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
||||||
rows := q.getRowsForOutType(out)
|
scope := (&Scope{}).New(out)
|
||||||
|
outVal := indirect(reflect.ValueOf(out))
|
||||||
|
|
||||||
for i, query := range q.queries {
|
// rows := q.getRowsForOutType(out)
|
||||||
query.WillReturnRows(rows[i])
|
destQuery := q.queries[0]
|
||||||
spew.Printf("___SET RETURN ROW___: %s", spew.Sdump(rows[i]))
|
subQueries := q.queries[1:]
|
||||||
|
|
||||||
|
// main query always at the head of the slice
|
||||||
|
q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)).
|
||||||
|
WillReturnRows(q.getDestRows(out))
|
||||||
|
|
||||||
|
// subqueries are preload
|
||||||
|
for _, subQuery := range subQueries {
|
||||||
|
if subQuery.preload != "" {
|
||||||
|
if field, ok := scope.FieldByName(subQuery.preload); ok {
|
||||||
|
expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql))
|
||||||
|
rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship)
|
||||||
|
|
||||||
|
if hasRows {
|
||||||
|
expectation.WillReturnRows(rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return q
|
return q
|
||||||
|
Loading…
x
Reference in New Issue
Block a user