fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 23:49:21 2022 +0800
    Refactor #5140
commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be
Author: a631807682 <631807682@qq.com>
Date:   Sun Mar 13 12:52:08 2022 +0800
    test: add test for LoadOrStoreVisitMap
commit 9d5c68e41000fd15dea124797dd5f2656bf6b304
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:33:47 2022 +0800
    chore: add more comment
commit bfffefb179c883389b72bef8f04469c0a8418043
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:28:48 2022 +0800
    fix: should check values has been saved instead of rel.Name
commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:48:01 2022 +0800
    chore: go lint
commit fe4715c5bd4ac28950c97dded9848710d8becb88
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:27:24 2022 +0800
    chore: add test comment
commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:22:33 2022 +0800
    fix: circular reference save
			
			
This commit is contained in:
		
							parent
							
								
									2990790fbc
								
							
						
					
					
						commit
						9b9ae325bb
					
				| @ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { | |||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					if elems.Len() > 0 { | 					if elems.Len() > 0 { | ||||||
| 						if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { | 						if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { | ||||||
| 							for i := 0; i < elems.Len(); i++ { | 							for i := 0; i < elems.Len(); i++ { | ||||||
| 								setupReferences(objs[i], elems.Index(i)) | 								setupReferences(objs[i], elems.Index(i)) | ||||||
| 							} | 							} | ||||||
| @ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { | |||||||
| 							rv = rv.Addr() | 							rv = rv.Addr() | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { | 						if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { | ||||||
| 							setupReferences(db.Statement.ReflectValue, rv) | 							setupReferences(db.Statement.ReflectValue, rv) | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| @ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | |||||||
| 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) | 						saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) | ||||||
| 					} | 					} | ||||||
| 				case reflect.Struct: | 				case reflect.Struct: | ||||||
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { | 					if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { | ||||||
| @ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | |||||||
| 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) | 						saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | |||||||
| 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) | 					saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| @ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { | |||||||
| 				// optimize elems of reflect value length
 | 				// optimize elems of reflect value length
 | ||||||
| 				if elemLen := elems.Len(); elemLen > 0 { | 				if elemLen := elems.Len(); elemLen > 0 { | ||||||
| 					if v, ok := selectColumns[rel.Name+".*"]; !ok || v { | 					if v, ok := selectColumns[rel.Name+".*"]; !ok || v { | ||||||
| 						saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) | 						saveAssociations(db, rel, elems, selectColumns, restricted, nil) | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					for i := 0; i < elemLen; i++ { | 					for i := 0; i < elemLen; i++ { | ||||||
| @ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { | func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { | ||||||
|  | 	// stop save association loop
 | ||||||
|  | 	if checkAssociationsSaved(db, rValues) { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	var ( | 	var ( | ||||||
| 		selects, omits []string | 		selects, omits []string | ||||||
| 		onConflict     = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) | 		onConflict     = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) | ||||||
| 		refName        = rel.Name + "." | 		refName        = rel.Name + "." | ||||||
|  | 		values         = rValues.Interface() | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	for name, ok := range selectColumns { | 	for name, ok := range selectColumns { | ||||||
| @ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, | |||||||
| 
 | 
 | ||||||
| 	return db.AddError(tx.Create(values).Error) | 	return db.AddError(tx.Create(values).Error) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // check association values has been saved
 | ||||||
|  | // if values kind is Struct, check it has been saved
 | ||||||
|  | // if values kind is Slice/Array, check all items have been saved
 | ||||||
|  | var visitMapStoreKey = "gorm:saved_association_map" | ||||||
|  | 
 | ||||||
|  | func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { | ||||||
|  | 	if visit, ok := db.Get(visitMapStoreKey); ok { | ||||||
|  | 		if v, ok := visit.(*visitMap); ok { | ||||||
|  | 			if loadOrStoreVisitMap(v, values) { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		vistMap := make(visitMap) | ||||||
|  | 		loadOrStoreVisitMap(&vistMap, values) | ||||||
|  | 		db.Set(visitMapStoreKey, &vistMap) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"reflect" | ||||||
| 	"sort" | 	"sort" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| @ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type visitMap = map[reflect.Value]bool | ||||||
|  | 
 | ||||||
|  | // Check if circular values, return true if loaded
 | ||||||
|  | func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { | ||||||
|  | 	if v.Kind() == reflect.Ptr { | ||||||
|  | 		v = v.Elem() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch v.Kind() { | ||||||
|  | 	case reflect.Slice, reflect.Array: | ||||||
|  | 		loaded = true | ||||||
|  | 		for i := 0; i < v.Len(); i++ { | ||||||
|  | 			if !loadOrStoreVisitMap(vistMap, v.Index(i)) { | ||||||
|  | 				loaded = false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	case reflect.Struct, reflect.Interface: | ||||||
|  | 		if v.CanAddr() { | ||||||
|  | 			p := v.Addr() | ||||||
|  | 			if _, ok := (*vistMap)[p]; ok { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			(*vistMap)[p] = true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										36
									
								
								callbacks/visit_map_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								callbacks/visit_map_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestLoadOrStoreVisitMap(t *testing.T) { | ||||||
|  | 	var vm visitMap | ||||||
|  | 	var loaded bool | ||||||
|  | 	type testM struct { | ||||||
|  | 		Name string | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	t1 := testM{Name: "t1"} | ||||||
|  | 	t2 := testM{Name: "t2"} | ||||||
|  | 	t3 := testM{Name: "t3"} | ||||||
|  | 
 | ||||||
|  | 	vm = make(visitMap) | ||||||
|  | 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { | ||||||
|  | 		t.Fatalf("loaded should be false") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { | ||||||
|  | 		t.Fatalf("loaded should be true") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// t1 already exist but t2 not
 | ||||||
|  | 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { | ||||||
|  | 		t.Fatalf("loaded should be false") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { | ||||||
|  | 		t.Fatalf("loaded should be true") | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) { | |||||||
| 		t.Errorf("Failed to preload AppliesToProduct") | 		t.Errorf("Failed to preload AppliesToProduct") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestSaveBelongsCircularReference(t *testing.T) { | ||||||
|  | 	parent := Parent{} | ||||||
|  | 	DB.Create(&parent) | ||||||
|  | 
 | ||||||
|  | 	child := Child{ParentID: &parent.ID, Parent: &parent} | ||||||
|  | 	DB.Create(&child) | ||||||
|  | 
 | ||||||
|  | 	parent.FavChildID = child.ID | ||||||
|  | 	parent.FavChild = &child | ||||||
|  | 	DB.Save(&parent) | ||||||
|  | 
 | ||||||
|  | 	var parent1 Parent | ||||||
|  | 	DB.First(&parent1, parent.ID) | ||||||
|  | 	AssertObjEqual(t, parent, parent1, "ID", "FavChildID") | ||||||
|  | 
 | ||||||
|  | 	// Save and Updates is the same
 | ||||||
|  | 	DB.Updates(&parent) | ||||||
|  | 	DB.First(&parent1, parent.ID) | ||||||
|  | 	AssertObjEqual(t, parent, parent1, "ID", "FavChildID") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSaveHasManyCircularReference(t *testing.T) { | ||||||
|  | 	parent := Parent{} | ||||||
|  | 	DB.Create(&parent) | ||||||
|  | 
 | ||||||
|  | 	child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} | ||||||
|  | 	child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} | ||||||
|  | 
 | ||||||
|  | 	parent.Children = []*Child{&child, &child1} | ||||||
|  | 	DB.Save(&parent) | ||||||
|  | 
 | ||||||
|  | 	var children []*Child | ||||||
|  | 	DB.Where("parent_id = ?", parent.ID).Find(&children) | ||||||
|  | 	if len(children) != len(parent.Children) || | ||||||
|  | 		children[0].ID != parent.Children[0].ID || | ||||||
|  | 		children[1].ID != parent.Children[1].ID { | ||||||
|  | 		t.Errorf("circular reference children save not equal children:%v parent.Children:%v", | ||||||
|  | 			children, parent.Children) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { | |||||||
| 
 | 
 | ||||||
| func RunMigrations() { | func RunMigrations() { | ||||||
| 	var err error | 	var err error | ||||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} | 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| 	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) | 	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -80,3 +80,17 @@ type Order struct { | |||||||
| 	Coupon   *Coupon | 	Coupon   *Coupon | ||||||
| 	CouponID string | 	CouponID string | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type Parent struct { | ||||||
|  | 	gorm.Model | ||||||
|  | 	FavChildID uint | ||||||
|  | 	FavChild   *Child | ||||||
|  | 	Children   []*Child | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Child struct { | ||||||
|  | 	gorm.Model | ||||||
|  | 	Name     string | ||||||
|  | 	ParentID *uint | ||||||
|  | 	Parent   *Parent | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 chenrui
						chenrui