Add naive get rows function
This commit is contained in:
parent
a4ba25028d
commit
b4591d4db8
33
expecter.go
33
expecter.go
@ -12,14 +12,37 @@ type Recorder struct {
|
||||
stmt string
|
||||
}
|
||||
|
||||
type Stmt struct {
|
||||
stmtType string
|
||||
sql string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
func getStmtFromLog(values ...interface{}) Stmt {
|
||||
var statement Stmt
|
||||
|
||||
if len(values) > 1 {
|
||||
var (
|
||||
level = values[0]
|
||||
)
|
||||
|
||||
if level == "sql" {
|
||||
statement.args = values[4].([]interface{})
|
||||
statement.sql = values[3].(string)
|
||||
}
|
||||
|
||||
return statement
|
||||
}
|
||||
|
||||
return statement
|
||||
}
|
||||
|
||||
// Print just sets the last recorded SQL statement
|
||||
// TODO: find a better way to extract SQL from log messages
|
||||
func (r *Recorder) Print(args ...interface{}) {
|
||||
msgs := LogFormatter(args...)
|
||||
if len(msgs) >= 4 {
|
||||
if v, ok := msgs[3].(string); ok {
|
||||
r.stmt = v
|
||||
}
|
||||
statement := getStmtFromLog(args...)
|
||||
if statement.sql != "" {
|
||||
r.stmt = statement.sql
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,12 +16,40 @@ type ExpectedExec interface {
|
||||
|
||||
// SqlmockQuery implements Query for asserter go-sqlmock
|
||||
type SqlmockQuery struct {
|
||||
scope *Scope
|
||||
query *sqlmock.ExpectedQuery
|
||||
}
|
||||
|
||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
|
||||
rows := sqlmock.NewRows([]string{"column1", "column2", "column3"})
|
||||
rows = rows.AddRow("someval1", "someval2", "someval3")
|
||||
var (
|
||||
columns []string
|
||||
rows *sqlmock.Rows
|
||||
values []driver.Value
|
||||
)
|
||||
|
||||
q.scope = &Scope{Value: out}
|
||||
fields := q.scope.Fields()
|
||||
|
||||
for _, field := range fields {
|
||||
if field.IsNormal {
|
||||
var (
|
||||
column = field.StructField.DBName
|
||||
value = field.Field.Interface()
|
||||
)
|
||||
|
||||
if isValue := driver.IsValue(value); isValue {
|
||||
columns = append(columns, column)
|
||||
values = append(values, value)
|
||||
} else if valuer, ok := value.(driver.Valuer); ok {
|
||||
if underlyingValue, err := valuer.Value(); err == nil {
|
||||
values = append(values, underlyingValue)
|
||||
columns = append(columns, field.StructField.DBName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rows = sqlmock.NewRows(columns).AddRow(values...)
|
||||
|
||||
return rows
|
||||
}
|
||||
@ -34,6 +62,7 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
||||
}
|
||||
|
||||
type SqlmockExec struct {
|
||||
scope *Scope
|
||||
exec *sqlmock.ExpectedExec
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
@ -39,9 +40,39 @@ func TestQuery(t *testing.T) {
|
||||
}
|
||||
|
||||
expect.First(&User{})
|
||||
db.First(&User{})
|
||||
db.LogMode(true).First(&User{})
|
||||
|
||||
if err := expect.AssertExpectations(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryReturn(t *testing.T) {
|
||||
db, expect, err := gorm.NewDefaultExpecter()
|
||||
defer func() {
|
||||
db.Close()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
in := &User{Id: 1}
|
||||
expectedOut := User{Id: 1, Name: "jinzhu"}
|
||||
|
||||
expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
|
||||
|
||||
db.First(in)
|
||||
|
||||
if e := expect.AssertExpectations(); e != nil {
|
||||
t.Error(e)
|
||||
}
|
||||
|
||||
if in.Name != "jinzhu" {
|
||||
t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
|
||||
}
|
||||
|
||||
if ne := reflect.DeepEqual(*in, expectedOut); !ne {
|
||||
t.Errorf("Not equal")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user