From c2a28c63c3145c9d0a74a2e7b930c5de0839ca3e Mon Sep 17 00:00:00 2001 From: Ian Tan Date: Wed, 22 Nov 2017 20:53:49 +0800 Subject: [PATCH] Fix mock preload many2many generating empty relation rows --- callback_query_preload.go | 11 +++++++++++ expecter_result.go | 18 +++++++++++++++++- expecter_test.go | 12 ++++++++---- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce..d136aee4 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -60,6 +60,7 @@ func preloadCallback(scope *Scope) { currentScope.handleBelongsToPreload(field, currentPreloadConditions) case "many_to_many": currentScope.handleManyToManyPreload(field, currentPreloadConditions) + default: scope.Err(errors.New("unsupported relation")) } @@ -264,6 +265,8 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ // handleManyToManyPreload used to preload many to many associations func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { + // spew.Println("___ENTERING HANDLE MANY TO MANY___\r\n") + // spew.Printf("___POPULATING %s___:\r\n%s\r\n", field.Name, spew.Sdump(field)) var ( relation = field.Relationship joinTableHandler = relation.JoinTableHandler @@ -303,6 +306,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } rows, err := preloadDB.Rows() + // spew.Printf("___RETURNED ROWS___: \r\n%s\r\n", spew.Sdump(rows)) if scope.Err(err) != nil { return @@ -312,6 +316,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface columns, _ := rows.Columns() for rows.Next() { var ( + // This is a Language zero value struct elem = reflect.New(fieldType).Elem() fields = scope.New(elem.Addr().Interface()).Fields() ) @@ -323,6 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } scope.scan(rows, columns, append(fields, joinTableFields...)) + // spew.Printf("___FIELDS___: \r\n%s\r\n", spew.Sdump(fields)) var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table @@ -351,12 +357,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface foreignFieldNames = []string{} ) + // spew.Printf("Foreign fields: %s", spew.Sdump(relation.ForeignFieldNames)) for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { foreignFieldNames = append(foreignFieldNames, field.Name) } } + // spew.Printf("Scope value: %s", spew.Sdump(indirectScopeValue)) if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) @@ -367,6 +375,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } + + // spew.Printf("Field source map: %s", spew.Sdump(fieldsSourceMap)) + // spew.Printf("Link hash: %s", spew.Sdump(linkHash)) for source, link := range linkHash { for i, field := range fieldsSourceMap[source] { //If not 0 this means Value is a pointer and we already added preloaded models to it diff --git a/expecter_result.go b/expecter_result.go index 1e5e02b2..f5c35c93 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" + "github.com/davecgh/go-spew/spew" sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" ) @@ -46,6 +47,7 @@ func getRowForFields(fields []*Field) []driver.Value { } concreteVal := value.Interface() + // spew.Printf("%v: %v\r\n", field.Name, concreteVal) if driver.IsValue(concreteVal) { values = append(values, concreteVal) @@ -68,6 +70,12 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi columns []string ) + // we need to check for zero values + if reflect.DeepEqual(rVal.Interface(), reflect.New(rVal.Type()).Elem().Interface()) { + spew.Printf("FOUND EMPTY INTERFACE FOR %s\r\n", fieldName) + return nil, false + } + switch relation.Kind { case "has_one": scope := &Scope{Value: rVal.Interface()} @@ -94,6 +102,9 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi columns = append(columns, field.DBName) } } + spew.Printf("___GENERATING ROWS FOR %s___\r\n", fieldName) + spew.Printf("___COLUMNS___:\r\n%s\r\n", spew.Sdump(columns)) + columns = append(columns, "user_id", "language_id") rows = sqlmock.NewRows(columns) @@ -102,11 +113,14 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi for i := 0; i < rVal.Len(); i++ { scope := &Scope{Value: rVal.Index(i).Interface()} row := getRowForFields(scope.Fields()) + row = append(row, 1, 1) rows = rows.AddRow(row...) } + + return rows, true } - return rows, true + return nil, false default: return nil, false } @@ -163,6 +177,7 @@ func (q *SqlmockQuery) getRowsForOutType(out interface{}) []*sqlmock.Rows { relationRows, hasRows := getRelationRows(rVal, name, relation) if hasRows { + spew.Printf("___GENERATED ROWS FOR %s___\r\n: %s\r\n", name, spew.Sdump(relationRows)) rowsSet = append(rowsSet, relationRows) } } @@ -181,6 +196,7 @@ func (q *SqlmockQuery) Returns(out interface{}) ExpectedQuery { for i, query := range q.queries { query.WillReturnRows(rows[i]) + spew.Printf("___SET RETURN ROW___: %s", spew.Sdump(rows[i])) } return q diff --git a/expecter_test.go b/expecter_test.go index fbdbbb2d..64c70e90 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + "github.com/davecgh/go-spew/spew" "github.com/jinzhu/gorm" ) @@ -181,7 +182,7 @@ func TestMockPreloadMany2Many(t *testing.T) { } in := User{Id: 1} - languages := []Language{Language{Name: "ZH"}, Language{Name: "EN"}} + languages := []Language{Language{Name: "ZH"}} out := User{Id: 1, Languages: languages} expect.Preload("Languages").Find(&in).Returns(out) @@ -191,7 +192,10 @@ func TestMockPreloadMany2Many(t *testing.T) { t.Error(err) } - // if !reflect.DeepEqual(in, out) { - // t.Error("In and out are not equal") - // } + spew.Printf("______IN______\r\n%s\r\n", spew.Sdump(in)) + spew.Printf("______OUT______\r\n%s\r\n", spew.Sdump(out)) + + if !reflect.DeepEqual(in, out) { + t.Error("In and out are not equal") + } }