Support preloading
This commit is contained in:
parent
da8c7c1802
commit
b06542dc77
12
expecter.go
12
expecter.go
@ -2,6 +2,8 @@ package gorm
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
// Recorder satisfies the logger interface
|
||||
@ -38,6 +40,7 @@ func getStmtFromLog(values ...interface{}) Stmt {
|
||||
// Print just sets the last recorded SQL statement
|
||||
// TODO: find a better way to extract SQL from log messages
|
||||
func (r *Recorder) Print(args ...interface{}) {
|
||||
spew.Dump(args...)
|
||||
statement := getStmtFromLog(args...)
|
||||
|
||||
if statement.sql != "" {
|
||||
@ -72,6 +75,7 @@ type Expecter struct {
|
||||
adapter Adapter
|
||||
gorm *DB
|
||||
recorder *Recorder
|
||||
preload []string // records fields to be preloaded
|
||||
}
|
||||
|
||||
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
|
||||
@ -134,7 +138,9 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
|
||||
|
||||
// Find triggers a Query
|
||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
||||
var q ExpectedQuery
|
||||
var (
|
||||
stmts []string
|
||||
)
|
||||
h.gorm.Find(out, where...)
|
||||
|
||||
if empty := h.recorder.IsEmpty(); empty {
|
||||
@ -142,10 +148,10 @@ func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
||||
}
|
||||
|
||||
for _, stmt := range h.recorder.stmts {
|
||||
q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql))
|
||||
stmts = append(stmts, regexp.QuoteMeta(stmt.sql))
|
||||
}
|
||||
|
||||
return q
|
||||
return h.adapter.ExpectQuery(stmts...)
|
||||
}
|
||||
|
||||
// Preload clones the expecter and sets a preload condition on gorm.DB
|
||||
|
@ -24,7 +24,7 @@ func init() {
|
||||
// Adapter provides an abstract interface over concrete mock database
|
||||
// implementations (e.g. go-sqlmock or go-testdb)
|
||||
type Adapter interface {
|
||||
ExpectQuery(stmt string) ExpectedQuery
|
||||
ExpectQuery(stmts ...string) ExpectedQuery
|
||||
ExpectExec(stmt string) ExpectedExec
|
||||
AssertExpectations() error
|
||||
}
|
||||
@ -49,11 +49,15 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
|
||||
}
|
||||
|
||||
// ExpectQuery wraps the underlying mock method for setting a query
|
||||
// expectation
|
||||
func (a *SqlmockAdapter) ExpectQuery(stmt string) ExpectedQuery {
|
||||
q := a.mocker.ExpectQuery(stmt)
|
||||
// expectation. It accepts multiple statements in the event of preloading
|
||||
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery {
|
||||
var queries []*sqlmock.ExpectedQuery
|
||||
|
||||
return &SqlmockQuery{query: q}
|
||||
for _, stmt := range stmts {
|
||||
queries = append(queries, a.mocker.ExpectQuery(stmt))
|
||||
}
|
||||
|
||||
return &SqlmockQuery{queries: queries}
|
||||
}
|
||||
|
||||
// ExpectExec wraps the underlying mock method for setting a exec
|
||||
|
@ -24,7 +24,7 @@ type ExpectedExec interface {
|
||||
|
||||
// SqlmockQuery implements Query for go-sqlmock
|
||||
type SqlmockQuery struct {
|
||||
query *sqlmock.ExpectedQuery
|
||||
queries []*sqlmock.ExpectedQuery
|
||||
}
|
||||
|
||||
func getRowForFields(fields []*Field) []driver.Value {
|
||||
@ -49,6 +49,8 @@ func getRowForFields(fields []*Field) []driver.Value {
|
||||
|
||||
if driver.IsValue(concreteVal) {
|
||||
values = append(values, concreteVal)
|
||||
} else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 {
|
||||
values = append(values, value.Int())
|
||||
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
|
||||
if convertedValue, err := valuer.Value(); err == nil {
|
||||
values = append(values, convertedValue)
|
||||
@ -60,13 +62,28 @@ func getRowForFields(fields []*Field) []driver.Value {
|
||||
return values
|
||||
}
|
||||
|
||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
||||
var columns []string
|
||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
|
||||
var (
|
||||
columns []string
|
||||
relations = make(map[string]*Relationship)
|
||||
rowsSet []*sqlmock.Rows
|
||||
)
|
||||
|
||||
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
||||
// we get the primary model's columns here
|
||||
if field.IsNormal {
|
||||
columns = append(columns, field.DBName)
|
||||
}
|
||||
|
||||
// check relations
|
||||
if !field.IsNormal {
|
||||
relationVal := reflect.ValueOf(field.Relationship)
|
||||
isNil := relationVal.IsNil()
|
||||
|
||||
if !isNil {
|
||||
relations[field.Name] = field.Relationship
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows(columns)
|
||||
@ -83,16 +100,50 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
||||
scope := &Scope{Value: outElem}
|
||||
row := getRowForFields(scope.Fields())
|
||||
rows = rows.AddRow(row...)
|
||||
rowsSet = append(rowsSet, rows)
|
||||
}
|
||||
} else if outVal.Kind() == reflect.Struct {
|
||||
scope := &Scope{Value: out}
|
||||
row := getRowForFields(scope.Fields())
|
||||
rows = rows.AddRow(row...)
|
||||
rowsSet = append(rowsSet, rows)
|
||||
|
||||
for name, relation := range relations {
|
||||
switch relation.Kind {
|
||||
case "has_many":
|
||||
rVal := outVal.FieldByName(name)
|
||||
rType := rVal.Type().Elem()
|
||||
rScope := &Scope{Value: reflect.New(rType).Interface()}
|
||||
rColumns := []string{}
|
||||
|
||||
for _, field := range rScope.GetModelStruct().StructFields {
|
||||
rColumns = append(rColumns, field.DBName)
|
||||
}
|
||||
|
||||
hasReturnRows := rVal.Len() > 0
|
||||
|
||||
// in this case we definitely have a slice
|
||||
if hasReturnRows {
|
||||
rRows := sqlmock.NewRows(rColumns)
|
||||
|
||||
for i := 0; i < rVal.Len(); i++ {
|
||||
scope := &Scope{Value: rVal.Index(i).Interface()}
|
||||
row := getRowForFields(scope.Fields())
|
||||
rRows = rRows.AddRow(row...)
|
||||
rowsSet = append(rowsSet, rRows)
|
||||
}
|
||||
}
|
||||
case "has_one":
|
||||
case "many2many":
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
||||
}
|
||||
|
||||
return rows
|
||||
return rowsSet
|
||||
}
|
||||
|
||||
// Returns accepts an out type which should either be a struct or slice. Under
|
||||
@ -100,7 +151,10 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
||||
// the underlying mock db
|
||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
||||
rows := q.getRowsForOutType(out)
|
||||
q.query = q.query.WillReturnRows(rows)
|
||||
|
||||
for i, query := range q.queries {
|
||||
query.WillReturnRows(rows[i])
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user