From da8c7c18020d8aef8a59913403dbfdf56a43651b Mon Sep 17 00:00:00 2001 From: Ian Tan Date: Tue, 21 Nov 2017 17:10:10 +0800 Subject: [PATCH] Add noop db driver --- expecter.go | 112 +++++++++++++++------------ expecter_noop.go | 193 +++++++++++++++++++++++++++++++++++++++++++++++ expecter_test.go | 37 +++++++-- 3 files changed, 285 insertions(+), 57 deletions(-) create mode 100644 expecter_noop.go diff --git a/expecter.go b/expecter.go index 1dab967c..dadbe8fc 100644 --- a/expecter.go +++ b/expecter.go @@ -1,14 +1,12 @@ package gorm import ( - "database/sql" - "errors" "regexp" ) // Recorder satisfies the logger interface type Recorder struct { - stmt string + stmts []Stmt } // Stmt represents a sql statement. It can be an Exec or Query @@ -41,11 +39,29 @@ func getStmtFromLog(values ...interface{}) Stmt { // TODO: find a better way to extract SQL from log messages func (r *Recorder) Print(args ...interface{}) { statement := getStmtFromLog(args...) + if statement.sql != "" { - r.stmt = statement.sql + r.stmts = append(r.stmts, statement) } } +// GetFirst returns the first recorded sql statement logged. If there are no +// statements, false is returned +func (r *Recorder) GetFirst() (Stmt, bool) { + var stmt Stmt + if len(r.stmts) > 0 { + stmt = r.stmts[0] + return stmt, true + } + + return stmt, false +} + +// IsEmpty returns true if the statements slice is empty +func (r *Recorder) IsEmpty() bool { + return len(r.stmts) == 0 +} + // AdapterFactory is a generic interface for arbitrary adapters that satisfy // the interface. variadic args are passed to gorm.Open. type AdapterFactory func(dialect string, args ...interface{}) (*DB, Adapter, error) @@ -54,52 +70,10 @@ type AdapterFactory func(dialect string, args ...interface{}) (*DB, Adapter, err type Expecter struct { // globally scoped expecter adapter Adapter - noop SQLCommon gorm *DB recorder *Recorder } -// DefaultNoopDB is a noop db used to get generated sql from gorm.DB -type DefaultNoopDB struct{} - -// NoopResult is a noop struct that satisfies sql.Result -type NoopResult struct{} - -// LastInsertId is a noop method for satisfying drive.Result -func (r NoopResult) LastInsertId() (int64, error) { - return 1, nil -} - -// RowsAffected is a noop method for satisfying drive.Result -func (r NoopResult) RowsAffected() (int64, error) { - return 1, nil -} - -// NewNoopDB initialises a new DefaultNoopDB -func NewNoopDB() SQLCommon { - return &DefaultNoopDB{} -} - -// Exec simulates a sql.DB.Exec -func (r *DefaultNoopDB) Exec(query string, args ...interface{}) (sql.Result, error) { - return NoopResult{}, nil -} - -// Prepare simulates a sql.DB.Prepare -func (r *DefaultNoopDB) Prepare(query string) (*sql.Stmt, error) { - return &sql.Stmt{}, nil -} - -// Query simulates a sql.DB.Query -func (r *DefaultNoopDB) Query(query string, args ...interface{}) (*sql.Rows, error) { - return nil, errors.New("noop") -} - -// QueryRow simulates a sql.DB.QueryRow -func (r *DefaultNoopDB) QueryRow(query string, args ...interface{}) *sql.Row { - return &sql.Row{} -} - // NewDefaultExpecter returns a Expecter powered by go-sqlmock func NewDefaultExpecter() (*DB, *Expecter, error) { gormDb, adapter, err := NewSqlmockAdapter("sqlmock", "mock_gorm_dsn") @@ -109,7 +83,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { } recorder := &Recorder{} - noop := &DefaultNoopDB{} + noop, _ := NewNoopDB() gorm := &DB{ db: noop, logger: recorder, @@ -121,7 +95,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { gorm.parent = gorm - return gormDb, &Expecter{adapter: adapter, noop: noop, gorm: gorm, recorder: recorder}, nil + return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil } // NewExpecter returns an Expecter for arbitrary adapters @@ -144,12 +118,50 @@ func (h *Expecter) AssertExpectations() error { // First triggers a Query func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { + var q ExpectedQuery h.gorm.First(out, where...) - return h.adapter.ExpectQuery(regexp.QuoteMeta(h.recorder.stmt)) + + if empty := h.recorder.IsEmpty(); empty { + panic("No recorded statements") + } + + for _, stmt := range h.recorder.stmts { + q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql)) + } + + return q } // Find triggers a Query func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { + var q ExpectedQuery h.gorm.Find(out, where...) - return h.adapter.ExpectQuery(regexp.QuoteMeta(h.recorder.stmt)) + + if empty := h.recorder.IsEmpty(); empty { + panic("No recorded statements") + } + + for _, stmt := range h.recorder.stmts { + q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql)) + } + + return q +} + +// Preload clones the expecter and sets a preload condition on gorm.DB +func (h *Expecter) Preload(column string, conditions ...interface{}) *Expecter { + clone := h.clone() + clone.gorm = clone.gorm.Preload(column, conditions...) + + return clone +} + +/* PRIVATE METHODS */ + +func (h *Expecter) clone() *Expecter { + return &Expecter{ + adapter: h.adapter, + gorm: h.gorm, + recorder: h.recorder, + } } diff --git a/expecter_noop.go b/expecter_noop.go new file mode 100644 index 00000000..0ea9732d --- /dev/null +++ b/expecter_noop.go @@ -0,0 +1,193 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "io" + "sync" + + "github.com/davecgh/go-spew/spew" +) + +var pool *NoopDriver + +func init() { + pool = &NoopDriver{ + conns: make(map[string]*NoopConnection), + } + + sql.Register("noop", pool) +} + +// NoopDriver implements sql/driver.Driver +type NoopDriver struct { + sync.Mutex + counter int + conns map[string]*NoopConnection +} + +// Open implements sql/driver.Driver +func (d *NoopDriver) Open(dsn string) (driver.Conn, error) { + d.Lock() + defer d.Unlock() + + c, ok := d.conns[dsn] + + if !ok { + return c, fmt.Errorf("No connection available") + } + + c.opened++ + return c, nil +} + +// NoopResult is a noop struct that satisfies sql.Result +type NoopResult struct{} + +// LastInsertId is a noop method for satisfying drive.Result +func (r NoopResult) LastInsertId() (int64, error) { + return 1, nil +} + +// RowsAffected is a noop method for satisfying drive.Result +func (r NoopResult) RowsAffected() (int64, error) { + return 1, nil +} + +// NoopRows implements driver.Rows +type NoopRows struct { + pos int +} + +// Columns implements driver.Rows +func (r *NoopRows) Columns() []string { + return []string{"foo", "bar", "baz", "lol", "kek", "zzz"} +} + +// Close implements driver.Rows +func (r *NoopRows) Close() error { + return nil +} + +// Next implements driver.Rows and alwys returns only one row +func (r *NoopRows) Next(dest []driver.Value) error { + if r.pos == 1 { + return io.EOF + } + cols := []string{"foo", "bar", "baz", "lol", "kek", "zzz"} + + for i, col := range cols { + dest[i] = col + } + + r.pos++ + + return nil +} + +// NoopStmt implements driver.Stmt +type NoopStmt struct{} + +// Close implements driver.Stmt +func (s *NoopStmt) Close() error { + return nil +} + +// NumInput implements driver.Stmt +func (s *NoopStmt) NumInput() int { + return 1 +} + +// Exec implements driver.Stmt +func (s *NoopStmt) Exec(args []driver.Value) (driver.Result, error) { + return &NoopResult{}, nil +} + +// Query implements driver.Stmt +func (s *NoopStmt) Query(args []driver.Value) (driver.Rows, error) { + return &NoopRows{}, nil +} + +// NewNoopDB initialises a new DefaultNoopDB +func NewNoopDB() (SQLCommon, error) { + pool.Lock() + dsn := fmt.Sprintf("noop_db_%d", pool.counter) + pool.counter++ + + noop := &NoopConnection{dsn: dsn, drv: pool} + pool.conns[dsn] = noop + pool.Unlock() + + db, err := noop.open() + + return db, err +} + +// NoopConnection implements sql/driver.Conn +// for our purposes, the noop connection never returns an error, as we only +// require it for generating queries. It is necessary because eagerloading +// will fail if any operation returns an error +type NoopConnection struct { + dsn string + drv *NoopDriver + opened int +} + +func (c *NoopConnection) open() (*sql.DB, error) { + db, err := sql.Open("noop", c.dsn) + + if err != nil { + return db, err + } + + fmt.Println(db.Ping()) + + return db, db.Ping() +} + +// Close implements sql/driver.Conn +func (c *NoopConnection) Close() error { + c.drv.Lock() + defer c.drv.Unlock() + + c.opened-- + if c.opened == 0 { + delete(c.drv.conns, c.dsn) + } + + return nil +} + +// Begin implements sql/driver.Conn +func (c *NoopConnection) Begin() (driver.Tx, error) { + fmt.Println("Called Begin()") + return c, nil +} + +// Exec implements sql/driver.Conn +func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result, error) { + return NoopResult{}, nil +} + +// Prepare implements sql/driver.Conn +func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) { + spew.Dump(query) + return &NoopStmt{}, nil +} + +// Query implements sql/driver.Conn +func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) { + spew.Dump(args) + return &NoopRows{}, nil +} + +// Commit implements sql/driver.Conn +func (c *NoopConnection) Commit() error { + return nil +} + +// Rollback implements sql/driver.Conn +func (c *NoopConnection) Rollback() error { + return nil +} diff --git a/expecter_test.go b/expecter_test.go index ba8cb8b7..aab5b22d 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "fmt" "reflect" "testing" @@ -20,9 +21,7 @@ func TestNewDefaultExpecter(t *testing.T) { func TestNewCustomExpecter(t *testing.T) { db, _, err := gorm.NewExpecter(gorm.NewSqlmockAdapter, "sqlmock", "mock_gorm_dsn") - defer func() { - db.Close() - }() + defer db.Close() if err != nil { t.Fatal(err) @@ -31,16 +30,14 @@ func TestNewCustomExpecter(t *testing.T) { func TestQuery(t *testing.T) { db, expect, err := gorm.NewDefaultExpecter() - defer func() { - db.Close() - }() if err != nil { t.Fatal(err) } + fmt.Println("Got here") expect.First(&User{}) - db.LogMode(true).First(&User{}) + db.First(&User{}) if err := expect.AssertExpectations(); err != nil { t.Error(err) @@ -121,3 +118,29 @@ func TestFindSlice(t *testing.T) { t.Error("Expected equal slices") } } + +func TestMockPreload(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := User{Id: 1} + outEmails := []Email{Email{Id: 1, UserId: 1}, Email{Id: 2, UserId: 1}} + out := User{Id: 1, Emails: outEmails} + + expect.Preload("Emails").Find(&in).Returns(out) + db.Preload("Emails").Find(&in) + + // if err := expect.AssertExpectations(); err != nil { + // t.Error(err) + // } + + if !reflect.DeepEqual(in, out) { + t.Error("In and out are not equal") + } +}