Add naive get rows function

This commit is contained in:
Ian Tan 2017-11-20 16:04:38 +08:00
parent a4ba25028d
commit b4591d4db8
3 changed files with 92 additions and 9 deletions

View File

@ -12,14 +12,37 @@ type Recorder struct {
stmt string
}
type Stmt struct {
stmtType string
sql string
args []interface{}
}
func getStmtFromLog(values ...interface{}) Stmt {
var statement Stmt
if len(values) > 1 {
var (
level = values[0]
)
if level == "sql" {
statement.args = values[4].([]interface{})
statement.sql = values[3].(string)
}
return statement
}
return statement
}
// Print just sets the last recorded SQL statement
// TODO: find a better way to extract SQL from log messages
func (r *Recorder) Print(args ...interface{}) {
msgs := LogFormatter(args...)
if len(msgs) >= 4 {
if v, ok := msgs[3].(string); ok {
r.stmt = v
}
statement := getStmtFromLog(args...)
if statement.sql != "" {
r.stmt = statement.sql
}
}

View File

@ -16,12 +16,40 @@ type ExpectedExec interface {
// SqlmockQuery implements Query for asserter go-sqlmock
type SqlmockQuery struct {
scope *Scope
query *sqlmock.ExpectedQuery
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
rows := sqlmock.NewRows([]string{"column1", "column2", "column3"})
rows = rows.AddRow("someval1", "someval2", "someval3")
var (
columns []string
rows *sqlmock.Rows
values []driver.Value
)
q.scope = &Scope{Value: out}
fields := q.scope.Fields()
for _, field := range fields {
if field.IsNormal {
var (
column = field.StructField.DBName
value = field.Field.Interface()
)
if isValue := driver.IsValue(value); isValue {
columns = append(columns, column)
values = append(values, value)
} else if valuer, ok := value.(driver.Valuer); ok {
if underlyingValue, err := valuer.Value(); err == nil {
values = append(values, underlyingValue)
columns = append(columns, field.StructField.DBName)
}
}
}
}
rows = sqlmock.NewRows(columns).AddRow(values...)
return rows
}
@ -34,7 +62,8 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
}
type SqlmockExec struct {
exec *sqlmock.ExpectedExec
scope *Scope
exec *sqlmock.ExpectedExec
}
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec {

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"reflect"
"testing"
"github.com/jinzhu/gorm"
@ -39,9 +40,39 @@ func TestQuery(t *testing.T) {
}
expect.First(&User{})
db.First(&User{})
db.LogMode(true).First(&User{})
if err := expect.AssertExpectations(); err != nil {
t.Error(err)
}
}
func TestQueryReturn(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
in := &User{Id: 1}
expectedOut := User{Id: 1, Name: "jinzhu"}
expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
db.First(in)
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
if in.Name != "jinzhu" {
t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
}
if ne := reflect.DeepEqual(*in, expectedOut); !ne {
t.Errorf("Not equal")
}
}