diff --git a/expecter_adapter.go b/expecter_adapter.go index 5ccc7271..62929678 100644 --- a/expecter_adapter.go +++ b/expecter_adapter.go @@ -24,7 +24,7 @@ func init() { // Adapter provides an abstract interface over concrete mock database // implementations (e.g. go-sqlmock or go-testdb) type Adapter interface { - ExpectQuery(stmts ...string) ExpectedQuery + ExpectQuery(stmts ...Stmt) ExpectedQuery ExpectExec(stmt string) ExpectedExec AssertExpectations() error } @@ -50,14 +50,8 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error // ExpectQuery wraps the underlying mock method for setting a query // expectation. It accepts multiple statements in the event of preloading -func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery { - var queries []*sqlmock.ExpectedQuery - - for _, stmt := range stmts { - queries = append(queries, a.mocker.ExpectQuery(stmt)) - } - - return &SqlmockQuery{queries: queries} +func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery { + return &SqlmockQuery{mock: a.mocker, queries: queries} } // ExpectExec wraps the underlying mock method for setting a exec diff --git a/expecter_result.go b/expecter_result.go index f5c35c93..f8faab5b 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -4,8 +4,8 @@ import ( "database/sql/driver" "fmt" "reflect" + "regexp" - "github.com/davecgh/go-spew/spew" sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" ) @@ -25,7 +25,8 @@ type ExpectedExec interface { // SqlmockQuery implements Query for go-sqlmock type SqlmockQuery struct { - queries []*sqlmock.ExpectedQuery + mock sqlmock.Sqlmock + queries []Stmt } func getRowForFields(fields []*Field) []driver.Value { @@ -64,7 +65,7 @@ func getRowForFields(fields []*Field) []driver.Value { return values } -func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) { +func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) { var ( rows *sqlmock.Rows columns []string @@ -72,7 +73,7 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi // we need to check for zero values if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) { - spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName) + // spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName) return nil, false } @@ -102,8 +103,8 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi columns = append(columns, field.DBName) } } - spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName) - spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns)) + // spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName) + // spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns)) columns = append(columns, "user_id", "language_id") rows = sqlmock.NewRows(columns) @@ -126,36 +127,21 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi } } -func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { - var ( - columns []string - relations = make(map[string]*Relationship) - rowsSet []*sqlmock.Rows - ) - +func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows { + var columns []string for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields { - // we get the primary model's columns here if field.IsNormal { columns = append(columns, field.DBName) } - - // check relations - if !field.IsNormal { - relationVal := reflect.ValueOf(field.Relationship) - isNil := relationVal.IsNil() - - if !isNil { - relations[field.Name] = field.Relationship - } - } } rows := sqlmock.NewRows(columns) - outVal := indirect(reflect.ValueOf(out)) + // SELECT multiple columns if outVal.Kind() == reflect.Slice { outSlice := []interface{}{} + for i := 0; i < outVal.Len(); i++ { outSlice = append(outSlice, outVal.Index(i).Interface()) } @@ -164,39 +150,45 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { scope := &Scope{Value: outElem} row := getRowForFields(scope.Fields()) rows = rows.AddRow(row...) - rowsSet = append(rowsSet, rows) } - } else if outVal.Kind() == reflect.Struct { + } else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1 scope := &Scope{Value: out} row := getRowForFields(scope.Fields()) rows = rows.AddRow(row...) - rowsSet = append(rowsSet, rows) - - for name, relation := range relations { - rVal := outVal.FieldByName(name) - relationRows, hasRows := getRelationRows(rVal, name, relation) - - if hasRows { - spew.Printf("___GENERATED ROWS FOR %s___\r\n: %s\r\n", name, spew.Sdump(relationRows)) - rowsSet = append(rowsSet, relationRows) - } - } } else { panic(fmt.Errorf("Can only get rows for slice or struct")) } - return rowsSet + return rows } // Returns accepts an out type which should either be a struct or slice. Under // the hood, it converts a gorm model struct to sql.Rows that can be passed to // the underlying mock db func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { - rows := q.getRowsForOutType(out) + scope := (&Scope{}).New(out) + outVal := indirect(reflect.ValueOf(out)) - for i, query := range q.queries { - query.WillReturnRows(rows[i]) - spew.Printf("___SET RETURN ROW___: %s", spew.Sdump(rows[i])) + // rows := q.getRowsForOutType(out) + destQuery := q.queries[0] + subQueries := q.queries[1:] + + // main query always at the head of the slice + q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)). + WillReturnRows(q.getDestRows(out)) + + // subqueries are preload + for _, subQuery := range subQueries { + if subQuery.preload != "" { + if field, ok := scope.FieldByName(subQuery.preload); ok { + expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql)) + rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship) + + if hasRows { + expectation.WillReturnRows(rows) + } + } + } } return q