Automatically insert join table values for many_to_many mock queries
This commit is contained in:
		
							parent
							
								
									5ef2153ab4
								
							
						
					
					
						commit
						4128722761
					
				@ -27,6 +27,7 @@ type ExpectedExec interface {
 | 
			
		||||
type SqlmockQuery struct {
 | 
			
		||||
	mock    sqlmock.Sqlmock
 | 
			
		||||
	queries []Stmt
 | 
			
		||||
	scope   *Scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getRowForFields(fields []*Field) []driver.Value {
 | 
			
		||||
@ -94,7 +95,7 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
 | 
			
		||||
		rows = rows.AddRow(row...)
 | 
			
		||||
 | 
			
		||||
		return rows, true
 | 
			
		||||
	case "has_many", "many_to_many":
 | 
			
		||||
	case "has_many":
 | 
			
		||||
		elem := rVal.Type().Elem()
 | 
			
		||||
		scope := &Scope{Value: reflect.New(elem).Interface()}
 | 
			
		||||
 | 
			
		||||
@ -103,9 +104,38 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
 | 
			
		||||
				columns = append(columns, field.DBName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName)
 | 
			
		||||
		// spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns))
 | 
			
		||||
		columns = append(columns, "user_id", "language_id")
 | 
			
		||||
 | 
			
		||||
		rows = sqlmock.NewRows(columns)
 | 
			
		||||
 | 
			
		||||
		if rVal.Len() > 0 {
 | 
			
		||||
			for i := 0; i < rVal.Len(); i++ {
 | 
			
		||||
				scope := &Scope{Value: rVal.Index(i).Interface()}
 | 
			
		||||
				row := getRowForFields(scope.Fields())
 | 
			
		||||
				rows = rows.AddRow(row...)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return rows, true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil, false
 | 
			
		||||
	case "many_to_many":
 | 
			
		||||
		elem := rVal.Type().Elem()
 | 
			
		||||
		scope := &Scope{Value: reflect.New(elem).Interface()}
 | 
			
		||||
		joinTable := relation.JoinTableHandler.(*JoinTableHandler)
 | 
			
		||||
 | 
			
		||||
		for _, field := range scope.GetModelStruct().StructFields {
 | 
			
		||||
			if field.IsNormal {
 | 
			
		||||
				columns = append(columns, field.DBName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, key := range joinTable.Source.ForeignKeys {
 | 
			
		||||
			columns = append(columns, key.DBName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, key := range joinTable.Destination.ForeignKeys {
 | 
			
		||||
			columns = append(columns, key.DBName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rows = sqlmock.NewRows(columns)
 | 
			
		||||
 | 
			
		||||
@ -114,7 +144,15 @@ func (q *SqlmockQuery) getRelationRows(rVal reflect.Value, fieldName string, rel
 | 
			
		||||
			for i := 0; i < rVal.Len(); i++ {
 | 
			
		||||
				scope := &Scope{Value: rVal.Index(i).Interface()}
 | 
			
		||||
				row := getRowForFields(scope.Fields())
 | 
			
		||||
				row = append(row, 1, 1)
 | 
			
		||||
 | 
			
		||||
				// need to append the values for join table keys
 | 
			
		||||
				sourcePk := q.scope.PrimaryKeyValue()
 | 
			
		||||
				destModelType := joinTable.Destination.ModelType
 | 
			
		||||
				destModelVal := reflect.New(destModelType).Interface()
 | 
			
		||||
				destPkVal := (&Scope{Value: destModelVal}).PrimaryKeyValue()
 | 
			
		||||
 | 
			
		||||
				row = append(row, sourcePk, destPkVal)
 | 
			
		||||
 | 
			
		||||
				rows = rows.AddRow(row...)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -152,8 +190,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
 | 
			
		||||
			rows = rows.AddRow(row...)
 | 
			
		||||
		}
 | 
			
		||||
	} else if outVal.Kind() == reflect.Struct { // SELECT with LIMIT 1
 | 
			
		||||
		scope := &Scope{Value: out}
 | 
			
		||||
		row := getRowForFields(scope.Fields())
 | 
			
		||||
		row := getRowForFields(q.scope.Fields())
 | 
			
		||||
		rows = rows.AddRow(row...)
 | 
			
		||||
	} else {
 | 
			
		||||
		panic(fmt.Errorf("Can only get rows for slice or struct"))
 | 
			
		||||
@ -167,6 +204,7 @@ func (q *SqlmockQuery) getDestRows(out interface{}) *sqlmock.Rows {
 | 
			
		||||
// the underlying mock db
 | 
			
		||||
func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery {
 | 
			
		||||
	scope := (&Scope{}).New(out)
 | 
			
		||||
	q.scope = scope
 | 
			
		||||
	outVal := indirect(reflect.ValueOf(out))
 | 
			
		||||
 | 
			
		||||
	// rows := q.getRowsForOutType(out)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user