From 80e7785499e80352dd01197c848f69df304dbe89 Mon Sep 17 00:00:00 2001 From: qiankunli Date: Tue, 31 Jan 2023 23:30:09 +0800 Subject: [PATCH] fix: support zeroValue tag on DeletedAt Signed-off-by: qiankunli --- soft_delete.go | 33 +++++++++++++++---- tests/soft_delete_test.go | 67 +++++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 2 +- 3 files changed, 94 insertions(+), 8 deletions(-) diff --git a/soft_delete.go b/soft_delete.go index 6d646288..2ab34ecd 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -6,6 +6,7 @@ import ( "encoding/json" "reflect" + "github.com/jinzhu/now" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -45,11 +46,27 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error { } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteQueryClause{Field: f}} + return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +func parseZeroValueTag(f *schema.Field) sql.NullString { + // parse zeroValue tag if not nil + tagSetting := schema.ParseTagSetting(f.Tag.Get("gorm"), ";") + zeroValueTag := tagSetting["ZEROVALUE"] + zeroValue := sql.NullString{Valid: false} + if len(zeroValueTag) > 0 { + // validate it + _, err := now.Parse(zeroValueTag) + if err == nil { + zeroValue = sql.NullString{String: zeroValueTag, Valid: true} + } + } + return zeroValue } type SoftDeleteQueryClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteQueryClause) Name() string { @@ -78,18 +95,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { } stmt.AddClause(clause.Where{Exprs: []clause.Expression{ - clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, }}) stmt.Clauses["soft_delete_enabled"] = clause.Clause{} } } func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteUpdateClause{Field: f}} + return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteUpdateClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteUpdateClause) Name() string { @@ -109,11 +127,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { } func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteDeleteClause{Field: f}} + return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteDeleteClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteDeleteClause) Name() string { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 1f9a4786..2a7ff2a1 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -7,6 +7,7 @@ import ( "regexp" "testing" + "github.com/jinzhu/now" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -98,3 +99,69 @@ func TestDeletedAtOneOr(t *testing.T) { t.Fatalf("invalid sql generated, got %v", actualSQL) } } + +type Book struct { + ID uint + Name string + Pages uint + DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"` +} + +func TestSoftDeleteZeroValue(t *testing.T) { + + book := Book{Name: "jinzhu", Pages: 10} + DB.Save(&book) + + var count int64 + if DB.Model(&Book{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + + var pages uint + if DB.Model(&Book{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages) + } + + if err := DB.Delete(&book).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + zeroTime, _ := now.Parse("1970-01-01 00:00:01") + if book.DeletedAt.Time.Equal(zeroTime) { + t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt) + } + + if DB.First(&Book{}, "name = ?", book.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + count = 0 + if DB.Model(&Book{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + + pages = 0 + if err := DB.Model(&Book{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 { + t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err) + } + + if err := DB.Unscoped().First(&Book{}, "name = ?", book.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + count = 0 + if DB.Unscoped().Model(&Book{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + + pages = 0 + if DB.Unscoped().Model(&Book{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages) + } + + DB.Unscoped().Delete(&book) + if err := DB.Unscoped().First(&Book{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("Can't find permanently deleted record") + } + +} diff --git a/tests/tests_test.go b/tests/tests_test.go index dcba3cbf..7b518169 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -100,7 +100,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Book{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })