fix: association many2many duplicate elem (#5473)
* fix: association many2many duplicate elem * chore: gofumpt style
This commit is contained in:
		
							parent
							
								
									235c093bb9
								
							
						
					
					
						commit
						c74bc57add
					
				| @ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 					fieldType = reflect.PtrTo(fieldType) | ||||
| 				} | ||||
| 				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) | ||||
| 				distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) | ||||
| 				joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) | ||||
| 				objs := []reflect.Value{} | ||||
| 
 | ||||
| @ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 					joins = reflect.Append(joins, joinValue) | ||||
| 				} | ||||
| 
 | ||||
| 				identityMap := map[string]bool{} | ||||
| 				appendToElems := func(v reflect.Value) { | ||||
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { | ||||
| 						f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) | ||||
| 
 | ||||
| 						for i := 0; i < f.Len(); i++ { | ||||
| 							elem := f.Index(i) | ||||
| 
 | ||||
| 							objs = append(objs, v) | ||||
| 							if isPtr { | ||||
| 								elems = reflect.Append(elems, elem) | ||||
| 							} else { | ||||
| 								elems = reflect.Append(elems, elem.Addr()) | ||||
| 							if !isPtr { | ||||
| 								elem = elem.Addr() | ||||
| 							} | ||||
| 							objs = append(objs, v) | ||||
| 							elems = reflect.Append(elems, elem) | ||||
| 
 | ||||
| 							relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) | ||||
| 							for _, pf := range rel.FieldSchema.PrimaryFields { | ||||
| 								if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { | ||||
| 									relPrimaryValues = append(relPrimaryValues, pfv) | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							cacheKey := utils.ToStringKey(relPrimaryValues) | ||||
| 							if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { | ||||
| 								identityMap[cacheKey] = true | ||||
| 								distinctElems = reflect.Append(distinctElems, elem) | ||||
| 							} | ||||
| 
 | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| @ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 				// optimize elems of reflect value length
 | ||||
| 				if elemLen := elems.Len(); elemLen > 0 { | ||||
| 					if v, ok := selectColumns[rel.Name+".*"]; !ok || v { | ||||
| 						saveAssociations(db, rel, elems, selectColumns, restricted, nil) | ||||
| 						saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) | ||||
| 					} | ||||
| 
 | ||||
| 					for i := 0; i < elemLen; i++ { | ||||
|  | ||||
| @ -3,6 +3,7 @@ package tests_test | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| @ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { | ||||
| 	DB.Model(&users).Association("Team").Clear() | ||||
| 	AssertAssociationCount(t, users, "Team", 0, "After Clear") | ||||
| } | ||||
| 
 | ||||
| func TestDuplicateMany2ManyAssociation(t *testing.T) { | ||||
| 	user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ | ||||
| 		{Code: "TestDuplicateMany2ManyAssociation-language-1"}, | ||||
| 		{Code: "TestDuplicateMany2ManyAssociation-language-2"}, | ||||
| 	}} | ||||
| 
 | ||||
| 	user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ | ||||
| 		{Code: "TestDuplicateMany2ManyAssociation-language-1"}, | ||||
| 		{Code: "TestDuplicateMany2ManyAssociation-language-3"}, | ||||
| 	}} | ||||
| 	users := []*User{&user1, &user2} | ||||
| 	var err error | ||||
| 	err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error | ||||
| 	AssertEqual(t, nil, err) | ||||
| 
 | ||||
| 	var findUser1 User | ||||
| 	err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error | ||||
| 	AssertEqual(t, nil, err) | ||||
| 	AssertEqual(t, user1, findUser1) | ||||
| 
 | ||||
| 	var findUser2 User | ||||
| 	err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error | ||||
| 	AssertEqual(t, nil, err) | ||||
| 	AssertEqual(t, user2, findUser2) | ||||
| } | ||||
|  | ||||
| @ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) { | ||||
| 	value, ok = ct.DefaultValue() | ||||
| 	AssertEqual(t, "", value) | ||||
| 	AssertEqual(t, false, ok) | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func findColumnType(dest interface{}, columnName string) ( | ||||
| 	foundColumn gorm.ColumnType, err error) { | ||||
| 	foundColumn gorm.ColumnType, err error, | ||||
| ) { | ||||
| 	columnTypes, err := DB.Migrator().ColumnTypes(dest) | ||||
| 	if err != nil { | ||||
| 		err = fmt.Errorf("ColumnTypes err:%v", err) | ||||
|  | ||||
| @ -113,7 +113,6 @@ func TestSerializer(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	AssertEqual(t, result, data) | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestSerializerAssignFirstOrCreate(t *testing.T) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Cr
						Cr