diff --git a/expecter.go b/expecter.go index aa12f30c..cd746436 100644 --- a/expecter.go +++ b/expecter.go @@ -18,6 +18,24 @@ type Stmt struct { args []interface{} } +func recordCreateCallback(scope *Scope) { + r, ok := scope.Get("gorm:recorder") + + if !ok { + panic(fmt.Errorf("Expected a recorder to be set, but got none")) + } + + stmt := Stmt{ + kind: "exec", + sql: scope.SQL, + args: scope.SQLVars, + } + + recorder := r.(*Recorder) + + recorder.Record(stmt) +} + func recordQueryCallback(scope *Scope) { r, ok := scope.Get("gorm:recorder") @@ -38,9 +56,6 @@ func recordQueryCallback(scope *Scope) { scope.prepareQuerySQL() stmt.preload = recorder.preload[0].schema - // spew.Printf("_____PRELOADING_____\r\n%s\r\n", stmt.preload) - // spew.Printf("_____SQL_____\r\n%s\r\n", scope.SQL) - // we just want to pop the first element off recorder.preload = recorder.preload[1:] } @@ -67,35 +82,6 @@ func (r *Recorder) Record(stmt Stmt) { r.stmts = append(r.stmts, stmt) } -func getStmtFromLog(values ...interface{}) Stmt { - var statement Stmt - - if len(values) > 1 { - var ( - level = values[0] - ) - - if level == "sql" { - statement.args = values[4].([]interface{}) - statement.sql = values[3].(string) - } - - return statement - } - - return statement -} - -// Print just sets the last recorded SQL statement -// TODO: find a better way to extract SQL from log messages -func (r *Recorder) Print(args ...interface{}) { - statement := getStmtFromLog(args...) - - if statement.sql != "" { - r.stmts = append(r.stmts, statement) - } -} - // GetFirst returns the first recorded sql statement logged. If there are no // statements, false is returned func (r *Recorder) GetFirst() (Stmt, bool) { @@ -138,7 +124,6 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { gorm := &DB{ db: noop, logger: defaultLogger, - logMode: 2, values: map[string]interface{}{}, callbacks: DefaultCallback, dialect: newDialect("sqlmock", noop), @@ -146,6 +131,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { gorm.parent = gorm gorm = gorm.Set("gorm:recorder", recorder) + gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordCreateCallback) gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback) gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback) gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback) @@ -171,6 +157,16 @@ func (h *Expecter) AssertExpectations() error { return h.adapter.AssertExpectations() } +/* CREATE */ + +// Create mocks insertion of a model into the DB +func (h *Expecter) Create(model interface{}) ExpectedExec { + h.gorm.Create(model) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +/* READ */ + // First triggers a Query func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { h.gorm.First(out, where...) diff --git a/expecter_adapter.go b/expecter_adapter.go index 62929678..81d4cde0 100644 --- a/expecter_adapter.go +++ b/expecter_adapter.go @@ -25,7 +25,7 @@ func init() { // implementations (e.g. go-sqlmock or go-testdb) type Adapter interface { ExpectQuery(stmts ...Stmt) ExpectedQuery - ExpectExec(stmt string) ExpectedExec + ExpectExec(stmt Stmt) ExpectedExec AssertExpectations() error } @@ -56,10 +56,8 @@ func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery { // ExpectExec wraps the underlying mock method for setting a exec // expectation -func (a *SqlmockAdapter) ExpectExec(stmt string) ExpectedExec { - e := a.mocker.ExpectExec(stmt) - - return &SqlmockExec{exec: e} +func (a *SqlmockAdapter) ExpectExec(exec Stmt) ExpectedExec { + return &SqlmockExec{mock: a.mocker, exec: exec} } // AssertExpectations asserts that _all_ expectations for a test have been met diff --git a/expecter_noop.go b/expecter_noop.go index fc9c9d89..04d10881 100644 --- a/expecter_noop.go +++ b/expecter_noop.go @@ -45,12 +45,12 @@ type NoopResult struct{} // LastInsertId is a noop method for satisfying drive.Result func (r NoopResult) LastInsertId() (int64, error) { - return 1, nil + return 0, nil } // RowsAffected is a noop method for satisfying drive.Result func (r NoopResult) RowsAffected() (int64, error) { - return 1, nil + return 0, nil } // NoopRows implements driver.Rows diff --git a/expecter_result.go b/expecter_result.go index f68ec388..dac76276 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -20,7 +20,8 @@ type ExpectedQuery interface { // return a result. It presents a fluent API for chaining calls to other // expectations type ExpectedExec interface { - Returns(result driver.Result) ExpectedExec + WillSucceed(lastInsertID, rowsAffected int64) ExpectedExec + WillFail(err error) ExpectedExec } // SqlmockQuery implements Query for go-sqlmock @@ -234,14 +235,25 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { // SqlmockExec implements Exec for go-sqlmock type SqlmockExec struct { - exec *sqlmock.ExpectedExec + exec Stmt + mock sqlmock.Sqlmock + scope *Scope } -// Returns accepts a driver.Result. It is passed directly to the underlying +// WillSucceed accepts a two int64s. They are passed directly to the underlying // mock db. Useful for checking DAO behaviour in the event that the incorrect // number of rows are affected by an Exec -func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec { - e.exec = e.exec.WillReturnResult(result) +func (e *SqlmockExec) WillSucceed(lastReturnedID, rowsAffected int64) ExpectedExec { + result := sqlmock.NewResult(lastReturnedID, rowsAffected) + e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result) + + return e +} + +// WillFail simulates returning an Error from an unsuccessful exec +func (e *SqlmockExec) WillFail(err error) ExpectedExec { + result := sqlmock.NewErrorResult(err) + e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result) return e } diff --git a/expecter_test.go b/expecter_test.go index 315f20a5..9b56d16e 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "errors" "fmt" "reflect" "testing" @@ -222,10 +223,52 @@ func TestMockPreloadMultiple(t *testing.T) { t.Error(err) } - // spew.Printf("______IN______\r\n%s\r\n", spew.Sdump(in)) - // spew.Printf("______OUT______\r\n%s\r\n", spew.Sdump(out)) - if !reflect.DeepEqual(in, out) { t.Error("In and out are not equal") } } + +func TestMockCreateBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + user := User{Name: "jinzhu"} + expect.Create(&user).WillSucceed(1, 1) + rowsAffected := db.Create(&user).RowsAffected + + if rowsAffected != 1 { + t.Errorf("Expected rows affected to be 1 but got %d", rowsAffected) + } + + if user.Id != 1 { + t.Errorf("User id field should be 1, but got %d", user.Id) + } +} + +func TestMockCreateError(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + mockError := errors.New("Could not insert user") + + user := User{Name: "jinzhu"} + expect.Create(&user).WillFail(mockError) + + dbError := db.Create(&user).Error + + if dbError == nil || dbError != mockError { + t.Errorf("Expected *DB.Error to be set, but it was not") + } +}