Add association tests for composite primary key
This commit is contained in:
		
							parent
							
								
									934e97b018
								
							
						
					
					
						commit
						9455215e61
					
				@ -143,7 +143,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
 | 
			
		||||
			var newPrimaryKeys [][]interface{}
 | 
			
		||||
			var associationForeignFieldNames []string
 | 
			
		||||
 | 
			
		||||
			if relationship.Kind == "many2many" {
 | 
			
		||||
			if relationship.Kind == "many_to_many" {
 | 
			
		||||
				// If many to many relations, get it from foreign key
 | 
			
		||||
				associationForeignFieldNames = relationship.AssociationForeignFieldNames
 | 
			
		||||
			} else {
 | 
			
		||||
@ -156,6 +156,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
 | 
			
		||||
 | 
			
		||||
			if len(newPrimaryKeys) > 0 {
 | 
			
		||||
				sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
 | 
			
		||||
				newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
 | 
			
		||||
 | 
			
		||||
@ -85,15 +85,14 @@ func (structField *StructField) clone() *StructField {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Relationship struct {
 | 
			
		||||
	Kind                               string
 | 
			
		||||
	PolymorphicType                    string
 | 
			
		||||
	PolymorphicDBName                  string
 | 
			
		||||
	ForeignFieldNames                  []string
 | 
			
		||||
	ForeignDBNames                     []string
 | 
			
		||||
	AssociationForeignFieldNames       []string
 | 
			
		||||
	AssociationForeignStructFieldNames []string
 | 
			
		||||
	AssociationForeignDBNames          []string
 | 
			
		||||
	JoinTableHandler                   JoinTableHandlerInterface
 | 
			
		||||
	Kind                         string
 | 
			
		||||
	PolymorphicType              string
 | 
			
		||||
	PolymorphicDBName            string
 | 
			
		||||
	ForeignFieldNames            []string
 | 
			
		||||
	ForeignDBNames               []string
 | 
			
		||||
	AssociationForeignFieldNames []string
 | 
			
		||||
	AssociationForeignDBNames    []string
 | 
			
		||||
	JoinTableHandler             JoinTableHandlerInterface
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
			
		||||
@ -263,7 +262,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
			
		||||
								for _, name := range associationForeignKeys {
 | 
			
		||||
									if field, ok := toScope.FieldByName(name); ok {
 | 
			
		||||
										relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
 | 
			
		||||
										relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
 | 
			
		||||
										joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
 | 
			
		||||
										relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
 | 
			
		||||
									}
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,9 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -20,10 +21,21 @@ type Tag struct {
 | 
			
		||||
	Value  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func compareTags(tags []Tag, contents []string) bool {
 | 
			
		||||
	var tagContents []string
 | 
			
		||||
	for _, tag := range tags {
 | 
			
		||||
		tagContents = append(tagContents, tag.Value)
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(tagContents)
 | 
			
		||||
	sort.Strings(contents)
 | 
			
		||||
	return reflect.DeepEqual(tagContents, contents)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
 | 
			
		||||
		DB.Exec(fmt.Sprintf("drop table blog_tags;"))
 | 
			
		||||
		DB.AutoMigrate(&Blog{}, &Tag{})
 | 
			
		||||
		DB.DropTable(&Blog{}, &Tag{})
 | 
			
		||||
		DB.DropTable("blog_tags")
 | 
			
		||||
		DB.CreateTable(&Blog{}, &Tag{})
 | 
			
		||||
		blog := Blog{
 | 
			
		||||
			Locale:  "ZH",
 | 
			
		||||
			Subject: "subject",
 | 
			
		||||
@ -35,12 +47,70 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Save(&blog)
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})
 | 
			
		||||
		if !compareTags(blog.Tags, []string{"tag1", "tag2"}) {
 | 
			
		||||
			t.Errorf("Blog should has two tags")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Append
 | 
			
		||||
		var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Append([]*Tag{tag3})
 | 
			
		||||
		if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 3 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tags []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags, "Tags")
 | 
			
		||||
		if len(tags) != 3 {
 | 
			
		||||
			t.Errorf("should found 3 tags with blog")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var blog1 Blog
 | 
			
		||||
		DB.Preload("Tags").Find(&blog1)
 | 
			
		||||
		if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Preload many2many relations")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Replace
 | 
			
		||||
		var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
 | 
			
		||||
		var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Replace(tag5, tag6)
 | 
			
		||||
		var tags2 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags2, "Tags")
 | 
			
		||||
		if !compareTags(tags2, []string{"tag5", "tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 2 tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 2 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Delete
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Delete(tag5)
 | 
			
		||||
		var tags3 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags3, "Tags")
 | 
			
		||||
		if !compareTags(tags3, []string{"tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 1 tags after Delete")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 1 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Delete")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Delete(tag3)
 | 
			
		||||
		var tags4 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags4, "Tags")
 | 
			
		||||
		if !compareTags(tags4, []string{"tag6"}) {
 | 
			
		||||
			t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Clear
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Clear()
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 0 {
 | 
			
		||||
			t.Errorf("All tags should be cleared")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user