feat: unscoped association (#5899)
This commit is contained in:
		
							parent
							
								
									e9637024d3
								
							
						
					
					
						commit
						eb70c3a84c
					
				@ -14,6 +14,7 @@ import (
 | 
				
			|||||||
type Association struct {
 | 
					type Association struct {
 | 
				
			||||||
	DB           *DB
 | 
						DB           *DB
 | 
				
			||||||
	Relationship *schema.Relationship
 | 
						Relationship *schema.Relationship
 | 
				
			||||||
 | 
						Unscope      bool
 | 
				
			||||||
	Error        error
 | 
						Error        error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
 | 
				
			|||||||
	return association
 | 
						return association
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (association *Association) Unscoped() *Association {
 | 
				
			||||||
 | 
						return &Association{
 | 
				
			||||||
 | 
							DB:           association.DB,
 | 
				
			||||||
 | 
							Relationship: association.Relationship,
 | 
				
			||||||
 | 
							Error:        association.Error,
 | 
				
			||||||
 | 
							Unscope:      true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
 | 
					func (association *Association) Find(out interface{}, conds ...interface{}) error {
 | 
				
			||||||
	if association.Error == nil {
 | 
						if association.Error == nil {
 | 
				
			||||||
		association.Error = association.buildCondition().Find(out, conds...).Error
 | 
							association.Error = association.buildCondition().Find(out, conds...).Error
 | 
				
			||||||
@ -75,7 +85,6 @@ func (association *Association) Replace(values ...interface{}) error {
 | 
				
			|||||||
		switch rel.Type {
 | 
							switch rel.Type {
 | 
				
			||||||
		case schema.BelongsTo:
 | 
							case schema.BelongsTo:
 | 
				
			||||||
			if len(values) == 0 {
 | 
								if len(values) == 0 {
 | 
				
			||||||
				updateMap := map[string]interface{}{}
 | 
					 | 
				
			||||||
				switch reflectValue.Kind() {
 | 
									switch reflectValue.Kind() {
 | 
				
			||||||
				case reflect.Slice, reflect.Array:
 | 
									case reflect.Slice, reflect.Array:
 | 
				
			||||||
					for i := 0; i < reflectValue.Len(); i++ {
 | 
										for i := 0; i < reflectValue.Len(); i++ {
 | 
				
			||||||
@ -85,11 +94,16 @@ func (association *Association) Replace(values ...interface{}) error {
 | 
				
			|||||||
					association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
 | 
										association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				for _, ref := range rel.References {
 | 
									if association.Unscope {
 | 
				
			||||||
					updateMap[ref.ForeignKey.DBName] = nil
 | 
										modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
 | 
				
			||||||
 | 
										association.Error = association.DB.Delete(modelValue).Error
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										updateMap := map[string]interface{}{}
 | 
				
			||||||
 | 
										for _, ref := range rel.References {
 | 
				
			||||||
 | 
											updateMap[ref.ForeignKey.DBName] = nil
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										association.Error = association.DB.UpdateColumns(updateMap).Error
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					 | 
				
			||||||
				association.Error = association.DB.UpdateColumns(updateMap).Error
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case schema.HasOne, schema.HasMany:
 | 
							case schema.HasOne, schema.HasMany:
 | 
				
			||||||
			var (
 | 
								var (
 | 
				
			||||||
@ -119,7 +133,11 @@ func (association *Association) Replace(values ...interface{}) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
 | 
								if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
 | 
				
			||||||
				column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
 | 
									column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
 | 
				
			||||||
				association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
 | 
									if association.Unscope {
 | 
				
			||||||
 | 
										association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case schema.Many2Many:
 | 
							case schema.Many2Many:
 | 
				
			||||||
			var (
 | 
								var (
 | 
				
			||||||
@ -184,7 +202,8 @@ func (association *Association) Delete(values ...interface{}) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		switch rel.Type {
 | 
							switch rel.Type {
 | 
				
			||||||
		case schema.BelongsTo:
 | 
							case schema.BelongsTo:
 | 
				
			||||||
			tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
 | 
								model := reflect.New(rel.Schema.ModelType).Interface()
 | 
				
			||||||
 | 
								tx := association.DB.Model(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
 | 
								_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
 | 
				
			||||||
			if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
 | 
								if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
 | 
				
			||||||
@ -197,9 +216,14 @@ func (association *Association) Delete(values ...interface{}) error {
 | 
				
			|||||||
			relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
 | 
								relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
 | 
				
			||||||
			conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
 | 
								conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
 | 
								if association.Unscope {
 | 
				
			||||||
 | 
									association.Error = tx.Clauses(conds...).Delete(model).Error
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		case schema.HasOne, schema.HasMany:
 | 
							case schema.HasOne, schema.HasMany:
 | 
				
			||||||
			tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
 | 
								model := reflect.New(rel.FieldSchema.ModelType).Interface()
 | 
				
			||||||
 | 
								tx := association.DB.Model(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
 | 
								_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
 | 
				
			||||||
			if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
 | 
								if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
 | 
				
			||||||
@ -212,7 +236,11 @@ func (association *Association) Delete(values ...interface{}) error {
 | 
				
			|||||||
			relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
 | 
								relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
 | 
				
			||||||
			conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
 | 
								conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
 | 
								if association.Unscope {
 | 
				
			||||||
 | 
									association.Error = tx.Clauses(conds...).Delete(model).Error
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		case schema.Many2Many:
 | 
							case schema.Many2Many:
 | 
				
			||||||
			var (
 | 
								var (
 | 
				
			||||||
				primaryFields, relPrimaryFields     []*schema.Field
 | 
									primaryFields, relPrimaryFields     []*schema.Field
 | 
				
			||||||
 | 
				
			|||||||
@ -394,3 +394,76 @@ func TestAssociationEmptyPrimaryKey(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	AssertEqual(t, result, user)
 | 
						AssertEqual(t, result, user)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestAssociationUnscoped(t *testing.T) {
 | 
				
			||||||
 | 
						type ItemContent struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							ItemID       uint   `gorm:"not null"`
 | 
				
			||||||
 | 
							Name         string `gorm:"not null;type:varchar(50)"`
 | 
				
			||||||
 | 
							LanguageCode string `gorm:"not null;type:varchar(2)"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						type Item struct {
 | 
				
			||||||
 | 
							gorm.Model
 | 
				
			||||||
 | 
							Logo     string        `gorm:"not null;type:varchar(50)"`
 | 
				
			||||||
 | 
							Contents []ItemContent `gorm:"foreignKey:ItemID"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx := DB.Session(&gorm.Session{})
 | 
				
			||||||
 | 
						tx.Migrator().DropTable(&ItemContent{}, &Item{})
 | 
				
			||||||
 | 
						tx.AutoMigrate(&ItemContent{}, &Item{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						item := Item{
 | 
				
			||||||
 | 
							Logo: "logo",
 | 
				
			||||||
 | 
							Contents: []ItemContent{
 | 
				
			||||||
 | 
								{Name: "name", LanguageCode: "en"},
 | 
				
			||||||
 | 
								{Name: "الاسم", LanguageCode: "ar"},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := tx.Create(&item).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("failed to create items, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test Replace
 | 
				
			||||||
 | 
						if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
 | 
				
			||||||
 | 
							{Name: "updated name", LanguageCode: "en"},
 | 
				
			||||||
 | 
							{Name: "الاسم المحدث", LanguageCode: "ar"},
 | 
				
			||||||
 | 
							{Name: "le nom", LanguageCode: "fr"},
 | 
				
			||||||
 | 
						}); err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to replace item content, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
 | 
				
			||||||
 | 
							t.Errorf("expected %d contents, got %d", 3, count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var contents []ItemContent
 | 
				
			||||||
 | 
						if err := tx.Find(&contents).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to find contents, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(contents) != 3 {
 | 
				
			||||||
 | 
							t.Errorf("expected 3 contents, got %d", len(contents))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test delete
 | 
				
			||||||
 | 
						if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to delete Contents, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
 | 
				
			||||||
 | 
							t.Errorf("expected %d contents, got %d", 2, count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test clear
 | 
				
			||||||
 | 
						if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to clear contents association, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
 | 
				
			||||||
 | 
							t.Errorf("expected %d contents, got %d", 0, count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := tx.Find(&contents).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to find contents, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(contents) != 0 {
 | 
				
			||||||
 | 
							t.Errorf("expected %d contents, got %d", 0, len(contents))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user