Refactor Create/Query/Update/DeleteClauses interface

This commit is contained in:
Jinzhu 2020-08-17 16:31:09 +08:00
parent 2a716e04e6
commit 681268cc43
4 changed files with 57 additions and 38 deletions

View File

@ -88,23 +88,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...)
}
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer) valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer { if isValuer {
@ -353,11 +336,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.TagSettings[k] = v ef.TagSettings[k] = v
} }
} }
field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...)
field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...)
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...)
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...)
} else { } else {
schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
} }

View File

@ -7,17 +7,17 @@ type GormDataTypeInterface interface {
} }
type CreateClausesInterface interface { type CreateClausesInterface interface {
CreateClauses() []clause.Interface CreateClauses(*Field) []clause.Interface
} }
type QueryClausesInterface interface { type QueryClausesInterface interface {
QueryClauses() []clause.Interface QueryClauses(*Field) []clause.Interface
} }
type UpdateClausesInterface interface { type UpdateClausesInterface interface {
UpdateClauses() []clause.Interface UpdateClauses(*Field) []clause.Interface
} }
type DeleteClausesInterface interface { type DeleteClausesInterface interface {
DeleteClauses() []clause.Interface DeleteClauses(*Field) []clause.Interface
} }

View File

@ -219,6 +219,23 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
return schema, schema.err return schema, schema.err
} }
} }
fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
} }
} }

View File

@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) {
return n.Time, nil return n.Time, nil
} }
func (DeletedAt) QueryClauses() []clause.Interface { func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{ return []clause.Interface{
clause.Where{Exprs: []clause.Expression{ clause.Where{Exprs: []clause.Expression{
clause.Eq{ clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName},
Value: nil, Value: nil,
}, },
}}, }},
} }
} }
func (DeletedAt) DeleteClauses() []clause.Interface { type SoftDeleteQueryClause struct {
return []clause.Interface{SoftDeleteClause{}} Field *schema.Field
} }
type SoftDeleteClause struct { func (sd SoftDeleteQueryClause) Name() string {
}
func (SoftDeleteClause) Name() string {
return "" return ""
} }
func (SoftDeleteClause) Build(clause.Builder) { func (sd SoftDeleteQueryClause) Build(clause.Builder) {
} }
func (SoftDeleteClause) MergeClause(*clause.Clause) { func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
} }
func (SoftDeleteClause) ModifyStatement(stmt *Statement) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
}})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
}
}
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
}
type SoftDeleteDeleteClause struct {
Field *schema.Field
}
func (sd SoftDeleteDeleteClause) Name() string {
return ""
}
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
}
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.String() == "" { if stmt.SQL.String() == "" {
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}})
if stmt.Schema != nil { if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)