From 8ae27d7ec78011f7e54a962ef46ba8ee615df229 Mon Sep 17 00:00:00 2001 From: Ian Tan Date: Wed, 22 Nov 2017 15:23:02 +0800 Subject: [PATCH] Add partial support for many_to_many --- expecter_result.go | 10 ++++++---- expecter_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/expecter_result.go b/expecter_result.go index e96f7c8c..59c2379b 100644 --- a/expecter_result.go +++ b/expecter_result.go @@ -49,8 +49,8 @@ func getRowForFields(fields []*Field) []driver.Value { if driver.IsValue(concreteVal) { values = append(values, concreteVal) - } else if value.Kind() == reflect.Int || value.Kind() == reflect.Int8 || value.Kind() == reflect.Int16 || value.Kind() == reflect.Int64 { - values = append(values, value.Int()) + } else if num, err := driver.DefaultParameterConverter.ConvertValue(concreteVal); err == nil { + values = append(values, num) } else if valuer, ok := concreteVal.(driver.Valuer); ok { if convertedValue, err := valuer.Value(); err == nil { values = append(values, convertedValue) @@ -69,12 +69,14 @@ func getRelationRows(rVal reflect.Value, fieldName string, relation *Relationshi ) switch relation.Kind { - case "has_many": + case "has_many", "many_to_many": elem := rVal.Type().Elem() scope := &Scope{Value: reflect.New(elem).Interface()} for _, field := range scope.GetModelStruct().StructFields { - columns = append(columns, field.DBName) + if field.IsNormal { + columns = append(columns, field.DBName) + } } rows = sqlmock.NewRows(columns) diff --git a/expecter_test.go b/expecter_test.go index 9eedd49c..6b84f198 100644 --- a/expecter_test.go +++ b/expecter_test.go @@ -119,7 +119,7 @@ func TestFindSlice(t *testing.T) { } } -func TestMockPreload(t *testing.T) { +func TestMockPreloadHasMany(t *testing.T) { db, expect, err := gorm.NewDefaultExpecter() defer func() { db.Close() @@ -144,3 +144,29 @@ func TestMockPreload(t *testing.T) { t.Error("In and out are not equal") } } + +func TestMockPreloadMany2Many(t *testing.T) { + db, expect, err := gorm.NewDefaultExpecter() + defer func() { + db.Close() + }() + + if err != nil { + t.Fatal(err) + } + + in := User{Id: 1} + languages := []Language{Language{Name: "ZH"}, Language{Name: "EN"}} + out := User{Id: 1, Languages: languages} + + expect.Preload("Languages").Find(&in).Returns(out) + db.Preload("Languages").Find(&in) + + if err := expect.AssertExpectations(); err != nil { + t.Error(err) + } + + // if !reflect.DeepEqual(in, out) { + // t.Error("In and out are not equal") + // } +}