diff --git a/expecter_result.go b/expecter_result.go index f8faab5b..f68ec388 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -27,6 +27,7 @@ type ExpectedExec interface { type SqlmockQuery struct { mock sqlmock.Sqlmock queries []Stmt + scope *Scope } func getRowForFields(fields []*Field) []driver.Value { @@ -94,7 +95,7 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel rows = rows.AddRow(row...) return rows, true - case "has_many", "many_to_many": + case "has_many": elem := rVal.Type().Elem() scope := &Scope{Value: reflect.New(elem).Interface()} @@ -103,9 +104,38 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel columns = append(columns, field.DBName) } } - // spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName) - // spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns)) - columns = append(columns, "user_id", "language_id") + + rows = sqlmock.NewRows(columns) + + if rVal.Len() > 0 { + for i := 0; i < rVal.Len(); i++ { + scope := &Scope{Value: rVal.Index(i).Interface()} + row := getRowForFields(scope.Fields()) + rows = rows.AddRow(row...) + } + + return rows, true + } + + return nil, false + case "many_to_many": + elem := rVal.Type().Elem() + scope := &Scope{Value: reflect.New(elem).Interface()} + joinTable := relation.JoinTableHandler.(*JoinTableHandler) + + for _, field := range scope.GetModelStruct().StructFields { + if field.IsNormal { + columns = append(columns, field.DBName) + } + } + + for _, key := range joinTable.Source.ForeignKeys { + columns = append(columns, key.DBName) + } + + for _, key := range joinTable.Destination.ForeignKeys { + columns = append(columns, key.DBName) + } rows = sqlmock.NewRows(columns) @@ -114,7 +144,15 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel for i := 0; i < rVal.Len(); i++ { scope := &Scope{Value: rVal.Index(i).Interface()} row := getRowForFields(scope.Fields()) - row = append(row, 1, 1) + + // need to append the values for join table keys + sourcePk := q.scope.PrimaryKeyValue() + destModelType := joinTable.Destination.ModelType + destModelVal := reflect.New(destModelType).Interface() + destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue() + + row = append(row, sourcePk, destPkVal) + rows = rows.AddRow(row...) } @@ -152,8 +190,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows { rows = rows.AddRow(row...) } } else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1 - scope := &Scope{Value: out} - row := getRowForFields(scope.Fields()) + row := getRowForFields(q.scope.Fields()) rows = rows.AddRow(row...) } else { panic(fmt.Errorf("Can only get rows for slice or struct")) @@ -167,6 +204,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows { // the underlying mock db func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { scope := (&Scope{}).New(out) + q.scope = scope outVal := indirect(reflect.ValueOf(out)) // rows := q.getRowsForOutType(out)