diff --git a/schema/field.go b/schema/field.go index b5103d53..f1e41ece 100644 --- a/schema/field.go +++ b/schema/field.go @@ -62,6 +62,7 @@ type Field struct { Creatable bool Updatable bool Readable bool + UpdateOnSoftDelete bool AutoCreateTime TimeType AutoUpdateTime TimeType HasDefaultValue bool @@ -113,6 +114,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Creatable: true, Updatable: true, Readable: true, + UpdateOnSoftDelete: false, PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), @@ -329,6 +331,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if _, ok := field.TagSettings["UPDATEONSOFTDELETE"]; ok { + field.UpdateOnSoftDelete = true + } + // setup permission if val, ok := field.TagSettings["-"]; ok { val = strings.ToLower(strings.TrimSpace(val)) diff --git a/soft_delete.go b/soft_delete.go index 5673d3b8..6a1454c0 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -142,9 +142,28 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { curTime := stmt.DB.NowFunc() - stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + + var setColumns clause.Set = clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}} stmt.SetColumn(sd.Field.DBName, curTime, true) + modelRefVal := reflect.ValueOf(stmt.Model) + if modelRefVal.Kind() == reflect.Ptr { + modelRefVal = reflect.Indirect(modelRefVal) + } + + if stmt.Model != nil && stmt.Schema != nil { + // Add additional fields to update in the same operation! + for _, field := range stmt.Schema.Fields { + if field.UpdateOnSoftDelete { + fieldVal := modelRefVal.FieldByName(field.Name).Interface() + setColumns = append(setColumns, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: fieldVal}) + stmt.SetColumn(field.DBName, fieldVal) + } + } + } + + stmt.AddClause(setColumns) + if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)