Automatically insert join table values for many_to_many mock queries
This commit is contained in:
parent
5ef2153ab4
commit
4128722761
@ -27,6 +27,7 @@ type ExpectedExec interface {
|
||||
type SqlmockQuery struct {
|
||||
mock sqlmock.Sqlmock
|
||||
queries []Stmt
|
||||
scope *Scope
|
||||
}
|
||||
|
||||
func getRowForFields(fields []*Field) []driver.Value {
|
||||
@ -94,7 +95,7 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
|
||||
rows = rows.AddRow(row...)
|
||||
|
||||
return rows, true
|
||||
case "has_many", "many_to_many":
|
||||
case "has_many":
|
||||
elem := rVal.Type().Elem()
|
||||
scope := &Scope{Value: reflect.New(elem).Interface()}
|
||||
|
||||
@ -103,9 +104,38 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
|
||||
columns = append(columns, field.DBName)
|
||||
}
|
||||
}
|
||||
// spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
|
||||
// spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
|
||||
columns = append(columns, "user_id", "language_id")
|
||||
|
||||
rows = sqlmock.NewRows(columns)
|
||||
|
||||
if rVal.Len() > 0 {
|
||||
for i := 0; i < rVal.Len(); i++ {
|
||||
scope := &Scope{Value: rVal.Index(i).Interface()}
|
||||
row := getRowForFields(scope.Fields())
|
||||
rows = rows.AddRow(row...)
|
||||
}
|
||||
|
||||
return rows, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
case "many_to_many":
|
||||
elem := rVal.Type().Elem()
|
||||
scope := &Scope{Value: reflect.New(elem).Interface()}
|
||||
joinTable := relation.JoinTableHandler.(*JoinTableHandler)
|
||||
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if field.IsNormal {
|
||||
columns = append(columns, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range joinTable.Source.ForeignKeys {
|
||||
columns = append(columns, key.DBName)
|
||||
}
|
||||
|
||||
for _, key := range joinTable.Destination.ForeignKeys {
|
||||
columns = append(columns, key.DBName)
|
||||
}
|
||||
|
||||
rows = sqlmock.NewRows(columns)
|
||||
|
||||
@ -114,7 +144,15 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
|
||||
for i := 0; i < rVal.Len(); i++ {
|
||||
scope := &Scope{Value: rVal.Index(i).Interface()}
|
||||
row := getRowForFields(scope.Fields())
|
||||
row = append(row, 1, 1)
|
||||
|
||||
// need to append the values for join table keys
|
||||
sourcePk := q.scope.PrimaryKeyValue()
|
||||
destModelType := joinTable.Destination.ModelType
|
||||
destModelVal := reflect.New(destModelType).Interface()
|
||||
destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue()
|
||||
|
||||
row = append(row, sourcePk, destPkVal)
|
||||
|
||||
rows = rows.AddRow(row...)
|
||||
}
|
||||
|
||||
@ -152,8 +190,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
|
||||
rows = rows.AddRow(row...)
|
||||
}
|
||||
} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
|
||||
scope := &Scope{Value: out}
|
||||
row := getRowForFields(scope.Fields())
|
||||
row := getRowForFields(q.scope.Fields())
|
||||
rows = rows.AddRow(row...)
|
||||
} else {
|
||||
panic(fmt.Errorf("Can only get rows for slice or struct"))
|
||||
@ -167,6 +204,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
|
||||
// the underlying mock db
|
||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
|
||||
scope := (&Scope{}).New(out)
|
||||
q.scope = scope
|
||||
outVal := indirect(reflect.ValueOf(out))
|
||||
|
||||
// rows := q.getRowsForOutType(out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user