Support basic mock Create use cases

This commit is contained in:
Ian Tan 2017-11-24 15:32:43 +08:00
parent 4128722761
commit 2368c373ae
5 changed files with 97 additions and 48 deletions

View File

@ -18,6 +18,24 @@ type Stmt struct {
args []interface{} args []interface{}
} }
func recordCreateCallback(scope *Scope) {
r, ok := scope.Get("gorm:recorder")
if !ok {
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
}
stmt := Stmt{
kind: "exec",
sql: scope.SQL,
args: scope.SQLVars,
}
recorder := r.(*Recorder)
recorder.Record(stmt)
}
func recordQueryCallback(scope *Scope) { func recordQueryCallback(scope *Scope) {
r, ok := scope.Get("gorm:recorder") r, ok := scope.Get("gorm:recorder")
@ -38,9 +56,6 @@ func recordQueryCallback(scope *Scope) {
scope.prepareQuerySQL() scope.prepareQuerySQL()
stmt.preload = recorder.preload[0].schema stmt.preload = recorder.preload[0].schema
// spew.Printf("_____PRELOADING_____\r\n%s\r\n", stmt.preload)
// spew.Printf("_____SQL_____\r\n%s\r\n", scope.SQL)
// we just want to pop the first element off // we just want to pop the first element off
recorder.preload = recorder.preload[1:] recorder.preload = recorder.preload[1:]
} }
@ -67,35 +82,6 @@ func (r *Recorder) Record(stmt Stmt) {
r.stmts = append(r.stmts, stmt) r.stmts = append(r.stmts, stmt)
} }
func getStmtFromLog(values ...interface{}) Stmt {
var statement Stmt
if len(values) > 1 {
var (
level = values[0]
)
if level == "sql" {
statement.args = values[4].([]interface{})
statement.sql = values[3].(string)
}
return statement
}
return statement
}
// Print just sets the last recorded SQL statement
// TODO: find a better way to extract SQL from log messages
func (r *Recorder) Print(args ...interface{}) {
statement := getStmtFromLog(args...)
if statement.sql != "" {
r.stmts = append(r.stmts, statement)
}
}
// GetFirst returns the first recorded sql statement logged. If there are no // GetFirst returns the first recorded sql statement logged. If there are no
// statements, false is returned // statements, false is returned
func (r *Recorder) GetFirst() (Stmt, bool) { func (r *Recorder) GetFirst() (Stmt, bool) {
@ -138,7 +124,6 @@ func NewDefaultExpecter() (*DB, *Expecter, error) {
gorm := &DB{ gorm := &DB{
db: noop, db: noop,
logger: defaultLogger, logger: defaultLogger,
logMode: 2,
values: map[string]interface{}{}, values: map[string]interface{}{},
callbacks: DefaultCallback, callbacks: DefaultCallback,
dialect: newDialect("sqlmock", noop), dialect: newDialect("sqlmock", noop),
@ -146,6 +131,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) {
gorm.parent = gorm gorm.parent = gorm
gorm = gorm.Set("gorm:recorder", recorder) gorm = gorm.Set("gorm:recorder", recorder)
gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordCreateCallback)
gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback) gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback)
gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback) gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback)
gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback) gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback)
@ -171,6 +157,16 @@ func (h *Expecter) AssertExpectations() error {
return h.adapter.AssertExpectations() return h.adapter.AssertExpectations()
} }
/* CREATE */
// Create mocks insertion of a model into the DB
func (h *Expecter) Create(model interface{}) ExpectedExec {
h.gorm.Create(model)
return h.adapter.ExpectExec(h.recorder.stmts[0])
}
/* READ */
// First triggers a Query // First triggers a Query
func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
h.gorm.First(out, where...) h.gorm.First(out, where...)

View File

@ -25,7 +25,7 @@ func init() {
// implementations (e.g. go-sqlmock or go-testdb) // implementations (e.g. go-sqlmock or go-testdb)
type Adapter interface { type Adapter interface {
ExpectQuery(stmts ...Stmt) ExpectedQuery ExpectQuery(stmts ...Stmt) ExpectedQuery
ExpectExec(stmt string) ExpectedExec ExpectExec(stmt Stmt) ExpectedExec
AssertExpectations() error AssertExpectations() error
} }
@ -56,10 +56,8 @@ func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
// ExpectExec wraps the underlying mock method for setting a exec // ExpectExec wraps the underlying mock method for setting a exec
// expectation // expectation
func (a *SqlmockAdapter) ExpectExec(stmt string) ExpectedExec { func (a *SqlmockAdapter) ExpectExec(exec Stmt) ExpectedExec {
e := a.mocker.ExpectExec(stmt) return &SqlmockExec{mock: a.mocker, exec: exec}
return &SqlmockExec{exec: e}
} }
// AssertExpectations asserts that _all_ expectations for a test have been met // AssertExpectations asserts that _all_ expectations for a test have been met

View File

@ -45,12 +45,12 @@ type NoopResult struct{}
// LastInsertId is a noop method for satisfying drive.Result // LastInsertId is a noop method for satisfying drive.Result
func (r NoopResult) LastInsertId() (int64, error) { func (r NoopResult) LastInsertId() (int64, error) {
return 1, nil return 0, nil
} }
// RowsAffected is a noop method for satisfying drive.Result // RowsAffected is a noop method for satisfying drive.Result
func (r NoopResult) RowsAffected() (int64, error) { func (r NoopResult) RowsAffected() (int64, error) {
return 1, nil return 0, nil
} }
// NoopRows implements driver.Rows // NoopRows implements driver.Rows

View File

@ -20,7 +20,8 @@ type ExpectedQuery interface {
// return a result. It presents a fluent API for chaining calls to other // return a result. It presents a fluent API for chaining calls to other
// expectations // expectations
type ExpectedExec interface { type ExpectedExec interface {
Returns(result driver.Result) ExpectedExec WillSucceed(lastInsertID, rowsAffected int64) ExpectedExec
WillFail(err error) ExpectedExec
} }
// SqlmockQuery implements Query for go-sqlmock // SqlmockQuery implements Query for go-sqlmock
@ -234,14 +235,25 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
// SqlmockExec implements Exec for go-sqlmock // SqlmockExec implements Exec for go-sqlmock
type SqlmockExec struct { type SqlmockExec struct {
exec *sqlmock.ExpectedExec exec Stmt
mock sqlmock.Sqlmock
scope *Scope
} }
// Returns accepts a driver.Result. It is passed directly to the underlying // WillSucceed accepts a two int64s. They are passed directly to the underlying
// mock db. Useful for checking DAO behaviour in the event that the incorrect // mock db. Useful for checking DAO behaviour in the event that the incorrect
// number of rows are affected by an Exec // number of rows are affected by an Exec
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec { func (e *SqlmockExec) WillSucceed(lastReturnedID, rowsAffected int64) ExpectedExec {
e.exec = e.exec.WillReturnResult(result) result := sqlmock.NewResult(lastReturnedID, rowsAffected)
e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
return e
}
// WillFail simulates returning an Error from an unsuccessful exec
func (e *SqlmockExec) WillFail(err error) ExpectedExec {
result := sqlmock.NewErrorResult(err)
e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
return e return e
} }

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -222,10 +223,52 @@ func TestMockPreloadMultiple(t *testing.T) {
t.Error(err) t.Error(err)
} }
// spew.Printf("______IN______\r\n%s\r\n", spew.Sdump(in))
// spew.Printf("______OUT______\r\n%s\r\n", spew.Sdump(out))
if !reflect.DeepEqual(in, out) { if !reflect.DeepEqual(in, out) {
t.Error("In and out are not equal") t.Error("In and out are not equal")
} }
} }
func TestMockCreateBasic(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
user := User{Name: "jinzhu"}
expect.Create(&user).WillSucceed(1, 1)
rowsAffected := db.Create(&user).RowsAffected
if rowsAffected != 1 {
t.Errorf("Expected rows affected to be 1 but got %d", rowsAffected)
}
if user.Id != 1 {
t.Errorf("User id field should be 1, but got %d", user.Id)
}
}
func TestMockCreateError(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
mockError := errors.New("Could not insert user")
user := User{Name: "jinzhu"}
expect.Create(&user).WillFail(mockError)
dbError := db.Create(&user).Error
if dbError == nil || dbError != mockError {
t.Errorf("Expected *DB.Error to be set, but it was not")
}
}