Add naive get rows function
This commit is contained in:
parent
a4ba25028d
commit
b4591d4db8
33
expecter.go
33
expecter.go
@ -12,14 +12,37 @@ type Recorder struct {
|
|||||||
stmt string
|
stmt string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Stmt struct {
|
||||||
|
stmtType string
|
||||||
|
sql string
|
||||||
|
args []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStmtFromLog(values ...interface{}) Stmt {
|
||||||
|
var statement Stmt
|
||||||
|
|
||||||
|
if len(values) > 1 {
|
||||||
|
var (
|
||||||
|
level = values[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if level == "sql" {
|
||||||
|
statement.args = values[4].([]interface{})
|
||||||
|
statement.sql = values[3].(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
return statement
|
||||||
|
}
|
||||||
|
|
||||||
|
return statement
|
||||||
|
}
|
||||||
|
|
||||||
// Print just sets the last recorded SQL statement
|
// Print just sets the last recorded SQL statement
|
||||||
// TODO: find a better way to extract SQL from log messages
|
// TODO: find a better way to extract SQL from log messages
|
||||||
func (r *Recorder) Print(args ...interface{}) {
|
func (r *Recorder) Print(args ...interface{}) {
|
||||||
msgs := LogFormatter(args...)
|
statement := getStmtFromLog(args...)
|
||||||
if len(msgs) >= 4 {
|
if statement.sql != "" {
|
||||||
if v, ok := msgs[3].(string); ok {
|
r.stmt = statement.sql
|
||||||
r.stmt = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,12 +16,40 @@ type ExpectedExec interface {
|
|||||||
|
|
||||||
// SqlmockQuery implements Query for asserter go-sqlmock
|
// SqlmockQuery implements Query for asserter go-sqlmock
|
||||||
type SqlmockQuery struct {
|
type SqlmockQuery struct {
|
||||||
|
scope *Scope
|
||||||
query *sqlmock.ExpectedQuery
|
query *sqlmock.ExpectedQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
||||||
rows := sqlmock.NewRows([]string{"column1", "column2", "column3"})
|
var (
|
||||||
rows = rows.AddRow("someval1", "someval2", "someval3")
|
columns []string
|
||||||
|
rows *sqlmock.Rows
|
||||||
|
values []driver.Value
|
||||||
|
)
|
||||||
|
|
||||||
|
q.scope = &Scope{Value: out}
|
||||||
|
fields := q.scope.Fields()
|
||||||
|
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.IsNormal {
|
||||||
|
var (
|
||||||
|
column = field.StructField.DBName
|
||||||
|
value = field.Field.Interface()
|
||||||
|
)
|
||||||
|
|
||||||
|
if isValue := driver.IsValue(value); isValue {
|
||||||
|
columns = append(columns, column)
|
||||||
|
values = append(values, value)
|
||||||
|
} else if valuer, ok := value.(driver.Valuer); ok {
|
||||||
|
if underlyingValue, err := valuer.Value(); err == nil {
|
||||||
|
values = append(values, underlyingValue)
|
||||||
|
columns = append(columns, field.StructField.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows = sqlmock.NewRows(columns).AddRow(values...)
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
}
|
}
|
||||||
@ -34,7 +62,8 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SqlmockExec struct {
|
type SqlmockExec struct {
|
||||||
exec *sqlmock.ExpectedExec
|
scope *Scope
|
||||||
|
exec *sqlmock.ExpectedExec
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec {
|
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
@ -39,9 +40,39 @@ func TestQuery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
expect.First(&User{})
|
expect.First(&User{})
|
||||||
db.First(&User{})
|
db.LogMode(true).First(&User{})
|
||||||
|
|
||||||
if err := expect.AssertExpectations(); err != nil {
|
if err := expect.AssertExpectations(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryReturn(t *testing.T) {
|
||||||
|
db, expect, err := gorm.NewDefaultExpecter()
|
||||||
|
defer func() {
|
||||||
|
db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := &User{Id: 1}
|
||||||
|
expectedOut := User{Id: 1, Name: "jinzhu"}
|
||||||
|
|
||||||
|
expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
|
||||||
|
|
||||||
|
db.First(in)
|
||||||
|
|
||||||
|
if e := expect.AssertExpectations(); e != nil {
|
||||||
|
t.Error(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.Name != "jinzhu" {
|
||||||
|
t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ne := reflect.DeepEqual(*in, expectedOut); !ne {
|
||||||
|
t.Errorf("Not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user