fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 23:49:21 2022 +0800
    Refactor #5140
commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be
Author: a631807682 <631807682@qq.com>
Date:   Sun Mar 13 12:52:08 2022 +0800
    test: add test for LoadOrStoreVisitMap
commit 9d5c68e41000fd15dea124797dd5f2656bf6b304
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:33:47 2022 +0800
    chore: add more comment
commit bfffefb179c883389b72bef8f04469c0a8418043
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:28:48 2022 +0800
    fix: should check values has been saved instead of rel.Name
commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:48:01 2022 +0800
    chore: go lint
commit fe4715c5bd4ac28950c97dded9848710d8becb88
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:27:24 2022 +0800
    chore: add test comment
commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:22:33 2022 +0800
    fix: circular reference save
			
			
This commit is contained in:
		
							parent
							
								
									2990790fbc
								
							
						
					
					
						commit
						9b9ae325bb
					
				| @ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { | ||||
| 					} | ||||
| 
 | ||||
| 					if elems.Len() > 0 { | ||||
| 						if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { | ||||
| 						if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { | ||||
| 							for i := 0; i < elems.Len(); i++ { | ||||
| 								setupReferences(objs[i], elems.Index(i)) | ||||
| 							} | ||||
| @ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { | ||||
| 							rv = rv.Addr() | ||||
| 						} | ||||
| 
 | ||||
| 						if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { | ||||
| 						if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { | ||||
| 							setupReferences(db.Statement.ReflectValue, rv) | ||||
| 						} | ||||
| 					} | ||||
| @ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||
| 						} | ||||
| 
 | ||||
| 						saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) | ||||
| 						saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) | ||||
| 					} | ||||
| 				case reflect.Struct: | ||||
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { | ||||
| @ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||
| 						} | ||||
| 
 | ||||
| 						saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) | ||||
| 						saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| @ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||
| 					} | ||||
| 
 | ||||
| 					saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) | ||||
| 					saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| @ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | ||||
| 				// optimize elems of reflect value length
 | ||||
| 				if elemLen := elems.Len(); elemLen > 0 { | ||||
| 					if v, ok := selectColumns[rel.Name+".*"]; !ok || v { | ||||
| 						saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) | ||||
| 						saveAssociations(db, rel, elems, selectColumns, restricted, nil) | ||||
| 					} | ||||
| 
 | ||||
| 					for i := 0; i < elemLen; i++ { | ||||
| @ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { | ||||
| func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { | ||||
| 	// stop save association loop
 | ||||
| 	if checkAssociationsSaved(db, rValues) { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	var ( | ||||
| 		selects, omits []string | ||||
| 		onConflict     = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) | ||||
| 		refName        = rel.Name + "." | ||||
| 		values         = rValues.Interface() | ||||
| 	) | ||||
| 
 | ||||
| 	for name, ok := range selectColumns { | ||||
| @ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, | ||||
| 
 | ||||
| 	return db.AddError(tx.Create(values).Error) | ||||
| } | ||||
| 
 | ||||
| // check association values has been saved
 | ||||
| // if values kind is Struct, check it has been saved
 | ||||
| // if values kind is Slice/Array, check all items have been saved
 | ||||
| var visitMapStoreKey = "gorm:saved_association_map" | ||||
| 
 | ||||
| func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { | ||||
| 	if visit, ok := db.Get(visitMapStoreKey); ok { | ||||
| 		if v, ok := visit.(*visitMap); ok { | ||||
| 			if loadOrStoreVisitMap(v, values) { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		vistMap := make(visitMap) | ||||
| 		loadOrStoreVisitMap(&vistMap, values) | ||||
| 		db.Set(visitMapStoreKey, &vistMap) | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| @ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) { | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type visitMap = map[reflect.Value]bool | ||||
| 
 | ||||
| // Check if circular values, return true if loaded
 | ||||
| func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { | ||||
| 	if v.Kind() == reflect.Ptr { | ||||
| 		v = v.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	switch v.Kind() { | ||||
| 	case reflect.Slice, reflect.Array: | ||||
| 		loaded = true | ||||
| 		for i := 0; i < v.Len(); i++ { | ||||
| 			if !loadOrStoreVisitMap(vistMap, v.Index(i)) { | ||||
| 				loaded = false | ||||
| 			} | ||||
| 		} | ||||
| 	case reflect.Struct, reflect.Interface: | ||||
| 		if v.CanAddr() { | ||||
| 			p := v.Addr() | ||||
| 			if _, ok := (*vistMap)[p]; ok { | ||||
| 				return true | ||||
| 			} | ||||
| 			(*vistMap)[p] = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
|  | ||||
							
								
								
									
										36
									
								
								callbacks/visit_map_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								callbacks/visit_map_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestLoadOrStoreVisitMap(t *testing.T) { | ||||
| 	var vm visitMap | ||||
| 	var loaded bool | ||||
| 	type testM struct { | ||||
| 		Name string | ||||
| 	} | ||||
| 
 | ||||
| 	t1 := testM{Name: "t1"} | ||||
| 	t2 := testM{Name: "t2"} | ||||
| 	t3 := testM{Name: "t3"} | ||||
| 
 | ||||
| 	vm = make(visitMap) | ||||
| 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { | ||||
| 		t.Fatalf("loaded should be false") | ||||
| 	} | ||||
| 
 | ||||
| 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { | ||||
| 		t.Fatalf("loaded should be true") | ||||
| 	} | ||||
| 
 | ||||
| 	// t1 already exist but t2 not
 | ||||
| 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { | ||||
| 		t.Fatalf("loaded should be false") | ||||
| 	} | ||||
| 
 | ||||
| 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { | ||||
| 		t.Fatalf("loaded should be true") | ||||
| 	} | ||||
| } | ||||
| @ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) { | ||||
| 		t.Errorf("Failed to preload AppliesToProduct") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSaveBelongsCircularReference(t *testing.T) { | ||||
| 	parent := Parent{} | ||||
| 	DB.Create(&parent) | ||||
| 
 | ||||
| 	child := Child{ParentID: &parent.ID, Parent: &parent} | ||||
| 	DB.Create(&child) | ||||
| 
 | ||||
| 	parent.FavChildID = child.ID | ||||
| 	parent.FavChild = &child | ||||
| 	DB.Save(&parent) | ||||
| 
 | ||||
| 	var parent1 Parent | ||||
| 	DB.First(&parent1, parent.ID) | ||||
| 	AssertObjEqual(t, parent, parent1, "ID", "FavChildID") | ||||
| 
 | ||||
| 	// Save and Updates is the same
 | ||||
| 	DB.Updates(&parent) | ||||
| 	DB.First(&parent1, parent.ID) | ||||
| 	AssertObjEqual(t, parent, parent1, "ID", "FavChildID") | ||||
| } | ||||
| 
 | ||||
| func TestSaveHasManyCircularReference(t *testing.T) { | ||||
| 	parent := Parent{} | ||||
| 	DB.Create(&parent) | ||||
| 
 | ||||
| 	child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} | ||||
| 	child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} | ||||
| 
 | ||||
| 	parent.Children = []*Child{&child, &child1} | ||||
| 	DB.Save(&parent) | ||||
| 
 | ||||
| 	var children []*Child | ||||
| 	DB.Where("parent_id = ?", parent.ID).Find(&children) | ||||
| 	if len(children) != len(parent.Children) || | ||||
| 		children[0].ID != parent.Children[0].ID || | ||||
| 		children[1].ID != parent.Children[1].ID { | ||||
| 		t.Errorf("circular reference children save not equal children:%v parent.Children:%v", | ||||
| 			children, parent.Children) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { | ||||
| 
 | ||||
| func RunMigrations() { | ||||
| 	var err error | ||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} | ||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) | ||||
| 
 | ||||
|  | ||||
| @ -80,3 +80,17 @@ type Order struct { | ||||
| 	Coupon   *Coupon | ||||
| 	CouponID string | ||||
| } | ||||
| 
 | ||||
| type Parent struct { | ||||
| 	gorm.Model | ||||
| 	FavChildID uint | ||||
| 	FavChild   *Child | ||||
| 	Children   []*Child | ||||
| } | ||||
| 
 | ||||
| type Child struct { | ||||
| 	gorm.Model | ||||
| 	Name     string | ||||
| 	ParentID *uint | ||||
| 	Parent   *Parent | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 chenrui
						chenrui