From 281d0eea1778b27d88cc0fe1cb4be51f7837c31e Mon Sep 17 00:00:00 2001 From: Sheng-Dean Chang Date: Thu, 6 Apr 2017 16:52:49 -0700 Subject: [PATCH] Add gorm.Gormw, an interface wrapper for gorm.DB. This commit adds an interface-equivalent of gorm.DB so that those who prefer mocking gorm.DB under test can easily do so. It also adds 2 functions Wrap() and Openw(), to facilitate creating gorm.Gormw instances. --- gormw.go | 394 ++++++++++++++++++++++++++++++++++++++++++++++++++ gormw_test.go | 83 +++++++++++ 2 files changed, 477 insertions(+) create mode 100644 gormw.go create mode 100644 gormw_test.go diff --git a/gormw.go b/gormw.go new file mode 100644 index 00000000..b80f80e2 --- /dev/null +++ b/gormw.go @@ -0,0 +1,394 @@ +package gorm + +import "database/sql" + +// Gormw is an interface which DB implements +type Gormw interface { + Close() error + DB() *sql.DB + New() Gormw + NewScope(value interface{}) *Scope + CommonDB() SQLCommon + Callback() *Callback + SetLogger(l Logger) + LogMode(enable bool) Gormw + SingularTable(enable bool) + Where(query interface{}, args ...interface{}) Gormw + Or(query interface{}, args ...interface{}) Gormw + Not(query interface{}, args ...interface{}) Gormw + Limit(value int) Gormw + Offset(value int) Gormw + Order(value string, reorder ...bool) Gormw + Select(query interface{}, args ...interface{}) Gormw + Omit(columns ...string) Gormw + Group(query string) Gormw + Having(query string, values ...interface{}) Gormw + Joins(query string, args ...interface{}) Gormw + Scopes(funcs ...func(*DB) *DB) Gormw + Unscoped() Gormw + Attrs(attrs ...interface{}) Gormw + Assign(attrs ...interface{}) Gormw + First(out interface{}, where ...interface{}) Gormw + Last(out interface{}, where ...interface{}) Gormw + Find(out interface{}, where ...interface{}) Gormw + Scan(dest interface{}) Gormw + Row() *sql.Row + Rows() (*sql.Rows, error) + ScanRows(rows *sql.Rows, result interface{}) error + Pluck(column string, value interface{}) Gormw + Count(value interface{}) Gormw + Related(value interface{}, foreignKeys ...string) Gormw + FirstOrInit(out interface{}, where ...interface{}) Gormw + FirstOrCreate(out interface{}, where ...interface{}) Gormw + Update(attrs ...interface{}) Gormw + Updates(values interface{}, ignoreProtectedAttrs ...bool) Gormw + UpdateColumn(attrs ...interface{}) Gormw + UpdateColumns(values interface{}) Gormw + Save(value interface{}) Gormw + Create(value interface{}) Gormw + Delete(value interface{}, where ...interface{}) Gormw + Raw(sql string, values ...interface{}) Gormw + Exec(sql string, values ...interface{}) Gormw + Model(value interface{}) Gormw + Table(name string) Gormw + Debug() Gormw + Begin() Gormw + Commit() Gormw + Rollback() Gormw + NewRecord(value interface{}) bool + RecordNotFound() bool + CreateTable(values ...interface{}) Gormw + DropTable(values ...interface{}) Gormw + DropTableIfExists(values ...interface{}) Gormw + HasTable(value interface{}) bool + AutoMigrate(values ...interface{}) Gormw + ModifyColumn(column string, typ string) Gormw + DropColumn(column string) Gormw + AddIndex(indexName string, column ...string) Gormw + AddUniqueIndex(indexName string, column ...string) Gormw + RemoveIndex(indexName string) Gormw + AddForeignKey(field string, dest string, onDelete string, onUpdate string) Gormw + Association(column string) *Association + Preload(column string, conditions ...interface{}) Gormw + Set(name string, value interface{}) Gormw + InstantSet(name string, value interface{}) Gormw + Get(name string) (value interface{}, ok bool) + SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) + AddError(err error) error + GetErrors() (errors []error) + + // extra + Error() error + RowsAffected() int64 +} + +type gormw struct { + w *DB +} + +// Openw is a drop-in replacement for Open() +func Openw(dialect string, args ...interface{}) (db Gormw, err error) { + gormdb, err := Open(dialect, args...) + return Wrap(gormdb), err +} + +// Wrap wraps gorm.DB in an interface +func Wrap(db *DB) Gormw { + return &gormw{db} +} + +func (it *gormw) Close() error { + return it.w.Close() +} + +func (it *gormw) DB() *sql.DB { + return it.w.DB() +} + +func (it *gormw) New() Gormw { + return Wrap(it.w.New()) +} + +func (it *gormw) NewScope(value interface{}) *Scope { + return it.w.NewScope(value) +} + +func (it *gormw) CommonDB() SQLCommon { + return it.w.CommonDB() +} + +func (it *gormw) Callback() *Callback { + return it.w.Callback() +} + +func (it *gormw) SetLogger(log Logger) { + it.w.SetLogger(log) +} + +func (it *gormw) LogMode(enable bool) Gormw { + return Wrap(it.w.LogMode(enable)) +} + +func (it *gormw) SingularTable(enable bool) { + it.w.SingularTable(enable) +} + +func (it *gormw) Where(query interface{}, args ...interface{}) Gormw { + return Wrap(it.w.Where(query, args...)) +} + +func (it *gormw) Or(query interface{}, args ...interface{}) Gormw { + return Wrap(it.w.Or(query, args...)) +} + +func (it *gormw) Not(query interface{}, args ...interface{}) Gormw { + return Wrap(it.w.Not(query, args...)) +} + +func (it *gormw) Limit(value int) Gormw { + return Wrap(it.w.Limit(value)) +} + +func (it *gormw) Offset(value int) Gormw { + return Wrap(it.w.Offset(value)) +} + +func (it *gormw) Order(value string, reorder ...bool) Gormw { + return Wrap(it.w.Order(value, reorder...)) +} + +func (it *gormw) Select(query interface{}, args ...interface{}) Gormw { + return Wrap(it.w.Select(query, args...)) +} + +func (it *gormw) Omit(columns ...string) Gormw { + return Wrap(it.w.Omit(columns...)) +} + +func (it *gormw) Group(query string) Gormw { + return Wrap(it.w.Group(query)) +} + +func (it *gormw) Having(query string, values ...interface{}) Gormw { + return Wrap(it.w.Having(query, values...)) +} + +func (it *gormw) Joins(query string, args ...interface{}) Gormw { + return Wrap(it.w.Joins(query, args...)) +} + +func (it *gormw) Scopes(funcs ...func(*DB) *DB) Gormw { + return Wrap(it.w.Scopes(funcs...)) +} + +func (it *gormw) Unscoped() Gormw { + return Wrap(it.w.Unscoped()) +} + +func (it *gormw) Attrs(attrs ...interface{}) Gormw { + return Wrap(it.w.Attrs(attrs...)) +} + +func (it *gormw) Assign(attrs ...interface{}) Gormw { + return Wrap(it.w.Assign(attrs...)) +} + +func (it *gormw) First(out interface{}, where ...interface{}) Gormw { + return Wrap(it.w.First(out, where...)) +} + +func (it *gormw) Last(out interface{}, where ...interface{}) Gormw { + return Wrap(it.w.Last(out, where...)) +} + +func (it *gormw) Find(out interface{}, where ...interface{}) Gormw { + return Wrap(it.w.Find(out, where...)) +} + +func (it *gormw) Scan(dest interface{}) Gormw { + return Wrap(it.w.Scan(dest)) +} + +func (it *gormw) Row() *sql.Row { + return it.w.Row() +} + +func (it *gormw) Rows() (*sql.Rows, error) { + return it.w.Rows() +} + +func (it *gormw) ScanRows(rows *sql.Rows, result interface{}) error { + return it.w.ScanRows(rows, result) +} + +func (it *gormw) Pluck(column string, value interface{}) Gormw { + return Wrap(it.w.Pluck(column, value)) +} + +func (it *gormw) Count(value interface{}) Gormw { + return Wrap(it.w.Count(value)) +} + +func (it *gormw) Related(value interface{}, foreignKeys ...string) Gormw { + return Wrap(it.w.Related(value, foreignKeys...)) +} + +func (it *gormw) FirstOrInit(out interface{}, where ...interface{}) Gormw { + return Wrap(it.w.FirstOrInit(out, where...)) +} + +func (it *gormw) FirstOrCreate(out interface{}, where ...interface{}) Gormw { + return Wrap(it.w.FirstOrCreate(out, where...)) +} + +func (it *gormw) Update(attrs ...interface{}) Gormw { + return Wrap(it.w.Update(attrs...)) +} + +func (it *gormw) Updates(values interface{}, ignoreProtectedAttrs ...bool) Gormw { + return Wrap(it.w.Updates(values, ignoreProtectedAttrs...)) +} + +func (it *gormw) UpdateColumn(attrs ...interface{}) Gormw { + return Wrap(it.w.UpdateColumn(attrs...)) +} + +func (it *gormw) UpdateColumns(values interface{}) Gormw { + return Wrap(it.w.UpdateColumns(values)) +} + +func (it *gormw) Save(value interface{}) Gormw { + return Wrap(it.w.Save(value)) +} + +func (it *gormw) Create(value interface{}) Gormw { + return Wrap(it.w.Create(value)) +} + +func (it *gormw) Delete(value interface{}, where ...interface{}) Gormw { + return Wrap(it.w.Delete(value, where...)) +} + +func (it *gormw) Raw(sql string, values ...interface{}) Gormw { + return Wrap(it.w.Raw(sql, values...)) +} + +func (it *gormw) Exec(sql string, values ...interface{}) Gormw { + return Wrap(it.w.Exec(sql, values...)) +} + +func (it *gormw) Model(value interface{}) Gormw { + return Wrap(it.w.Model(value)) +} + +func (it *gormw) Table(name string) Gormw { + return Wrap(it.w.Table(name)) +} + +func (it *gormw) Debug() Gormw { + return Wrap(it.w.Debug()) +} + +func (it *gormw) Begin() Gormw { + return Wrap(it.w.Begin()) +} + +func (it *gormw) Commit() Gormw { + return Wrap(it.w.Commit()) +} + +func (it *gormw) Rollback() Gormw { + return Wrap(it.w.Rollback()) +} + +func (it *gormw) NewRecord(value interface{}) bool { + return it.w.NewRecord(value) +} + +func (it *gormw) RecordNotFound() bool { + return it.w.RecordNotFound() +} + +func (it *gormw) CreateTable(values ...interface{}) Gormw { + return Wrap(it.w.CreateTable(values...)) +} + +func (it *gormw) DropTable(values ...interface{}) Gormw { + return Wrap(it.w.DropTable(values...)) +} + +func (it *gormw) DropTableIfExists(values ...interface{}) Gormw { + return Wrap(it.w.DropTableIfExists(values...)) +} + +func (it *gormw) HasTable(value interface{}) bool { + return it.w.HasTable(value) +} + +func (it *gormw) AutoMigrate(values ...interface{}) Gormw { + return Wrap(it.w.AutoMigrate(values...)) +} + +func (it *gormw) ModifyColumn(column string, typ string) Gormw { + return Wrap(it.w.ModifyColumn(column, typ)) +} + +func (it *gormw) DropColumn(column string) Gormw { + return Wrap(it.w.DropColumn(column)) +} + +func (it *gormw) AddIndex(indexName string, columns ...string) Gormw { + return Wrap(it.w.AddIndex(indexName, columns...)) +} + +func (it *gormw) AddUniqueIndex(indexName string, columns ...string) Gormw { + return Wrap(it.w.AddUniqueIndex(indexName, columns...)) +} + +func (it *gormw) RemoveIndex(indexName string) Gormw { + return Wrap(it.w.RemoveIndex(indexName)) +} + +func (it *gormw) Association(column string) *Association { + return it.w.Association(column) +} + +func (it *gormw) Preload(column string, conditions ...interface{}) Gormw { + return Wrap(it.w.Preload(column, conditions...)) +} + +func (it *gormw) Set(name string, value interface{}) Gormw { + return Wrap(it.w.Set(name, value)) +} + +func (it *gormw) InstantSet(name string, value interface{}) Gormw { + return Wrap(it.w.InstantSet(name, value)) +} + +func (it *gormw) Get(name string) (interface{}, bool) { + return it.w.Get(name) +} + +func (it *gormw) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { + it.w.SetJoinTableHandler(source, column, handler) +} + +func (it *gormw) AddForeignKey(field string, dest string, onDelete string, onUpdate string) Gormw { + return Wrap(it.w.AddForeignKey(field, dest, onDelete, onUpdate)) +} + +func (it *gormw) AddError(err error) error { + return it.w.AddError(err) +} + +func (it *gormw) GetErrors() (errors []error) { + return it.w.GetErrors() +} + +func (it *gormw) RowsAffected() int64 { + return it.w.RowsAffected +} + +func (it *gormw) Error() error { + return it.w.Error +} diff --git a/gormw_test.go b/gormw_test.go new file mode 100644 index 00000000..d51e3e26 --- /dev/null +++ b/gormw_test.go @@ -0,0 +1,83 @@ +package gorm_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mssql" + _ "github.com/jinzhu/gorm/dialects/mysql" + _ "github.com/jinzhu/gorm/dialects/sqlite" +) + +type Employee struct { + ID int + Name string + Salary int +} + +func setupGormw(t *testing.T) (db gorm.Gormw) { + db, err := gorm.Openw("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) + if err != nil { + t.Fatalf("failed to connect to database: %v", err) + } + if db == nil { + t.Fatal("db should not be nil") + } + return db +} + +func teardownGormw(t *testing.T, db gorm.Gormw) { + db.Close() + for _, err := range db.GetErrors() { + t.Error(err) + } +} + +func TestDDL(t *testing.T) { + db := setupGormw(t) + db.CreateTable(&Employee{}) + if !db.HasTable("employees") { + t.Error(`table "employees" should exist`) + } + db.DropTableIfExists(&Employee{}) + + db.SingularTable(true) + db.CreateTable(&Employee{}) + name := db.NewScope(&Employee{}).TableName() + if name != "employee" { + t.Errorf(`expected table name "employee"; got "%s"`, name) + } + db.DropTable(&Employee{}) + teardownGormw(t, db) +} + +func TestBasicDML(t *testing.T) { + db := setupGormw(t) + db.CreateTable(&Employee{}) + + emp := &Employee{1, "jinzhu", 1000000} + db.Create(emp) + emp1 := &Employee{0, "littledot", 0} + if !db.NewRecord(emp1) { + t.Errorf(`NewRecord() should return true`) + } + if !db.Where(emp1).First(emp1).RecordNotFound() { + t.Errorf(`non-existent row should not be found`) + } + + db.FirstOrInit(emp1, emp1) + emp1.Salary = 1 + affected := db.Model(emp1).Updates(emp1).RowsAffected() + if affected == 0 { + t.Errorf(`expected 1 affected row; got %d`, affected) + } + + if err := db.Raw("bad syntax burp").Scan(&Employee{}).Error(); err == nil { + t.Error(`expected error; got nil`) + } + + db.DropTable(&Employee{}) + teardownGormw(t, db) +}