Merge branch 'kimiby-many2many_preload'
This commit is contained in:
commit
e2689d335b
@ -13,6 +13,7 @@ type JoinTableHandlerInterface interface {
|
|||||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
|
PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
SourceForeignKeys() []JoinTableForeignKey
|
SourceForeignKeys() []JoinTableForeignKey
|
||||||
DestinationForeignKeys() []JoinTableForeignKey
|
DestinationForeignKeys() []JoinTableForeignKey
|
||||||
}
|
}
|
||||||
@ -137,8 +138,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||||||
var queryConditions []string
|
var queryConditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
if s.Source.ModelType == modelType {
|
if s.Source.ModelType == modelType {
|
||||||
|
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
|
||||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,3 +154,34 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||||
|
quotedTable := handler.Table(db)
|
||||||
|
|
||||||
|
scope := db.NewScope(source)
|
||||||
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
var joinConditions []string
|
||||||
|
var queryConditions []string
|
||||||
|
var values []interface{}
|
||||||
|
if s.Source.ModelType == modelType {
|
||||||
|
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName))
|
||||||
|
|
||||||
|
keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name})
|
||||||
|
values = append(values, toQueryValues(keys))
|
||||||
|
|
||||||
|
queryConditions = append(queryConditions, condString)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||||
|
Where(strings.Join(queryConditions, " AND "), values...)
|
||||||
|
} else {
|
||||||
|
db.Error = errors.New("wrong source type for join table handler")
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -62,14 +62,16 @@ func (structField *StructField) clone() *StructField {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Relationship struct {
|
type Relationship struct {
|
||||||
Kind string
|
Kind string
|
||||||
PolymorphicType string
|
PolymorphicType string
|
||||||
PolymorphicDBName string
|
PolymorphicDBName string
|
||||||
ForeignFieldNames []string
|
ForeignFieldNames []string
|
||||||
ForeignDBNames []string
|
ForeignStructFieldNames []string
|
||||||
AssociationForeignFieldNames []string
|
ForeignDBNames []string
|
||||||
AssociationForeignDBNames []string
|
AssociationForeignFieldNames []string
|
||||||
JoinTableHandler JoinTableHandlerInterface
|
AssociationForeignStructFieldNames []string
|
||||||
|
AssociationForeignDBNames []string
|
||||||
|
JoinTableHandler JoinTableHandlerInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
@ -224,6 +226,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
for _, foreignKey := range foreignKeys {
|
for _, foreignKey := range foreignKeys {
|
||||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
||||||
|
relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name)
|
||||||
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
@ -242,6 +245,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
for _, name := range associationForeignKeys {
|
for _, name := range associationForeignKeys {
|
||||||
if field, ok := toScope.FieldByName(name); ok {
|
if field, ok := toScope.FieldByName(name); ok {
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||||
|
relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
|
||||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
|
139
preload.go
139
preload.go
@ -10,11 +10,13 @@ import (
|
|||||||
|
|
||||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
||||||
for _, column := range columns {
|
for _, column := range columns {
|
||||||
result := reflect.Indirect(value).FieldByName(column).Interface()
|
if reflect.Indirect(value).FieldByName(column).IsValid() {
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
result := reflect.Indirect(value).FieldByName(column).Interface()
|
||||||
result, _ = r.Value()
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
|
result, _ = r.Value()
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
}
|
}
|
||||||
results = append(results, result)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -61,7 +63,7 @@ func Preload(scope *Scope) {
|
|||||||
case "belongs_to":
|
case "belongs_to":
|
||||||
currentScope.handleBelongsToPreload(field, conditions)
|
currentScope.handleBelongsToPreload(field, conditions)
|
||||||
case "many_to_many":
|
case "many_to_many":
|
||||||
fallthrough
|
currentScope.handleHasManyToManyPreload(field, conditions)
|
||||||
default:
|
default:
|
||||||
currentScope.Err(errors.New("not supported relation"))
|
currentScope.Err(errors.New("not supported relation"))
|
||||||
}
|
}
|
||||||
@ -189,6 +191,133 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
joinTableHandler := relation.JoinTableHandler
|
||||||
|
destType := field.StructField.Struct.Type.Elem()
|
||||||
|
var isPtr bool
|
||||||
|
if destType.Kind() == reflect.Ptr {
|
||||||
|
isPtr = true
|
||||||
|
destType = destType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var destKeys []string
|
||||||
|
var sourceKeys []string
|
||||||
|
var linkHash = make(map[string][]string)
|
||||||
|
|
||||||
|
for _, key := range joinTableHandler.DestinationForeignKeys() {
|
||||||
|
destKeys = append(destKeys, key.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||||
|
sourceKeys = append(sourceKeys, key.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := reflect.New(field.Struct.Type).Elem()
|
||||||
|
|
||||||
|
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName())
|
||||||
|
preloadJoinDB := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value)
|
||||||
|
if len(conditions) > 0 {
|
||||||
|
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
|
||||||
|
}
|
||||||
|
rows, err := preloadJoinDB.Rows()
|
||||||
|
|
||||||
|
if scope.Err(err) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
columns, _ := rows.Columns()
|
||||||
|
for rows.Next() {
|
||||||
|
elem := reflect.New(destType).Elem()
|
||||||
|
var values = make([]interface{}, len(columns))
|
||||||
|
|
||||||
|
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||||
|
|
||||||
|
for index, column := range columns {
|
||||||
|
if field, ok := fields[column]; ok {
|
||||||
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
|
values[index] = field.Field.Addr().Interface()
|
||||||
|
} else {
|
||||||
|
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var i interface{}
|
||||||
|
values[index] = &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Err(rows.Scan(values...))
|
||||||
|
|
||||||
|
var destKey []interface{}
|
||||||
|
var sourceKey []interface{}
|
||||||
|
|
||||||
|
for index, column := range columns {
|
||||||
|
value := values[index]
|
||||||
|
if field, ok := fields[column]; ok {
|
||||||
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
|
field.Field.Set(reflect.ValueOf(value).Elem())
|
||||||
|
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
|
||||||
|
field.Field.Set(v)
|
||||||
|
}
|
||||||
|
} else if strInSlice(column, destKeys) {
|
||||||
|
destKey = append(destKey, *(value.(*interface{})))
|
||||||
|
} else if strInSlice(column, sourceKeys) {
|
||||||
|
sourceKey = append(sourceKey, *(value.(*interface{})))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(destKey) != 0 && len(sourceKey) != 0 {
|
||||||
|
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
if isPtr {
|
||||||
|
results = reflect.Append(results, elem.Addr())
|
||||||
|
} else {
|
||||||
|
results = reflect.Append(results, elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
|
objects := scope.IndirectValue()
|
||||||
|
for j := 0; j < objects.Len(); j++ {
|
||||||
|
var checked []string
|
||||||
|
|
||||||
|
object := reflect.Indirect(objects.Index(j))
|
||||||
|
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
|
||||||
|
|
||||||
|
for i := 0; i < results.Len(); i++ {
|
||||||
|
result := results.Index(i)
|
||||||
|
value := getRealValue(result, relation.ForeignStructFieldNames)
|
||||||
|
|
||||||
|
if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) {
|
||||||
|
f := object.FieldByName(field.Name)
|
||||||
|
f.Set(reflect.Append(f, result))
|
||||||
|
checked = append(checked, toString(value))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
object := scope.IndirectValue()
|
||||||
|
var checked []string
|
||||||
|
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
|
||||||
|
|
||||||
|
for i := 0; i < results.Len(); i++ {
|
||||||
|
result := results.Index(i)
|
||||||
|
value := getRealValue(result, relation.ForeignStructFieldNames)
|
||||||
|
|
||||||
|
if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) {
|
||||||
|
f := object.FieldByName(field.Name)
|
||||||
|
f.Set(reflect.Append(f, result))
|
||||||
|
checked = append(checked, toString(value))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
|
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
|
||||||
values := scope.IndirectValue()
|
values := scope.IndirectValue()
|
||||||
switch values.Kind() {
|
switch values.Kind() {
|
||||||
|
@ -603,6 +603,86 @@ func TestNestedPreload9(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManyToManyPreload(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Level1 struct {
|
||||||
|
ID uint `gorm:"primary_key;"`
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
Level2 struct {
|
||||||
|
ID uint `gorm:"primary_key;"`
|
||||||
|
Value string
|
||||||
|
Level1s []*Level1 `gorm:"many2many:levels;"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.DropTableIfExists(&Level2{})
|
||||||
|
DB.DropTableIfExists(&Level1{})
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := Level2{Value: "Bob", Level1s: []*Level1{
|
||||||
|
{Value: "ru"},
|
||||||
|
{Value: "en"},
|
||||||
|
}}
|
||||||
|
if err := DB.Save(&want).Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want2 := Level2{Value: "Tom", Level1s: []*Level1{
|
||||||
|
{Value: "zh"},
|
||||||
|
{Value: "de"},
|
||||||
|
}}
|
||||||
|
if err := DB.Save(&want2).Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Level2
|
||||||
|
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got2 Level2
|
||||||
|
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got2, want2) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got3 []Level2
|
||||||
|
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got4 []Level2
|
||||||
|
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ruLevel1 Level1
|
||||||
|
var zhLevel1 Level1
|
||||||
|
DB.First(&ruLevel1, "value = ?", "ru")
|
||||||
|
DB.First(&zhLevel1, "value = ?", "zh")
|
||||||
|
|
||||||
|
got.Level1s = []*Level1{&ruLevel1}
|
||||||
|
got2.Level1s = []*Level1{&zhLevel1}
|
||||||
|
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toJSONString(v interface{}) []byte {
|
func toJSONString(v interface{}) []byte {
|
||||||
r, _ := json.MarshalIndent(v, "", " ")
|
r, _ := json.MarshalIndent(v, "", " ")
|
||||||
return r
|
return r
|
||||||
|
@ -71,3 +71,16 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
|||||||
}
|
}
|
||||||
return attrs
|
return attrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toString(a interface{}) string {
|
||||||
|
return fmt.Sprintf("%v", a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func strInSlice(a string, list []string) bool {
|
||||||
|
for _, b := range list {
|
||||||
|
if b == a {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user