feat: unscoped association (#5899)
This commit is contained in:
parent
e9637024d3
commit
eb70c3a84c
@ -14,6 +14,7 @@ import (
|
|||||||
type Association struct {
|
type Association struct {
|
||||||
DB *DB
|
DB *DB
|
||||||
Relationship *schema.Relationship
|
Relationship *schema.Relationship
|
||||||
|
Unscope bool
|
||||||
Error error
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
|
|||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (association *Association) Unscoped() *Association {
|
||||||
|
return &Association{
|
||||||
|
DB: association.DB,
|
||||||
|
Relationship: association.Relationship,
|
||||||
|
Error: association.Error,
|
||||||
|
Unscope: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||||
@ -75,7 +85,6 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.BelongsTo:
|
case schema.BelongsTo:
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
updateMap := map[string]interface{}{}
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
@ -85,11 +94,16 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
if association.Unscope {
|
||||||
updateMap[ref.ForeignKey.DBName] = nil
|
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
association.Error = association.DB.Delete(modelValue).Error
|
||||||
|
} else {
|
||||||
|
updateMap := map[string]interface{}{}
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
updateMap[ref.ForeignKey.DBName] = nil
|
||||||
|
}
|
||||||
|
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
|
||||||
}
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
var (
|
var (
|
||||||
@ -119,7 +133,11 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
|
|
||||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
if association.Unscope {
|
||||||
|
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||||
|
} else {
|
||||||
|
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var (
|
var (
|
||||||
@ -184,7 +202,8 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
|
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.BelongsTo:
|
case schema.BelongsTo:
|
||||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
model := reflect.New(rel.Schema.ModelType).Interface()
|
||||||
|
tx := association.DB.Model(model)
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||||
@ -197,9 +216,14 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
if association.Unscope {
|
||||||
|
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||||
|
} else {
|
||||||
|
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||||
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
tx := association.DB.Model(model)
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||||
@ -212,7 +236,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
if association.Unscope {
|
||||||
|
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||||
|
} else {
|
||||||
|
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||||
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var (
|
var (
|
||||||
primaryFields, relPrimaryFields []*schema.Field
|
primaryFields, relPrimaryFields []*schema.Field
|
||||||
|
@ -394,3 +394,76 @@ func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
|||||||
|
|
||||||
AssertEqual(t, result, user)
|
AssertEqual(t, result, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAssociationUnscoped(t *testing.T) {
|
||||||
|
type ItemContent struct {
|
||||||
|
gorm.Model
|
||||||
|
ItemID uint `gorm:"not null"`
|
||||||
|
Name string `gorm:"not null;type:varchar(50)"`
|
||||||
|
LanguageCode string `gorm:"not null;type:varchar(2)"`
|
||||||
|
}
|
||||||
|
type Item struct {
|
||||||
|
gorm.Model
|
||||||
|
Logo string `gorm:"not null;type:varchar(50)"`
|
||||||
|
Contents []ItemContent `gorm:"foreignKey:ItemID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Session(&gorm.Session{})
|
||||||
|
tx.Migrator().DropTable(&ItemContent{}, &Item{})
|
||||||
|
tx.AutoMigrate(&ItemContent{}, &Item{})
|
||||||
|
|
||||||
|
item := Item{
|
||||||
|
Logo: "logo",
|
||||||
|
Contents: []ItemContent{
|
||||||
|
{Name: "name", LanguageCode: "en"},
|
||||||
|
{Name: "الاسم", LanguageCode: "ar"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := tx.Create(&item).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create items, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test Replace
|
||||||
|
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
|
||||||
|
{Name: "updated name", LanguageCode: "en"},
|
||||||
|
{Name: "الاسم المحدث", LanguageCode: "ar"},
|
||||||
|
{Name: "le nom", LanguageCode: "fr"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Errorf("failed to replace item content, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 3, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var contents []ItemContent
|
||||||
|
if err := tx.Find(&contents).Error; err != nil {
|
||||||
|
t.Errorf("failed to find contents, got error: %v", err)
|
||||||
|
}
|
||||||
|
if len(contents) != 3 {
|
||||||
|
t.Errorf("expected 3 contents, got %d", len(contents))
|
||||||
|
}
|
||||||
|
|
||||||
|
// test delete
|
||||||
|
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
|
||||||
|
t.Errorf("failed to delete Contents, got error: %v", err)
|
||||||
|
}
|
||||||
|
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test clear
|
||||||
|
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
|
||||||
|
t.Errorf("failed to clear contents association, got error: %v", err)
|
||||||
|
}
|
||||||
|
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 0, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Find(&contents).Error; err != nil {
|
||||||
|
t.Errorf("failed to find contents, got error: %v", err)
|
||||||
|
}
|
||||||
|
if len(contents) != 0 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 0, len(contents))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user