Refactor code for extracting has_many relations

This commit is contained in:
Ian Tan 2017-11-21 19:32:17 +08:00
parent b06542dc77
commit 7aa08b9014
3 changed files with 36 additions and 37 deletions

View File

@ -2,8 +2,6 @@ package gorm
import (
"regexp"
"github.com/davecgh/go-spew/spew"
)
// Recorder satisfies the logger interface
@ -40,7 +38,6 @@ func getStmtFromLog(values ...interface{}) Stmt {
// Print just sets the last recorded SQL statement
// TODO: find a better way to extract SQL from log messages
func (r *Recorder) Print(args ...interface{}) {
spew.Dump(args...)
statement := getStmtFromLog(args...)
if statement.sql != "" {

View File

@ -6,8 +6,6 @@ import (
"fmt"
"io"
"sync"
"github.com/davecgh/go-spew/spew"
)
var pool *NoopDriver
@ -141,8 +139,6 @@ func (c *NoopConnection) open() (*sql.DB, error) {
return db, err
}
fmt.Println(db.Ping())
return db, db.Ping()
}
@ -161,7 +157,6 @@ func (c *NoopConnection) Close() error {
// Begin implements sql/driver.Conn
func (c *NoopConnection) Begin() (driver.Tx, error) {
fmt.Println("Called Begin()")
return c, nil
}
@ -172,13 +167,11 @@ func (c *NoopConnection) Exec(query string, args []driver.Value) (driver.Result,
// Prepare implements sql/driver.Conn
func (c *NoopConnection) Prepare(query string) (driver.Stmt, error) {
spew.Dump(query)
return &NoopStmt{}, nil
}
// Query implements sql/driver.Conn
func (c *NoopConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
spew.Dump(args)
return &NoopRows{}, nil
}

View File

@ -62,6 +62,38 @@ func getRowForFields(fields []*Field) []driver.Value {
return values
}
func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationship) (*sqlmock.Rows, bool) {
var (
rows *sqlmock.Rows
columns []string
)
switch relation.Kind {
case "has_many":
elem := rVal.Type().Elem()
scope := &Scope{Value: reflect.New(elem).Interface()}
for _, field := range scope.GetModelStruct().StructFields {
columns = append(columns, field.DBName)
}
rows = sqlmock.NewRows(columns)
// in this case we definitely have a slice
if rVal.Len() > 0 {
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
}
}
return rows, true
default:
return nil, false
}
}
func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
var (
columns []string
@ -109,34 +141,11 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows {
rowsSet = append(rowsSet, rows)
for name, relation := range relations {
switch relation.Kind {
case "has_many":
rVal := outVal.FieldByName(name)
rType := rVal.Type().Elem()
rScope := &Scope{Value: reflect.New(rType).Interface()}
rColumns := []string{}
relationRows, hasRows := getRelationRows(rVal, name, relation)
for _, field := range rScope.GetModelStruct().StructFields {
rColumns = append(rColumns, field.DBName)
}
hasReturnRows := rVal.Len() > 0
// in this case we definitely have a slice
if hasReturnRows {
rRows := sqlmock.NewRows(rColumns)
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rRows = rRows.AddRow(row...)
rowsSet = append(rowsSet, rRows)
}
}
case "has_one":
case "many2many":
default:
continue
if hasRows {
rowsSet = append(rowsSet, relationRows)
}
}
} else {