Support delete associations with Select when deleting
This commit is contained in:
		
							parent
							
								
									53caa85cf4
								
							
						
					
					
						commit
						70a7bd52ca
					
				| @ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | |||||||
| 	deleteCallback := db.Callback().Delete() | 	deleteCallback := db.Callback().Delete() | ||||||
| 	deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | 	deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||||
| 	deleteCallback.Register("gorm:before_delete", BeforeDelete) | 	deleteCallback.Register("gorm:before_delete", BeforeDelete) | ||||||
|  | 	deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) | ||||||
| 	deleteCallback.Register("gorm:delete", Delete) | 	deleteCallback.Register("gorm:delete", Delete) | ||||||
| 	deleteCallback.Register("gorm:after_delete", AfterDelete) | 	deleteCallback.Register("gorm:after_delete", AfterDelete) | ||||||
| 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
|  | |||||||
| @ -21,6 +21,59 @@ func BeforeDelete(db *gorm.DB) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func DeleteBeforeAssociations(db *gorm.DB) { | ||||||
|  | 	if db.Error == nil && db.Statement.Schema != nil { | ||||||
|  | 		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) | ||||||
|  | 
 | ||||||
|  | 		if restricted { | ||||||
|  | 			for column, v := range selectColumns { | ||||||
|  | 				if v { | ||||||
|  | 					if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { | ||||||
|  | 						switch rel.Type { | ||||||
|  | 						case schema.HasOne, schema.HasMany: | ||||||
|  | 							queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) | ||||||
|  | 							modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||||
|  | 							tx := db.Session(&gorm.Session{}).Model(modelValue) | ||||||
|  | 							if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||||
|  | 								return | ||||||
|  | 							} | ||||||
|  | 						case schema.Many2Many: | ||||||
|  | 							var ( | ||||||
|  | 								queryConds     []clause.Expression | ||||||
|  | 								foreignFields  []*schema.Field | ||||||
|  | 								relForeignKeys []string | ||||||
|  | 								modelValue     = reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
|  | 								table          = rel.JoinTable.Table | ||||||
|  | 								tx             = db.Session(&gorm.Session{}).Model(modelValue).Table(table) | ||||||
|  | 							) | ||||||
|  | 
 | ||||||
|  | 							for _, ref := range rel.References { | ||||||
|  | 								if ref.OwnPrimaryKey { | ||||||
|  | 									foreignFields = append(foreignFields, ref.PrimaryKey) | ||||||
|  | 									relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||||
|  | 								} else if ref.PrimaryValue != "" { | ||||||
|  | 									queryConds = append(queryConds, clause.Eq{ | ||||||
|  | 										Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, | ||||||
|  | 										Value:  ref.PrimaryValue, | ||||||
|  | 									}) | ||||||
|  | 								} | ||||||
|  | 							} | ||||||
|  | 
 | ||||||
|  | 							_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) | ||||||
|  | 							column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) | ||||||
|  | 							queryConds = append(queryConds, clause.IN{Column: column, Values: values}) | ||||||
|  | 
 | ||||||
|  | 							if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { | ||||||
|  | 								return | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func Delete(db *gorm.DB) { | func Delete(db *gorm.DB) { | ||||||
| 	if db.Error == nil { | 	if db.Error == nil { | ||||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) { | |||||||
| 		t.Errorf("should returns no error while enable global update, but got err %v", err) | 		t.Errorf("should returns no error while enable global update, but got err %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestDeleteWithAssociations(t *testing.T) { | ||||||
|  | 	user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Create(user).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create user, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to delete user, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { | ||||||
|  | 		if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { | ||||||
|  | 			t.Errorf("user's %v expects: %v, got %v", key, value, count) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { | ||||||
|  | 		if count := DB.Model(&user).Association(key).Count(); count != value { | ||||||
|  | 			t.Errorf("user's %v expects: %v, got %v", key, value, count) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestDeleteSliceWithAssociations(t *testing.T) { | ||||||
|  | 	users := []User{ | ||||||
|  | 		*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), | ||||||
|  | 		*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), | ||||||
|  | 		*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), | ||||||
|  | 		*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Create(users).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to create user, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to delete user, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { | ||||||
|  | 		if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { | ||||||
|  | 			t.Errorf("user's %v expects: %v, got %v", key, value, count) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { | ||||||
|  | 		if count := DB.Model(&users).Association(key).Count(); count != value { | ||||||
|  | 			t.Errorf("user's %v expects: %v, got %v", key, value, count) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -5,12 +5,14 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Person struct { | type Person struct { | ||||||
| 	ID        int | 	ID        int | ||||||
| 	Name      string | 	Name      string | ||||||
| 	Addresses []Address `gorm:"many2many:person_addresses;"` | 	Addresses []Address `gorm:"many2many:person_addresses;"` | ||||||
|  | 	DeletedAt gorm.DeletedAt | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type Address struct { | type Address struct { | ||||||
| @ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) { | |||||||
| 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { | 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { | ||||||
| 		t.Fatalf("address should be deleted when clear with unscoped") | 		t.Fatalf("address should be deleted when clear with unscoped") | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	address2_1 := Address{Name: "address 2-1"} | ||||||
|  | 	address2_2 := Address{Name: "address 2-2"} | ||||||
|  | 	person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} | ||||||
|  | 	DB.Create(&person2) | ||||||
|  | 	if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to delete person, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { | ||||||
|  | 		t.Errorf("person's addresses expects 2, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { | ||||||
|  | 		t.Errorf("person's addresses expects 2, got %v", count) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ func FileWithLineNum() string { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func IsValidDBNameChar(c rune) bool { | func IsValidDBNameChar(c rune) bool { | ||||||
| 	return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' | 	return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func CheckTruth(val interface{}) bool { | func CheckTruth(val interface{}) bool { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu