From 607e8f60e498588fb6715f1ff94b2a3c3d135669 Mon Sep 17 00:00:00 2001 From: Ian Tan Date: Fri, 24 Nov 2017 17:33:45 +0800 Subject: [PATCH] Support create and updates --- expecter.go | 31 ++++++++++++++++++-- expecter_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/expecter.go b/expecter.go index cd746436..0bcfb09d 100644 --- a/expecter.go +++ b/expecter.go @@ -18,7 +18,7 @@ type Stmt struct { args []interface{} } -func recordCreateCallback(scope *Scope) { +func recordExecCallback(scope *Scope) { r, ok := scope.Get("gorm:recorder") if !ok { @@ -131,10 +131,11 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { gorm.parent = gorm gorm = gorm.Set("gorm:recorder", recorder) - gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordCreateCallback) + gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordExecCallback) gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback) gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback) gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback) + gorm.Callback().Update().After("gorm:update").Register("gorm:record_exec", recordExecCallback) return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil } @@ -157,6 +158,12 @@ func (h *Expecter) AssertExpectations() error { return h.adapter.AssertExpectations() } +// Model sets scope.Value +func (h *Expecter) Model(model interface{}) *Expecter { + h.gorm = h.gorm.Model(model) + return h +} + /* CREATE */ // Create mocks insertion of a model into the DB @@ -187,6 +194,26 @@ func (h *Expecter) Preload(column string, conditions ...interface{}) *Expecter { return clone } +/* UPDATE */ + +// Save mocks updating a record in the DB and will trigger db.Exec() +func (h *Expecter) Save(model interface{}) ExpectedExec { + h.gorm.Save(model) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +// Update mocks updating the given attributes in the DB +func (h *Expecter) Update(attrs ...interface{}) ExpectedExec { + h.gorm.Update(attrs...) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +// Updates does the same thing as Update, but with map or struct +func (h *Expecter) Updates(values interface{}, ignoreProtectedAttrs ...bool) ExpectedExec { + h.gorm.Updates(values, ignoreProtectedAttrs...) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + /* PRIVATE METHODS */ func (h *Expecter) clone() *Expecter { diff --git a/expecter_test.go b/expecter_test.go index 9b56d16e..e6ac58e6 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -272,3 +272,76 @@ func TestMockCreateError(t *testing.T) { t.Errorf("Expected *DB.Error to be set, but it was not") } } + +func TestMockSaveBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + user := User{Name: "jinzhu"} + expect.Save(&user).WillSucceed(1, 1) + expected := db.Save(&user) + + if err := expect.AssertExpectations(); err != nil { + t.Errorf("Expectations were not met %s", err.Error()) + } + + if expected.RowsAffected != 1 || user.Id != 1 { + t.Errorf("Expected result was not returned") + } +} + +func TestMockUpdateBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + newName := "uhznij" + user := User{Name: "jinzhu"} + + expect.Model(&user).Update("name", newName).WillSucceed(1, 1) + db.Model(&user).Update("name", newName) + + if err := expect.AssertExpectations(); err != nil { + t.Errorf("Expectations were not met %s", err.Error()) + } + + if user.Name != newName { + t.Errorf("Should have name %s but got %s", newName, user.Name) + } +} + +func TestMockUpdatesBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + user := User{Name: "jinzhu", Age: 18} + updated := User{Name: "jinzhu", Age: 88} + + expect.Model(&user).Updates(updated).WillSucceed(1, 1) + db.Model(&user).Updates(updated) + + if err := expect.AssertExpectations(); err != nil { + t.Errorf("Expectations were not met %s", err.Error()) + } + + if user.Age != updated.Age { + t.Errorf("Should have age %d but got %d", user.Age, updated.Age) + } +}