support polymorphic to many2many relation

This commit is contained in:
Yang Qing 2019-07-05 16:46:51 +08:00
parent 443f1de146
commit bb6023b24b
5 changed files with 349 additions and 16 deletions

View File

@ -120,9 +120,11 @@ func saveAfterAssociationsCallback(scope *Scope) {
}
if relationship.PolymorphicType != "" {
if relationship.Kind != "many_to_many" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
}
}
if newScope.PrimaryKeyZero() {
if autoCreate {

View File

@ -42,6 +42,10 @@ type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
// polymorphic
polymorphicType string
polymorphicDBName string
polymorphicValue string
}
// SourceForeignKeys return source foreign keys
@ -75,6 +79,11 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
AssociationDBName: dbName,
})
}
// polymorphic
s.polymorphicType = relationship.PolymorphicType
s.polymorphicDBName = relationship.PolymorphicDBName
s.polymorphicValue = relationship.PolymorphicValue
}
// Table return join table's table name
@ -122,6 +131,14 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source
values = append(values, value)
}
// polymorphic type field
if s.polymorphicType != "" {
assignColumns = append(assignColumns, scope.Quote(s.polymorphicDBName))
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(s.polymorphicDBName)))
values = append(values, s.polymorphicValue)
}
for _, value := range values {
values = append(values, value)
}
@ -188,6 +205,21 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
var condString string
// polymorphic
if s.polymorphicType != "" {
var idDBName = ToColumnName(strings.TrimSuffix(s.polymorphicType, "Type") + "ID")
var quotedForeignDBNames []string
for index := 0; index < len(foreignDBNames); index++ {
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+idDBName)
}
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
condString = fmt.Sprintf("%v AND %v", condString, fmt.Sprintf("%v=%v", toQueryCondition(scope, []string{tableName + "." + s.polymorphicDBName}), "?"))
interfaceSlice := make([]interface{}, 1)
interfaceSlice[0] = s.polymorphicValue
foreignFieldValues = append(foreignFieldValues, interfaceSlice)
} else {
if len(foreignFieldValues) > 0 {
var quotedForeignDBNames []string
for _, dbName := range foreignDBNames {
@ -201,6 +233,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
} else {
condString = fmt.Sprintf("1 <> 1")
}
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...)

View File

@ -291,6 +291,22 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
relationship.Kind = "many_to_many"
{ // Foreign Keys for Source
// Deal with POLYMORPHIC tag
var associationType = reflectType.Name()
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
// Post has many tags, Video has many tags too, tag polymorphic is Owner, then associationType is Owner
// Toy use OwnerID, OwnerType ('posts') as foreign key
associationType = polymorphic
relationship.PolymorphicType = polymorphic + "Type"
relationship.PolymorphicDBName = ToColumnName(polymorphic + "Type")
// if Post has multiple set of tags set name of the set (instead of default 'posts')
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
relationship.PolymorphicValue = value
} else {
relationship.PolymorphicValue = scope.TableName()
}
}
joinTableDBNames := []string{}
if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
@ -314,7 +330,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
// if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else {
defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
defaultJointableForeignKey := ToColumnName(associationType) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
}
}

View File

@ -0,0 +1,243 @@
package gorm_test
import (
"testing"
)
// DB Tables structure:
// simple_posts
// id - integer
// name - string
// simple_videos
// id - integer
// name - string
// simple_tags
// id - integer
// name - string
// taggables
// tag_id - integer
// taggable_id - integer
// taggable_type - string
type SimplePost struct {
Id int
Name string
Tags []*SimpleTag `gorm:"many2many:taggables;polymorphic:taggable;"`
}
type SimpleVideo struct {
Id int
Name string
Tags []*SimpleTag `gorm:"many2many:taggables;polymorphic:taggable;polymorphic_value:video"`
}
type SimpleTag struct {
Id int
Name string
}
func TestPolymorphicMany2many(t *testing.T) {
DB.DropTableIfExists(&SimpleTag{}, &SimplePost{}, &SimpleVideo{}, "taggables")
DB.AutoMigrate(&SimpleTag{}, &SimplePost{}, &SimpleVideo{})
DB.LogMode(true)
tag1 := SimpleTag{Name: "hero"}
tag2 := SimpleTag{Name: "bloods"}
tag3 := SimpleTag{Name: "frendship"}
tag4 := SimpleTag{Name: "romantic"}
tag5 := SimpleTag{Name: "gruesome"}
// Test Save associations together
post1 := SimplePost{Name: "First Post", Tags: []*SimpleTag{&tag1, &tag2, &tag3}}
err := DB.Save(&post1).Error
if err != nil {
t.Errorf("Data init fail : %v \n", err)
}
count := DB.Model(&post1).Association("Tags").Count()
if count != 3 {
t.Errorf("Post1 should have 3 associations to tags, but got %d", count)
}
post2 := SimplePost{Name: "Second Post"}
video1 := SimpleVideo{Name: "First Video"}
video2 := SimpleVideo{Name: "Second Video"}
DB.Save(&post2).Save(&video1).Save(&video2)
// Test Append
DB.Model(&post2).Association("Tags").Append(&tag2, &tag4)
DB.Model(&video1).Association("Tags").Append(&tag1, &tag2, &tag5)
DB.Model(&video2).Association("Tags").Append(&tag2, &tag3, &tag4)
count = DB.Model(&post2).Association("Tags").Count()
if count != 2 {
t.Errorf("Post2 should have 2 associations to tags, but got %d", count)
}
exists := false
for _, t := range post2.Tags {
if exists = t.Name == "bloods"; exists {
break
}
}
if !exists {
t.Errorf("Post2 should have a tag named 'bloods'")
}
count = DB.Model(&video1).Association("Tags").Count()
if count != 3 {
t.Errorf("Video1 should have 3 associations to tags, but got %d", count)
}
// Test Replace
tag6 := SimpleTag{Name: "tag6"}
DB.Model(&post2).Association("Tags").Replace(&tag5, &tag4, &tag6)
tag2Exists := false
tag4Exists := false
tag5Exists := false
tag6Exists := false
for _, t := range post2.Tags {
if !tag2Exists {
tag2Exists = t.Name == "bloods"
}
if !tag4Exists {
tag4Exists = t.Name == "romantic"
}
if !tag5Exists {
tag5Exists = t.Name == "gruesome"
}
if !tag6Exists {
tag6Exists = t.Name == "tag6"
}
}
if tag2Exists {
t.Errorf("Post2 should NOT HAVE a tag named 'bloods'")
}
if !tag4Exists {
t.Errorf("Post2 should HAVE a tag named 'romantic'")
}
if !tag5Exists {
t.Errorf("Post2 should HAVE a tag named 'gruesome'")
}
if !tag6Exists {
t.Errorf("Post2 should HAVE a tag named 'tag6'")
}
// Test Delete
DB.Model(&post1).Association("Tags").Delete(&tag1)
count = DB.Model(&post2).Association("Tags").Count()
if count != 3 {
t.Errorf("Post1 should be removed 1 association, should remain 3, but %d", count)
}
// Test Clear
count = DB.Model(&video2).Association("Tags").Count()
if count != 3 {
t.Errorf("Video2 should have 3 association, but got %d", count)
}
DB.Model(&video2).Association("Tags").Clear()
count = DB.Model(&video2).Association("Tags").Count()
if count != 0 {
t.Errorf("Video2 should be removed all association, but got %d", count)
}
DB.LogMode(false)
}
func TestNamedPolymorphicMany2many(t *testing.T) {
DB.DropTableIfExists(&SimpleTag{}, &SimplePost{}, &SimpleVideo{}, "taggables")
DB.AutoMigrate(&SimpleTag{}, &SimplePost{}, &SimpleVideo{})
DB.LogMode(true)
tag1 := SimpleTag{Name: "hero"}
tag2 := SimpleTag{Name: "bloods"}
tag3 := SimpleTag{Name: "frendship"}
tag4 := SimpleTag{Name: "romantic"}
tag5 := SimpleTag{Name: "gruesome"}
// Test Save associations together
post1 := SimplePost{Name: "First Post", Tags: []*SimpleTag{&tag1, &tag2, &tag3}}
err := DB.Save(&post1).Error
if err != nil {
t.Errorf("Data init fail : %v \n", err)
}
count := DB.Model(&post1).Association("Tags").Count()
if count != 3 {
t.Errorf("Post1 should have 3 associations to tags, but got %d", count)
}
post2 := SimplePost{Name: "Second Post"}
video1 := SimpleVideo{Name: "First Video"}
video2 := SimpleVideo{Name: "Second Video"}
DB.Save(&post2).Save(&video1).Save(&video2)
// Test Append
DB.Model(&post2).Association("Tags").Append(&tag2, &tag4)
DB.Model(&video1).Association("Tags").Append(&tag1, &tag2, &tag5)
DB.Model(&video2).Association("Tags").Append(&tag2, &tag3, &tag4)
count = DB.Model(&video1).Association("Tags").Count()
if count != 3 {
t.Errorf("Video1 should have 3 associations to tags, but got %d", count)
}
exists := false
for _, t := range video1.Tags {
if exists = t.Name == "bloods"; exists {
break
}
}
if !exists {
t.Errorf("Video1 should have a tag named 'bloods'")
}
// Test Replace
tag6 := SimpleTag{Name: "tag6"}
DB.Model(&video1).Association("Tags").Replace(&tag2, &tag4, &tag6)
tag2Exists := false
tag4Exists := false
tag5Exists := false
tag6Exists := false
for _, t := range video1.Tags {
if !tag2Exists {
tag2Exists = t.Name == "bloods"
}
if !tag4Exists {
tag4Exists = t.Name == "romantic"
}
if !tag5Exists {
tag5Exists = t.Name == "gruesome"
}
if !tag6Exists {
tag6Exists = t.Name == "tag6"
}
}
if !tag2Exists {
t.Errorf("Video1 should HAVE a tag named 'bloods'")
}
if !tag4Exists {
t.Errorf("Video1 should HAVE a tag named 'romantic'")
}
if tag5Exists {
t.Errorf("Video1 should NOT HAVE a tag named 'gruesome'")
}
if !tag6Exists {
t.Errorf("Video1 should HAVE a tag named 'tag6'")
}
// Test Delete
DB.Model(&video1).Association("Tags").Delete(&tag2)
count = DB.Model(&video1).Association("Tags").Count()
if count != 2 {
t.Errorf("video1 should be removed 1 association, should remain 2, but %d", count)
}
DB.LogMode(false)
}

View File

@ -1151,6 +1151,8 @@ func (scope *Scope) createJoinTable(field *StructField) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string
if relationship.PolymorphicType == "" {
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.FieldByName(fieldName); ok {
foreignKeyStruct := field.clone()
@ -1161,6 +1163,43 @@ func (scope *Scope) createJoinTable(field *StructField) {
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}
} else {
// Deal with POLYMORPHIC tag
// Create OwnerType & OwnerID columns in the middle table
mockStruct := struct {
mockField string
}{}
reflectType := reflect.ValueOf(mockStruct).Type()
fieldStruct := reflectType.Field(0)
field := &StructField{
Struct: fieldStruct,
Name: fieldStruct.Name,
Names: []string{fieldStruct.Name},
Tag: fieldStruct.Tag,
TagSettings: parseTagSetting(fieldStruct.Tag),
}
quotedType := scope.Quote(relationship.PolymorphicDBName)
sqlTypes = append(sqlTypes, quotedType+" "+scope.Dialect().DataTypeOf(field))
primaryKeys = append(primaryKeys, quotedType)
morphIDName := ToColumnName(strings.TrimSuffix(relationship.PolymorphicType, "Type") + "ID")
if !scope.Dialect().HasColumn(joinTable, morphIDName) {
if len(relationship.ForeignFieldNames) > 0 {
foreignFieldName := relationship.ForeignFieldNames[0]
if field, ok := toScope.FieldByName(foreignFieldName); ok {
foreignKeyStruct := field.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
// 使用外键 ID 类型
quotedID := scope.Quote(morphIDName)
sqlTypes = append(sqlTypes, quotedID+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, quotedID)
}
}
}
}
for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.FieldByName(fieldName); ok {