From 32045fdd7d7a298f09f7ffdca286c3097cfda293 Mon Sep 17 00:00:00 2001 From: black-06 Date: Thu, 4 May 2023 19:30:45 +0800 Subject: [PATCH] feat: unscoped association (#5899) (#6246) * feat: unscoped association (#5899) * modify name because mysql character is latin1 * work only on has association * format * Unscoped on belongs_to association --- association.go | 63 ++++++++++++++++++++--- tests/associations_belongs_to_test.go | 55 ++++++++++++++++++++ tests/associations_has_many_test.go | 74 +++++++++++++++++++++++++++ 3 files changed, 186 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 6719a1d0..7c93ebea 100644 --- a/association.go +++ b/association.go @@ -14,6 +14,7 @@ import ( type Association struct { DB *DB Relationship *schema.Relationship + Unscope bool Error error } @@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association { return association } +func (association *Association) Unscoped() *Association { + return &Association{ + DB: association.DB, + Relationship: association.Relationship, + Error: association.Error, + Unscope: true, + } +} + func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { association.Error = association.buildCondition().Find(out, conds...).Error @@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} + } + } + // save associations if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { return association.Error } // set old associations's foreign key to null - reflectValue := association.DB.Statement.ReflectValue - rel := association.Relationship switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error { association.Error = association.DB.UpdateColumns(updateMap).Error } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + if association.Unscope { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error + } else { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } } case schema.Many2Many: var ( @@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.BelongsTo: - tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { @@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error { conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + } case schema.HasOne, schema.HasMany: - tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + model := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := association.DB.Model(model) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { @@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + association.Error = tx.Clauses(conds...).Delete(model).Error + } else { + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + } case schema.Many2Many: var ( primaryFields, relPrimaryFields []*schema.Field diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 99e8aa79..6befb5f2 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -251,3 +251,58 @@ func TestBelongsToDefaultValue(t *testing.T) { err := DB.Create(&user).Error AssertEqual(t, err, nil) } + +func TestBelongsToAssociationUnscoped(t *testing.T) { + type ItemParent struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + } + type ItemChild struct { + gorm.Model + Name string `gorm:"type:varchar(50)"` + ItemParentID uint + ItemParent ItemParent + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemParent{}, &ItemChild{}) + tx.AutoMigrate(&ItemParent{}, &ItemChild{}) + + item := ItemChild{ + Name: "name", + ItemParent: ItemParent{ + Logo: "logo", + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + tx = tx.Debug() + + // test replace + if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ + Logo: "updated logo", + }); err != nil { + t.Errorf("failed to replace item parent, got error: %v", err) + } + + var parents []ItemParent + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 1 { + t.Errorf("expected %d parents, got %d", 1, len(parents)) + } + + // test delete + if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil { + t.Errorf("failed to delete item parent, got error: %v", err) + } + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 0 { + t.Errorf("expected %d parents, got %d", 0, len(parents)) + } +} diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 002ae636..c31c4b40 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -471,3 +472,76 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Toys").Clear() AssertAssociationCount(t, users, "Toys", 0, "After Clear") } + +func TestHasManyAssociationUnscoped(t *testing.T) { + type ItemContent struct { + gorm.Model + ItemID uint `gorm:"not null"` + Name string `gorm:"not null;type:varchar(50)"` + LanguageCode string `gorm:"not null;type:varchar(2)"` + } + type Item struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + Contents []ItemContent `gorm:"foreignKey:ItemID"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemContent{}, &Item{}) + tx.AutoMigrate(&ItemContent{}, &Item{}) + + item := Item{ + Logo: "logo", + Contents: []ItemContent{ + {Name: "name", LanguageCode: "en"}, + {Name: "ar name", LanguageCode: "ar"}, + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + // test Replace + if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{ + {Name: "updated name", LanguageCode: "en"}, + {Name: "ar updated name", LanguageCode: "ar"}, + {Name: "le nom", LanguageCode: "fr"}, + }); err != nil { + t.Errorf("failed to replace item content, got error: %v", err) + } + + if count := tx.Model(&item).Association("Contents").Count(); count != 3 { + t.Errorf("expected %d contents, got %d", 3, count) + } + + var contents []ItemContent + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 3 { + t.Errorf("expected %d contents, got %d", 3, len(contents)) + } + + // test delete + if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil { + t.Errorf("failed to delete Contents, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 2 { + t.Errorf("expected %d contents, got %d", 2, count) + } + + // test clear + if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil { + t.Errorf("failed to clear contents association, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 0 { + t.Errorf("expected %d contents, got %d", 0, count) + } + + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 0 { + t.Errorf("expected %d contents, got %d", 0, len(contents)) + } +}