From 681268cc43a2aa665e5577680b88ac77b9e5b64c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 16:31:09 +0800 Subject: [PATCH] Refactor Create/Query/Update/DeleteClauses interface --- schema/field.go | 22 -------------------- schema/interfaces.go | 8 ++++---- schema/schema.go | 17 ++++++++++++++++ soft_delete.go | 48 +++++++++++++++++++++++++++++++++----------- 4 files changed, 57 insertions(+), 38 deletions(-) diff --git a/schema/field.go b/schema/field.go index 78eeccdc..bc47e543 100644 --- a/schema/field.go +++ b/schema/field.go @@ -88,23 +88,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) - } - - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) - } - - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) - } - - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) - } - // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { @@ -353,11 +336,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } else { schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } diff --git a/schema/interfaces.go b/schema/interfaces.go index f5d07843..e8e51e4c 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -7,17 +7,17 @@ type GormDataTypeInterface interface { } type CreateClausesInterface interface { - CreateClauses() []clause.Interface + CreateClauses(*Field) []clause.Interface } type QueryClausesInterface interface { - QueryClauses() []clause.Interface + QueryClauses(*Field) []clause.Interface } type UpdateClausesInterface interface { - UpdateClauses() []clause.Interface + UpdateClauses(*Field) []clause.Interface } type DeleteClausesInterface interface { - DeleteClauses() []clause.Interface + DeleteClauses(*Field) []clause.Interface } diff --git a/schema/schema.go b/schema/schema.go index 9206c24e..d81da4b8 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -219,6 +219,23 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return schema, schema.err } } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } diff --git a/soft_delete.go b/soft_delete.go index 180bf745..875623bc 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) { return n.Time, nil } -func (DeletedAt) QueryClauses() []clause.Interface { +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{ clause.Where{Exprs: []clause.Expression{ clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: nil, }, }}, } } -func (DeletedAt) DeleteClauses() []clause.Interface { - return []clause.Interface{SoftDeleteClause{}} +type SoftDeleteQueryClause struct { + Field *schema.Field } -type SoftDeleteClause struct { -} - -func (SoftDeleteClause) Name() string { +func (sd SoftDeleteQueryClause) Name() string { return "" } -func (SoftDeleteClause) Build(clause.Builder) { +func (sd SoftDeleteQueryClause) Build(clause.Builder) { } -func (SoftDeleteClause) MergeClause(*clause.Clause) { +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } -func (SoftDeleteClause) ModifyStatement(stmt *Statement) { +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)