Support preloading

This commit is contained in:
Ian Tan 2017-11-21 19:02:54 +08:00
parent da8c7c1802
commit b06542dc77
3 changed files with 77 additions and 13 deletions

View File

@ -2,6 +2,8 @@ package gorm
import (
"regexp"
"github.com/davecgh/go-spew/spew"
)
// Recorder satisfies the logger interface
@ -38,6 +40,7 @@ func getStmtFromLog(values ...interface{}) Stmt {
// Print just sets the last recorded SQL statement
// TODO: find a better way to extract SQL from log messages
func (r *Recorder) Print(args ...interface{}) {
spew.Dump(args...)
statement := getStmtFromLog(args...)
if statement.sql != "" {
@ -72,6 +75,7 @@ type Expecter struct {
adapter Adapter
gorm *DB
recorder *Recorder
preload []string // records fields to be preloaded
}
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
@ -134,7 +138,9 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
// Find triggers a Query
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
var q ExpectedQuery
var (
stmts []string
)
h.gorm.Find(out, where...)
if empty := h.recorder.IsEmpty(); empty {
@ -142,10 +148,10 @@ func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
}
for _, stmt := range h.recorder.stmts {
q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql))
stmts = append(stmts, regexp.QuoteMeta(stmt.sql))
}
return q
return h.adapter.ExpectQuery(stmts...)
}
// Preload clones the expecter and sets a preload condition on gorm.DB

View File

@ -24,7 +24,7 @@ func init() {
// Adapter provides an abstract interface over concrete mock database
// implementations (e.g. go-sqlmock or go-testdb)
type Adapter interface {
ExpectQuery(stmt string) ExpectedQuery
ExpectQuery(stmts ...string) ExpectedQuery
ExpectExec(stmt string) ExpectedExec
AssertExpectations() error
}
@ -49,11 +49,15 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
}
// ExpectQuery wraps the underlying mock method for setting a query
// expectation
func (a *SqlmockAdapter) ExpectQuery(stmt string) ExpectedQuery {
q := a.mocker.ExpectQuery(stmt)
// expectation. It accepts multiple statements in the event of preloading
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery {
var queries []*sqlmock.ExpectedQuery
return &SqlmockQuery{query: q}
for _, stmt := range stmts {
queries = append(queries, a.mocker.ExpectQuery(stmt))
}
return &SqlmockQuery{queries: queries}
}
// ExpectExec wraps the underlying mock method for setting a exec

View File

@ -24,7 +24,7 @@ type ExpectedExec interface {
// SqlmockQuery implements Query for go-sqlmock
type SqlmockQuery struct {
query *sqlmock.ExpectedQuery
queries []*sqlmock.ExpectedQuery
}
func getRowForFields(fields []*Field) []driver.Value {
@ -49,6 +49,8 @@ func getRowForFields(fields []*Field) []driver.Value {
if driver.IsValue(concreteVal) {
values = append(values, concreteVal)
} else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 {
values = append(values, value.Int())
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
if convertedValue, err := valuer.Value(); err == nil {
values = append(values, convertedValue)
@ -60,13 +62,28 @@ func getRowForFields(fields []*Field) []driver.Value {
return values
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
var columns []string
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
var (
columns []string
relations = make(map[string]*Relationship)
rowsSet []*sqlmock.Rows
)
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
// we get the primary model's columns here
if field.IsNormal {
columns = append(columns, field.DBName)
}
// check relations
if !field.IsNormal {
relationVal := reflect.ValueOf(field.Relationship)
isNil := relationVal.IsNil()
if !isNil {
relations[field.Name] = field.Relationship
}
}
}
rows := sqlmock.NewRows(columns)
@ -83,16 +100,50 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
scope := &Scope{Value: outElem}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
}
} else if outVal.Kind() == reflect.Struct {
scope := &Scope{Value: out}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
for name, relation := range relations {
switch relation.Kind {
case "has_many":
rVal := outVal.FieldByName(name)
rType := rVal.Type().Elem()
rScope := &Scope{Value: reflect.New(rType).Interface()}
rColumns := []string{}
for _, field := range rScope.GetModelStruct().StructFields {
rColumns = append(rColumns, field.DBName)
}
hasReturnRows := rVal.Len() > 0
// in this case we definitely have a slice
if hasReturnRows {
rRows := sqlmock.NewRows(rColumns)
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rRows = rRows.AddRow(row...)
rowsSet = append(rowsSet, rRows)
}
}
case "has_one":
case "many2many":
default:
continue
}
}
} else {
panic(fmt.Errorf("Can only get rows for slice or struct"))
}
return rows
return rowsSet
}
// Returns accepts an out type which should either be a struct or slice. Under
@ -100,7 +151,10 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
// the underlying mock db
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
rows := q.getRowsForOutType(out)
q.query = q.query.WillReturnRows(rows)
for i, query := range q.queries {
query.WillReturnRows(rows[i])
}
return q
}