From c6511a0dc6c882ba5e8868d07dbd98968da5d582 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Fri, 3 Sep 2021 05:46:22 +0000 Subject: [PATCH] feat: support unique --- clause/on_conflict.go | 2 +- model.go | 7 ++++ schema/index.go | 20 +++++++++++ soft_delete_unique.go | 80 +++++++++++++++++++++++++++++++++++++++++++ tests/query_test.go | 2 +- 5 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 soft_delete_unique.go diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 64ee7f53..309c5fcd 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -26,7 +26,7 @@ func (onConflict OnConflict) Build(builder Builder) { } builder.WriteString(`) `) } - + if len(onConflict.TargetWhere.Exprs) > 0 { builder.WriteString(" WHERE ") onConflict.TargetWhere.Build(builder) diff --git a/model.go b/model.go index 3334d17c..201ced5e 100644 --- a/model.go +++ b/model.go @@ -13,3 +13,10 @@ type Model struct { UpdatedAt time.Time DeletedAt DeletedAt `gorm:"index"` } + +type ModelSupportUnique struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedFlag DeletedFlag `gorm:"type:BIGINT UNSIGNED NOT NULL DEFAULT 0" json:"deleted_flag"` +} diff --git a/schema/index.go b/schema/index.go index b54e08ad..ce501cc0 100644 --- a/schema/index.go +++ b/schema/index.go @@ -51,6 +51,26 @@ func (schema *Schema) ParseIndexes() map[string]Index { } idx.Fields = append(idx.Fields, index.Fields...) + + // create combined index for unique + if index.Class == "UNIQUE" { + if df := schema.LookUpField("deleted_flag"); df != nil { + var exists bool + for _, f := range idx.Fields { + if f.Field.Name == df.Name { + exists = true + break + } + } + + if !exists { + idx.Fields = append(idx.Fields, IndexOption{ + Field: df, + }) + } + } + } + sort.Slice(idx.Fields, func(i, j int) bool { return idx.Fields[i].priority < idx.Fields[j].priority }) diff --git a/soft_delete_unique.go b/soft_delete_unique.go new file mode 100644 index 00000000..bb58fd25 --- /dev/null +++ b/soft_delete_unique.go @@ -0,0 +1,80 @@ +package gorm + +import ( + "fmt" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +type DeletedFlag uint + +// // Scan implements the Scanner interface. +// func (n *DeletedFlag) Scan(value interface{}) error { +// return (*sql.NullTime)(n).Scan(value) +// } + +// // Value implements the driver Valuer interface. +// func (n DeletedFlag) Value() (driver.Value, error) { +// if !n.Valid { +// return nil, nil +// } +// return n.Time, nil +// } + +// func (n DeletedFlag) MarshalJSON() ([]byte, error) { +// if n.Valid { +// return json.Marshal(n.Time) +// } +// return json.Marshal(nil) +// } + +// func (n *DeletedFlag) UnmarshalJSON(b []byte) error { +// if string(b) == "null" { +// n.Valid = false +// return nil +// } +// err := json.Unmarshal(b, &n.Time) +// if err == nil { +// n.Valid = true +// } +// return err +// } + +func (DeletedFlag) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteQueryClause{Field: f}} +} + +func (DeletedFlag) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteUniqueDeleteClause{Field: f}} +} + +type SoftDeleteUniqueDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteUniqueDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteUniqueDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteUniqueDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteUniqueDeleteClause) ModifyStatement(stmt *Statement) { + re := regexp.MustCompile(`UPDATE (.*) WHERE `) + if sql := stmt.SQL.String(); sql != "" { + setClause := re.FindStringSubmatch(sql)[1] + if setClause == "" { + return + } + + newSetClause := fmt.Sprintf("%s, %s = `%s`.`id`", setClause, sd.Field.DBName, stmt.Table) + stmt.SQL.Reset() + stmt.SQL.WriteString(strings.Replace(sql, setClause, newSetClause, 1)) + } +} diff --git a/tests/query_test.go b/tests/query_test.go index 8a476598..e3321ed7 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -440,7 +440,7 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } - + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())