From 42c3f39163a1676bafcce69bd1d34252fd6bf653 Mon Sep 17 00:00:00 2001 From: kimiby Date: Sun, 16 Aug 2015 10:10:11 +0300 Subject: [PATCH 1/7] m2m preload --- join_table_handler.go | 42 +++++++++++++++ preload.go | 121 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 158 insertions(+), 5 deletions(-) diff --git a/join_table_handler.go b/join_table_handler.go index 10e1e848..1fb25e5d 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,8 +13,10 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB SourceForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey + DestinationType() reflect.Type } type JoinTableForeignKey struct { @@ -153,3 +155,43 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so return db } } + +func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB { + quotedTable := handler.Table(db) + + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + var joinConditions []string + var queryConditions []string + var values []interface{} + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).inlineCondition(conditions...).QuotedTableName() + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Source.ForeignKeys { + condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName)) + + keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name}) + values = append(values, toQueryValues(keys)) + + queryConditions = append(queryConditions, condString) + } + + if len(conditions) > 0 { + queryConditions = append(queryConditions, toString(conditions[0])) + values = append(values, conditions[1:]...) + } + + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) + } else { + db.Error = errors.New("wrong source type for join table handler") + return db + } +} + +func (s JoinTableHandler) DestinationType() reflect.Type { + return s.Destination.ModelType +} diff --git a/preload.go b/preload.go index 0db6fbde..75be26dc 100644 --- a/preload.go +++ b/preload.go @@ -10,11 +10,22 @@ import ( func getRealValue(value reflect.Value, columns []string) (results []interface{}) { for _, column := range columns { - result := reflect.Indirect(value).FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() + if reflect.Indirect(value).FieldByName(column).IsValid() { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } else { + column = upFL(column) + if reflect.Indirect(value).FieldByName(column).IsValid() { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } } - results = append(results, result) } return } @@ -61,7 +72,7 @@ func Preload(scope *Scope) { case "belongs_to": currentScope.handleBelongsToPreload(field, conditions) case "many_to_many": - fallthrough + currentScope.handleHasManyToManyPreload(field, conditions) default: currentScope.Err(errors.New("not supported relation")) } @@ -189,6 +200,106 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } +func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + joinTableHandler := relation.JoinTableHandler + destType := joinTableHandler.DestinationType() + + db := scope.NewDB().Table(scope.db.NewScope(reflect.New(destType).Elem().Interface()).TableName()) + + var destKeys []string + var sourceKeys []string + + linkHash := make(map[string][]string) + + for _, key := range joinTableHandler.DestinationForeignKeys() { + destKeys = append(destKeys, key.DBName) + } + + for _, key := range joinTableHandler.SourceForeignKeys() { + sourceKeys = append(sourceKeys, key.DBName) + } + + results := reflect.New(field.Struct.Type).Elem() + rows, err := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value, conditions...).Rows() + + if scope.Err(err) != nil { + return + } + defer rows.Close() + + columns, _ := rows.Columns() + for rows.Next() { + elem := reflect.New(destType).Elem() + var values = make([]interface{}, len(columns)) + + fields := scope.New(elem.Addr().Interface()).Fields() + + for index, column := range columns { + if field, ok := fields[column]; ok { + if field.Field.Kind() == reflect.Ptr { + values[index] = field.Field.Addr().Interface() + } else { + values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() + } + } else { + var i interface{} + values[index] = &i + } + } + + scope.Err(rows.Scan(values...)) + + var destKey []interface{} + var sourceKey []interface{} + + for index, column := range columns { + value := values[index] + if field, ok := fields[column]; ok { + if field.Field.Kind() == reflect.Ptr { + field.Field.Set(reflect.ValueOf(value).Elem()) + } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { + field.Field.Set(v) + } + } else if strInSlice(column, destKeys) { + destKey = append(destKey, *(value.(*interface{}))) + } else if strInSlice(column, sourceKeys) { + sourceKey = append(sourceKey, *(value.(*interface{}))) + } + } + + if len(destKey) != 0 && len(sourceKey) != 0 { + linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey)) + } + + results = reflect.Append(results, elem) + + } + + if scope.IndirectValue().Kind() == reflect.Slice { + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + var checked []string + + object := reflect.Indirect(objects.Index(j)) + source := getRealValue(object, relation.AssociationForeignFieldNames) + + for i := 0; i < results.Len(); i++ { + result := results.Index(i) + value := getRealValue(result, relation.ForeignFieldNames) + + if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + checked = append(checked, toString(value)) + continue + } + } + } + } +} + func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { values := scope.IndirectValue() switch values.Kind() { From fbce0d98f6c36027255f73ecf787fc324e0a8007 Mon Sep 17 00:00:00 2001 From: kimiby Date: Sun, 16 Aug 2015 10:35:58 +0300 Subject: [PATCH 2/7] m2m preload utils --- utils_private.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/utils_private.go b/utils_private.go index 6f609ae0..8b43453f 100644 --- a/utils_private.go +++ b/utils_private.go @@ -5,6 +5,8 @@ import ( "reflect" "regexp" "runtime" + "unicode" + "unicode/utf8" ) func fileWithLineNum() string { @@ -71,3 +73,24 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { } return attrs } + +func toString(a interface{}) string { + return fmt.Sprintf("%v", a) +} + +func strInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +func upFL(s string) string { + if s == "" { + return "" + } + r, n := utf8.DecodeRuneInString(s) + return string(unicode.ToUpper(r)) + s[n:] +} From f35dee5531fccaf56720c659898da1835001b315 Mon Sep 17 00:00:00 2001 From: kimiby Date: Sun, 16 Aug 2015 12:25:25 +0300 Subject: [PATCH 3/7] preload_m2m fix and test --- preload.go | 16 ++++++++++++++++ preload_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/preload.go b/preload.go index 75be26dc..dd19a8bc 100644 --- a/preload.go +++ b/preload.go @@ -297,6 +297,22 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } } } + } else { + object := scope.IndirectValue() + var checked []string + source := getRealValue(object, relation.AssociationForeignFieldNames) + + for i := 0; i < results.Len(); i++ { + result := results.Index(i) + value := getRealValue(result, relation.ForeignFieldNames) + + if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + checked = append(checked, toString(value)) + continue + } + } } } diff --git a/preload_test.go b/preload_test.go index a6647bbd..aec6d816 100644 --- a/preload_test.go +++ b/preload_test.go @@ -603,6 +603,44 @@ func TestNestedPreload9(t *testing.T) { } } +func TestManyToManyPreload(t *testing.T) { + type ( + Level1 struct { + ID uint `gorm:"primary_key;"` + Value string + } + Level2 struct { + ID uint `gorm:"primary_key;"` + Value string + Level1s []Level1 `gorm:"many2many:levels;"` + } + ) + + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level2{Value: "Bob", Level1s: []Level1{ + Level1{Value: "ru"}, + Level1{Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + panic(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r From 281c5d10f6e00b1644f1473c7b7ecc6cf88724cb Mon Sep 17 00:00:00 2001 From: kimiby Date: Sun, 16 Aug 2015 12:36:23 +0300 Subject: [PATCH 4/7] preload_m2m improve --- model_struct.go | 20 ++++++++++++-------- preload.go | 17 ++++------------- utils_private.go | 10 ---------- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/model_struct.go b/model_struct.go index 26c58fc5..db6d9a88 100644 --- a/model_struct.go +++ b/model_struct.go @@ -62,14 +62,16 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignStructFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignStructFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } func (scope *Scope) GetModelStruct() *ModelStruct { @@ -224,6 +226,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if field, ok := scope.FieldByName(foreignKey); ok { relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name) joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } @@ -242,6 +245,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, name := range associationForeignKeys { if field, ok := toScope.FieldByName(name); ok { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } diff --git a/preload.go b/preload.go index dd19a8bc..c7810b63 100644 --- a/preload.go +++ b/preload.go @@ -16,15 +16,6 @@ func getRealValue(value reflect.Value, columns []string) (results []interface{}) result, _ = r.Value() } results = append(results, result) - } else { - column = upFL(column) - if reflect.Indirect(value).FieldByName(column).IsValid() { - result := reflect.Indirect(value).FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } } } return @@ -283,11 +274,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf var checked []string object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, relation.AssociationForeignFieldNames) + source := getRealValue(object, relation.AssociationForeignStructFieldNames) for i := 0; i < results.Len(); i++ { result := results.Index(i) - value := getRealValue(result, relation.ForeignFieldNames) + value := getRealValue(result, relation.ForeignStructFieldNames) if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { f := object.FieldByName(field.Name) @@ -300,11 +291,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } else { object := scope.IndirectValue() var checked []string - source := getRealValue(object, relation.AssociationForeignFieldNames) + source := getRealValue(object, relation.AssociationForeignStructFieldNames) for i := 0; i < results.Len(); i++ { result := results.Index(i) - value := getRealValue(result, relation.ForeignFieldNames) + value := getRealValue(result, relation.ForeignStructFieldNames) if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { f := object.FieldByName(field.Name) diff --git a/utils_private.go b/utils_private.go index 8b43453f..b82aa807 100644 --- a/utils_private.go +++ b/utils_private.go @@ -5,8 +5,6 @@ import ( "reflect" "regexp" "runtime" - "unicode" - "unicode/utf8" ) func fileWithLineNum() string { @@ -86,11 +84,3 @@ func strInSlice(a string, list []string) bool { } return false } - -func upFL(s string) string { - if s == "" { - return "" - } - r, n := utf8.DecodeRuneInString(s) - return string(unicode.ToUpper(r)) + s[n:] -} From 3ca4a72a4c3a189e8cba1d348c0b41070846e7ed Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2015 16:38:01 +0800 Subject: [PATCH 5/7] Handle DB error in callbacks --- scope.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 7b6764a6..fe3fccc3 100644 --- a/scope.go +++ b/scope.go @@ -213,7 +213,9 @@ func (scope *Scope) CallMethod(name string, checkError bool) { case func(s *Scope) error: scope.Err(f(scope)) case func(s *DB) error: - scope.Err(f(scope.NewDB())) + newDB := scope.NewDB() + scope.Err(f(newDB)) + scope.Err(newDB.Error) default: scope.Err(fmt.Errorf("unsupported function %v", name)) } From 27511118fe5fb20c0371a38804646fb8f7b183cf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2015 22:19:10 +0800 Subject: [PATCH 6/7] Update Preload Many2Many tests --- preload_test.go | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/preload_test.go b/preload_test.go index aec6d816..db37f37d 100644 --- a/preload_test.go +++ b/preload_test.go @@ -624,21 +624,47 @@ func TestManyToManyPreload(t *testing.T) { } want := Level2{Value: "Bob", Level1s: []Level1{ - Level1{Value: "ru"}, - Level1{Value: "en"}, + {Value: "ru"}, + {Value: "en"}, }} if err := DB.Save(&want).Error; err != nil { panic(err) } + want2 := Level2{Value: "Tom", Level1s: []Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + panic(err) + } + var got Level2 - if err := DB.Preload("Level1s").Find(&got).Error; err != nil { + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { panic(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } } func toJSONString(v interface{}) []byte { From f8e2f0456223494021cecb97cb387940309d063a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2015 23:09:07 +0800 Subject: [PATCH 7/7] Some Tweaks for Preload Many2Many, Add tests with inline conditions --- join_table_handler.go | 18 ++++-------------- preload.go | 27 +++++++++++++++++++-------- preload_test.go | 22 +++++++++++++++++++--- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/join_table_handler.go b/join_table_handler.go index 1fb25e5d..162af4a8 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,10 +13,9 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB + PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB SourceForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey - DestinationType() reflect.Type } type JoinTableForeignKey struct { @@ -139,8 +138,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so var queryConditions []string var values []interface{} if s.Source.ModelType == modelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } @@ -156,7 +155,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so } } -func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB { +func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { quotedTable := handler.Table(db) scope := db.NewScope(source) @@ -165,8 +164,8 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db var queryConditions []string var values []interface{} if s.Source.ModelType == modelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).inlineCondition(conditions...).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } @@ -179,11 +178,6 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db queryConditions = append(queryConditions, condString) } - if len(conditions) > 0 { - queryConditions = append(queryConditions, toString(conditions[0])) - values = append(values, conditions[1:]...) - } - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). Where(strings.Join(queryConditions, " AND "), values...) } else { @@ -191,7 +185,3 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db return db } } - -func (s JoinTableHandler) DestinationType() reflect.Type { - return s.Destination.ModelType -} diff --git a/preload.go b/preload.go index c7810b63..c0869a2a 100644 --- a/preload.go +++ b/preload.go @@ -195,14 +195,16 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf relation := field.Relationship joinTableHandler := relation.JoinTableHandler - destType := joinTableHandler.DestinationType() - - db := scope.NewDB().Table(scope.db.NewScope(reflect.New(destType).Elem().Interface()).TableName()) + destType := field.StructField.Struct.Type.Elem() + var isPtr bool + if destType.Kind() == reflect.Ptr { + isPtr = true + destType = destType.Elem() + } var destKeys []string var sourceKeys []string - - linkHash := make(map[string][]string) + var linkHash = make(map[string][]string) for _, key := range joinTableHandler.DestinationForeignKeys() { destKeys = append(destKeys, key.DBName) @@ -213,7 +215,13 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } results := reflect.New(field.Struct.Type).Elem() - rows, err := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value, conditions...).Rows() + + db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()) + preloadJoinDB := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value) + if len(conditions) > 0 { + preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) + } + rows, err := preloadJoinDB.Rows() if scope.Err(err) != nil { return @@ -264,8 +272,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey)) } - results = reflect.Append(results, elem) - + if isPtr { + results = reflect.Append(results, elem.Addr()) + } else { + results = reflect.Append(results, elem) + } } if scope.IndirectValue().Kind() == reflect.Slice { diff --git a/preload_test.go b/preload_test.go index db37f37d..043e24d6 100644 --- a/preload_test.go +++ b/preload_test.go @@ -612,7 +612,7 @@ func TestManyToManyPreload(t *testing.T) { Level2 struct { ID uint `gorm:"primary_key;"` Value string - Level1s []Level1 `gorm:"many2many:levels;"` + Level1s []*Level1 `gorm:"many2many:levels;"` } ) @@ -623,7 +623,7 @@ func TestManyToManyPreload(t *testing.T) { panic(err) } - want := Level2{Value: "Bob", Level1s: []Level1{ + want := Level2{Value: "Bob", Level1s: []*Level1{ {Value: "ru"}, {Value: "en"}, }} @@ -631,7 +631,7 @@ func TestManyToManyPreload(t *testing.T) { panic(err) } - want2 := Level2{Value: "Tom", Level1s: []Level1{ + want2 := Level2{Value: "Tom", Level1s: []*Level1{ {Value: "zh"}, {Value: "de"}, }} @@ -665,6 +665,22 @@ func TestManyToManyPreload(t *testing.T) { if !reflect.DeepEqual(got3, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } } func toJSONString(v interface{}) []byte {