Support preloading
This commit is contained in:
parent
da8c7c1802
commit
b06542dc77
12
expecter.go
12
expecter.go
@ -2,6 +2,8 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Recorder satisfies the logger interface
|
// Recorder satisfies the logger interface
|
||||||
@ -38,6 +40,7 @@ func getStmtFromLog(values ...interface{}) Stmt {
|
|||||||
// Print just sets the last recorded SQL statement
|
// Print just sets the last recorded SQL statement
|
||||||
// TODO: find a better way to extract SQL from log messages
|
// TODO: find a better way to extract SQL from log messages
|
||||||
func (r *Recorder) Print(args ...interface{}) {
|
func (r *Recorder) Print(args ...interface{}) {
|
||||||
|
spew.Dump(args...)
|
||||||
statement := getStmtFromLog(args...)
|
statement := getStmtFromLog(args...)
|
||||||
|
|
||||||
if statement.sql != "" {
|
if statement.sql != "" {
|
||||||
@ -72,6 +75,7 @@ type Expecter struct {
|
|||||||
adapter Adapter
|
adapter Adapter
|
||||||
gorm *DB
|
gorm *DB
|
||||||
recorder *Recorder
|
recorder *Recorder
|
||||||
|
preload []string // records fields to be preloaded
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
|
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
|
||||||
@ -134,7 +138,9 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
|
|||||||
|
|
||||||
// Find triggers a Query
|
// Find triggers a Query
|
||||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
||||||
var q ExpectedQuery
|
var (
|
||||||
|
stmts []string
|
||||||
|
)
|
||||||
h.gorm.Find(out, where...)
|
h.gorm.Find(out, where...)
|
||||||
|
|
||||||
if empty := h.recorder.IsEmpty(); empty {
|
if empty := h.recorder.IsEmpty(); empty {
|
||||||
@ -142,10 +148,10 @@ func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, stmt := range h.recorder.stmts {
|
for _, stmt := range h.recorder.stmts {
|
||||||
q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql))
|
stmts = append(stmts, regexp.QuoteMeta(stmt.sql))
|
||||||
}
|
}
|
||||||
|
|
||||||
return q
|
return h.adapter.ExpectQuery(stmts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Preload clones the expecter and sets a preload condition on gorm.DB
|
// Preload clones the expecter and sets a preload condition on gorm.DB
|
||||||
|
@ -24,7 +24,7 @@ func init() {
|
|||||||
// Adapter provides an abstract interface over concrete mock database
|
// Adapter provides an abstract interface over concrete mock database
|
||||||
// implementations (e.g. go-sqlmock or go-testdb)
|
// implementations (e.g. go-sqlmock or go-testdb)
|
||||||
type Adapter interface {
|
type Adapter interface {
|
||||||
ExpectQuery(stmt string) ExpectedQuery
|
ExpectQuery(stmts ...string) ExpectedQuery
|
||||||
ExpectExec(stmt string) ExpectedExec
|
ExpectExec(stmt string) ExpectedExec
|
||||||
AssertExpectations() error
|
AssertExpectations() error
|
||||||
}
|
}
|
||||||
@ -49,11 +49,15 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExpectQuery wraps the underlying mock method for setting a query
|
// ExpectQuery wraps the underlying mock method for setting a query
|
||||||
// expectation
|
// expectation. It accepts multiple statements in the event of preloading
|
||||||
func (a *SqlmockAdapter) ExpectQuery(stmt string) ExpectedQuery {
|
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery {
|
||||||
q := a.mocker.ExpectQuery(stmt)
|
var queries []*sqlmock.ExpectedQuery
|
||||||
|
|
||||||
return &SqlmockQuery{query: q}
|
for _, stmt := range stmts {
|
||||||
|
queries = append(queries, a.mocker.ExpectQuery(stmt))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SqlmockQuery{queries: queries}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpectExec wraps the underlying mock method for setting a exec
|
// ExpectExec wraps the underlying mock method for setting a exec
|
||||||
|
@ -24,7 +24,7 @@ type ExpectedExec interface {
|
|||||||
|
|
||||||
// SqlmockQuery implements Query for go-sqlmock
|
// SqlmockQuery implements Query for go-sqlmock
|
||||||
type SqlmockQuery struct {
|
type SqlmockQuery struct {
|
||||||
query *sqlmock.ExpectedQuery
|
queries []*sqlmock.ExpectedQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRowForFields(fields []*Field) []driver.Value {
|
func getRowForFields(fields []*Field) []driver.Value {
|
||||||
@ -49,6 +49,8 @@ func getRowForFields(fields []*Field) []driver.Value {
|
|||||||
|
|
||||||
if driver.IsValue(concreteVal) {
|
if driver.IsValue(concreteVal) {
|
||||||
values = append(values, concreteVal)
|
values = append(values, concreteVal)
|
||||||
|
} else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 {
|
||||||
|
values = append(values, value.Int())
|
||||||
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
|
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
|
||||||
if convertedValue, err := valuer.Value(); err == nil {
|
if convertedValue, err := valuer.Value(); err == nil {
|
||||||
values = append(values, convertedValue)
|
values = append(values, convertedValue)
|
||||||
@ -60,13 +62,28 @@ func getRowForFields(fields []*Field) []driver.Value {
|
|||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
|
||||||
var columns []string
|
var (
|
||||||
|
columns []string
|
||||||
|
relations = make(map[string]*Relationship)
|
||||||
|
rowsSet []*sqlmock.Rows
|
||||||
|
)
|
||||||
|
|
||||||
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
||||||
|
// we get the primary model's columns here
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
columns = append(columns, field.DBName)
|
columns = append(columns, field.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check relations
|
||||||
|
if !field.IsNormal {
|
||||||
|
relationVal := reflect.ValueOf(field.Relationship)
|
||||||
|
isNil := relationVal.IsNil()
|
||||||
|
|
||||||
|
if !isNil {
|
||||||
|
relations[field.Name] = field.Relationship
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := sqlmock.NewRows(columns)
|
rows := sqlmock.NewRows(columns)
|
||||||
@ -83,16 +100,50 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
|||||||
scope := &Scope{Value: outElem}
|
scope := &Scope{Value: outElem}
|
||||||
row := getRowForFields(scope.Fields())
|
row := getRowForFields(scope.Fields())
|
||||||
rows = rows.AddRow(row...)
|
rows = rows.AddRow(row...)
|
||||||
|
rowsSet = append(rowsSet, rows)
|
||||||
}
|
}
|
||||||
} else if outVal.Kind() == reflect.Struct {
|
} else if outVal.Kind() == reflect.Struct {
|
||||||
scope := &Scope{Value: out}
|
scope := &Scope{Value: out}
|
||||||
row := getRowForFields(scope.Fields())
|
row := getRowForFields(scope.Fields())
|
||||||
rows = rows.AddRow(row...)
|
rows = rows.AddRow(row...)
|
||||||
|
rowsSet = append(rowsSet, rows)
|
||||||
|
|
||||||
|
for name, relation := range relations {
|
||||||
|
switch relation.Kind {
|
||||||
|
case "has_many":
|
||||||
|
rVal := outVal.FieldByName(name)
|
||||||
|
rType := rVal.Type().Elem()
|
||||||
|
rScope := &Scope{Value: reflect.New(rType).Interface()}
|
||||||
|
rColumns := []string{}
|
||||||
|
|
||||||
|
for _, field := range rScope.GetModelStruct().StructFields {
|
||||||
|
rColumns = append(rColumns, field.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasReturnRows := rVal.Len() > 0
|
||||||
|
|
||||||
|
// in this case we definitely have a slice
|
||||||
|
if hasReturnRows {
|
||||||
|
rRows := sqlmock.NewRows(rColumns)
|
||||||
|
|
||||||
|
for i := 0; i < rVal.Len(); i++ {
|
||||||
|
scope := &Scope{Value: rVal.Index(i).Interface()}
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
rRows = rRows.AddRow(row...)
|
||||||
|
rowsSet = append(rowsSet, rRows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "has_one":
|
||||||
|
case "many2many":
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return rows
|
return rowsSet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns accepts an out type which should either be a struct or slice. Under
|
// Returns accepts an out type which should either be a struct or slice. Under
|
||||||
@ -100,7 +151,10 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
|||||||
// the underlying mock db
|
// the underlying mock db
|
||||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
||||||
rows := q.getRowsForOutType(out)
|
rows := q.getRowsForOutType(out)
|
||||||
q.query = q.query.WillReturnRows(rows)
|
|
||||||
|
for i, query := range q.queries {
|
||||||
|
query.WillReturnRows(rows[i])
|
||||||
|
}
|
||||||
|
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user