diff --git a/expecter.go b/expecter.go index dadbe8fc..5da09d09 100644 --- a/expecter.go +++ b/expecter.go @@ -2,6 +2,8 @@ package gorm import ( "regexp" + + "github.com/davecgh/go-spew/spew" ) // Recorder satisfies the logger interface @@ -38,6 +40,7 @@ func getStmtFromLog(values ...interface{}) Stmt { // Print just sets the last recorded SQL statement // TODO: find a better way to extract SQL from log messages func (r *Recorder) Print(args ...interface{}) { + spew.Dump(args...) statement := getStmtFromLog(args...) if statement.sql != "" { @@ -72,6 +75,7 @@ type Expecter struct { adapter Adapter gorm *DB recorder *Recorder + preload []string // records fields to be preloaded } // NewDefaultExpecter returns a Expecter powered by go-sqlmock @@ -134,7 +138,9 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { // Find triggers a Query func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { - var q ExpectedQuery + var ( + stmts []string + ) h.gorm.Find(out, where...) if empty := h.recorder.IsEmpty(); empty { @@ -142,10 +148,10 @@ func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { } for _, stmt := range h.recorder.stmts { - q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql)) + stmts = append(stmts, regexp.QuoteMeta(stmt.sql)) } - return q + return h.adapter.ExpectQuery(stmts...) } // Preload clones the expecter and sets a preload condition on gorm.DB diff --git a/expecter_adapter.go b/expecter_adapter.go index 7e3d3da9..5ccc7271 100644 --- a/expecter_adapter.go +++ b/expecter_adapter.go @@ -24,7 +24,7 @@ func init() { // Adapter provides an abstract interface over concrete mock database // implementations (e.g. go-sqlmock or go-testdb) type Adapter interface { - ExpectQuery(stmt string) ExpectedQuery + ExpectQuery(stmts ...string) ExpectedQuery ExpectExec(stmt string) ExpectedExec AssertExpectations() error } @@ -49,11 +49,15 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error } // ExpectQuery wraps the underlying mock method for setting a query -// expectation -func (a *SqlmockAdapter) ExpectQuery(stmt string) ExpectedQuery { - q := a.mocker.ExpectQuery(stmt) +// expectation. It accepts multiple statements in the event of preloading +func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery { + var queries []*sqlmock.ExpectedQuery - return &SqlmockQuery{query: q} + for _, stmt := range stmts { + queries = append(queries, a.mocker.ExpectQuery(stmt)) + } + + return &SqlmockQuery{queries: queries} } // ExpectExec wraps the underlying mock method for setting a exec diff --git a/expecter_result.go b/expecter_result.go index 0f2ecc6e..d659f97e 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -24,7 +24,7 @@ type ExpectedExec interface { // SqlmockQuery implements Query for go-sqlmock type SqlmockQuery struct { - query *sqlmock.ExpectedQuery + queries []*sqlmock.ExpectedQuery } func getRowForFields(fields []*Field) []driver.Value { @@ -49,6 +49,8 @@ func getRowForFields(fields []*Field) []driver.Value { if driver.IsValue(concreteVal) { values = append(values, concreteVal) + } else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 { + values = append(values, value.Int()) } else if valuer, ok := concreteVal.(driver.Valuer); ok { if convertedValue, err := valuer.Value(); err == nil { values = append(values, convertedValue) @@ -60,13 +62,28 @@ func getRowForFields(fields []*Field) []driver.Value { return values } -func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { - var columns []string +func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { + var ( + columns []string + relations = make(map[string]*Relationship) + rowsSet []*sqlmock.Rows + ) for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields { + // we get the primary model's columns here if field.IsNormal { columns = append(columns, field.DBName) } + + // check relations + if !field.IsNormal { + relationVal := reflect.ValueOf(field.Relationship) + isNil := relationVal.IsNil() + + if !isNil { + relations[field.Name] = field.Relationship + } + } } rows := sqlmock.NewRows(columns) @@ -83,16 +100,50 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { scope := &Scope{Value: outElem} row := getRowForFields(scope.Fields()) rows = rows.AddRow(row...) + rowsSet = append(rowsSet, rows) } } else if outVal.Kind() == reflect.Struct { scope := &Scope{Value: out} row := getRowForFields(scope.Fields()) rows = rows.AddRow(row...) + rowsSet = append(rowsSet, rows) + + for name, relation := range relations { + switch relation.Kind { + case "has_many": + rVal := outVal.FieldByName(name) + rType := rVal.Type().Elem() + rScope := &Scope{Value: reflect.New(rType).Interface()} + rColumns := []string{} + + for _, field := range rScope.GetModelStruct().StructFields { + rColumns = append(rColumns, field.DBName) + } + + hasReturnRows := rVal.Len() > 0 + + // in this case we definitely have a slice + if hasReturnRows { + rRows := sqlmock.NewRows(rColumns) + + for i := 0; i < rVal.Len(); i++ { + scope := &Scope{Value: rVal.Index(i).Interface()} + row := getRowForFields(scope.Fields()) + rRows = rRows.AddRow(row...) + rowsSet = append(rowsSet, rRows) + } + } + case "has_one": + case "many2many": + default: + continue + } + } } else { panic(fmt.Errorf("Can only get rows for slice or struct")) } - return rows + return rowsSet } // Returns accepts an out type which should either be a struct or slice. Under @@ -100,7 +151,10 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { // the underlying mock db func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { rows := q.getRowsForOutType(out) - q.query = q.query.WillReturnRows(rows) + + for i, query := range q.queries { + query.WillReturnRows(rows[i]) + } return q }