gorm/expecter_result.go
2017-11-21 19:02:54 +08:00

175 lines
4.4 KiB
Go

package gorm
import (
"database/sql/driver"
"fmt"
"reflect"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
// ExpectedQuery represents an expected query that will be executed and can
// return some rows. It presents a fluent API for chaining calls to other
// expectations
type ExpectedQuery interface {
Returns(model interface{}) ExpectedQuery
}
// ExpectedExec represents an expected exec that will be executed and can
// return a result. It presents a fluent API for chaining calls to other
// expectations
type ExpectedExec interface {
Returns(result driver.Result) ExpectedExec
}
// SqlmockQuery implements Query for go-sqlmock
type SqlmockQuery struct {
queries []*sqlmock.ExpectedQuery
}
func getRowForFields(fields []*Field) []driver.Value {
var values []driver.Value
for _, field := range fields {
if field.IsNormal {
value := field.Field
// dereference pointers
if field.Field.Kind() == reflect.Ptr {
value = reflect.Indirect(field.Field)
}
// check if we have a zero Value
// just append nil if it's not valid, so sqlmock won't complain
if !value.IsValid() {
values = append(values, nil)
continue
}
concreteVal := value.Interface()
if driver.IsValue(concreteVal) {
values = append(values, concreteVal)
} else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 {
values = append(values, value.Int())
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
if convertedValue, err := valuer.Value(); err == nil {
values = append(values, convertedValue)
}
}
}
}
return values
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
var (
columns []string
relations = make(map[string]*Relationship)
rowsSet []*sqlmock.Rows
)
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
// we get the primary model's columns here
if field.IsNormal {
columns = append(columns, field.DBName)
}
// check relations
if !field.IsNormal {
relationVal := reflect.ValueOf(field.Relationship)
isNil := relationVal.IsNil()
if !isNil {
relations[field.Name] = field.Relationship
}
}
}
rows := sqlmock.NewRows(columns)
outVal := indirect(reflect.ValueOf(out))
if outVal.Kind() == reflect.Slice {
outSlice := []interface{}{}
for i := 0; i < outVal.Len(); i++ {
outSlice = append(outSlice, outVal.Index(i).Interface())
}
for _, outElem := range outSlice {
scope := &Scope{Value: outElem}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
}
} else if outVal.Kind() == reflect.Struct {
scope := &Scope{Value: out}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
rowsSet = append(rowsSet, rows)
for name, relation := range relations {
switch relation.Kind {
case "has_many":
rVal := outVal.FieldByName(name)
rType := rVal.Type().Elem()
rScope := &Scope{Value: reflect.New(rType).Interface()}
rColumns := []string{}
for _, field := range rScope.GetModelStruct().StructFields {
rColumns = append(rColumns, field.DBName)
}
hasReturnRows := rVal.Len() > 0
// in this case we definitely have a slice
if hasReturnRows {
rRows := sqlmock.NewRows(rColumns)
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rRows = rRows.AddRow(row...)
rowsSet = append(rowsSet, rRows)
}
}
case "has_one":
case "many2many":
default:
continue
}
}
} else {
panic(fmt.Errorf("Can only get rows for slice or struct"))
}
return rowsSet
}
// Returns accepts an out type which should either be a struct or slice. Under
// the hood, it converts a gorm model struct to sql.Rows that can be passed to
// the underlying mock db
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
rows := q.getRowsForOutType(out)
for i, query := range q.queries {
query.WillReturnRows(rows[i])
}
return q
}
// SqlmockExec implements Exec for go-sqlmock
type SqlmockExec struct {
exec *sqlmock.ExpectedExec
}
// Returns accepts a driver.Result. It is passed directly to the underlying
// mock db. Useful for checking DAO behaviour in the event that the incorrect
// number of rows are affected by an Exec
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec {
e.exec = e.exec.WillReturnResult(result)
return e
}