diff --git a/callbacks/delete.go b/callbacks/delete.go index a1fd0a57..08737505 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -26,82 +26,87 @@ func BeforeDelete(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } - if restricted { - for column, v := range selectColumns { - if v { - if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { - switch rel.Type { - case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) - withoutConditions := false - if db.Statement.Unscoped { - tx = tx.Unscoped() - } + for column, v := range selectColumns { + if !v { + continue + } - if len(db.Statement.Selects) > 0 { - selects := make([]string, 0, len(db.Statement.Selects)) - for _, s := range db.Statement.Selects { - if s == clause.Associations { - selects = append(selects, s) - } else if strings.HasPrefix(s, column+".") { - selects = append(selects, strings.TrimPrefix(s, column+".")) - } - } + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } - if len(selects) > 0 { - tx = tx.Select(selects) - } - } + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } - for _, cond := range queryConds { - if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { - withoutConditions = true - break - } - } - - if !withoutConditions { - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } - } - case schema.Many2Many: - var ( - queryConds = make([]clause.Expression, 0, len(rel.References)) - foreignFields = make([]*schema.Field, 0, len(rel.References)) - relForeignKeys = make([]string, 0, len(rel.References)) - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) - ) - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - foreignFields = append(foreignFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - queryConds = append(queryConds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) - column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) - queryConds = append(queryConds, clause.IN{Column: column, Values: values}) - - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) } } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return } } } + } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 9882590c..c887c6c0 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -145,27 +145,30 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues[idx], _ = field.ValueOf(elem) } - if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { - for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) - if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { - reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) - } + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", + elem.Interface())) + continue + } - reflectFieldValue = reflect.Indirect(reflectFieldValue) - switch reflectFieldValue.Kind() { - case reflect.Struct: - rel.Field.Set(data, reflectResults.Index(i).Interface()) - case reflect.Slice, reflect.Array: - if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) - } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) - } + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, elem.Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } - } else { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) } } }