Merge d630799e906b41467a47b801ec24b892e71a47e0 into 0a51f6cdc55d1650d9ed3b4c13026cfa9133b01e
This commit is contained in:
commit
3087813925
@ -60,6 +60,7 @@ func preloadCallback(scope *Scope) {
|
|||||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||||
case "many_to_many":
|
case "many_to_many":
|
||||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
scope.Err(errors.New("unsupported relation"))
|
scope.Err(errors.New("unsupported relation"))
|
||||||
}
|
}
|
||||||
@ -264,6 +265,8 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
|||||||
|
|
||||||
// handleManyToManyPreload used to preload many to many associations
|
// handleManyToManyPreload used to preload many to many associations
|
||||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||||
|
// spew.Println("___ENTERING HANDLE MANY TO MANY___\r\n")
|
||||||
|
// spew.Printf("___POPULATING %s___:\r\n%s\r\n", field.Name, spew.Sdump(field))
|
||||||
var (
|
var (
|
||||||
relation = field.Relationship
|
relation = field.Relationship
|
||||||
joinTableHandler = relation.JoinTableHandler
|
joinTableHandler = relation.JoinTableHandler
|
||||||
@ -303,6 +306,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
}
|
}
|
||||||
|
|
||||||
rows, err := preloadDB.Rows()
|
rows, err := preloadDB.Rows()
|
||||||
|
// spew.Printf("___RETURNED ROWS___: \r\n%s\r\n", spew.Sdump(rows))
|
||||||
|
|
||||||
if scope.Err(err) != nil {
|
if scope.Err(err) != nil {
|
||||||
return
|
return
|
||||||
@ -312,6 +316,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
|
// This is a Language zero value struct
|
||||||
elem = reflect.New(fieldType).Elem()
|
elem = reflect.New(fieldType).Elem()
|
||||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||||
)
|
)
|
||||||
@ -323,6 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
}
|
}
|
||||||
|
|
||||||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||||
|
// spew.Printf("___FIELDS___: \r\n%s\r\n", spew.Sdump(fields))
|
||||||
|
|
||||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||||
// generate hashed forkey keys in join table
|
// generate hashed forkey keys in join table
|
||||||
@ -351,12 +357,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
foreignFieldNames = []string{}
|
foreignFieldNames = []string{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// spew.Printf("Foreign fields: %s", spew.Sdump(relation.ForeignFieldNames))
|
||||||
for _, dbName := range relation.ForeignFieldNames {
|
for _, dbName := range relation.ForeignFieldNames {
|
||||||
if field, ok := scope.FieldByName(dbName); ok {
|
if field, ok := scope.FieldByName(dbName); ok {
|
||||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// spew.Printf("Scope value: %s", spew.Sdump(indirectScopeValue))
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
object := indirect(indirectScopeValue.Index(j))
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
@ -367,6 +375,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
||||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// spew.Printf("Field source map: %s", spew.Sdump(fieldsSourceMap))
|
||||||
|
// spew.Printf("Link hash: %s", spew.Sdump(linkHash))
|
||||||
for source, link := range linkHash {
|
for source, link := range linkHash {
|
||||||
for i, field := range fieldsSourceMap[source] {
|
for i, field := range fieldsSourceMap[source] {
|
||||||
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
||||||
|
225
expecter.go
Normal file
225
expecter.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recorder satisfies the logger interface
|
||||||
|
type Recorder struct {
|
||||||
|
stmts []Stmt
|
||||||
|
preload []searchPreload // store it on Recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stmt represents a sql statement. It can be an Exec, Query, or QueryRow
|
||||||
|
type Stmt struct {
|
||||||
|
kind string // can be Query, Exec, QueryRow
|
||||||
|
preload string // contains schema if it is a preload query
|
||||||
|
sql string
|
||||||
|
args []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordExecCallback(scope *Scope) {
|
||||||
|
r, ok := scope.Get("gorm:recorder")
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := Stmt{
|
||||||
|
kind: "exec",
|
||||||
|
sql: scope.SQL,
|
||||||
|
args: scope.SQLVars,
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := r.(*Recorder)
|
||||||
|
|
||||||
|
recorder.Record(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordQueryCallback(scope *Scope) {
|
||||||
|
r, ok := scope.Get("gorm:recorder")
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := r.(*Recorder)
|
||||||
|
|
||||||
|
stmt := Stmt{
|
||||||
|
kind: "query",
|
||||||
|
sql: scope.SQL,
|
||||||
|
args: scope.SQLVars,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(recorder.preload) > 0 {
|
||||||
|
// this will cause the scope.SQL to mutate to the preload query
|
||||||
|
stmt.preload = recorder.preload[0].schema
|
||||||
|
|
||||||
|
// we just want to pop the first element off
|
||||||
|
recorder.preload = recorder.preload[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder.Record(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordPreloadCallback(scope *Scope) {
|
||||||
|
// this callback runs _before_ gorm:preload
|
||||||
|
// it should record the next thing to be preloaded
|
||||||
|
recorder, ok := scope.Get("gorm:recorder")
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(scope.Search.preload) > 0 {
|
||||||
|
// spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload))
|
||||||
|
recorder.(*Recorder).preload = scope.Search.preload
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record records a Stmt for use when SQL is finally executed
|
||||||
|
func (r *Recorder) Record(stmt Stmt) {
|
||||||
|
r.stmts = append(r.stmts, stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFirst returns the first recorded sql statement logged. If there are no
|
||||||
|
// statements, false is returned
|
||||||
|
func (r *Recorder) GetFirst() (Stmt, bool) {
|
||||||
|
var stmt Stmt
|
||||||
|
if len(r.stmts) > 0 {
|
||||||
|
stmt = r.stmts[0]
|
||||||
|
return stmt, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return stmt, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEmpty returns true if the statements slice is empty
|
||||||
|
func (r *Recorder) IsEmpty() bool {
|
||||||
|
return len(r.stmts) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdapterFactory is a generic interface for arbitrary adapters that satisfy
|
||||||
|
// the interface. variadic args are passed to gorm.Open.
|
||||||
|
type AdapterFactory func(dialect string, args ...interface{}) (*DB, Adapter, error)
|
||||||
|
|
||||||
|
// Expecter is the exported struct used for setting expectations
|
||||||
|
type Expecter struct {
|
||||||
|
// globally scoped expecter
|
||||||
|
adapter Adapter
|
||||||
|
gorm *DB
|
||||||
|
recorder *Recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
|
||||||
|
func NewDefaultExpecter() (*DB, *Expecter, error) {
|
||||||
|
gormDb, adapter, err := NewSqlmockAdapter("sqlmock", "mock_gorm_dsn")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := &Recorder{}
|
||||||
|
noop, _ := NewNoopDB()
|
||||||
|
gorm := &DB{
|
||||||
|
db: noop,
|
||||||
|
logger: defaultLogger,
|
||||||
|
values: map[string]interface{}{},
|
||||||
|
callbacks: DefaultCallback,
|
||||||
|
dialect: newDialect("sqlmock", noop),
|
||||||
|
}
|
||||||
|
|
||||||
|
gorm.parent = gorm
|
||||||
|
gorm = gorm.Set("gorm:recorder", recorder)
|
||||||
|
gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordExecCallback)
|
||||||
|
gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback)
|
||||||
|
gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback)
|
||||||
|
gorm.Callback().RowQuery().After("gorm:row_query").Register("gorm:record_query", recordQueryCallback)
|
||||||
|
gorm.Callback().Update().After("gorm:update").Register("gorm:record_exec", recordExecCallback)
|
||||||
|
|
||||||
|
return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExpecter returns an Expecter for arbitrary adapters
|
||||||
|
func NewExpecter(fn AdapterFactory, dialect string, args ...interface{}) (*DB, *Expecter, error) {
|
||||||
|
gormDb, adapter, err := fn(dialect, args...)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return gormDb, &Expecter{adapter: adapter}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/* PUBLIC METHODS */
|
||||||
|
|
||||||
|
// AssertExpectations checks if all expected Querys and Execs were satisfied.
|
||||||
|
func (h *Expecter) AssertExpectations() error {
|
||||||
|
return h.adapter.AssertExpectations()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model sets scope.Value
|
||||||
|
func (h *Expecter) Model(model interface{}) *Expecter {
|
||||||
|
h.gorm = h.gorm.Model(model)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
/* CREATE */
|
||||||
|
|
||||||
|
// Create mocks insertion of a model into the DB
|
||||||
|
func (h *Expecter) Create(model interface{}) ExpectedExec {
|
||||||
|
h.gorm.Create(model)
|
||||||
|
return h.adapter.ExpectExec(h.recorder.stmts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
/* READ */
|
||||||
|
|
||||||
|
// First triggers a Query
|
||||||
|
func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
|
||||||
|
h.gorm.First(out, where...)
|
||||||
|
return h.adapter.ExpectQuery(h.recorder.stmts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find triggers a Query
|
||||||
|
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
|
||||||
|
h.gorm.Find(out, where...)
|
||||||
|
return h.adapter.ExpectQuery(h.recorder.stmts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preload clones the expecter and sets a preload condition on gorm.DB
|
||||||
|
func (h *Expecter) Preload(column string, conditions ...interface{}) *Expecter {
|
||||||
|
clone := h.clone()
|
||||||
|
clone.gorm = clone.gorm.Preload(column, conditions...)
|
||||||
|
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
/* UPDATE */
|
||||||
|
|
||||||
|
// Save mocks updating a record in the DB and will trigger db.Exec()
|
||||||
|
func (h *Expecter) Save(model interface{}) ExpectedExec {
|
||||||
|
h.gorm.Save(model)
|
||||||
|
return h.adapter.ExpectExec(h.recorder.stmts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update mocks updating the given attributes in the DB
|
||||||
|
func (h *Expecter) Update(attrs ...interface{}) ExpectedExec {
|
||||||
|
h.gorm.Update(attrs...)
|
||||||
|
return h.adapter.ExpectExec(h.recorder.stmts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updates does the same thing as Update, but with map or struct
|
||||||
|
func (h *Expecter) Updates(values interface{}, ignoreProtectedAttrs ...bool) ExpectedExec {
|
||||||
|
h.gorm.Updates(values, ignoreProtectedAttrs...)
|
||||||
|
return h.adapter.ExpectExec(h.recorder.stmts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
/* PRIVATE METHODS */
|
||||||
|
|
||||||
|
func (h *Expecter) clone() *Expecter {
|
||||||
|
return &Expecter{
|
||||||
|
adapter: h.adapter,
|
||||||
|
gorm: h.gorm,
|
||||||
|
recorder: h.recorder,
|
||||||
|
}
|
||||||
|
}
|
68
expecter_adapter.go
Normal file
68
expecter_adapter.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
db *sql.DB
|
||||||
|
mock sqlmock.Sqlmock
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
db, mock, err = sqlmock.NewWithDSN("mock_gorm_dsn")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adapter provides an abstract interface over concrete mock database
|
||||||
|
// implementations (e.g. go-sqlmock or go-testdb)
|
||||||
|
type Adapter interface {
|
||||||
|
ExpectQuery(stmts ...Stmt) ExpectedQuery
|
||||||
|
ExpectExec(stmt Stmt) ExpectedExec
|
||||||
|
AssertExpectations() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlmockAdapter implemenets the Adapter interface using go-sqlmock
|
||||||
|
// it is the default Adapter
|
||||||
|
type SqlmockAdapter struct {
|
||||||
|
db *sql.DB
|
||||||
|
mocker sqlmock.Sqlmock
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSqlmockAdapter returns a mock gorm.DB and an Adapter backed by
|
||||||
|
// go-sqlmock
|
||||||
|
func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error) {
|
||||||
|
gormDb, err := Open("sqlmock", "mock_gorm_dsn")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return gormDb, &SqlmockAdapter{db: db, mocker: mock}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectQuery wraps the underlying mock method for setting a query
|
||||||
|
// expectation. It accepts multiple statements in the event of preloading
|
||||||
|
func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
|
||||||
|
return &SqlmockQuery{mock: a.mocker, queries: queries}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectExec wraps the underlying mock method for setting a exec
|
||||||
|
// expectation
|
||||||
|
func (a *SqlmockAdapter) ExpectExec(exec Stmt) ExpectedExec {
|
||||||
|
return &SqlmockExec{mock: a.mocker, exec: exec}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertExpectations asserts that _all_ expectations for a test have been met
|
||||||
|
// and returns an error specifying which have not if there are unmet
|
||||||
|
// expectations
|
||||||
|
func (a *SqlmockAdapter) AssertExpectations() error {
|
||||||
|
return a.mocker.ExpectationsWereMet()
|
||||||
|
}
|
186
expecter_noop.go
Normal file
186
expecter_noop.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var pool *NoopDriver
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
pool = &NoopDriver{
|
||||||
|
conns: make(map[string]*NoopConnection),
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("noop", pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopDriver implements sql/driver.Driver
|
||||||
|
type NoopDriver struct {
|
||||||
|
sync.Mutex
|
||||||
|
counter int
|
||||||
|
conns map[string]*NoopConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open implements sql/driver.Driver
|
||||||
|
func (d *NoopDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
|
||||||
|
c, ok := d.conns[dsn]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return c, fmt.Errorf("No connection available")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.opened++
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopResult is a noop struct that satisfies sql.Result
|
||||||
|
type NoopResult struct{}
|
||||||
|
|
||||||
|
// LastInsertId is a noop method for satisfying drive.Result
|
||||||
|
func (r NoopResult) LastInsertId() (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowsAffected is a noop method for satisfying drive.Result
|
||||||
|
func (r NoopResult) RowsAffected() (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopRows implements driver.Rows
|
||||||
|
type NoopRows struct {
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Columns implements driver.Rows
|
||||||
|
func (r *NoopRows) Columns() []string {
|
||||||
|
return []string{"foo", "bar", "baz", "lol", "kek", "zzz"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements driver.Rows
|
||||||
|
func (r *NoopRows) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next implements driver.Rows and alwys returns only one row
|
||||||
|
func (r *NoopRows) Next(dest []driver.Value) error {
|
||||||
|
if r.pos == 1 {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
cols := []string{"foo", "bar", "baz", "lol", "kek", "zzz"}
|
||||||
|
|
||||||
|
for i, col := range cols {
|
||||||
|
dest[i] = col
|
||||||
|
}
|
||||||
|
|
||||||
|
r.pos++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopStmt implements driver.Stmt
|
||||||
|
type NoopStmt struct{}
|
||||||
|
|
||||||
|
// Close implements driver.Stmt
|
||||||
|
func (s *NoopStmt) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumInput implements driver.Stmt
|
||||||
|
func (s *NoopStmt) NumInput() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec implements driver.Stmt
|
||||||
|
func (s *NoopStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
|
return &NoopResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query implements driver.Stmt
|
||||||
|
func (s *NoopStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
|
return &NoopRows{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNoopDB initialises a new DefaultNoopDB
|
||||||
|
func NewNoopDB() (SQLCommon, error) {
|
||||||
|
pool.Lock()
|
||||||
|
dsn := fmt.Sprintf("noop_db_%d", pool.counter)
|
||||||
|
pool.counter++
|
||||||
|
|
||||||
|
noop := &NoopConnection{dsn: dsn, drv: pool}
|
||||||
|
pool.conns[dsn] = noop
|
||||||
|
pool.Unlock()
|
||||||
|
|
||||||
|
db, err := noop.open()
|
||||||
|
|
||||||
|
return db, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopConnection implements sql/driver.Conn
|
||||||
|
// for our purposes, the noop connection never returns an error, as we only
|
||||||
|
// require it for generating queries. It is necessary because eagerloading
|
||||||
|
// will fail if any operation returns an error
|
||||||
|
type NoopConnection struct {
|
||||||
|
dsn string
|
||||||
|
drv *NoopDriver
|
||||||
|
opened int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NoopConnection) open() (*sql.DB, error) {
|
||||||
|
db, err := sql.Open("noop", c.dsn)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return db, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, db.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Close() error {
|
||||||
|
c.drv.Lock()
|
||||||
|
defer c.drv.Unlock()
|
||||||
|
|
||||||
|
c.opened--
|
||||||
|
if c.opened == 0 {
|
||||||
|
delete(c.drv.conns, c.dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Begin() (driver.Tx, error) {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||||
|
return NoopResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) {
|
||||||
|
return &NoopStmt{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||||
|
return &NoopRows{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Commit() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rollback implements sql/driver.Conn
|
||||||
|
func (c *NoopConnection) Rollback() error {
|
||||||
|
return nil
|
||||||
|
}
|
259
expecter_result.go
Normal file
259
expecter_result.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExpectedQuery represents an expected query that will be executed and can
|
||||||
|
// return some rows. It presents a fluent API for chaining calls to other
|
||||||
|
// expectations
|
||||||
|
type ExpectedQuery interface {
|
||||||
|
Returns(model interface{}) ExpectedQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedExec represents an expected exec that will be executed and can
|
||||||
|
// return a result. It presents a fluent API for chaining calls to other
|
||||||
|
// expectations
|
||||||
|
type ExpectedExec interface {
|
||||||
|
WillSucceed(lastInsertID, rowsAffected int64) ExpectedExec
|
||||||
|
WillFail(err error) ExpectedExec
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlmockQuery implements Query for go-sqlmock
|
||||||
|
type SqlmockQuery struct {
|
||||||
|
mock sqlmock.Sqlmock
|
||||||
|
queries []Stmt
|
||||||
|
scope *Scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRowForFields(fields []*Field) []driver.Value {
|
||||||
|
var values []driver.Value
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.IsNormal {
|
||||||
|
value := field.Field
|
||||||
|
|
||||||
|
// dereference pointers
|
||||||
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
|
value = reflect.Indirect(field.Field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if we have a zero Value
|
||||||
|
// just append nil if it's not valid, so sqlmock won't complain
|
||||||
|
if !value.IsValid() {
|
||||||
|
values = append(values, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
concreteVal := value.Interface()
|
||||||
|
// spew.Printf("%v: %v\r\n", field.Name, concreteVal)
|
||||||
|
|
||||||
|
if driver.IsValue(concreteVal) {
|
||||||
|
values = append(values, concreteVal)
|
||||||
|
} else if num, err := driver.DefaultParameterConverter.ConvertValue(concreteVal); err == nil {
|
||||||
|
values = append(values, num)
|
||||||
|
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
|
||||||
|
if convertedValue, err := valuer.Value(); err == nil {
|
||||||
|
values = append(values, convertedValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
|
||||||
|
var (
|
||||||
|
rows *sqlmock.Rows
|
||||||
|
columns []string
|
||||||
|
)
|
||||||
|
|
||||||
|
// we need to check for zero values
|
||||||
|
if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
|
||||||
|
// spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch relation.Kind {
|
||||||
|
case "has_one":
|
||||||
|
scope := &Scope{Value: rVal.Interface()}
|
||||||
|
|
||||||
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
if field.IsNormal {
|
||||||
|
columns = append(columns, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows = sqlmock.NewRows(columns)
|
||||||
|
|
||||||
|
// we don't have a slice
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
|
||||||
|
return rows, true
|
||||||
|
case "has_many":
|
||||||
|
elem := rVal.Type().Elem()
|
||||||
|
scope := &Scope{Value: reflect.New(elem).Interface()}
|
||||||
|
|
||||||
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
if field.IsNormal {
|
||||||
|
columns = append(columns, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows = sqlmock.NewRows(columns)
|
||||||
|
|
||||||
|
if rVal.Len() > 0 {
|
||||||
|
for i := 0; i < rVal.Len(); i++ {
|
||||||
|
scope := &Scope{Value: rVal.Index(i).Interface()}
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
case "many_to_many":
|
||||||
|
elem := rVal.Type().Elem()
|
||||||
|
scope := &Scope{Value: reflect.New(elem).Interface()}
|
||||||
|
joinTable := relation.JoinTableHandler.(*JoinTableHandler)
|
||||||
|
|
||||||
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
if field.IsNormal {
|
||||||
|
columns = append(columns, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range joinTable.Source.ForeignKeys {
|
||||||
|
columns = append(columns, key.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range joinTable.Destination.ForeignKeys {
|
||||||
|
columns = append(columns, key.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows = sqlmock.NewRows(columns)
|
||||||
|
|
||||||
|
// in this case we definitely have a slice
|
||||||
|
if rVal.Len() > 0 {
|
||||||
|
for i := 0; i < rVal.Len(); i++ {
|
||||||
|
scope := &Scope{Value: rVal.Index(i).Interface()}
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
|
||||||
|
// need to append the values for join table keys
|
||||||
|
sourcePk := q.scope.PrimaryKeyValue()
|
||||||
|
destModelType := joinTable.Destination.ModelType
|
||||||
|
destModelVal := reflect.New(destModelType).Interface()
|
||||||
|
destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue()
|
||||||
|
|
||||||
|
row = append(row, sourcePk, destPkVal)
|
||||||
|
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
|
||||||
|
var columns []string
|
||||||
|
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
||||||
|
if field.IsNormal {
|
||||||
|
columns = append(columns, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows(columns)
|
||||||
|
outVal := indirect(reflect.ValueOf(out))
|
||||||
|
|
||||||
|
// SELECT multiple columns
|
||||||
|
if outVal.Kind() == reflect.Slice {
|
||||||
|
outSlice := []interface{}{}
|
||||||
|
|
||||||
|
for i := 0; i < outVal.Len(); i++ {
|
||||||
|
outSlice = append(outSlice, outVal.Index(i).Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, outElem := range outSlice {
|
||||||
|
scope := &Scope{Value: outElem}
|
||||||
|
row := getRowForFields(scope.Fields())
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
}
|
||||||
|
} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
|
||||||
|
row := getRowForFields(q.scope.Fields())
|
||||||
|
rows = rows.AddRow(row...)
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns accepts an out type which should either be a struct or slice. Under
|
||||||
|
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
|
||||||
|
// the underlying mock db
|
||||||
|
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
||||||
|
scope := (&Scope{}).New(out)
|
||||||
|
q.scope = scope
|
||||||
|
outVal := indirect(reflect.ValueOf(out))
|
||||||
|
|
||||||
|
// rows := q.getRowsForOutType(out)
|
||||||
|
destQuery := q.queries[0]
|
||||||
|
subQueries := q.queries[1:]
|
||||||
|
|
||||||
|
// main query always at the head of the slice
|
||||||
|
q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)).
|
||||||
|
WillReturnRows(q.getDestRows(out))
|
||||||
|
|
||||||
|
// subqueries are preload
|
||||||
|
for _, subQuery := range subQueries {
|
||||||
|
if subQuery.preload != "" {
|
||||||
|
if field, ok := scope.FieldByName(subQuery.preload); ok {
|
||||||
|
expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql))
|
||||||
|
rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship)
|
||||||
|
|
||||||
|
if hasRows {
|
||||||
|
expectation.WillReturnRows(rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlmockExec implements Exec for go-sqlmock
|
||||||
|
type SqlmockExec struct {
|
||||||
|
exec Stmt
|
||||||
|
mock sqlmock.Sqlmock
|
||||||
|
scope *Scope
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillSucceed accepts a two int64s. They are passed directly to the underlying
|
||||||
|
// mock db. Useful for checking DAO behaviour in the event that the incorrect
|
||||||
|
// number of rows are affected by an Exec
|
||||||
|
func (e *SqlmockExec) WillSucceed(lastReturnedID, rowsAffected int64) ExpectedExec {
|
||||||
|
result := sqlmock.NewResult(lastReturnedID, rowsAffected)
|
||||||
|
e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
|
||||||
|
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillFail simulates returning an Error from an unsuccessful exec
|
||||||
|
func (e *SqlmockExec) WillFail(err error) ExpectedExec {
|
||||||
|
result := sqlmock.NewErrorResult(err)
|
||||||
|
e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
|
||||||
|
|
||||||
|
return e
|
||||||
|
}
|
322
expecter_test.go
Normal file
322
expecter_test.go
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDefaultExpecter(t *testing.T) {
|
||||||
|
db, _, err := gorm.NewDefaultExpecter()
|
||||||
|
//lint:ignore SA5001 just a mock
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCustomExpecter(t *testing.T) {
|
||||||
|
db, _, err := gorm.NewExpecter(gorm.NewSqlmockAdapter, "sqlmock", "mock_gorm_dsn")
|
||||||
|
//lint:ignore SA5001 just a mock
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect.First(&User{})
|
||||||
|
db.First(&User{})
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryReturn(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer func() {
|
||||||
|
db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := User{Id: 1}
|
||||||
|
out := User{Id: 1, Name: "jinzhu"}
|
||||||
|
|
||||||
|
expect.First(&in).Returns(out)
|
||||||
|
|
||||||
|
db.First(&in)
|
||||||
|
|
||||||
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
|
t.Error(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Name != "jinzhu" {
|
||||||
|
t.Errorf("Expected %s, got %s", out.Name, in.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ne := reflect.DeepEqual(in, out); !ne {
|
||||||
|
t.Errorf("Not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindStructDest(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer func() {
|
||||||
|
db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := &User{Id: 1}
|
||||||
|
|
||||||
|
expect.Find(in)
|
||||||
|
db.Find(&User{Id: 1})
|
||||||
|
|
||||||
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
|
t.Error(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindSlice(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := []User{}
|
||||||
|
out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}}
|
||||||
|
|
||||||
|
expect.Find(&in).Returns(&out)
|
||||||
|
db.Find(&in)
|
||||||
|
|
||||||
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
|
t.Error(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ne := reflect.DeepEqual(in, out); !ne {
|
||||||
|
t.Error("Expected equal slices")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockPreloadHasMany(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := User{Id: 1}
|
||||||
|
outEmails := []Email{Email{Id: 1, UserId: 1}, Email{Id: 2, UserId: 1}}
|
||||||
|
out := User{Id: 1, Emails: outEmails}
|
||||||
|
|
||||||
|
expect.Preload("Emails").Find(&in).Returns(out)
|
||||||
|
db.Preload("Emails").Find(&in)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Error("In and out are not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockPreloadHasOne(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := User{Id: 1}
|
||||||
|
out := User{Id: 1, CreditCard: CreditCard{Number: "12345678"}}
|
||||||
|
|
||||||
|
expect.Preload("CreditCard").Find(&in).Returns(out)
|
||||||
|
db.Preload("CreditCard").Find(&in)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Error("In and out are not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockPreloadMany2Many(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := User{Id: 1}
|
||||||
|
languages := []Language{Language{Name: "ZH"}}
|
||||||
|
out := User{Id: 1, Languages: languages}
|
||||||
|
|
||||||
|
expect.Preload("Languages").Find(&in).Returns(out)
|
||||||
|
db.Preload("Languages").Find(&in)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Error("In and out are not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockPreloadMultiple(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
creditCard := CreditCard{Number: "12345678"}
|
||||||
|
languages := []Language{Language{Name: "ZH"}}
|
||||||
|
|
||||||
|
in := User{Id: 1}
|
||||||
|
out := User{Id: 1, Languages: languages, CreditCard: creditCard}
|
||||||
|
|
||||||
|
expect.Preload("Languages").Preload("CreditCard").Find(&in).Returns(out)
|
||||||
|
db.Preload("Languages").Preload("CreditCard").Find(&in)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Error("In and out are not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockCreateBasic(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
expect.Create(&user).WillSucceed(1, 1)
|
||||||
|
rowsAffected := db.Create(&user).RowsAffected
|
||||||
|
|
||||||
|
if rowsAffected != 1 {
|
||||||
|
t.Errorf("Expected rows affected to be 1 but got %d", rowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Id != 1 {
|
||||||
|
t.Errorf("User id field should be 1, but got %d", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockCreateError(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockError := errors.New("Could not insert user")
|
||||||
|
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
expect.Create(&user).WillFail(mockError)
|
||||||
|
|
||||||
|
dbError := db.Create(&user).Error
|
||||||
|
|
||||||
|
if dbError == nil || dbError != mockError {
|
||||||
|
t.Errorf("Expected *DB.Error to be set, but it was not")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockSaveBasic(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
expect.Save(&user).WillSucceed(1, 1)
|
||||||
|
expected := db.Save(&user)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Errorf("Expectations were not met %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected.RowsAffected != 1 || user.Id != 1 {
|
||||||
|
t.Errorf("Expected result was not returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockUpdateBasic(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newName := "uhznij"
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
|
||||||
|
expect.Model(&user).Update("name", newName).WillSucceed(1, 1)
|
||||||
|
db.Model(&user).Update("name", newName)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Errorf("Expectations were not met %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Name != newName {
|
||||||
|
t.Errorf("Should have name %s but got %s", newName, user.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockUpdatesBasic(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{Name: "jinzhu", Age: 18}
|
||||||
|
updated := User{Name: "jinzhu", Age: 88}
|
||||||
|
|
||||||
|
expect.Model(&user).Updates(updated).WillSucceed(1, 1)
|
||||||
|
db.Model(&user).Updates(updated)
|
||||||
|
|
||||||
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
|
t.Errorf("Expectations were not met %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Age != updated.Age {
|
||||||
|
t.Errorf("Should have age %d but got %d", user.Age, updated.Age)
|
||||||
|
}
|
||||||
|
}
|
@ -139,6 +139,11 @@ func (role Role) IsAdmin() bool {
|
|||||||
|
|
||||||
type Num int64
|
type Num int64
|
||||||
|
|
||||||
|
func (i Num) Value() (driver.Value, error) {
|
||||||
|
// guaranteed ok
|
||||||
|
return int64(i), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (i *Num) Scan(src interface{}) error {
|
func (i *Num) Scan(src interface{}) error {
|
||||||
switch s := src.(type) {
|
switch s := src.(type) {
|
||||||
case []byte:
|
case []byte:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user