Support preloading
This commit is contained in:
		
							parent
							
								
									da8c7c1802
								
							
						
					
					
						commit
						b06542dc77
					
				
							
								
								
									
										12
									
								
								expecter.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								expecter.go
									
									
									
									
									
								
							@ -2,6 +2,8 @@ package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"regexp"
 | 
			
		||||
 | 
			
		||||
	"github.com/davecgh/go-spew/spew"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Recorder satisfies the logger interface
 | 
			
		||||
@ -38,6 +40,7 @@ func getStmtFromLog(values ...interface{}) Stmt {
 | 
			
		||||
// Print just sets the last recorded SQL statement
 | 
			
		||||
// TODO: find a better way to extract SQL from log messages
 | 
			
		||||
func (r *Recorder) Print(args ...interface{}) {
 | 
			
		||||
	spew.Dump(args...)
 | 
			
		||||
	statement := getStmtFromLog(args...)
 | 
			
		||||
 | 
			
		||||
	if statement.sql != "" {
 | 
			
		||||
@ -72,6 +75,7 @@ type Expecter struct {
 | 
			
		||||
	adapter  Adapter
 | 
			
		||||
	gorm     *DB
 | 
			
		||||
	recorder *Recorder
 | 
			
		||||
	preload  []string // records fields to be preloaded
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
 | 
			
		||||
@ -134,7 +138,9 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
 | 
			
		||||
// Find triggers a Query
 | 
			
		||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	var q ExpectedQuery
 | 
			
		||||
	var (
 | 
			
		||||
		stmts []string
 | 
			
		||||
	)
 | 
			
		||||
	h.gorm.Find(out, where...)
 | 
			
		||||
 | 
			
		||||
	if empty := h.recorder.IsEmpty(); empty {
 | 
			
		||||
@ -142,10 +148,10 @@ func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, stmt := range h.recorder.stmts {
 | 
			
		||||
		q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql))
 | 
			
		||||
		stmts = append(stmts, regexp.QuoteMeta(stmt.sql))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return q
 | 
			
		||||
	return h.adapter.ExpectQuery(stmts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Preload clones the expecter and sets a preload condition on gorm.DB
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ func init() {
 | 
			
		||||
// Adapter provides an abstract interface over concrete mock database
 | 
			
		||||
// implementations (e.g. go-sqlmock or go-testdb)
 | 
			
		||||
type Adapter interface {
 | 
			
		||||
	ExpectQuery(stmt string) ExpectedQuery
 | 
			
		||||
	ExpectQuery(stmts ...string) ExpectedQuery
 | 
			
		||||
	ExpectExec(stmt string) ExpectedExec
 | 
			
		||||
	AssertExpectations() error
 | 
			
		||||
}
 | 
			
		||||
@ -49,11 +49,15 @@ func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExpectQuery wraps the underlying mock method for setting a query
 | 
			
		||||
// expectation
 | 
			
		||||
func (a *SqlmockAdapter) ExpectQuery(stmt string) ExpectedQuery {
 | 
			
		||||
	q := a.mocker.ExpectQuery(stmt)
 | 
			
		||||
// expectation. It accepts multiple statements in the event of preloading
 | 
			
		||||
func (a *SqlmockAdapter) ExpectQuery(stmts ...string) ExpectedQuery {
 | 
			
		||||
	var queries []*sqlmock.ExpectedQuery
 | 
			
		||||
 | 
			
		||||
	return &SqlmockQuery{query: q}
 | 
			
		||||
	for _, stmt := range stmts {
 | 
			
		||||
		queries = append(queries, a.mocker.ExpectQuery(stmt))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &SqlmockQuery{queries: queries}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExpectExec wraps the underlying mock method for setting a exec
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ type ExpectedExec interface {
 | 
			
		||||
 | 
			
		||||
// SqlmockQuery implements Query for go-sqlmock
 | 
			
		||||
type SqlmockQuery struct {
 | 
			
		||||
	query *sqlmock.ExpectedQuery
 | 
			
		||||
	queries []*sqlmock.ExpectedQuery
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getRowForFields(fields []*Field) []driver.Value {
 | 
			
		||||
@ -49,6 +49,8 @@ func getRowForFields(fields []*Field) []driver.Value {
 | 
			
		||||
 | 
			
		||||
			if driver.IsValue(concreteVal) {
 | 
			
		||||
				values = append(values, concreteVal)
 | 
			
		||||
			} else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 {
 | 
			
		||||
				values = append(values, value.Int())
 | 
			
		||||
			} else if valuer, ok := concreteVal.(driver.Valuer); ok {
 | 
			
		||||
				if convertedValue, err := valuer.Value(); err == nil {
 | 
			
		||||
					values = append(values, convertedValue)
 | 
			
		||||
@ -60,13 +62,28 @@ func getRowForFields(fields []*Field) []driver.Value {
 | 
			
		||||
	return values
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
 | 
			
		||||
	var columns []string
 | 
			
		||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
 | 
			
		||||
	var (
 | 
			
		||||
		columns   []string
 | 
			
		||||
		relations = make(map[string]*Relationship)
 | 
			
		||||
		rowsSet   []*sqlmock.Rows
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
 | 
			
		||||
		// we get the primary model's columns here
 | 
			
		||||
		if field.IsNormal {
 | 
			
		||||
			columns = append(columns, field.DBName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// check relations
 | 
			
		||||
		if !field.IsNormal {
 | 
			
		||||
			relationVal := reflect.ValueOf(field.Relationship)
 | 
			
		||||
			isNil := relationVal.IsNil()
 | 
			
		||||
 | 
			
		||||
			if !isNil {
 | 
			
		||||
				relations[field.Name] = field.Relationship
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows := sqlmock.NewRows(columns)
 | 
			
		||||
@ -83,16 +100,50 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
 | 
			
		||||
			scope := &Scope{Value: outElem}
 | 
			
		||||
			row := getRowForFields(scope.Fields())
 | 
			
		||||
			rows = rows.AddRow(row...)
 | 
			
		||||
			rowsSet = append(rowsSet, rows)
 | 
			
		||||
		}
 | 
			
		||||
	} else if outVal.Kind() == reflect.Struct {
 | 
			
		||||
		scope := &Scope{Value: out}
 | 
			
		||||
		row := getRowForFields(scope.Fields())
 | 
			
		||||
		rows = rows.AddRow(row...)
 | 
			
		||||
		rowsSet = append(rowsSet, rows)
 | 
			
		||||
 | 
			
		||||
		for name, relation := range relations {
 | 
			
		||||
			switch relation.Kind {
 | 
			
		||||
			case "has_many":
 | 
			
		||||
				rVal := outVal.FieldByName(name)
 | 
			
		||||
				rType := rVal.Type().Elem()
 | 
			
		||||
				rScope := &Scope{Value: reflect.New(rType).Interface()}
 | 
			
		||||
				rColumns := []string{}
 | 
			
		||||
 | 
			
		||||
				for _, field := range rScope.GetModelStruct().StructFields {
 | 
			
		||||
					rColumns = append(rColumns, field.DBName)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				hasReturnRows := rVal.Len() > 0
 | 
			
		||||
 | 
			
		||||
				// in this case we definitely have a slice
 | 
			
		||||
				if hasReturnRows {
 | 
			
		||||
					rRows := sqlmock.NewRows(rColumns)
 | 
			
		||||
 | 
			
		||||
					for i := 0; i < rVal.Len(); i++ {
 | 
			
		||||
						scope := &Scope{Value: rVal.Index(i).Interface()}
 | 
			
		||||
						row := getRowForFields(scope.Fields())
 | 
			
		||||
						rRows = rRows.AddRow(row...)
 | 
			
		||||
						rowsSet = append(rowsSet, rRows)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			case "has_one":
 | 
			
		||||
			case "many2many":
 | 
			
		||||
			default:
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		panic(fmt.Errorf("Can only get rows for slice or struct"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rows
 | 
			
		||||
	return rowsSet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns accepts an out type which should either be a struct or slice. Under
 | 
			
		||||
@ -100,7 +151,10 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
 | 
			
		||||
// the underlying mock db
 | 
			
		||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
 | 
			
		||||
	rows := q.getRowsForOutType(out)
 | 
			
		||||
	q.query = q.query.WillReturnRows(rows)
 | 
			
		||||
 | 
			
		||||
	for i, query := range q.queries {
 | 
			
		||||
		query.WillReturnRows(rows[i])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return q
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user