From bb6023b24bf0da774c39a6e7e8c9786126270383 Mon Sep 17 00:00:00 2001 From: Yang Qing Date: Fri, 5 Jul 2019 16:46:51 +0800 Subject: [PATCH] support polymorphic to many2many relation --- callback_save.go | 4 +- join_table_handler.go | 45 ++++++- model_struct.go | 18 ++- polymorphic_many2many_test.go | 243 ++++++++++++++++++++++++++++++++++ scope.go | 55 ++++++-- 5 files changed, 349 insertions(+), 16 deletions(-) create mode 100644 polymorphic_many2many_test.go diff --git a/callback_save.go b/callback_save.go index 3b4e0589..ee3f9b10 100644 --- a/callback_save.go +++ b/callback_save.go @@ -120,7 +120,9 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + if relationship.Kind != "many_to_many" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } } } diff --git a/join_table_handler.go b/join_table_handler.go index a036d46d..bb221376 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -42,6 +42,10 @@ type JoinTableHandler struct { TableName string `sql:"-"` Source JoinTableSource `sql:"-"` Destination JoinTableSource `sql:"-"` + // polymorphic + polymorphicType string + polymorphicDBName string + polymorphicValue string } // SourceForeignKeys return source foreign keys @@ -75,6 +79,11 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s AssociationDBName: dbName, }) } + + // polymorphic + s.polymorphicType = relationship.PolymorphicType + s.polymorphicDBName = relationship.PolymorphicDBName + s.polymorphicValue = relationship.PolymorphicValue } // Table return join table's table name @@ -122,6 +131,14 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source values = append(values, value) } + // polymorphic type field + if s.polymorphicType != "" { + assignColumns = append(assignColumns, scope.Quote(s.polymorphicDBName)) + binVars = append(binVars, `?`) + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(s.polymorphicDBName))) + values = append(values, s.polymorphicValue) + } + for _, value := range values { values = append(values, value) } @@ -188,18 +205,34 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) var condString string - if len(foreignFieldValues) > 0 { + // polymorphic + if s.polymorphicType != "" { + var idDBName = ToColumnName(strings.TrimSuffix(s.polymorphicType, "Type") + "ID") var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) + for index := 0; index < len(foreignDBNames); index++ { + quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+idDBName) } condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) + condString = fmt.Sprintf("%v AND %v", condString, fmt.Sprintf("%v=%v", toQueryCondition(scope, []string{tableName + "." + s.polymorphicDBName}), "?")) + interfaceSlice := make([]interface{}, 1) + interfaceSlice[0] = s.polymorphicValue + foreignFieldValues = append(foreignFieldValues, interfaceSlice) } else { - condString = fmt.Sprintf("1 <> 1") + if len(foreignFieldValues) > 0 { + var quotedForeignDBNames []string + for _, dbName := range foreignDBNames { + quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) + } + + condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) + + keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) + values = append(values, toQueryValues(keys)) + } else { + condString = fmt.Sprintf("1 <> 1") + } } return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). diff --git a/model_struct.go b/model_struct.go index 5234b287..8c593cd7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -291,6 +291,22 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.Kind = "many_to_many" { // Foreign Keys for Source + // Deal with POLYMORPHIC tag + var associationType = reflectType.Name() + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { + // Post has many tags, Video has many tags too, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('posts') as foreign key + associationType = polymorphic + relationship.PolymorphicType = polymorphic + "Type" + relationship.PolymorphicDBName = ToColumnName(polymorphic + "Type") + // if Post has multiple set of tags set name of the set (instead of default 'posts') + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } + } + joinTableDBNames := []string{} if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { @@ -314,7 +330,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(associationType) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } diff --git a/polymorphic_many2many_test.go b/polymorphic_many2many_test.go new file mode 100644 index 00000000..3e81f001 --- /dev/null +++ b/polymorphic_many2many_test.go @@ -0,0 +1,243 @@ +package gorm_test + +import ( + "testing" +) + +// DB Tables structure: + +// simple_posts +// id - integer +// name - string + +// simple_videos +// id - integer +// name - string + +// simple_tags +// id - integer +// name - string + +// taggables +// tag_id - integer +// taggable_id - integer +// taggable_type - string + +type SimplePost struct { + Id int + Name string + Tags []*SimpleTag `gorm:"many2many:taggables;polymorphic:taggable;"` +} + +type SimpleVideo struct { + Id int + Name string + Tags []*SimpleTag `gorm:"many2many:taggables;polymorphic:taggable;polymorphic_value:video"` +} + +type SimpleTag struct { + Id int + Name string +} + +func TestPolymorphicMany2many(t *testing.T) { + DB.DropTableIfExists(&SimpleTag{}, &SimplePost{}, &SimpleVideo{}, "taggables") + DB.AutoMigrate(&SimpleTag{}, &SimplePost{}, &SimpleVideo{}) + + DB.LogMode(true) + + tag1 := SimpleTag{Name: "hero"} + tag2 := SimpleTag{Name: "bloods"} + tag3 := SimpleTag{Name: "frendship"} + tag4 := SimpleTag{Name: "romantic"} + tag5 := SimpleTag{Name: "gruesome"} + + // Test Save associations together + post1 := SimplePost{Name: "First Post", Tags: []*SimpleTag{&tag1, &tag2, &tag3}} + err := DB.Save(&post1).Error + if err != nil { + t.Errorf("Data init fail : %v \n", err) + } + count := DB.Model(&post1).Association("Tags").Count() + if count != 3 { + t.Errorf("Post1 should have 3 associations to tags, but got %d", count) + } + + post2 := SimplePost{Name: "Second Post"} + video1 := SimpleVideo{Name: "First Video"} + video2 := SimpleVideo{Name: "Second Video"} + DB.Save(&post2).Save(&video1).Save(&video2) + + // Test Append + DB.Model(&post2).Association("Tags").Append(&tag2, &tag4) + DB.Model(&video1).Association("Tags").Append(&tag1, &tag2, &tag5) + DB.Model(&video2).Association("Tags").Append(&tag2, &tag3, &tag4) + + count = DB.Model(&post2).Association("Tags").Count() + if count != 2 { + t.Errorf("Post2 should have 2 associations to tags, but got %d", count) + } + + exists := false + for _, t := range post2.Tags { + if exists = t.Name == "bloods"; exists { + break + } + } + + if !exists { + t.Errorf("Post2 should have a tag named 'bloods'") + } + + count = DB.Model(&video1).Association("Tags").Count() + if count != 3 { + t.Errorf("Video1 should have 3 associations to tags, but got %d", count) + } + + // Test Replace + tag6 := SimpleTag{Name: "tag6"} + DB.Model(&post2).Association("Tags").Replace(&tag5, &tag4, &tag6) + tag2Exists := false + tag4Exists := false + tag5Exists := false + tag6Exists := false + for _, t := range post2.Tags { + if !tag2Exists { + tag2Exists = t.Name == "bloods" + } + if !tag4Exists { + tag4Exists = t.Name == "romantic" + } + if !tag5Exists { + tag5Exists = t.Name == "gruesome" + } + if !tag6Exists { + tag6Exists = t.Name == "tag6" + } + } + if tag2Exists { + t.Errorf("Post2 should NOT HAVE a tag named 'bloods'") + } + if !tag4Exists { + t.Errorf("Post2 should HAVE a tag named 'romantic'") + } + if !tag5Exists { + t.Errorf("Post2 should HAVE a tag named 'gruesome'") + } + if !tag6Exists { + t.Errorf("Post2 should HAVE a tag named 'tag6'") + } + + // Test Delete + DB.Model(&post1).Association("Tags").Delete(&tag1) + count = DB.Model(&post2).Association("Tags").Count() + if count != 3 { + t.Errorf("Post1 should be removed 1 association, should remain 3, but %d", count) + } + + // Test Clear + count = DB.Model(&video2).Association("Tags").Count() + if count != 3 { + t.Errorf("Video2 should have 3 association, but got %d", count) + } + DB.Model(&video2).Association("Tags").Clear() + count = DB.Model(&video2).Association("Tags").Count() + if count != 0 { + t.Errorf("Video2 should be removed all association, but got %d", count) + } + + DB.LogMode(false) +} + +func TestNamedPolymorphicMany2many(t *testing.T) { + DB.DropTableIfExists(&SimpleTag{}, &SimplePost{}, &SimpleVideo{}, "taggables") + DB.AutoMigrate(&SimpleTag{}, &SimplePost{}, &SimpleVideo{}) + + DB.LogMode(true) + + tag1 := SimpleTag{Name: "hero"} + tag2 := SimpleTag{Name: "bloods"} + tag3 := SimpleTag{Name: "frendship"} + tag4 := SimpleTag{Name: "romantic"} + tag5 := SimpleTag{Name: "gruesome"} + + // Test Save associations together + post1 := SimplePost{Name: "First Post", Tags: []*SimpleTag{&tag1, &tag2, &tag3}} + err := DB.Save(&post1).Error + if err != nil { + t.Errorf("Data init fail : %v \n", err) + } + count := DB.Model(&post1).Association("Tags").Count() + if count != 3 { + t.Errorf("Post1 should have 3 associations to tags, but got %d", count) + } + + post2 := SimplePost{Name: "Second Post"} + video1 := SimpleVideo{Name: "First Video"} + video2 := SimpleVideo{Name: "Second Video"} + DB.Save(&post2).Save(&video1).Save(&video2) + + // Test Append + DB.Model(&post2).Association("Tags").Append(&tag2, &tag4) + DB.Model(&video1).Association("Tags").Append(&tag1, &tag2, &tag5) + DB.Model(&video2).Association("Tags").Append(&tag2, &tag3, &tag4) + + count = DB.Model(&video1).Association("Tags").Count() + if count != 3 { + t.Errorf("Video1 should have 3 associations to tags, but got %d", count) + } + + exists := false + for _, t := range video1.Tags { + if exists = t.Name == "bloods"; exists { + break + } + } + + if !exists { + t.Errorf("Video1 should have a tag named 'bloods'") + } + + // Test Replace + tag6 := SimpleTag{Name: "tag6"} + DB.Model(&video1).Association("Tags").Replace(&tag2, &tag4, &tag6) + tag2Exists := false + tag4Exists := false + tag5Exists := false + tag6Exists := false + for _, t := range video1.Tags { + if !tag2Exists { + tag2Exists = t.Name == "bloods" + } + if !tag4Exists { + tag4Exists = t.Name == "romantic" + } + if !tag5Exists { + tag5Exists = t.Name == "gruesome" + } + if !tag6Exists { + tag6Exists = t.Name == "tag6" + } + } + if !tag2Exists { + t.Errorf("Video1 should HAVE a tag named 'bloods'") + } + if !tag4Exists { + t.Errorf("Video1 should HAVE a tag named 'romantic'") + } + if tag5Exists { + t.Errorf("Video1 should NOT HAVE a tag named 'gruesome'") + } + if !tag6Exists { + t.Errorf("Video1 should HAVE a tag named 'tag6'") + } + + // Test Delete + DB.Model(&video1).Association("Tags").Delete(&tag2) + count = DB.Model(&video1).Association("Tags").Count() + if count != 2 { + t.Errorf("video1 should be removed 1 association, should remain 2, but %d", count) + } + + DB.LogMode(false) +} diff --git a/scope.go b/scope.go index 541fe522..953eafaf 100644 --- a/scope.go +++ b/scope.go @@ -1151,14 +1151,53 @@ func (scope *Scope) createJoinTable(field *StructField) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) + + if relationship.PolymorphicType == "" { + for idx, fieldName := range relationship.ForeignFieldNames { + if field, ok := scope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) + primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) + } + } + } else { + // Deal with POLYMORPHIC tag + // Create OwnerType & OwnerID columns in the middle table + mockStruct := struct { + mockField string + }{} + reflectType := reflect.ValueOf(mockStruct).Type() + fieldStruct := reflectType.Field(0) + field := &StructField{ + Struct: fieldStruct, + Name: fieldStruct.Name, + Names: []string{fieldStruct.Name}, + Tag: fieldStruct.Tag, + TagSettings: parseTagSetting(fieldStruct.Tag), + } + + quotedType := scope.Quote(relationship.PolymorphicDBName) + sqlTypes = append(sqlTypes, quotedType+" "+scope.Dialect().DataTypeOf(field)) + primaryKeys = append(primaryKeys, quotedType) + + morphIDName := ToColumnName(strings.TrimSuffix(relationship.PolymorphicType, "Type") + "ID") + if !scope.Dialect().HasColumn(joinTable, morphIDName) { + if len(relationship.ForeignFieldNames) > 0 { + foreignFieldName := relationship.ForeignFieldNames[0] + if field, ok := toScope.FieldByName(foreignFieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") + // 使用外键 ID 类型 + quotedID := scope.Quote(morphIDName) + sqlTypes = append(sqlTypes, quotedID+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) + primaryKeys = append(primaryKeys, quotedID) + } + } } }