diff --git a/expecter.go b/expecter.go index 6ded475e..7251dd3a 100644 --- a/expecter.go +++ b/expecter.go @@ -3,7 +3,6 @@ package gorm import ( "database/sql" "errors" - "fmt" "regexp" ) @@ -148,6 +147,6 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { // Find triggers a Query func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { - fmt.Printf("Expecting query: %s\n", "some query involving Find") - return h.adapter.ExpectQuery("some find condition") + h.gorm.Find(out, where...) + return h.adapter.ExpectQuery(regexp.QuoteMeta(h.recorder.stmt)) } diff --git a/expecter_result.go b/expecter_result.go index 80bc8ed7..1339c195 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -2,7 +2,10 @@ package gorm import ( "database/sql/driver" + "fmt" + "reflect" + "github.com/davecgh/go-spew/spew" sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" ) @@ -20,36 +23,76 @@ type SqlmockQuery struct { query *sqlmock.ExpectedQuery } -func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { - var ( - columns []string - rows *sqlmock.Rows - values []driver.Value - ) - - q.scope = &Scope{Value: out} - fields := q.scope.Fields() - +func getRowForFields(fields []*Field) []driver.Value { + var values []driver.Value for _, field := range fields { if field.IsNormal { - var ( - column = field.StructField.DBName - value = field.Field.Interface() - ) + value := field.Field - if isValue := driver.IsValue(value); isValue { - columns = append(columns, column) - values = append(values, value) - } else if valuer, ok := value.(driver.Valuer); ok { - if underlyingValue, err := valuer.Value(); err == nil { - values = append(values, underlyingValue) - columns = append(columns, field.StructField.DBName) + // dereference pointers + if field.Field.Kind() == reflect.Ptr { + value = reflect.Indirect(field.Field) + } + + // check if we have a zero Value + if !value.IsValid() { + values = append(values, nil) + continue + } + + concreteVal := value.Interface() + + // if we already have a driver.Value, just append + _, isValuer := concreteVal.(driver.Valuer) + spew.Printf("%s: %v\r\n", field.DBName, isValuer) + + if driver.IsValue(concreteVal) { + values = append(values, concreteVal) + } else if valuer, ok := concreteVal.(driver.Valuer); ok { + if convertedValue, err := valuer.Value(); err == nil { + values = append(values, convertedValue) } } } } - rows = sqlmock.NewRows(columns).AddRow(values...) + return values +} + +func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows { + var columns []string + + for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields { + if field.IsNormal { + columns = append(columns, field.DBName) + } + } + + rows := sqlmock.NewRows(columns) + + outVal := indirect(reflect.ValueOf(out)) + + if outVal.Kind() == reflect.Slice { + outSlice := []interface{}{} + for i := 0; i < outVal.Len(); i++ { + outSlice = append(outSlice, outVal.Index(i).Interface()) + } + + for _, outElem := range outSlice { + scope := &Scope{Value: outElem} + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + } + } else if outVal.Kind() == reflect.Struct { + scope := &Scope{Value: out} + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + } else { + panic(fmt.Errorf("Can only get rows for slice or struct")) + } + + spew.Dump(columns) + spew.Dump(rows) return rows } diff --git a/expecter_test.go b/expecter_test.go index cb77e01c..ba8cb8b7 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -57,22 +57,67 @@ func TestQueryReturn(t *testing.T) { t.Fatal(err) } - in := &User{Id: 1} - expectedOut := User{Id: 1, Name: "jinzhu"} + in := User{Id: 1} + out := User{Id: 1, Name: "jinzhu"} - expect.First(in).Returns(User{Id: 1, Name: "jinzhu"}) + expect.First(&in).Returns(out) - db.First(in) + db.First(&in) if e := expect.AssertExpectations(); e != nil { t.Error(e) } if in.Name != "jinzhu" { - t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name) + t.Errorf("Expected %s, got %s", out.Name, in.Name) } - if ne := reflect.DeepEqual(*in, expectedOut); !ne { + if ne := reflect.DeepEqual(in, out); !ne { t.Errorf("Not equal") } } + +func TestFindStructDest(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := &User{Id: 1} + + expect.Find(in) + db.Find(&User{Id: 1}) + + if e := expect.AssertExpectations(); e != nil { + t.Error(e) + } +} + +func TestFindSlice(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := []User{} + out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}} + + expect.Find(&in).Returns(&out) + db.Find(&in) + + if e := expect.AssertExpectations(); e != nil { + t.Error(e) + } + + if ne := reflect.DeepEqual(in, out); !ne { + t.Error("Expected equal slices") + } +} diff --git a/migration_test.go b/migration_test.go index 6c10f62e..0d389268 100644 --- a/migration_test.go +++ b/migration_test.go @@ -139,9 +139,9 @@ func (role Role) IsAdmin() bool { type Num int64 -func (i *Num) Value() (driver.Value, error) { +func (i Num) Value() (driver.Value, error) { // guaranteed ok - return int64(*i), nil + return int64(i), nil } func (i *Num) Scan(src interface{}) error {