diff --git a/association.go b/association.go index 89103fd1..7c93ebea 100644 --- a/association.go +++ b/association.go @@ -74,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} + } + } + // save associations if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { return association.Error } // set old associations's foreign key to null - reflectValue := association.DB.Statement.ReflectValue - rel := association.Relationship switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -101,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error { association.Error = association.DB.UpdateColumns(updateMap).Error } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -198,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.BelongsTo: - tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { @@ -212,6 +232,18 @@ func (association *Association) Delete(values ...interface{}) error { conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + } case schema.HasOne, schema.HasMany: model := reflect.New(rel.FieldSchema.ModelType).Interface() tx := association.DB.Model(model) diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 99e8aa79..6befb5f2 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -251,3 +251,58 @@ func TestBelongsToDefaultValue(t *testing.T) { err := DB.Create(&user).Error AssertEqual(t, err, nil) } + +func TestBelongsToAssociationUnscoped(t *testing.T) { + type ItemParent struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + } + type ItemChild struct { + gorm.Model + Name string `gorm:"type:varchar(50)"` + ItemParentID uint + ItemParent ItemParent + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemParent{}, &ItemChild{}) + tx.AutoMigrate(&ItemParent{}, &ItemChild{}) + + item := ItemChild{ + Name: "name", + ItemParent: ItemParent{ + Logo: "logo", + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + tx = tx.Debug() + + // test replace + if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ + Logo: "updated logo", + }); err != nil { + t.Errorf("failed to replace item parent, got error: %v", err) + } + + var parents []ItemParent + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 1 { + t.Errorf("expected %d parents, got %d", 1, len(parents)) + } + + // test delete + if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil { + t.Errorf("failed to delete item parent, got error: %v", err) + } + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 0 { + t.Errorf("expected %d parents, got %d", 0, len(parents)) + } +}