Add naive get rows function

This commit is contained in:
Ian Tan 2017-11-20 16:04:38 +08:00
parent a4ba25028d
commit b4591d4db8
3 changed files with 92 additions and 9 deletions

View File

@ -12,14 +12,37 @@ type Recorder struct {
stmt string stmt string
} }
type Stmt struct {
stmtType string
sql string
args []interface{}
}
func getStmtFromLog(values ...interface{}) Stmt {
var statement Stmt
if len(values) > 1 {
var (
level = values[0]
)
if level == "sql" {
statement.args = values[4].([]interface{})
statement.sql = values[3].(string)
}
return statement
}
return statement
}
// Print just sets the last recorded SQL statement // Print just sets the last recorded SQL statement
// TODO: find a better way to extract SQL from log messages // TODO: find a better way to extract SQL from log messages
func (r *Recorder) Print(args ...interface{}) { func (r *Recorder) Print(args ...interface{}) {
msgs := LogFormatter(args...) statement := getStmtFromLog(args...)
if len(msgs) >= 4 { if statement.sql != "" {
if v, ok := msgs[3].(string); ok { r.stmt = statement.sql
r.stmt = v
}
} }
} }

View File

@ -16,12 +16,40 @@ type ExpectedExec interface {
// SqlmockQuery implements Query for asserter go-sqlmock // SqlmockQuery implements Query for asserter go-sqlmock
type SqlmockQuery struct { type SqlmockQuery struct {
scope *Scope
query *sqlmock.ExpectedQuery query *sqlmock.ExpectedQuery
} }
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
rows := sqlmock.NewRows([]string{"column1", "column2", "column3"}) var (
rows = rows.AddRow("someval1", "someval2", "someval3") columns []string
rows *sqlmock.Rows
values []driver.Value
)
q.scope = &Scope{Value: out}
fields := q.scope.Fields()
for _, field := range fields {
if field.IsNormal {
var (
column = field.StructField.DBName
value = field.Field.Interface()
)
if isValue := driver.IsValue(value); isValue {
columns = append(columns, column)
values = append(values, value)
} else if valuer, ok := value.(driver.Valuer); ok {
if underlyingValue, err := valuer.Value(); err == nil {
values = append(values, underlyingValue)
columns = append(columns, field.StructField.DBName)
}
}
}
}
rows = sqlmock.NewRows(columns).AddRow(values...)
return rows return rows
} }
@ -34,7 +62,8 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
} }
type SqlmockExec struct { type SqlmockExec struct {
exec *sqlmock.ExpectedExec scope *Scope
exec *sqlmock.ExpectedExec
} }
func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec { func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec {

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"reflect"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -39,9 +40,39 @@ func TestQuery(t *testing.T) {
} }
expect.First(&User{}) expect.First(&User{})
db.First(&User{}) db.LogMode(true).First(&User{})
if err := expect.AssertExpectations(); err != nil { if err := expect.AssertExpectations(); err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestQueryReturn(t *testing.T) {
db, expect, err := gorm.NewDefaultExpecter()
defer func() {
db.Close()
}()
if err != nil {
t.Fatal(err)
}
in := &User{Id: 1}
expectedOut := User{Id: 1, Name: "jinzhu"}
expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
db.First(in)
if e := expect.AssertExpectations(); e != nil {
t.Error(e)
}
if in.Name != "jinzhu" {
t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
}
if ne := reflect.DeepEqual(*in, expectedOut); !ne {
t.Errorf("Not equal")
}
}