Merge d630799e906b41467a47b801ec24b892e71a47e0 into 0a51f6cdc55d1650d9ed3b4c13026cfa9133b01e
This commit is contained in:
		
						commit
						3087813925
					
				@ -60,6 +60,7 @@ func preloadCallback(scope *Scope) {
 | 
			
		||||
						currentScope.handleBelongsToPreload(field, currentPreloadConditions)
 | 
			
		||||
					case "many_to_many":
 | 
			
		||||
						currentScope.handleManyToManyPreload(field, currentPreloadConditions)
 | 
			
		||||
 | 
			
		||||
					default:
 | 
			
		||||
						scope.Err(errors.New("unsupported relation"))
 | 
			
		||||
					}
 | 
			
		||||
@ -264,6 +265,8 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
 | 
			
		||||
 | 
			
		||||
// handleManyToManyPreload used to preload many to many associations
 | 
			
		||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
 | 
			
		||||
	// spew.Println("___ENTERING HANDLE MANY TO MANY___\r\n")
 | 
			
		||||
	// spew.Printf("___POPULATING %s___:\r\n%s\r\n", field.Name, spew.Sdump(field))
 | 
			
		||||
	var (
 | 
			
		||||
		relation         = field.Relationship
 | 
			
		||||
		joinTableHandler = relation.JoinTableHandler
 | 
			
		||||
@ -303,6 +306,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := preloadDB.Rows()
 | 
			
		||||
	// spew.Printf("___RETURNED ROWS___: \r\n%s\r\n", spew.Sdump(rows))
 | 
			
		||||
 | 
			
		||||
	if scope.Err(err) != nil {
 | 
			
		||||
		return
 | 
			
		||||
@ -312,6 +316,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 | 
			
		||||
	columns, _ := rows.Columns()
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var (
 | 
			
		||||
			// This is a Language zero value struct
 | 
			
		||||
			elem   = reflect.New(fieldType).Elem()
 | 
			
		||||
			fields = scope.New(elem.Addr().Interface()).Fields()
 | 
			
		||||
		)
 | 
			
		||||
@ -323,6 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		scope.scan(rows, columns, append(fields, joinTableFields...))
 | 
			
		||||
		// spew.Printf("___FIELDS___: \r\n%s\r\n", spew.Sdump(fields))
 | 
			
		||||
 | 
			
		||||
		var foreignKeys = make([]interface{}, len(sourceKeys))
 | 
			
		||||
		// generate hashed forkey keys in join table
 | 
			
		||||
@ -351,12 +357,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 | 
			
		||||
		foreignFieldNames  = []string{}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// spew.Printf("Foreign fields: %s", spew.Sdump(relation.ForeignFieldNames))
 | 
			
		||||
	for _, dbName := range relation.ForeignFieldNames {
 | 
			
		||||
		if field, ok := scope.FieldByName(dbName); ok {
 | 
			
		||||
			foreignFieldNames = append(foreignFieldNames, field.Name)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// spew.Printf("Scope value: %s", spew.Sdump(indirectScopeValue))
 | 
			
		||||
	if indirectScopeValue.Kind() == reflect.Slice {
 | 
			
		||||
		for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
			
		||||
			object := indirect(indirectScopeValue.Index(j))
 | 
			
		||||
@ -367,6 +375,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 | 
			
		||||
		key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
 | 
			
		||||
		fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// spew.Printf("Field source map: %s", spew.Sdump(fieldsSourceMap))
 | 
			
		||||
	// spew.Printf("Link hash: %s", spew.Sdump(linkHash))
 | 
			
		||||
	for source, link := range linkHash {
 | 
			
		||||
		for i, field := range fieldsSourceMap[source] {
 | 
			
		||||
			//If not 0 this means Value is a pointer and we already added preloaded models to it
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										225
									
								
								expecter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								expecter.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,225 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Recorder satisfies the logger interface
 | 
			
		||||
type Recorder struct {
 | 
			
		||||
	stmts   []Stmt
 | 
			
		||||
	preload []searchPreload // store it on Recorder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Stmt represents a sql statement. It can be an Exec, Query, or QueryRow
 | 
			
		||||
type Stmt struct {
 | 
			
		||||
	kind    string // can be Query, Exec, QueryRow
 | 
			
		||||
	preload string // contains schema if it is a preload query
 | 
			
		||||
	sql     string
 | 
			
		||||
	args    []interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func recordExecCallback(scope *Scope) {
 | 
			
		||||
	r, ok := scope.Get("gorm:recorder")
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic(fmt.Errorf("Expected a recorder to be set, but got none"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	stmt := Stmt{
 | 
			
		||||
		kind: "exec",
 | 
			
		||||
		sql:  scope.SQL,
 | 
			
		||||
		args: scope.SQLVars,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder := r.(*Recorder)
 | 
			
		||||
 | 
			
		||||
	recorder.Record(stmt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func recordQueryCallback(scope *Scope) {
 | 
			
		||||
	r, ok := scope.Get("gorm:recorder")
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic(fmt.Errorf("Expected a recorder to be set, but got none"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder := r.(*Recorder)
 | 
			
		||||
 | 
			
		||||
	stmt := Stmt{
 | 
			
		||||
		kind: "query",
 | 
			
		||||
		sql:  scope.SQL,
 | 
			
		||||
		args: scope.SQLVars,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(recorder.preload) > 0 {
 | 
			
		||||
		// this will cause the scope.SQL to mutate to the preload query
 | 
			
		||||
		stmt.preload = recorder.preload[0].schema
 | 
			
		||||
 | 
			
		||||
		// we just want to pop the first element off
 | 
			
		||||
		recorder.preload = recorder.preload[1:]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder.Record(stmt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func recordPreloadCallback(scope *Scope) {
 | 
			
		||||
	// this callback runs _before_ gorm:preload
 | 
			
		||||
	// it should record the next thing to be preloaded
 | 
			
		||||
	recorder, ok := scope.Get("gorm:recorder")
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic(fmt.Errorf("Expected a recorder to be set, but got none"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(scope.Search.preload) > 0 {
 | 
			
		||||
		// spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload))
 | 
			
		||||
		recorder.(*Recorder).preload = scope.Search.preload
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Record records a Stmt for use when SQL is finally executed
 | 
			
		||||
func (r *Recorder) Record(stmt Stmt) {
 | 
			
		||||
	r.stmts = append(r.stmts, stmt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetFirst returns the first recorded sql statement logged. If there are no
 | 
			
		||||
// statements, false is returned
 | 
			
		||||
func (r *Recorder) GetFirst() (Stmt, bool) {
 | 
			
		||||
	var stmt Stmt
 | 
			
		||||
	if len(r.stmts) > 0 {
 | 
			
		||||
		stmt = r.stmts[0]
 | 
			
		||||
		return stmt, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return stmt, false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsEmpty returns true if the statements slice is empty
 | 
			
		||||
func (r *Recorder) IsEmpty() bool {
 | 
			
		||||
	return len(r.stmts) == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AdapterFactory is a generic interface for arbitrary adapters that satisfy
 | 
			
		||||
// the interface. variadic args are passed to gorm.Open.
 | 
			
		||||
type AdapterFactory func(dialect string, args ...interface{}) (*DB, Adapter, error)
 | 
			
		||||
 | 
			
		||||
// Expecter is the exported struct used for setting expectations
 | 
			
		||||
type Expecter struct {
 | 
			
		||||
	// globally scoped expecter
 | 
			
		||||
	adapter  Adapter
 | 
			
		||||
	gorm     *DB
 | 
			
		||||
	recorder *Recorder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
 | 
			
		||||
func NewDefaultExpecter() (*DB, *Expecter, error) {
 | 
			
		||||
	gormDb, adapter, err := NewSqlmockAdapter("sqlmock", "mock_gorm_dsn")
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder := &Recorder{}
 | 
			
		||||
	noop, _ := NewNoopDB()
 | 
			
		||||
	gorm := &DB{
 | 
			
		||||
		db:        noop,
 | 
			
		||||
		logger:    defaultLogger,
 | 
			
		||||
		values:    map[string]interface{}{},
 | 
			
		||||
		callbacks: DefaultCallback,
 | 
			
		||||
		dialect:   newDialect("sqlmock", noop),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gorm.parent = gorm
 | 
			
		||||
	gorm = gorm.Set("gorm:recorder", recorder)
 | 
			
		||||
	gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordExecCallback)
 | 
			
		||||
	gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback)
 | 
			
		||||
	gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback)
 | 
			
		||||
	gorm.Callback().RowQuery().After("gorm:row_query").Register("gorm:record_query", recordQueryCallback)
 | 
			
		||||
	gorm.Callback().Update().After("gorm:update").Register("gorm:record_exec", recordExecCallback)
 | 
			
		||||
 | 
			
		||||
	return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewExpecter returns an Expecter for arbitrary adapters
 | 
			
		||||
func NewExpecter(fn AdapterFactory, dialect string, args ...interface{}) (*DB, *Expecter, error) {
 | 
			
		||||
	gormDb, adapter, err := fn(dialect, args...)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return gormDb, &Expecter{adapter: adapter}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* PUBLIC METHODS */
 | 
			
		||||
 | 
			
		||||
// AssertExpectations checks if all expected Querys and Execs were satisfied.
 | 
			
		||||
func (h *Expecter) AssertExpectations() error {
 | 
			
		||||
	return h.adapter.AssertExpectations()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Model sets scope.Value
 | 
			
		||||
func (h *Expecter) Model(model interface{}) *Expecter {
 | 
			
		||||
	h.gorm = h.gorm.Model(model)
 | 
			
		||||
	return h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* CREATE */
 | 
			
		||||
 | 
			
		||||
// Create mocks insertion of a model into the DB
 | 
			
		||||
func (h *Expecter) Create(model interface{}) ExpectedExec {
 | 
			
		||||
	h.gorm.Create(model)
 | 
			
		||||
	return h.adapter.ExpectExec(h.recorder.stmts[0])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* READ */
 | 
			
		||||
 | 
			
		||||
// First triggers a Query
 | 
			
		||||
func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	h.gorm.First(out, where...)
 | 
			
		||||
	return h.adapter.ExpectQuery(h.recorder.stmts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Find triggers a Query
 | 
			
		||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	h.gorm.Find(out, where...)
 | 
			
		||||
	return h.adapter.ExpectQuery(h.recorder.stmts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Preload clones the expecter and sets a preload condition on gorm.DB
 | 
			
		||||
func (h *Expecter) Preload(column string, conditions ...interface{}) *Expecter {
 | 
			
		||||
	clone := h.clone()
 | 
			
		||||
	clone.gorm = clone.gorm.Preload(column, conditions...)
 | 
			
		||||
 | 
			
		||||
	return clone
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* UPDATE */
 | 
			
		||||
 | 
			
		||||
// Save mocks updating a record in the DB and will trigger db.Exec()
 | 
			
		||||
func (h *Expecter) Save(model interface{}) ExpectedExec {
 | 
			
		||||
	h.gorm.Save(model)
 | 
			
		||||
	return h.adapter.ExpectExec(h.recorder.stmts[0])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Update mocks updating the given attributes in the DB
 | 
			
		||||
func (h *Expecter) Update(attrs ...interface{}) ExpectedExec {
 | 
			
		||||
	h.gorm.Update(attrs...)
 | 
			
		||||
	return h.adapter.ExpectExec(h.recorder.stmts[0])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Updates does the same thing as Update, but with map or struct
 | 
			
		||||
func (h *Expecter) Updates(values interface{}, ignoreProtectedAttrs ...bool) ExpectedExec {
 | 
			
		||||
	h.gorm.Updates(values, ignoreProtectedAttrs...)
 | 
			
		||||
	return h.adapter.ExpectExec(h.recorder.stmts[0])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* PRIVATE METHODS */
 | 
			
		||||
 | 
			
		||||
func (h *Expecter) clone() *Expecter {
 | 
			
		||||
	return &Expecter{
 | 
			
		||||
		adapter:  h.adapter,
 | 
			
		||||
		gorm:     h.gorm,
 | 
			
		||||
		recorder: h.recorder,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										68
									
								
								expecter_adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								expecter_adapter.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,68 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
 | 
			
		||||
	sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	db   *sql.DB
 | 
			
		||||
	mock sqlmock.Sqlmock
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	db, mock, err = sqlmock.NewWithDSN("mock_gorm_dsn")
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Adapter provides an abstract interface over concrete mock database
 | 
			
		||||
// implementations (e.g. go-sqlmock or go-testdb)
 | 
			
		||||
type Adapter interface {
 | 
			
		||||
	ExpectQuery(stmts ...Stmt) ExpectedQuery
 | 
			
		||||
	ExpectExec(stmt Stmt) ExpectedExec
 | 
			
		||||
	AssertExpectations() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SqlmockAdapter implemenets the Adapter interface using go-sqlmock
 | 
			
		||||
// it is the default Adapter
 | 
			
		||||
type SqlmockAdapter struct {
 | 
			
		||||
	db     *sql.DB
 | 
			
		||||
	mocker sqlmock.Sqlmock
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewSqlmockAdapter returns a mock gorm.DB and an Adapter backed by
 | 
			
		||||
// go-sqlmock
 | 
			
		||||
func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error) {
 | 
			
		||||
	gormDb, err := Open("sqlmock", "mock_gorm_dsn")
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return gormDb, &SqlmockAdapter{db: db, mocker: mock}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExpectQuery wraps the underlying mock method for setting a query
 | 
			
		||||
// expectation. It accepts multiple statements in the event of preloading
 | 
			
		||||
func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
 | 
			
		||||
	return &SqlmockQuery{mock: a.mocker, queries: queries}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExpectExec wraps the underlying mock method for setting a exec
 | 
			
		||||
// expectation
 | 
			
		||||
func (a *SqlmockAdapter) ExpectExec(exec Stmt) ExpectedExec {
 | 
			
		||||
	return &SqlmockExec{mock: a.mocker, exec: exec}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AssertExpectations asserts that _all_ expectations for a test have been met
 | 
			
		||||
// and returns an error specifying which have not if there are unmet
 | 
			
		||||
// expectations
 | 
			
		||||
func (a *SqlmockAdapter) AssertExpectations() error {
 | 
			
		||||
	return a.mocker.ExpectationsWereMet()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										186
									
								
								expecter_noop.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								expecter_noop.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,186 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var pool *NoopDriver
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	pool = &NoopDriver{
 | 
			
		||||
		conns: make(map[string]*NoopConnection),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sql.Register("noop", pool)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NoopDriver implements sql/driver.Driver
 | 
			
		||||
type NoopDriver struct {
 | 
			
		||||
	sync.Mutex
 | 
			
		||||
	counter int
 | 
			
		||||
	conns   map[string]*NoopConnection
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Open implements sql/driver.Driver
 | 
			
		||||
func (d *NoopDriver) Open(dsn string) (driver.Conn, error) {
 | 
			
		||||
	d.Lock()
 | 
			
		||||
	defer d.Unlock()
 | 
			
		||||
 | 
			
		||||
	c, ok := d.conns[dsn]
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return c, fmt.Errorf("No connection available")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.opened++
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NoopResult is a noop struct that satisfies sql.Result
 | 
			
		||||
type NoopResult struct{}
 | 
			
		||||
 | 
			
		||||
// LastInsertId is a noop method for satisfying drive.Result
 | 
			
		||||
func (r NoopResult) LastInsertId() (int64, error) {
 | 
			
		||||
	return 0, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RowsAffected is a noop method for satisfying drive.Result
 | 
			
		||||
func (r NoopResult) RowsAffected() (int64, error) {
 | 
			
		||||
	return 0, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NoopRows implements driver.Rows
 | 
			
		||||
type NoopRows struct {
 | 
			
		||||
	pos int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Columns implements driver.Rows
 | 
			
		||||
func (r *NoopRows) Columns() []string {
 | 
			
		||||
	return []string{"foo", "bar", "baz", "lol", "kek", "zzz"}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close implements driver.Rows
 | 
			
		||||
func (r *NoopRows) Close() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Next implements driver.Rows and alwys returns only one row
 | 
			
		||||
func (r *NoopRows) Next(dest []driver.Value) error {
 | 
			
		||||
	if r.pos == 1 {
 | 
			
		||||
		return io.EOF
 | 
			
		||||
	}
 | 
			
		||||
	cols := []string{"foo", "bar", "baz", "lol", "kek", "zzz"}
 | 
			
		||||
 | 
			
		||||
	for i, col := range cols {
 | 
			
		||||
		dest[i] = col
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.pos++
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NoopStmt implements driver.Stmt
 | 
			
		||||
type NoopStmt struct{}
 | 
			
		||||
 | 
			
		||||
// Close implements driver.Stmt
 | 
			
		||||
func (s *NoopStmt) Close() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NumInput implements driver.Stmt
 | 
			
		||||
func (s *NoopStmt) NumInput() int {
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Exec implements driver.Stmt
 | 
			
		||||
func (s *NoopStmt) Exec(args []driver.Value) (driver.Result, error) {
 | 
			
		||||
	return &NoopResult{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Query implements driver.Stmt
 | 
			
		||||
func (s *NoopStmt) Query(args []driver.Value) (driver.Rows, error) {
 | 
			
		||||
	return &NoopRows{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewNoopDB initialises a new DefaultNoopDB
 | 
			
		||||
func NewNoopDB() (SQLCommon, error) {
 | 
			
		||||
	pool.Lock()
 | 
			
		||||
	dsn := fmt.Sprintf("noop_db_%d", pool.counter)
 | 
			
		||||
	pool.counter++
 | 
			
		||||
 | 
			
		||||
	noop := &NoopConnection{dsn: dsn, drv: pool}
 | 
			
		||||
	pool.conns[dsn] = noop
 | 
			
		||||
	pool.Unlock()
 | 
			
		||||
 | 
			
		||||
	db, err := noop.open()
 | 
			
		||||
 | 
			
		||||
	return db, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NoopConnection implements sql/driver.Conn
 | 
			
		||||
// for our purposes, the noop connection never returns an error, as we only
 | 
			
		||||
// require it for generating queries. It is necessary because eagerloading
 | 
			
		||||
// will fail if any operation returns an error
 | 
			
		||||
type NoopConnection struct {
 | 
			
		||||
	dsn    string
 | 
			
		||||
	drv    *NoopDriver
 | 
			
		||||
	opened int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *NoopConnection) open() (*sql.DB, error) {
 | 
			
		||||
	db, err := sql.Open("noop", c.dsn)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return db, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db, db.Ping()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Close() error {
 | 
			
		||||
	c.drv.Lock()
 | 
			
		||||
	defer c.drv.Unlock()
 | 
			
		||||
 | 
			
		||||
	c.opened--
 | 
			
		||||
	if c.opened == 0 {
 | 
			
		||||
		delete(c.drv.conns, c.dsn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Begin implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Begin() (driver.Tx, error) {
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Exec implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result, error) {
 | 
			
		||||
	return NoopResult{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Prepare implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) {
 | 
			
		||||
	return &NoopStmt{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Query implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
 | 
			
		||||
	return &NoopRows{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Commit implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Commit() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Rollback implements sql/driver.Conn
 | 
			
		||||
func (c *NoopConnection) Rollback() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										259
									
								
								expecter_result.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										259
									
								
								expecter_result.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,259 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
 | 
			
		||||
	sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ExpectedQuery represents an expected query that will be executed and can
 | 
			
		||||
// return some rows. It presents a fluent API for chaining calls to other
 | 
			
		||||
// expectations
 | 
			
		||||
type ExpectedQuery interface {
 | 
			
		||||
	Returns(model interface{}) ExpectedQuery
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExpectedExec represents an expected exec that will be executed and can
 | 
			
		||||
// return a result. It presents a fluent API for chaining calls to other
 | 
			
		||||
// expectations
 | 
			
		||||
type ExpectedExec interface {
 | 
			
		||||
	WillSucceed(lastInsertID, rowsAffected int64) ExpectedExec
 | 
			
		||||
	WillFail(err error) ExpectedExec
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SqlmockQuery implements Query for go-sqlmock
 | 
			
		||||
type SqlmockQuery struct {
 | 
			
		||||
	mock    sqlmock.Sqlmock
 | 
			
		||||
	queries []Stmt
 | 
			
		||||
	scope   *Scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getRowForFields(fields []*Field) []driver.Value {
 | 
			
		||||
	var values []driver.Value
 | 
			
		||||
	for _, field := range fields {
 | 
			
		||||
		if field.IsNormal {
 | 
			
		||||
			value := field.Field
 | 
			
		||||
 | 
			
		||||
			// dereference pointers
 | 
			
		||||
			if field.Field.Kind() == reflect.Ptr {
 | 
			
		||||
				value = reflect.Indirect(field.Field)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// check if we have a zero Value
 | 
			
		||||
			// just append nil if it's not valid, so sqlmock won't complain
 | 
			
		||||
			if !value.IsValid() {
 | 
			
		||||
				values = append(values, nil)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			concreteVal := value.Interface()
 | 
			
		||||
			// spew.Printf("%v: %v\r\n", field.Name, concreteVal)
 | 
			
		||||
 | 
			
		||||
			if driver.IsValue(concreteVal) {
 | 
			
		||||
				values = append(values, concreteVal)
 | 
			
		||||
			} else if num, err := driver.DefaultParameterConverter.ConvertValue(concreteVal); err == nil {
 | 
			
		||||
				values = append(values, num)
 | 
			
		||||
			} else if valuer, ok := concreteVal.(driver.Valuer); ok {
 | 
			
		||||
				if convertedValue, err := valuer.Value(); err == nil {
 | 
			
		||||
					values = append(values, convertedValue)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return values
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
 | 
			
		||||
	var (
 | 
			
		||||
		rows    *sqlmock.Rows
 | 
			
		||||
		columns []string
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// we need to check for zero values
 | 
			
		||||
	if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
 | 
			
		||||
		// spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch relation.Kind {
 | 
			
		||||
	case "has_one":
 | 
			
		||||
		scope := &Scope{Value: rVal.Interface()}
 | 
			
		||||
 | 
			
		||||
		for _, field := range scope.GetModelStruct().StructFields {
 | 
			
		||||
			if field.IsNormal {
 | 
			
		||||
				columns = append(columns, field.DBName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rows = sqlmock.NewRows(columns)
 | 
			
		||||
 | 
			
		||||
		// we don't have a slice
 | 
			
		||||
		row := getRowForFields(scope.Fields())
 | 
			
		||||
		rows = rows.AddRow(row...)
 | 
			
		||||
 | 
			
		||||
		return rows, true
 | 
			
		||||
	case "has_many":
 | 
			
		||||
		elem := rVal.Type().Elem()
 | 
			
		||||
		scope := &Scope{Value: reflect.New(elem).Interface()}
 | 
			
		||||
 | 
			
		||||
		for _, field := range scope.GetModelStruct().StructFields {
 | 
			
		||||
			if field.IsNormal {
 | 
			
		||||
				columns = append(columns, field.DBName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rows = sqlmock.NewRows(columns)
 | 
			
		||||
 | 
			
		||||
		if rVal.Len() > 0 {
 | 
			
		||||
			for i := 0; i < rVal.Len(); i++ {
 | 
			
		||||
				scope := &Scope{Value: rVal.Index(i).Interface()}
 | 
			
		||||
				row := getRowForFields(scope.Fields())
 | 
			
		||||
				rows = rows.AddRow(row...)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return rows, true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil, false
 | 
			
		||||
	case "many_to_many":
 | 
			
		||||
		elem := rVal.Type().Elem()
 | 
			
		||||
		scope := &Scope{Value: reflect.New(elem).Interface()}
 | 
			
		||||
		joinTable := relation.JoinTableHandler.(*JoinTableHandler)
 | 
			
		||||
 | 
			
		||||
		for _, field := range scope.GetModelStruct().StructFields {
 | 
			
		||||
			if field.IsNormal {
 | 
			
		||||
				columns = append(columns, field.DBName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, key := range joinTable.Source.ForeignKeys {
 | 
			
		||||
			columns = append(columns, key.DBName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, key := range joinTable.Destination.ForeignKeys {
 | 
			
		||||
			columns = append(columns, key.DBName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rows = sqlmock.NewRows(columns)
 | 
			
		||||
 | 
			
		||||
		// in this case we definitely have a slice
 | 
			
		||||
		if rVal.Len() > 0 {
 | 
			
		||||
			for i := 0; i < rVal.Len(); i++ {
 | 
			
		||||
				scope := &Scope{Value: rVal.Index(i).Interface()}
 | 
			
		||||
				row := getRowForFields(scope.Fields())
 | 
			
		||||
 | 
			
		||||
				// need to append the values for join table keys
 | 
			
		||||
				sourcePk := q.scope.PrimaryKeyValue()
 | 
			
		||||
				destModelType := joinTable.Destination.ModelType
 | 
			
		||||
				destModelVal := reflect.New(destModelType).Interface()
 | 
			
		||||
				destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue()
 | 
			
		||||
 | 
			
		||||
				row = append(row, sourcePk, destPkVal)
 | 
			
		||||
 | 
			
		||||
				rows = rows.AddRow(row...)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return rows, true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil, false
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
 | 
			
		||||
	var columns []string
 | 
			
		||||
	for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
 | 
			
		||||
		if field.IsNormal {
 | 
			
		||||
			columns = append(columns, field.DBName)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows := sqlmock.NewRows(columns)
 | 
			
		||||
	outVal := indirect(reflect.ValueOf(out))
 | 
			
		||||
 | 
			
		||||
	// SELECT multiple columns
 | 
			
		||||
	if outVal.Kind() == reflect.Slice {
 | 
			
		||||
		outSlice := []interface{}{}
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < outVal.Len(); i++ {
 | 
			
		||||
			outSlice = append(outSlice, outVal.Index(i).Interface())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, outElem := range outSlice {
 | 
			
		||||
			scope := &Scope{Value: outElem}
 | 
			
		||||
			row := getRowForFields(scope.Fields())
 | 
			
		||||
			rows = rows.AddRow(row...)
 | 
			
		||||
		}
 | 
			
		||||
	} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
 | 
			
		||||
		row := getRowForFields(q.scope.Fields())
 | 
			
		||||
		rows = rows.AddRow(row...)
 | 
			
		||||
	} else {
 | 
			
		||||
		panic(fmt.Errorf("Can only get rows for slice or struct"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rows
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns accepts an out type which should either be a struct or slice. Under
 | 
			
		||||
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
 | 
			
		||||
// the underlying mock db
 | 
			
		||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
 | 
			
		||||
	scope := (&Scope{}).New(out)
 | 
			
		||||
	q.scope = scope
 | 
			
		||||
	outVal := indirect(reflect.ValueOf(out))
 | 
			
		||||
 | 
			
		||||
	// rows := q.getRowsForOutType(out)
 | 
			
		||||
	destQuery := q.queries[0]
 | 
			
		||||
	subQueries := q.queries[1:]
 | 
			
		||||
 | 
			
		||||
	// main query always at the head of the slice
 | 
			
		||||
	q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)).
 | 
			
		||||
		WillReturnRows(q.getDestRows(out))
 | 
			
		||||
 | 
			
		||||
	// subqueries are preload
 | 
			
		||||
	for _, subQuery := range subQueries {
 | 
			
		||||
		if subQuery.preload != "" {
 | 
			
		||||
			if field, ok := scope.FieldByName(subQuery.preload); ok {
 | 
			
		||||
				expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql))
 | 
			
		||||
				rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship)
 | 
			
		||||
 | 
			
		||||
				if hasRows {
 | 
			
		||||
					expectation.WillReturnRows(rows)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return q
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SqlmockExec implements Exec for go-sqlmock
 | 
			
		||||
type SqlmockExec struct {
 | 
			
		||||
	exec  Stmt
 | 
			
		||||
	mock  sqlmock.Sqlmock
 | 
			
		||||
	scope *Scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WillSucceed accepts a two int64s. They are passed directly to the underlying
 | 
			
		||||
// mock db. Useful for checking DAO behaviour in the event that the incorrect
 | 
			
		||||
// number of rows are affected by an Exec
 | 
			
		||||
func (e *SqlmockExec) WillSucceed(lastReturnedID, rowsAffected int64) ExpectedExec {
 | 
			
		||||
	result := sqlmock.NewResult(lastReturnedID, rowsAffected)
 | 
			
		||||
	e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
 | 
			
		||||
 | 
			
		||||
	return e
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WillFail simulates returning an Error from an unsuccessful exec
 | 
			
		||||
func (e *SqlmockExec) WillFail(err error) ExpectedExec {
 | 
			
		||||
	result := sqlmock.NewErrorResult(err)
 | 
			
		||||
	e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
 | 
			
		||||
 | 
			
		||||
	return e
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										322
									
								
								expecter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										322
									
								
								expecter_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,322 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewDefaultExpecter(t *testing.T) {
 | 
			
		||||
	db, _, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	//lint:ignore SA5001 just a mock
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewCustomExpecter(t *testing.T) {
 | 
			
		||||
	db, _, err := gorm.NewExpecter(gorm.NewSqlmockAdapter, "sqlmock", "mock_gorm_dsn")
 | 
			
		||||
	//lint:ignore SA5001 just a mock
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQuery(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expect.First(&User{})
 | 
			
		||||
	db.First(&User{})
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryReturn(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		db.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := User{Id: 1}
 | 
			
		||||
	out := User{Id: 1, Name: "jinzhu"}
 | 
			
		||||
 | 
			
		||||
	expect.First(&in).Returns(out)
 | 
			
		||||
 | 
			
		||||
	db.First(&in)
 | 
			
		||||
 | 
			
		||||
	if e := expect.AssertExpectations(); e != nil {
 | 
			
		||||
		t.Error(e)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if in.Name != "jinzhu" {
 | 
			
		||||
		t.Errorf("Expected %s, got %s", out.Name, in.Name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ne := reflect.DeepEqual(in, out); !ne {
 | 
			
		||||
		t.Errorf("Not equal")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindStructDest(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		db.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := &User{Id: 1}
 | 
			
		||||
 | 
			
		||||
	expect.Find(in)
 | 
			
		||||
	db.Find(&User{Id: 1})
 | 
			
		||||
 | 
			
		||||
	if e := expect.AssertExpectations(); e != nil {
 | 
			
		||||
		t.Error(e)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindSlice(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := []User{}
 | 
			
		||||
	out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}}
 | 
			
		||||
 | 
			
		||||
	expect.Find(&in).Returns(&out)
 | 
			
		||||
	db.Find(&in)
 | 
			
		||||
 | 
			
		||||
	if e := expect.AssertExpectations(); e != nil {
 | 
			
		||||
		t.Error(e)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ne := reflect.DeepEqual(in, out); !ne {
 | 
			
		||||
		t.Error("Expected equal slices")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockPreloadHasMany(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := User{Id: 1}
 | 
			
		||||
	outEmails := []Email{Email{Id: 1, UserId: 1}, Email{Id: 2, UserId: 1}}
 | 
			
		||||
	out := User{Id: 1, Emails: outEmails}
 | 
			
		||||
 | 
			
		||||
	expect.Preload("Emails").Find(&in).Returns(out)
 | 
			
		||||
	db.Preload("Emails").Find(&in)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(in, out) {
 | 
			
		||||
		t.Error("In and out are not equal")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockPreloadHasOne(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := User{Id: 1}
 | 
			
		||||
	out := User{Id: 1, CreditCard: CreditCard{Number: "12345678"}}
 | 
			
		||||
 | 
			
		||||
	expect.Preload("CreditCard").Find(&in).Returns(out)
 | 
			
		||||
	db.Preload("CreditCard").Find(&in)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(in, out) {
 | 
			
		||||
		t.Error("In and out are not equal")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockPreloadMany2Many(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := User{Id: 1}
 | 
			
		||||
	languages := []Language{Language{Name: "ZH"}}
 | 
			
		||||
	out := User{Id: 1, Languages: languages}
 | 
			
		||||
 | 
			
		||||
	expect.Preload("Languages").Find(&in).Returns(out)
 | 
			
		||||
	db.Preload("Languages").Find(&in)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(in, out) {
 | 
			
		||||
		t.Error("In and out are not equal")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockPreloadMultiple(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	creditCard := CreditCard{Number: "12345678"}
 | 
			
		||||
	languages := []Language{Language{Name: "ZH"}}
 | 
			
		||||
 | 
			
		||||
	in := User{Id: 1}
 | 
			
		||||
	out := User{Id: 1, Languages: languages, CreditCard: creditCard}
 | 
			
		||||
 | 
			
		||||
	expect.Preload("Languages").Preload("CreditCard").Find(&in).Returns(out)
 | 
			
		||||
	db.Preload("Languages").Preload("CreditCard").Find(&in)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(in, out) {
 | 
			
		||||
		t.Error("In and out are not equal")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockCreateBasic(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user := User{Name: "jinzhu"}
 | 
			
		||||
	expect.Create(&user).WillSucceed(1, 1)
 | 
			
		||||
	rowsAffected := db.Create(&user).RowsAffected
 | 
			
		||||
 | 
			
		||||
	if rowsAffected != 1 {
 | 
			
		||||
		t.Errorf("Expected rows affected to be 1 but got %d", rowsAffected)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Id != 1 {
 | 
			
		||||
		t.Errorf("User id field should be 1, but got %d", user.Id)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockCreateError(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mockError := errors.New("Could not insert user")
 | 
			
		||||
 | 
			
		||||
	user := User{Name: "jinzhu"}
 | 
			
		||||
	expect.Create(&user).WillFail(mockError)
 | 
			
		||||
 | 
			
		||||
	dbError := db.Create(&user).Error
 | 
			
		||||
 | 
			
		||||
	if dbError == nil || dbError != mockError {
 | 
			
		||||
		t.Errorf("Expected *DB.Error to be set, but it was not")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockSaveBasic(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user := User{Name: "jinzhu"}
 | 
			
		||||
	expect.Save(&user).WillSucceed(1, 1)
 | 
			
		||||
	expected := db.Save(&user)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Errorf("Expectations were not met %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if expected.RowsAffected != 1 || user.Id != 1 {
 | 
			
		||||
		t.Errorf("Expected result was not returned")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockUpdateBasic(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newName := "uhznij"
 | 
			
		||||
	user := User{Name: "jinzhu"}
 | 
			
		||||
 | 
			
		||||
	expect.Model(&user).Update("name", newName).WillSucceed(1, 1)
 | 
			
		||||
	db.Model(&user).Update("name", newName)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Errorf("Expectations were not met %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Name != newName {
 | 
			
		||||
		t.Errorf("Should have name %s but got %s", newName, user.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMockUpdatesBasic(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user := User{Name: "jinzhu", Age: 18}
 | 
			
		||||
	updated := User{Name: "jinzhu", Age: 88}
 | 
			
		||||
 | 
			
		||||
	expect.Model(&user).Updates(updated).WillSucceed(1, 1)
 | 
			
		||||
	db.Model(&user).Updates(updated)
 | 
			
		||||
 | 
			
		||||
	if err := expect.AssertExpectations(); err != nil {
 | 
			
		||||
		t.Errorf("Expectations were not met %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Age != updated.Age {
 | 
			
		||||
		t.Errorf("Should have age %d but got %d", user.Age, updated.Age)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -139,6 +139,11 @@ func (role Role) IsAdmin() bool {
 | 
			
		||||
 | 
			
		||||
type Num int64
 | 
			
		||||
 | 
			
		||||
func (i Num) Value() (driver.Value, error) {
 | 
			
		||||
	// guaranteed ok
 | 
			
		||||
	return int64(i), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *Num) Scan(src interface{}) error {
 | 
			
		||||
	switch s := src.(type) {
 | 
			
		||||
	case []byte:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user