From 7aa08b901450ace44b12f252c2ae945119ae5da9 Mon Sep 17 00:00:00 2001 From: Ian Tan Date: Tue, 21 Nov 2017 19:32:17 +0800 Subject: [PATCH] Refactor code for extracting has_many relations --- expecter.go | 3 --- expecter_noop.go | 7 ------ expecter_result.go | 63 ++++++++++++++++++++++++++-------------------- 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/expecter.go b/expecter.go index 5da09d09..41d23ce3 100644 --- a/expecter.go +++ b/expecter.go @@ -2,8 +2,6 @@ package gorm import ( "regexp" - - "github.com/davecgh/go-spew/spew" ) // Recorder satisfies the logger interface @@ -40,7 +38,6 @@ func getStmtFromLog(values ...interface{}) Stmt { // Print just sets the last recorded SQL statement // TODO: find a better way to extract SQL from log messages func (r *Recorder) Print(args ...interface{}) { - spew.Dump(args...) statement := getStmtFromLog(args...) if statement.sql != "" { diff --git a/expecter_noop.go b/expecter_noop.go index 0ea9732d..fc9c9d89 100644 --- a/expecter_noop.go +++ b/expecter_noop.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "sync" - - "github.com/davecgh/go-spew/spew" ) var pool *NoopDriver @@ -141,8 +139,6 @@ func (c *NoopConnection) open() (*sql.DB, error) { return db, err } - fmt.Println(db.Ping()) - return db, db.Ping() } @@ -161,7 +157,6 @@ func (c *NoopConnection) Close() error { // Begin implements sql/driver.Conn func (c *NoopConnection) Begin() (driver.Tx, error) { - fmt.Println("Called Begin()") return c, nil } @@ -172,13 +167,11 @@ func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result, // Prepare implements sql/driver.Conn func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) { - spew.Dump(query) return &NoopStmt{}, nil } // Query implements sql/driver.Conn func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) { - spew.Dump(args) return &NoopRows{}, nil } diff --git a/expecter_result.go b/expecter_result.go index d659f97e..e96f7c8c 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -62,6 +62,38 @@ func getRowForFields(fields []*Field) []driver.Value { return values } +func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) { + var ( + rows *sqlmock.Rows + columns []string + ) + + switch relation.Kind { + case "has_many": + elem := rVal.Type().Elem() + scope := &Scope{Value: reflect.New(elem).Interface()} + + for _, field := range scope.GetModelStruct().StructFields { + columns = append(columns, field.DBName) + } + + rows = sqlmock.NewRows(columns) + + // in this case we definitely have a slice + if rVal.Len() > 0 { + for i := 0; i < rVal.Len(); i++ { + scope := &Scope{Value: rVal.Index(i).Interface()} + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + } + } + + return rows, true + default: + return nil, false + } +} + func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { var ( columns []string @@ -109,34 +141,11 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { rowsSet = append(rowsSet, rows) for name, relation := range relations { - switch relation.Kind { - case "has_many": - rVal := outVal.FieldByName(name) - rType := rVal.Type().Elem() - rScope := &Scope{Value: reflect.New(rType).Interface()} - rColumns := []string{} + rVal := outVal.FieldByName(name) + relationRows, hasRows := getRelationRows(rVal, name, relation) - for _, field := range rScope.GetModelStruct().StructFields { - rColumns = append(rColumns, field.DBName) - } - - hasReturnRows := rVal.Len() > 0 - - // in this case we definitely have a slice - if hasReturnRows { - rRows := sqlmock.NewRows(rColumns) - - for i := 0; i < rVal.Len(); i++ { - scope := &Scope{Value: rVal.Index(i).Interface()} - row := getRowForFields(scope.Fields()) - rRows = rRows.AddRow(row...) - rowsSet = append(rowsSet, rRows) - } - } - case "has_one": - case "many2many": - default: - continue + if hasRows { + rowsSet = append(rowsSet, relationRows) } } } else {