117 lines
2.5 KiB
Go
117 lines
2.5 KiB
Go
package gorm
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/davecgh/go-spew/spew"
|
|
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
|
)
|
|
|
|
type ExpectedQuery interface {
|
|
Returns(model interface{}) ExpectedQuery
|
|
}
|
|
|
|
type ExpectedExec interface {
|
|
Returns(result driver.Result) ExpectedExec
|
|
}
|
|
|
|
// SqlmockQuery implements Query for asserter go-sqlmock
|
|
type SqlmockQuery struct {
|
|
scope *Scope
|
|
query *sqlmock.ExpectedQuery
|
|
}
|
|
|
|
func getRowForFields(fields []*Field) []driver.Value {
|
|
var values []driver.Value
|
|
for _, field := range fields {
|
|
if field.IsNormal {
|
|
value := field.Field
|
|
|
|
// dereference pointers
|
|
if field.Field.Kind() == reflect.Ptr {
|
|
value = reflect.Indirect(field.Field)
|
|
}
|
|
|
|
// check if we have a zero Value
|
|
if !value.IsValid() {
|
|
values = append(values, nil)
|
|
continue
|
|
}
|
|
|
|
concreteVal := value.Interface()
|
|
|
|
// if we already have a driver.Value, just append
|
|
_, isValuer := concreteVal.(driver.Valuer)
|
|
spew.Printf("%s: %v\r\n", field.DBName, isValuer)
|
|
|
|
if driver.IsValue(concreteVal) {
|
|
values = append(values, concreteVal)
|
|
} else if valuer, ok := concreteVal.(driver.Valuer); ok {
|
|
if convertedValue, err := valuer.Value(); err == nil {
|
|
values = append(values, convertedValue)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return values
|
|
}
|
|
|
|
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
|
var columns []string
|
|
|
|
for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
|
|
if field.IsNormal {
|
|
columns = append(columns, field.DBName)
|
|
}
|
|
}
|
|
|
|
rows := sqlmock.NewRows(columns)
|
|
|
|
outVal := indirect(reflect.ValueOf(out))
|
|
|
|
if outVal.Kind() == reflect.Slice {
|
|
outSlice := []interface{}{}
|
|
for i := 0; i < outVal.Len(); i++ {
|
|
outSlice = append(outSlice, outVal.Index(i).Interface())
|
|
}
|
|
|
|
for _, outElem := range outSlice {
|
|
scope := &Scope{Value: outElem}
|
|
row := getRowForFields(scope.Fields())
|
|
rows = rows.AddRow(row...)
|
|
}
|
|
} else if outVal.Kind() == reflect.Struct {
|
|
scope := &Scope{Value: out}
|
|
row := getRowForFields(scope.Fields())
|
|
rows = rows.AddRow(row...)
|
|
} else {
|
|
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
|
}
|
|
|
|
spew.Dump(columns)
|
|
spew.Dump(rows)
|
|
|
|
return rows
|
|
}
|
|
|
|
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
|
rows := q.getRowsForOutType(out)
|
|
q.query = q.query.WillReturnRows(rows)
|
|
|
|
return q
|
|
}
|
|
|
|
type SqlmockExec struct {
|
|
scope *Scope
|
|
exec *sqlmock.ExpectedExec
|
|
}
|
|
|
|
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec {
|
|
e.exec = e.exec.WillReturnResult(result)
|
|
|
|
return e
|
|
}
|