Automatically insert join table values for many_to_many mock queries

This commit is contained in:
Ian Tan 2017-11-23 18:39:06 +08:00
parent 5ef2153ab4
commit 4128722761

View File

@ -27,6 +27,7 @@ type ExpectedExec interface {
type SqlmockQuery struct {
mock sqlmock.Sqlmock
queries []Stmt
scope *Scope
}
func getRowForFields(fields []*Field) []driver.Value {
@ -94,7 +95,7 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
rows = rows.AddRow(row...)
return rows, true
case "has_many", "many_to_many":
case "has_many":
elem := rVal.Type().Elem()
scope := &Scope{Value: reflect.New(elem).Interface()}
@ -103,9 +104,38 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
columns = append(columns, field.DBName)
}
}
// spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
// spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
columns = append(columns, "user_id", "language_id")
rows = sqlmock.NewRows(columns)
if rVal.Len() > 0 {
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
rows = rows.AddRow(row...)
}
return rows, true
}
return nil, false
case "many_to_many":
elem := rVal.Type().Elem()
scope := &Scope{Value: reflect.New(elem).Interface()}
joinTable := relation.JoinTableHandler.(*JoinTableHandler)
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
columns = append(columns, field.DBName)
}
}
for _, key := range joinTable.Source.ForeignKeys {
columns = append(columns, key.DBName)
}
for _, key := range joinTable.Destination.ForeignKeys {
columns = append(columns, key.DBName)
}
rows = sqlmock.NewRows(columns)
@ -114,7 +144,15 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
for i := 0; i < rVal.Len(); i++ {
scope := &Scope{Value: rVal.Index(i).Interface()}
row := getRowForFields(scope.Fields())
row = append(row, 1, 1)
// need to append the values for join table keys
sourcePk := q.scope.PrimaryKeyValue()
destModelType := joinTable.Destination.ModelType
destModelVal := reflect.New(destModelType).Interface()
destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue()
row = append(row, sourcePk, destPkVal)
rows = rows.AddRow(row...)
}
@ -152,8 +190,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
rows = rows.AddRow(row...)
}
} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
scope := &Scope{Value: out}
row := getRowForFields(scope.Fields())
row := getRowForFields(q.scope.Fields())
rows = rows.AddRow(row...)
} else {
panic(fmt.Errorf("Can only get rows for slice or struct"))
@ -167,6 +204,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
// the underlying mock db
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
scope := (&Scope{}).New(out)
q.scope = scope
outVal := indirect(reflect.ValueOf(out))
// rows := q.getRowsForOutType(out)