Refactor code for extracting has_many relations

This commit is contained in:
Ian Tan 2017-11-21 19:32:17 +08:00
parent b06542dc77
commit 7aa08b9014
3 changed files with 36 additions and 37 deletions

View File

@ -2,8 +2,6 @@ package gorm
import ( import (
"regexp" "regexp"
"github.com/davecgh/go-spew/spew"
) )
// Recorder satisfies the logger interface // Recorder satisfies the logger interface
@ -40,7 +38,6 @@ func getStmtFromLog(values ...interface{}) Stmt {
// Print just sets the last recorded SQL statement // Print just sets the last recorded SQL statement
// TODO: find a better way to extract SQL from log messages // TODO: find a better way to extract SQL from log messages
func (r *Recorder) Print(args ...interface{}) { func (r *Recorder) Print(args ...interface{}) {
spew.Dump(args...)
statement := getStmtFromLog(args...) statement := getStmtFromLog(args...)
if statement.sql != "" { if statement.sql != "" {

View File

@ -6,8 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"sync" "sync"
"github.com/davecgh/go-spew/spew"
) )
var pool *NoopDriver var pool *NoopDriver
@ -141,8 +139,6 @@ func (c *NoopConnection) open() (*sql.DB, error) {
return db, err return db, err
} }
fmt.Println(db.Ping())
return db, db.Ping() return db, db.Ping()
} }
@ -161,7 +157,6 @@ func (c *NoopConnection) Close() error {
// Begin implements sql/driver.Conn // Begin implements sql/driver.Conn
func (c *NoopConnection) Begin() (driver.Tx, error) { func (c *NoopConnection) Begin() (driver.Tx, error) {
fmt.Println("Called Begin()")
return c, nil return c, nil
} }
@ -172,13 +167,11 @@ func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result,
// Prepare implements sql/driver.Conn // Prepare implements sql/driver.Conn
func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) { func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) {
spew.Dump(query)
return &NoopStmt{}, nil return &NoopStmt{}, nil
} }
// Query implements sql/driver.Conn // Query implements sql/driver.Conn
func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) { func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
spew.Dump(args)
return &NoopRows{}, nil return &NoopRows{}, nil
} }

View File

@ -62,6 +62,38 @@ func getRowForFields(fields []*Field) []driver.Value {
return values return values
} }
func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
var (
rows *sqlmock.Rows
columns []string
)
switch relation.Kind {
case "has_many":
elem := rVal.Type().Elem()
scope := &Scope{Value: reflect.New(elem).Interface()}
for _, field := range scope.GetModelStruct().StructFields {
columns = append(columns, field.DBName)
}
rows = sqlmock.NewRows(columns)
// in this case we definitely have a slice
if rVal.Len() > 0 {
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
}
}
return rows, true
default:
return nil, false
}
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
var ( var (
columns []string columns []string
@ -109,34 +141,11 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
rowsSet = append(rowsSet, rows) rowsSet = append(rowsSet, rows)
for name, relation := range relations { for name, relation := range relations {
switch relation.Kind { rVal := outVal.FieldByName(name)
case "has_many": relationRows, hasRows := getRelationRows(rVal, name, relation)
rVal := outVal.FieldByName(name)
rType := rVal.Type().Elem()
rScope := &Scope{Value: reflect.New(rType).Interface()}
rColumns := []string{}
for _, field := range rScope.GetModelStruct().StructFields { if hasRows {
rColumns = append(rColumns, field.DBName) rowsSet = append(rowsSet, relationRows)
}
hasReturnRows := rVal.Len() > 0
// in this case we definitely have a slice
if hasReturnRows {
rRows := sqlmock.NewRows(rColumns)
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rRows = rRows.AddRow(row...)
rowsSet = append(rowsSet, rRows)
}
}
case "has_one":
case "many2many":
default:
continue
} }
} }
} else { } else {