diff --git a/expecter.go b/expecter.go index c18a7f55..6ded475e 100644 --- a/expecter.go +++ b/expecter.go @@ -12,14 +12,37 @@ type Recorder struct { stmt string } +type Stmt struct { + stmtType string + sql string + args []interface{} +} + +func getStmtFromLog(values ...interface{}) Stmt { + var statement Stmt + + if len(values) > 1 { + var ( + level = values[0] + ) + + if level == "sql" { + statement.args = values[4].([]interface{}) + statement.sql = values[3].(string) + } + + return statement + } + + return statement +} + // Print just sets the last recorded SQL statement // TODO: find a better way to extract SQL from log messages func (r *Recorder) Print(args ...interface{}) { - msgs := LogFormatter(args...) - if len(msgs) >= 4 { - if v, ok := msgs[3].(string); ok { - r.stmt = v - } + statement := getStmtFromLog(args...) + if statement.sql != "" { + r.stmt = statement.sql } } diff --git a/expecter_result.go b/expecter_result.go index d8d60095..80bc8ed7 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -16,12 +16,40 @@ type ExpectedExec interface { // SqlmockQuery implements Query for asserter go-sqlmock type SqlmockQuery struct { + scope *Scope query *sqlmock.ExpectedQuery } func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { - rows := sqlmock.NewRows([]string{"column1", "column2", "column3"}) - rows = rows.AddRow("someval1", "someval2", "someval3") + var ( + columns []string + rows *sqlmock.Rows + values []driver.Value + ) + + q.scope = &Scope{Value: out} + fields := q.scope.Fields() + + for _, field := range fields { + if field.IsNormal { + var ( + column = field.StructField.DBName + value = field.Field.Interface() + ) + + if isValue := driver.IsValue(value); isValue { + columns = append(columns, column) + values = append(values, value) + } else if valuer, ok := value.(driver.Valuer); ok { + if underlyingValue, err := valuer.Value(); err == nil { + values = append(values, underlyingValue) + columns = append(columns, field.StructField.DBName) + } + } + } + } + + rows = sqlmock.NewRows(columns).AddRow(values...) return rows } @@ -34,7 +62,8 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { } type SqlmockExec struct { - exec *sqlmock.ExpectedExec + scope *Scope + exec *sqlmock.ExpectedExec } func (e *SqlmockExec) Returns(result driver.Result) ExpectedExec { diff --git a/expecter_test.go b/expecter_test.go index c93bd06f..cb77e01c 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "reflect" "testing" "github.com/jinzhu/gorm" @@ -39,9 +40,39 @@ func TestQuery(t *testing.T) { } expect.First(&User{}) - db.First(&User{}) + db.LogMode(true).First(&User{}) if err := expect.AssertExpectations(); err != nil { t.Error(err) } } + +func TestQueryReturn(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := &User{Id: 1} + expectedOut := User{Id: 1, Name: "jinzhu"} + + expect.First(in).Returns(User{Id: 1, Name: "jinzhu"}) + + db.First(in) + + if e := expect.AssertExpectations(); e != nil { + t.Error(e) + } + + if in.Name != "jinzhu" { + t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name) + } + + if ne := reflect.DeepEqual(*in, expectedOut); !ne { + t.Errorf("Not equal") + } +}