diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce..d136aee4 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -60,6 +60,7 @@ func preloadCallback(scope *Scope) { currentScope.handleBelongsToPreload(field, currentPreloadConditions) case "many_to_many": currentScope.handleManyToManyPreload(field, currentPreloadConditions) + default: scope.Err(errors.New("unsupported relation")) } @@ -264,6 +265,8 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ // handleManyToManyPreload used to preload many to many associations func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { + // spew.Println("___ENTERING HANDLE MANY TO MANY___\r\n") + // spew.Printf("___POPULATING %s___:\r\n%s\r\n", field.Name, spew.Sdump(field)) var ( relation = field.Relationship joinTableHandler = relation.JoinTableHandler @@ -303,6 +306,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } rows, err := preloadDB.Rows() + // spew.Printf("___RETURNED ROWS___: \r\n%s\r\n", spew.Sdump(rows)) if scope.Err(err) != nil { return @@ -312,6 +316,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface columns, _ := rows.Columns() for rows.Next() { var ( + // This is a Language zero value struct elem = reflect.New(fieldType).Elem() fields = scope.New(elem.Addr().Interface()).Fields() ) @@ -323,6 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } scope.scan(rows, columns, append(fields, joinTableFields...)) + // spew.Printf("___FIELDS___: \r\n%s\r\n", spew.Sdump(fields)) var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table @@ -351,12 +357,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface foreignFieldNames = []string{} ) + // spew.Printf("Foreign fields: %s", spew.Sdump(relation.ForeignFieldNames)) for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { foreignFieldNames = append(foreignFieldNames, field.Name) } } + // spew.Printf("Scope value: %s", spew.Sdump(indirectScopeValue)) if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) @@ -367,6 +375,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } + + // spew.Printf("Field source map: %s", spew.Sdump(fieldsSourceMap)) + // spew.Printf("Link hash: %s", spew.Sdump(linkHash)) for source, link := range linkHash { for i, field := range fieldsSourceMap[source] { //If not 0 this means Value is a pointer and we already added preloaded models to it diff --git a/expecter.go b/expecter.go new file mode 100644 index 00000000..6dc28996 --- /dev/null +++ b/expecter.go @@ -0,0 +1,225 @@ +package gorm + +import ( + "fmt" +) + +// Recorder satisfies the logger interface +type Recorder struct { + stmts []Stmt + preload []searchPreload // store it on Recorder +} + +// Stmt represents a sql statement. It can be an Exec, Query, or QueryRow +type Stmt struct { + kind string // can be Query, Exec, QueryRow + preload string // contains schema if it is a preload query + sql string + args []interface{} +} + +func recordExecCallback(scope *Scope) { + r, ok := scope.Get("gorm:recorder") + + if !ok { + panic(fmt.Errorf("Expected a recorder to be set, but got none")) + } + + stmt := Stmt{ + kind: "exec", + sql: scope.SQL, + args: scope.SQLVars, + } + + recorder := r.(*Recorder) + + recorder.Record(stmt) +} + +func recordQueryCallback(scope *Scope) { + r, ok := scope.Get("gorm:recorder") + + if !ok { + panic(fmt.Errorf("Expected a recorder to be set, but got none")) + } + + recorder := r.(*Recorder) + + stmt := Stmt{ + kind: "query", + sql: scope.SQL, + args: scope.SQLVars, + } + + if len(recorder.preload) > 0 { + // this will cause the scope.SQL to mutate to the preload query + stmt.preload = recorder.preload[0].schema + + // we just want to pop the first element off + recorder.preload = recorder.preload[1:] + } + + recorder.Record(stmt) +} + +func recordPreloadCallback(scope *Scope) { + // this callback runs _before_ gorm:preload + // it should record the next thing to be preloaded + recorder, ok := scope.Get("gorm:recorder") + + if !ok { + panic(fmt.Errorf("Expected a recorder to be set, but got none")) + } + + if len(scope.Search.preload) > 0 { + // spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload)) + recorder.(*Recorder).preload = scope.Search.preload + } +} + +// Record records a Stmt for use when SQL is finally executed +func (r *Recorder) Record(stmt Stmt) { + r.stmts = append(r.stmts, stmt) +} + +// GetFirst returns the first recorded sql statement logged. If there are no +// statements, false is returned +func (r *Recorder) GetFirst() (Stmt, bool) { + var stmt Stmt + if len(r.stmts) > 0 { + stmt = r.stmts[0] + return stmt, true + } + + return stmt, false +} + +// IsEmpty returns true if the statements slice is empty +func (r *Recorder) IsEmpty() bool { + return len(r.stmts) == 0 +} + +// AdapterFactory is a generic interface for arbitrary adapters that satisfy +// the interface. variadic args are passed to gorm.Open. +type AdapterFactory func(dialect string, args ...interface{}) (*DB, Adapter, error) + +// Expecter is the exported struct used for setting expectations +type Expecter struct { + // globally scoped expecter + adapter Adapter + gorm *DB + recorder *Recorder +} + +// NewDefaultExpecter returns a Expecter powered by go-sqlmock +func NewDefaultExpecter() (*DB, *Expecter, error) { + gormDb, adapter, err := NewSqlmockAdapter("sqlmock", "mock_gorm_dsn") + + if err != nil { + return nil, nil, err + } + + recorder := &Recorder{} + noop, _ := NewNoopDB() + gorm := &DB{ + db: noop, + logger: defaultLogger, + values: map[string]interface{}{}, + callbacks: DefaultCallback, + dialect: newDialect("sqlmock", noop), + } + + gorm.parent = gorm + gorm = gorm.Set("gorm:recorder", recorder) + gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordExecCallback) + gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback) + gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback) + gorm.Callback().RowQuery().After("gorm:row_query").Register("gorm:record_query", recordQueryCallback) + gorm.Callback().Update().After("gorm:update").Register("gorm:record_exec", recordExecCallback) + + return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil +} + +// NewExpecter returns an Expecter for arbitrary adapters +func NewExpecter(fn AdapterFactory, dialect string, args ...interface{}) (*DB, *Expecter, error) { + gormDb, adapter, err := fn(dialect, args...) + + if err != nil { + return nil, nil, err + } + + return gormDb, &Expecter{adapter: adapter}, nil +} + +/* PUBLIC METHODS */ + +// AssertExpectations checks if all expected Querys and Execs were satisfied. +func (h *Expecter) AssertExpectations() error { + return h.adapter.AssertExpectations() +} + +// Model sets scope.Value +func (h *Expecter) Model(model interface{}) *Expecter { + h.gorm = h.gorm.Model(model) + return h +} + +/* CREATE */ + +// Create mocks insertion of a model into the DB +func (h *Expecter) Create(model interface{}) ExpectedExec { + h.gorm.Create(model) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +/* READ */ + +// First triggers a Query +func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { + h.gorm.First(out, where...) + return h.adapter.ExpectQuery(h.recorder.stmts...) +} + +// Find triggers a Query +func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { + h.gorm.Find(out, where...) + return h.adapter.ExpectQuery(h.recorder.stmts...) +} + +// Preload clones the expecter and sets a preload condition on gorm.DB +func (h *Expecter) Preload(column string, conditions ...interface{}) *Expecter { + clone := h.clone() + clone.gorm = clone.gorm.Preload(column, conditions...) + + return clone +} + +/* UPDATE */ + +// Save mocks updating a record in the DB and will trigger db.Exec() +func (h *Expecter) Save(model interface{}) ExpectedExec { + h.gorm.Save(model) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +// Update mocks updating the given attributes in the DB +func (h *Expecter) Update(attrs ...interface{}) ExpectedExec { + h.gorm.Update(attrs...) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +// Updates does the same thing as Update, but with map or struct +func (h *Expecter) Updates(values interface{}, ignoreProtectedAttrs ...bool) ExpectedExec { + h.gorm.Updates(values, ignoreProtectedAttrs...) + return h.adapter.ExpectExec(h.recorder.stmts[0]) +} + +/* PRIVATE METHODS */ + +func (h *Expecter) clone() *Expecter { + return &Expecter{ + adapter: h.adapter, + gorm: h.gorm, + recorder: h.recorder, + } +} diff --git a/expecter_adapter.go b/expecter_adapter.go new file mode 100644 index 00000000..81d4cde0 --- /dev/null +++ b/expecter_adapter.go @@ -0,0 +1,68 @@ +package gorm + +import ( + "database/sql" + + sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" +) + +var ( + db *sql.DB + mock sqlmock.Sqlmock +) + +func init() { + var err error + + db, mock, err = sqlmock.NewWithDSN("mock_gorm_dsn") + + if err != nil { + panic(err.Error()) + } +} + +// Adapter provides an abstract interface over concrete mock database +// implementations (e.g. go-sqlmock or go-testdb) +type Adapter interface { + ExpectQuery(stmts ...Stmt) ExpectedQuery + ExpectExec(stmt Stmt) ExpectedExec + AssertExpectations() error +} + +// SqlmockAdapter implemenets the Adapter interface using go-sqlmock +// it is the default Adapter +type SqlmockAdapter struct { + db *sql.DB + mocker sqlmock.Sqlmock +} + +// NewSqlmockAdapter returns a mock gorm.DB and an Adapter backed by +// go-sqlmock +func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error) { + gormDb, err := Open("sqlmock", "mock_gorm_dsn") + + if err != nil { + return nil, nil, err + } + + return gormDb, &SqlmockAdapter{db: db, mocker: mock}, nil +} + +// ExpectQuery wraps the underlying mock method for setting a query +// expectation. It accepts multiple statements in the event of preloading +func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery { + return &SqlmockQuery{mock: a.mocker, queries: queries} +} + +// ExpectExec wraps the underlying mock method for setting a exec +// expectation +func (a *SqlmockAdapter) ExpectExec(exec Stmt) ExpectedExec { + return &SqlmockExec{mock: a.mocker, exec: exec} +} + +// AssertExpectations asserts that _all_ expectations for a test have been met +// and returns an error specifying which have not if there are unmet +// expectations +func (a *SqlmockAdapter) AssertExpectations() error { + return a.mocker.ExpectationsWereMet() +} diff --git a/expecter_noop.go b/expecter_noop.go new file mode 100644 index 00000000..04d10881 --- /dev/null +++ b/expecter_noop.go @@ -0,0 +1,186 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "io" + "sync" +) + +var pool *NoopDriver + +func init() { + pool = &NoopDriver{ + conns: make(map[string]*NoopConnection), + } + + sql.Register("noop", pool) +} + +// NoopDriver implements sql/driver.Driver +type NoopDriver struct { + sync.Mutex + counter int + conns map[string]*NoopConnection +} + +// Open implements sql/driver.Driver +func (d *NoopDriver) Open(dsn string) (driver.Conn, error) { + d.Lock() + defer d.Unlock() + + c, ok := d.conns[dsn] + + if !ok { + return c, fmt.Errorf("No connection available") + } + + c.opened++ + return c, nil +} + +// NoopResult is a noop struct that satisfies sql.Result +type NoopResult struct{} + +// LastInsertId is a noop method for satisfying drive.Result +func (r NoopResult) LastInsertId() (int64, error) { + return 0, nil +} + +// RowsAffected is a noop method for satisfying drive.Result +func (r NoopResult) RowsAffected() (int64, error) { + return 0, nil +} + +// NoopRows implements driver.Rows +type NoopRows struct { + pos int +} + +// Columns implements driver.Rows +func (r *NoopRows) Columns() []string { + return []string{"foo", "bar", "baz", "lol", "kek", "zzz"} +} + +// Close implements driver.Rows +func (r *NoopRows) Close() error { + return nil +} + +// Next implements driver.Rows and alwys returns only one row +func (r *NoopRows) Next(dest []driver.Value) error { + if r.pos == 1 { + return io.EOF + } + cols := []string{"foo", "bar", "baz", "lol", "kek", "zzz"} + + for i, col := range cols { + dest[i] = col + } + + r.pos++ + + return nil +} + +// NoopStmt implements driver.Stmt +type NoopStmt struct{} + +// Close implements driver.Stmt +func (s *NoopStmt) Close() error { + return nil +} + +// NumInput implements driver.Stmt +func (s *NoopStmt) NumInput() int { + return 1 +} + +// Exec implements driver.Stmt +func (s *NoopStmt) Exec(args []driver.Value) (driver.Result, error) { + return &NoopResult{}, nil +} + +// Query implements driver.Stmt +func (s *NoopStmt) Query(args []driver.Value) (driver.Rows, error) { + return &NoopRows{}, nil +} + +// NewNoopDB initialises a new DefaultNoopDB +func NewNoopDB() (SQLCommon, error) { + pool.Lock() + dsn := fmt.Sprintf("noop_db_%d", pool.counter) + pool.counter++ + + noop := &NoopConnection{dsn: dsn, drv: pool} + pool.conns[dsn] = noop + pool.Unlock() + + db, err := noop.open() + + return db, err +} + +// NoopConnection implements sql/driver.Conn +// for our purposes, the noop connection never returns an error, as we only +// require it for generating queries. It is necessary because eagerloading +// will fail if any operation returns an error +type NoopConnection struct { + dsn string + drv *NoopDriver + opened int +} + +func (c *NoopConnection) open() (*sql.DB, error) { + db, err := sql.Open("noop", c.dsn) + + if err != nil { + return db, err + } + + return db, db.Ping() +} + +// Close implements sql/driver.Conn +func (c *NoopConnection) Close() error { + c.drv.Lock() + defer c.drv.Unlock() + + c.opened-- + if c.opened == 0 { + delete(c.drv.conns, c.dsn) + } + + return nil +} + +// Begin implements sql/driver.Conn +func (c *NoopConnection) Begin() (driver.Tx, error) { + return c, nil +} + +// Exec implements sql/driver.Conn +func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result, error) { + return NoopResult{}, nil +} + +// Prepare implements sql/driver.Conn +func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) { + return &NoopStmt{}, nil +} + +// Query implements sql/driver.Conn +func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) { + return &NoopRows{}, nil +} + +// Commit implements sql/driver.Conn +func (c *NoopConnection) Commit() error { + return nil +} + +// Rollback implements sql/driver.Conn +func (c *NoopConnection) Rollback() error { + return nil +} diff --git a/expecter_result.go b/expecter_result.go new file mode 100644 index 00000000..dac76276 --- /dev/null +++ b/expecter_result.go @@ -0,0 +1,259 @@ +package gorm + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + + sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" +) + +// ExpectedQuery represents an expected query that will be executed and can +// return some rows. It presents a fluent API for chaining calls to other +// expectations +type ExpectedQuery interface { + Returns(model interface{}) ExpectedQuery +} + +// ExpectedExec represents an expected exec that will be executed and can +// return a result. It presents a fluent API for chaining calls to other +// expectations +type ExpectedExec interface { + WillSucceed(lastInsertID, rowsAffected int64) ExpectedExec + WillFail(err error) ExpectedExec +} + +// SqlmockQuery implements Query for go-sqlmock +type SqlmockQuery struct { + mock sqlmock.Sqlmock + queries []Stmt + scope *Scope +} + +func getRowForFields(fields []*Field) []driver.Value { + var values []driver.Value + for _, field := range fields { + if field.IsNormal { + value := field.Field + + // dereference pointers + if field.Field.Kind() == reflect.Ptr { + value = reflect.Indirect(field.Field) + } + + // check if we have a zero Value + // just append nil if it's not valid, so sqlmock won't complain + if !value.IsValid() { + values = append(values, nil) + continue + } + + concreteVal := value.Interface() + // spew.Printf("%v: %v\r\n", field.Name, concreteVal) + + if driver.IsValue(concreteVal) { + values = append(values, concreteVal) + } else if num, err := driver.DefaultParameterConverter.ConvertValue(concreteVal); err == nil { + values = append(values, num) + } else if valuer, ok := concreteVal.(driver.Valuer); ok { + if convertedValue, err := valuer.Value(); err == nil { + values = append(values, convertedValue) + } + } + } + } + + return values +} + +func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) { + var ( + rows *sqlmock.Rows + columns []string + ) + + // we need to check for zero values + if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) { + // spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName) + return nil, false + } + + switch relation.Kind { + case "has_one": + scope := &Scope{Value: rVal.Interface()} + + for _, field := range scope.GetModelStruct().StructFields { + if field.IsNormal { + columns = append(columns, field.DBName) + } + } + + rows = sqlmock.NewRows(columns) + + // we don't have a slice + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + + return rows, true + case "has_many": + elem := rVal.Type().Elem() + scope := &Scope{Value: reflect.New(elem).Interface()} + + for _, field := range scope.GetModelStruct().StructFields { + if field.IsNormal { + columns = append(columns, field.DBName) + } + } + + rows = sqlmock.NewRows(columns) + + if rVal.Len() > 0 { + for i := 0; i < rVal.Len(); i++ { + scope := &Scope{Value: rVal.Index(i).Interface()} + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + } + + return rows, true + } + + return nil, false + case "many_to_many": + elem := rVal.Type().Elem() + scope := &Scope{Value: reflect.New(elem).Interface()} + joinTable := relation.JoinTableHandler.(*JoinTableHandler) + + for _, field := range scope.GetModelStruct().StructFields { + if field.IsNormal { + columns = append(columns, field.DBName) + } + } + + for _, key := range joinTable.Source.ForeignKeys { + columns = append(columns, key.DBName) + } + + for _, key := range joinTable.Destination.ForeignKeys { + columns = append(columns, key.DBName) + } + + rows = sqlmock.NewRows(columns) + + // in this case we definitely have a slice + if rVal.Len() > 0 { + for i := 0; i < rVal.Len(); i++ { + scope := &Scope{Value: rVal.Index(i).Interface()} + row := getRowForFields(scope.Fields()) + + // need to append the values for join table keys + sourcePk := q.scope.PrimaryKeyValue() + destModelType := joinTable.Destination.ModelType + destModelVal := reflect.New(destModelType).Interface() + destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue() + + row = append(row, sourcePk, destPkVal) + + rows = rows.AddRow(row...) + } + + return rows, true + } + + return nil, false + default: + return nil, false + } +} + +func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows { + var columns []string + for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields { + if field.IsNormal { + columns = append(columns, field.DBName) + } + } + + rows := sqlmock.NewRows(columns) + outVal := indirect(reflect.ValueOf(out)) + + // SELECT multiple columns + if outVal.Kind() == reflect.Slice { + outSlice := []interface{}{} + + for i := 0; i < outVal.Len(); i++ { + outSlice = append(outSlice, outVal.Index(i).Interface()) + } + + for _, outElem := range outSlice { + scope := &Scope{Value: outElem} + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + } + } else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1 + row := getRowForFields(q.scope.Fields()) + rows = rows.AddRow(row...) + } else { + panic(fmt.Errorf("Can only get rows for slice or struct")) + } + + return rows +} + +// Returns accepts an out type which should either be a struct or slice. Under +// the hood, it converts a gorm model struct to sql.Rows that can be passed to +// the underlying mock db +func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { + scope := (&Scope{}).New(out) + q.scope = scope + outVal := indirect(reflect.ValueOf(out)) + + // rows := q.getRowsForOutType(out) + destQuery := q.queries[0] + subQueries := q.queries[1:] + + // main query always at the head of the slice + q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)). + WillReturnRows(q.getDestRows(out)) + + // subqueries are preload + for _, subQuery := range subQueries { + if subQuery.preload != "" { + if field, ok := scope.FieldByName(subQuery.preload); ok { + expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql)) + rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship) + + if hasRows { + expectation.WillReturnRows(rows) + } + } + } + } + + return q +} + +// SqlmockExec implements Exec for go-sqlmock +type SqlmockExec struct { + exec Stmt + mock sqlmock.Sqlmock + scope *Scope +} + +// WillSucceed accepts a two int64s. They are passed directly to the underlying +// mock db. Useful for checking DAO behaviour in the event that the incorrect +// number of rows are affected by an Exec +func (e *SqlmockExec) WillSucceed(lastReturnedID, rowsAffected int64) ExpectedExec { + result := sqlmock.NewResult(lastReturnedID, rowsAffected) + e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result) + + return e +} + +// WillFail simulates returning an Error from an unsuccessful exec +func (e *SqlmockExec) WillFail(err error) ExpectedExec { + result := sqlmock.NewErrorResult(err) + e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result) + + return e +} diff --git a/expecter_test.go b/expecter_test.go new file mode 100644 index 00000000..64dfc02d --- /dev/null +++ b/expecter_test.go @@ -0,0 +1,322 @@ +package gorm_test + +import ( + "errors" + "reflect" + "testing" + + "github.com/jinzhu/gorm" +) + +func TestNewDefaultExpecter(t *testing.T) { + db, _, err := gorm.NewDefaultExpecter() + //lint:ignore SA5001 just a mock + defer db.Close() + + if err != nil { + t.Fatal(err) + } +} + +func TestNewCustomExpecter(t *testing.T) { + db, _, err := gorm.NewExpecter(gorm.NewSqlmockAdapter, "sqlmock", "mock_gorm_dsn") + //lint:ignore SA5001 just a mock + defer db.Close() + + if err != nil { + t.Fatal(err) + } +} + +func TestQuery(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + + if err != nil { + t.Fatal(err) + } + + expect.First(&User{}) + db.First(&User{}) + + if err := expect.AssertExpectations(); err != nil { + t.Error(err) + } +} + +func TestQueryReturn(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := User{Id: 1} + out := User{Id: 1, Name: "jinzhu"} + + expect.First(&in).Returns(out) + + db.First(&in) + + if e := expect.AssertExpectations(); e != nil { + t.Error(e) + } + + if in.Name != "jinzhu" { + t.Errorf("Expected %s, got %s", out.Name, in.Name) + } + + if ne := reflect.DeepEqual(in, out); !ne { + t.Errorf("Not equal") + } +} + +func TestFindStructDest(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := &User{Id: 1} + + expect.Find(in) + db.Find(&User{Id: 1}) + + if e := expect.AssertExpectations(); e != nil { + t.Error(e) + } +} + +func TestFindSlice(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + in := []User{} + out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}} + + expect.Find(&in).Returns(&out) + db.Find(&in) + + if e := expect.AssertExpectations(); e != nil { + t.Error(e) + } + + if ne := reflect.DeepEqual(in, out); !ne { + t.Error("Expected equal slices") + } +} + +func TestMockPreloadHasMany(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + in := User{Id: 1} + outEmails := []Email{Email{Id: 1, UserId: 1}, Email{Id: 2, UserId: 1}} + out := User{Id: 1, Emails: outEmails} + + expect.Preload("Emails").Find(&in).Returns(out) + db.Preload("Emails").Find(&in) + + if err := expect.AssertExpectations(); err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(in, out) { + t.Error("In and out are not equal") + } +} + +func TestMockPreloadHasOne(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + in := User{Id: 1} + out := User{Id: 1, CreditCard: CreditCard{Number: "12345678"}} + + expect.Preload("CreditCard").Find(&in).Returns(out) + db.Preload("CreditCard").Find(&in) + + if err := expect.AssertExpectations(); err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(in, out) { + t.Error("In and out are not equal") + } +} + +func TestMockPreloadMany2Many(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + in := User{Id: 1} + languages := []Language{Language{Name: "ZH"}} + out := User{Id: 1, Languages: languages} + + expect.Preload("Languages").Find(&in).Returns(out) + db.Preload("Languages").Find(&in) + + if err := expect.AssertExpectations(); err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(in, out) { + t.Error("In and out are not equal") + } +} + +func TestMockPreloadMultiple(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + creditCard := CreditCard{Number: "12345678"} + languages := []Language{Language{Name: "ZH"}} + + in := User{Id: 1} + out := User{Id: 1, Languages: languages, CreditCard: creditCard} + + expect.Preload("Languages").Preload("CreditCard").Find(&in).Returns(out) + db.Preload("Languages").Preload("CreditCard").Find(&in) + + if err := expect.AssertExpectations(); err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(in, out) { + t.Error("In and out are not equal") + } +} + +func TestMockCreateBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + user := User{Name: "jinzhu"} + expect.Create(&user).WillSucceed(1, 1) + rowsAffected := db.Create(&user).RowsAffected + + if rowsAffected != 1 { + t.Errorf("Expected rows affected to be 1 but got %d", rowsAffected) + } + + if user.Id != 1 { + t.Errorf("User id field should be 1, but got %d", user.Id) + } +} + +func TestMockCreateError(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + mockError := errors.New("Could not insert user") + + user := User{Name: "jinzhu"} + expect.Create(&user).WillFail(mockError) + + dbError := db.Create(&user).Error + + if dbError == nil || dbError != mockError { + t.Errorf("Expected *DB.Error to be set, but it was not") + } +} + +func TestMockSaveBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + user := User{Name: "jinzhu"} + expect.Save(&user).WillSucceed(1, 1) + expected := db.Save(&user) + + if err := expect.AssertExpectations(); err != nil { + t.Errorf("Expectations were not met %s", err.Error()) + } + + if expected.RowsAffected != 1 || user.Id != 1 { + t.Errorf("Expected result was not returned") + } +} + +func TestMockUpdateBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + newName := "uhznij" + user := User{Name: "jinzhu"} + + expect.Model(&user).Update("name", newName).WillSucceed(1, 1) + db.Model(&user).Update("name", newName) + + if err := expect.AssertExpectations(); err != nil { + t.Errorf("Expectations were not met %s", err.Error()) + } + + if user.Name != newName { + t.Errorf("Should have name %s but got %s", newName, user.Name) + } +} + +func TestMockUpdatesBasic(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer db.Close() + + if err != nil { + t.Fatal(err) + } + + user := User{Name: "jinzhu", Age: 18} + updated := User{Name: "jinzhu", Age: 88} + + expect.Model(&user).Updates(updated).WillSucceed(1, 1) + db.Model(&user).Updates(updated) + + if err := expect.AssertExpectations(); err != nil { + t.Errorf("Expectations were not met %s", err.Error()) + } + + if user.Age != updated.Age { + t.Errorf("Should have age %d but got %d", user.Age, updated.Age) + } +} diff --git a/migration_test.go b/migration_test.go index 3f3a5c8f..0d389268 100644 --- a/migration_test.go +++ b/migration_test.go @@ -139,6 +139,11 @@ func (role Role) IsAdmin() bool { type Num int64 +func (i Num) Value() (driver.Value, error) { + // guaranteed ok + return int64(i), nil +} + func (i *Num) Scan(src interface{}) error { switch s := src.(type) { case []byte: