Merge d630799e906b41467a47b801ec24b892e71a47e0 into 0a51f6cdc55d1650d9ed3b4c13026cfa9133b01e

This commit is contained in:
Ian Tan 2017-11-26 07:14:44 +00:00 committed by GitHub
commit 3087813925
7 changed files with 1076 additions and 0 deletions

View File

@ -60,6 +60,7 @@ func preloadCallback(scope *Scope) {
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
@ -264,6 +265,8 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
// spew.Println("___ENTERING HANDLE MANY TO MANY___\r\n")
// spew.Printf("___POPULATING %s___:\r\n%s\r\n", field.Name, spew.Sdump(field))
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
@ -303,6 +306,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
}
rows, err := preloadDB.Rows()
// spew.Printf("___RETURNED ROWS___: \r\n%s\r\n", spew.Sdump(rows))
if scope.Err(err) != nil {
return
@ -312,6 +316,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
columns, _ := rows.Columns()
for rows.Next() {
var (
// This is a Language zero value struct
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
)
@ -323,6 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
}
scope.scan(rows, columns, append(fields, joinTableFields...))
// spew.Printf("___FIELDS___: \r\n%s\r\n", spew.Sdump(fields))
var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
@ -351,12 +357,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
foreignFieldNames = []string{}
)
// spew.Printf("Foreign fields: %s", spew.Sdump(relation.ForeignFieldNames))
for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
// spew.Printf("Scope value: %s", spew.Sdump(indirectScopeValue))
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
@ -367,6 +375,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
}
// spew.Printf("Field source map: %s", spew.Sdump(fieldsSourceMap))
// spew.Printf("Link hash: %s", spew.Sdump(linkHash))
for source, link := range linkHash {
for i, field := range fieldsSourceMap[source] {
//If not 0 this means Value is a pointer and we already added preloaded models to it

225
expecter.go Normal file
View File

@ -0,0 +1,225 @@
package gorm
import (
"fmt"
)
// Recorder satisfies the logger interface
type Recorder struct {
stmts []Stmt
preload []searchPreload // store it on Recorder
}
// Stmt represents a sql statement. It can be an Exec, Query, or QueryRow
type Stmt struct {
kind string // can be Query, Exec, QueryRow
preload string // contains schema if it is a preload query
sql string
args []interface{}
}
func recordExecCallback(scope *Scope) {
r, ok := scope.Get("gorm:recorder")
if !ok {
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
}
stmt := Stmt{
kind: "exec",
sql: scope.SQL,
args: scope.SQLVars,
}
recorder := r.(*Recorder)
recorder.Record(stmt)
}
func recordQueryCallback(scope *Scope) {
r, ok := scope.Get("gorm:recorder")
if !ok {
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
}
recorder := r.(*Recorder)
stmt := Stmt{
kind: "query",
sql: scope.SQL,
args: scope.SQLVars,
}
if len(recorder.preload) > 0 {
// this will cause the scope.SQL to mutate to the preload query
stmt.preload = recorder.preload[0].schema
// we just want to pop the first element off
recorder.preload = recorder.preload[1:]
}
recorder.Record(stmt)
}
func recordPreloadCallback(scope *Scope) {
// this callback runs _before_ gorm:preload
// it should record the next thing to be preloaded
recorder, ok := scope.Get("gorm:recorder")
if !ok {
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
}
if len(scope.Search.preload) > 0 {
// spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload))
recorder.(*Recorder).preload = scope.Search.preload
}
}
// Record records a Stmt for use when SQL is finally executed
func (r *Recorder) Record(stmt Stmt) {
r.stmts = append(r.stmts, stmt)
}
// GetFirst returns the first recorded sql statement logged. If there are no
// statements, false is returned
func (r *Recorder) GetFirst() (Stmt, bool) {
var stmt Stmt
if len(r.stmts) > 0 {
stmt = r.stmts[0]
return stmt, true
}
return stmt, false
}
// IsEmpty returns true if the statements slice is empty
func (r *Recorder) IsEmpty() bool {
return len(r.stmts) == 0
}
// AdapterFactory is a generic interface for arbitrary adapters that satisfy
// the interface. variadic args are passed to gorm.Open.
type AdapterFactory func(dialect string, args ...interface{}) (*DB, Adapter, error)
// Expecter is the exported struct used for setting expectations
type Expecter struct {
// globally scoped expecter
adapter Adapter
gorm *DB
recorder *Recorder
}
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
func NewDefaultExpecter() (*DB, *Expecter, error) {
gormDb, adapter, err := NewSqlmockAdapter("sqlmock", "mock_gorm_dsn")
if err != nil {
return nil, nil, err
}
recorder := &Recorder{}
noop, _ := NewNoopDB()
gorm := &DB{
db: noop,
logger: defaultLogger,
values: map[string]interface{}{},
callbacks: DefaultCallback,
dialect: newDialect("sqlmock", noop),
}
gorm.parent = gorm
gorm = gorm.Set("gorm:recorder", recorder)
gorm.Callback().Create().After("gorm:create").Register("gorm:record_exec", recordExecCallback)
gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback)
gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback)
gorm.Callback().RowQuery().After("gorm:row_query").Register("gorm:record_query", recordQueryCallback)
gorm.Callback().Update().After("gorm:update").Register("gorm:record_exec", recordExecCallback)
return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil
}
// NewExpecter returns an Expecter for arbitrary adapters
func NewExpecter(fn AdapterFactory, dialect string, args ...interface{}) (*DB, *Expecter, error) {
gormDb, adapter, err := fn(dialect, args...)
if err != nil {
return nil, nil, err
}
return gormDb, &Expecter{adapter: adapter}, nil
}
/* PUBLIC METHODS */
// AssertExpectations checks if all expected Querys and Execs were satisfied.
func (h *Expecter) AssertExpectations() error {
return h.adapter.AssertExpectations()
}
// Model sets scope.Value
func (h *Expecter) Model(model interface{}) *Expecter {
h.gorm = h.gorm.Model(model)
return h
}
/* CREATE */
// Create mocks insertion of a model into the DB
func (h *Expecter) Create(model interface{}) ExpectedExec {
h.gorm.Create(model)
return h.adapter.ExpectExec(h.recorder.stmts[0])
}
/* READ */
// First triggers a Query
func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
h.gorm.First(out, where...)
return h.adapter.ExpectQuery(h.recorder.stmts...)
}
// Find triggers a Query
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
h.gorm.Find(out, where...)
return h.adapter.ExpectQuery(h.recorder.stmts...)
}
// Preload clones the expecter and sets a preload condition on gorm.DB
func (h *Expecter) Preload(column string, conditions ...interface{}) *Expecter {
clone := h.clone()
clone.gorm = clone.gorm.Preload(column, conditions...)
return clone
}
/* UPDATE */
// Save mocks updating a record in the DB and will trigger db.Exec()
func (h *Expecter) Save(model interface{}) ExpectedExec {
h.gorm.Save(model)
return h.adapter.ExpectExec(h.recorder.stmts[0])
}
// Update mocks updating the given attributes in the DB
func (h *Expecter) Update(attrs ...interface{}) ExpectedExec {
h.gorm.Update(attrs...)
return h.adapter.ExpectExec(h.recorder.stmts[0])
}
// Updates does the same thing as Update, but with map or struct
func (h *Expecter) Updates(values interface{}, ignoreProtectedAttrs ...bool) ExpectedExec {
h.gorm.Updates(values, ignoreProtectedAttrs...)
return h.adapter.ExpectExec(h.recorder.stmts[0])
}
/* PRIVATE METHODS */
func (h *Expecter) clone() *Expecter {
return &Expecter{
adapter: h.adapter,
gorm: h.gorm,
recorder: h.recorder,
}
}

68
expecter_adapter.go Normal file
View File

@ -0,0 +1,68 @@
package gorm
import (
"database/sql"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
var (
db *sql.DB
mock sqlmock.Sqlmock
)
func init() {
var err error
db, mock, err = sqlmock.NewWithDSN("mock_gorm_dsn")
if err != nil {
panic(err.Error())
}
}
// Adapter provides an abstract interface over concrete mock database
// implementations (e.g. go-sqlmock or go-testdb)
type Adapter interface {
ExpectQuery(stmts ...Stmt) ExpectedQuery
ExpectExec(stmt Stmt) ExpectedExec
AssertExpectations() error
}
// SqlmockAdapter implemenets the Adapter interface using go-sqlmock
// it is the default Adapter
type SqlmockAdapter struct {
db *sql.DB
mocker sqlmock.Sqlmock
}
// NewSqlmockAdapter returns a mock gorm.DB and an Adapter backed by
// go-sqlmock
func NewSqlmockAdapter(dialect string, args ...interface{}) (*DB, Adapter, error) {
gormDb, err := Open("sqlmock", "mock_gorm_dsn")
if err != nil {
return nil, nil, err
}
return gormDb, &SqlmockAdapter{db: db, mocker: mock}, nil
}
// ExpectQuery wraps the underlying mock method for setting a query
// expectation. It accepts multiple statements in the event of preloading
func (a *SqlmockAdapter) ExpectQuery(queries ...Stmt) ExpectedQuery {
return &SqlmockQuery{mock: a.mocker, queries: queries}
}
// ExpectExec wraps the underlying mock method for setting a exec
// expectation
func (a *SqlmockAdapter) ExpectExec(exec Stmt) ExpectedExec {
return &SqlmockExec{mock: a.mocker, exec: exec}
}
// AssertExpectations asserts that _all_ expectations for a test have been met
// and returns an error specifying which have not if there are unmet
// expectations
func (a *SqlmockAdapter) AssertExpectations() error {
return a.mocker.ExpectationsWereMet()
}

186
expecter_noop.go Normal file
View File

@ -0,0 +1,186 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"fmt"
"io"
"sync"
)
var pool *NoopDriver
func init() {
pool = &NoopDriver{
conns: make(map[string]*NoopConnection),
}
sql.Register("noop", pool)
}
// NoopDriver implements sql/driver.Driver
type NoopDriver struct {
sync.Mutex
counter int
conns map[string]*NoopConnection
}
// Open implements sql/driver.Driver
func (d *NoopDriver) Open(dsn string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
c, ok := d.conns[dsn]
if !ok {
return c, fmt.Errorf("No connection available")
}
c.opened++
return c, nil
}
// NoopResult is a noop struct that satisfies sql.Result
type NoopResult struct{}
// LastInsertId is a noop method for satisfying drive.Result
func (r NoopResult) LastInsertId() (int64, error) {
return 0, nil
}
// RowsAffected is a noop method for satisfying drive.Result
func (r NoopResult) RowsAffected() (int64, error) {
return 0, nil
}
// NoopRows implements driver.Rows
type NoopRows struct {
pos int
}
// Columns implements driver.Rows
func (r *NoopRows) Columns() []string {
return []string{"foo", "bar", "baz", "lol", "kek", "zzz"}
}
// Close implements driver.Rows
func (r *NoopRows) Close() error {
return nil
}
// Next implements driver.Rows and alwys returns only one row
func (r *NoopRows) Next(dest []driver.Value) error {
if r.pos == 1 {
return io.EOF
}
cols := []string{"foo", "bar", "baz", "lol", "kek", "zzz"}
for i, col := range cols {
dest[i] = col
}
r.pos++
return nil
}
// NoopStmt implements driver.Stmt
type NoopStmt struct{}
// Close implements driver.Stmt
func (s *NoopStmt) Close() error {
return nil
}
// NumInput implements driver.Stmt
func (s *NoopStmt) NumInput() int {
return 1
}
// Exec implements driver.Stmt
func (s *NoopStmt) Exec(args []driver.Value) (driver.Result, error) {
return &NoopResult{}, nil
}
// Query implements driver.Stmt
func (s *NoopStmt) Query(args []driver.Value) (driver.Rows, error) {
return &NoopRows{}, nil
}
// NewNoopDB initialises a new DefaultNoopDB
func NewNoopDB() (SQLCommon, error) {
pool.Lock()
dsn := fmt.Sprintf("noop_db_%d", pool.counter)
pool.counter++
noop := &NoopConnection{dsn: dsn, drv: pool}
pool.conns[dsn] = noop
pool.Unlock()
db, err := noop.open()
return db, err
}
// NoopConnection implements sql/driver.Conn
// for our purposes, the noop connection never returns an error, as we only
// require it for generating queries. It is necessary because eagerloading
// will fail if any operation returns an error
type NoopConnection struct {
dsn string
drv *NoopDriver
opened int
}
func (c *NoopConnection) open() (*sql.DB, error) {
db, err := sql.Open("noop", c.dsn)
if err != nil {
return db, err
}
return db, db.Ping()
}
// Close implements sql/driver.Conn
func (c *NoopConnection) Close() error {
c.drv.Lock()
defer c.drv.Unlock()
c.opened--
if c.opened == 0 {
delete(c.drv.conns, c.dsn)
}
return nil
}
// Begin implements sql/driver.Conn
func (c *NoopConnection) Begin() (driver.Tx, error) {
return c, nil
}
// Exec implements sql/driver.Conn
func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result, error) {
return NoopResult{}, nil
}
// Prepare implements sql/driver.Conn
func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) {
return &NoopStmt{}, nil
}
// Query implements sql/driver.Conn
func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
return &NoopRows{}, nil
}
// Commit implements sql/driver.Conn
func (c *NoopConnection) Commit() error {
return nil
}
// Rollback implements sql/driver.Conn
func (c *NoopConnection) Rollback() error {
return nil
}

259
expecter_result.go Normal file
View File

@ -0,0 +1,259 @@
package gorm
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
// ExpectedQuery represents an expected query that will be executed and can
// return some rows. It presents a fluent API for chaining calls to other
// expectations
type ExpectedQuery interface {
Returns(model interface{}) ExpectedQuery
}
// ExpectedExec represents an expected exec that will be executed and can
// return a result. It presents a fluent API for chaining calls to other
// expectations
type ExpectedExec interface {
WillSucceed(lastInsertID, rowsAffected int64) ExpectedExec
WillFail(err error) ExpectedExec
}
// SqlmockQuery implements Query for go-sqlmock
type SqlmockQuery struct {
mock sqlmock.Sqlmock
queries []Stmt
scope *Scope
}
func getRowForFields(fields []*Field) []driver.Value {
var values []driver.Value
for _, field := range fields {
if field.IsNormal {
value := field.Field
// dereference pointers
if field.Field.Kind() == reflect.Ptr {
value = reflect.Indirect(field.Field)
}
// check if we have a zero Value
// just append nil if it's not valid, so sqlmock won't complain
if !value.IsValid() {
values = append(values, nil)
continue
}
concreteVal := value.Interface()
// spew.Printf("%v: %v\r\n", field.Name, concreteVal)
if driver.IsValue(concreteVal) {
values = append(values, concreteVal)
} else if num, err := driver.DefaultParameterConverter.ConvertValue(concreteVal); err == nil {
values = append(values, num)
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
if convertedValue, err := valuer.Value(); err == nil {
values = append(values, convertedValue)
}
}
}
}
return values
}
func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
var (
rows *sqlmock.Rows
columns []string
)
// we need to check for zero values
if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) {
// spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName)
return nil, false
}
switch relation.Kind {
case "has_one":
scope := &Scope{Value: rVal.Interface()}
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
columns = append(columns, field.DBName)
}
}
rows = sqlmock.NewRows(columns)
// we don't have a slice
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
return rows, true
case "has_many":
elem := rVal.Type().Elem()
scope := &Scope{Value: reflect.New(elem).Interface()}
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
columns = append(columns, field.DBName)
}
}
rows = sqlmock.NewRows(columns)
if rVal.Len() > 0 {
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
}
return rows, true
}
return nil, false
case "many_to_many":
elem := rVal.Type().Elem()
scope := &Scope{Value: reflect.New(elem).Interface()}
joinTable := relation.JoinTableHandler.(*JoinTableHandler)
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
columns = append(columns, field.DBName)
}
}
for _, key := range joinTable.Source.ForeignKeys {
columns = append(columns, key.DBName)
}
for _, key := range joinTable.Destination.ForeignKeys {
columns = append(columns, key.DBName)
}
rows = sqlmock.NewRows(columns)
// in this case we definitely have a slice
if rVal.Len() > 0 {
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
// need to append the values for join table keys
sourcePk := q.scope.PrimaryKeyValue()
destModelType := joinTable.Destination.ModelType
destModelVal := reflect.New(destModelType).Interface()
destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue()
row = append(row, sourcePk, destPkVal)
rows = rows.AddRow(row...)
}
return rows, true
}
return nil, false
default:
return nil, false
}
}
func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
var columns []string
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
if field.IsNormal {
columns = append(columns, field.DBName)
}
}
rows := sqlmock.NewRows(columns)
outVal := indirect(reflect.ValueOf(out))
// SELECT multiple columns
if outVal.Kind() == reflect.Slice {
outSlice := []interface{}{}
for i := 0; i < outVal.Len(); i++ {
outSlice = append(outSlice, outVal.Index(i).Interface())
}
for _, outElem := range outSlice {
scope := &Scope{Value: outElem}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
}
} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
row := getRowForFields(q.scope.Fields())
rows = rows.AddRow(row...)
} else {
panic(fmt.Errorf("Can only get rows for slice or struct"))
}
return rows
}
// Returns accepts an out type which should either be a struct or slice. Under
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
// the underlying mock db
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
scope := (&Scope{}).New(out)
q.scope = scope
outVal := indirect(reflect.ValueOf(out))
// rows := q.getRowsForOutType(out)
destQuery := q.queries[0]
subQueries := q.queries[1:]
// main query always at the head of the slice
q.mock.ExpectQuery(regexp.QuoteMeta(destQuery.sql)).
WillReturnRows(q.getDestRows(out))
// subqueries are preload
for _, subQuery := range subQueries {
if subQuery.preload != "" {
if field, ok := scope.FieldByName(subQuery.preload); ok {
expectation := q.mock.ExpectQuery(regexp.QuoteMeta(subQuery.sql))
rows, hasRows := q.getRelationRows(outVal.FieldByName(subQuery.preload), subQuery.preload, field.Relationship)
if hasRows {
expectation.WillReturnRows(rows)
}
}
}
}
return q
}
// SqlmockExec implements Exec for go-sqlmock
type SqlmockExec struct {
exec Stmt
mock sqlmock.Sqlmock
scope *Scope
}
// WillSucceed accepts a two int64s. They are passed directly to the underlying
// mock db. Useful for checking DAO behaviour in the event that the incorrect
// number of rows are affected by an Exec
func (e *SqlmockExec) WillSucceed(lastReturnedID, rowsAffected int64) ExpectedExec {
result := sqlmock.NewResult(lastReturnedID, rowsAffected)
e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
return e
}
// WillFail simulates returning an Error from an unsuccessful exec
func (e *SqlmockExec) WillFail(err error) ExpectedExec {
result := sqlmock.NewErrorResult(err)
e.mock.ExpectExec(regexp.QuoteMeta(e.exec.sql)).WillReturnResult(result)
return e
}

322
expecter_test.go Normal file
View File

@ -0,0 +1,322 @@
package gorm_test
import (
"errors"
"reflect"
"testing"
"github.com/jinzhu/gorm"
)
func TestNewDefaultExpecter(t *testing.T) {
db, _, err := gorm.NewDefaultExpecter()
//lint:ignore SA5001 just a mock
defer db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestNewCustomExpecter(t *testing.T) {
db, _, err := gorm.NewExpecter(gorm.NewSqlmockAdapter, "sqlmock", "mock_gorm_dsn")
//lint:ignore SA5001 just a mock
defer db.Close()
if err != nil {
t.Fatal(err)
}
}
func TestQuery(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
if err != nil {
t.Fatal(err)
}
expect.First(&User{})
db.First(&User{})
if err := expect.AssertExpectations(); err != nil {
t.Error(err)
}
}
func TestQueryReturn(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
in := User{Id: 1}
out := User{Id: 1, Name: "jinzhu"}
expect.First(&in).Returns(out)
db.First(&in)
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
if in.Name != "jinzhu" {
t.Errorf("Expected %s, got %s", out.Name, in.Name)
}
if ne := reflect.DeepEqual(in, out); !ne {
t.Errorf("Not equal")
}
}
func TestFindStructDest(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
in := &User{Id: 1}
expect.Find(in)
db.Find(&User{Id: 1})
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
}
func TestFindSlice(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
in := []User{}
out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}}
expect.Find(&in).Returns(&out)
db.Find(&in)
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
if ne := reflect.DeepEqual(in, out); !ne {
t.Error("Expected equal slices")
}
}
func TestMockPreloadHasMany(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
in := User{Id: 1}
outEmails := []Email{Email{Id: 1, UserId: 1}, Email{Id: 2, UserId: 1}}
out := User{Id: 1, Emails: outEmails}
expect.Preload("Emails").Find(&in).Returns(out)
db.Preload("Emails").Find(&in)
if err := expect.AssertExpectations(); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(in, out) {
t.Error("In and out are not equal")
}
}
func TestMockPreloadHasOne(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
in := User{Id: 1}
out := User{Id: 1, CreditCard: CreditCard{Number: "12345678"}}
expect.Preload("CreditCard").Find(&in).Returns(out)
db.Preload("CreditCard").Find(&in)
if err := expect.AssertExpectations(); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(in, out) {
t.Error("In and out are not equal")
}
}
func TestMockPreloadMany2Many(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
in := User{Id: 1}
languages := []Language{Language{Name: "ZH"}}
out := User{Id: 1, Languages: languages}
expect.Preload("Languages").Find(&in).Returns(out)
db.Preload("Languages").Find(&in)
if err := expect.AssertExpectations(); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(in, out) {
t.Error("In and out are not equal")
}
}
func TestMockPreloadMultiple(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
creditCard := CreditCard{Number: "12345678"}
languages := []Language{Language{Name: "ZH"}}
in := User{Id: 1}
out := User{Id: 1, Languages: languages, CreditCard: creditCard}
expect.Preload("Languages").Preload("CreditCard").Find(&in).Returns(out)
db.Preload("Languages").Preload("CreditCard").Find(&in)
if err := expect.AssertExpectations(); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(in, out) {
t.Error("In and out are not equal")
}
}
func TestMockCreateBasic(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
user := User{Name: "jinzhu"}
expect.Create(&user).WillSucceed(1, 1)
rowsAffected := db.Create(&user).RowsAffected
if rowsAffected != 1 {
t.Errorf("Expected rows affected to be 1 but got %d", rowsAffected)
}
if user.Id != 1 {
t.Errorf("User id field should be 1, but got %d", user.Id)
}
}
func TestMockCreateError(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
mockError := errors.New("Could not insert user")
user := User{Name: "jinzhu"}
expect.Create(&user).WillFail(mockError)
dbError := db.Create(&user).Error
if dbError == nil || dbError != mockError {
t.Errorf("Expected *DB.Error to be set, but it was not")
}
}
func TestMockSaveBasic(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
user := User{Name: "jinzhu"}
expect.Save(&user).WillSucceed(1, 1)
expected := db.Save(&user)
if err := expect.AssertExpectations(); err != nil {
t.Errorf("Expectations were not met %s", err.Error())
}
if expected.RowsAffected != 1 || user.Id != 1 {
t.Errorf("Expected result was not returned")
}
}
func TestMockUpdateBasic(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
newName := "uhznij"
user := User{Name: "jinzhu"}
expect.Model(&user).Update("name", newName).WillSucceed(1, 1)
db.Model(&user).Update("name", newName)
if err := expect.AssertExpectations(); err != nil {
t.Errorf("Expectations were not met %s", err.Error())
}
if user.Name != newName {
t.Errorf("Should have name %s but got %s", newName, user.Name)
}
}
func TestMockUpdatesBasic(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer db.Close()
if err != nil {
t.Fatal(err)
}
user := User{Name: "jinzhu", Age: 18}
updated := User{Name: "jinzhu", Age: 88}
expect.Model(&user).Updates(updated).WillSucceed(1, 1)
db.Model(&user).Updates(updated)
if err := expect.AssertExpectations(); err != nil {
t.Errorf("Expectations were not met %s", err.Error())
}
if user.Age != updated.Age {
t.Errorf("Should have age %d but got %d", user.Age, updated.Age)
}
}

View File

@ -139,6 +139,11 @@ func (role Role) IsAdmin() bool {
type Num int64
func (i Num) Value() (driver.Value, error) {
// guaranteed ok
return int64(i), nil
}
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte: