fix: support zeroValue tag on DeletedAt

Signed-off-by: qiankunli <qiankun.li@qq.com>
This commit is contained in:
qiankunli 2023-01-31 23:30:09 +08:00
parent d834dd60b7
commit 80e7785499
3 changed files with 94 additions and 8 deletions

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"reflect" "reflect"
"github.com/jinzhu/now"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -45,11 +46,27 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
} }
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f}} return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
// parse zeroValue tag if not nil
tagSetting := schema.ParseTagSetting(f.Tag.Get("gorm"), ";")
zeroValueTag := tagSetting["ZEROVALUE"]
zeroValue := sql.NullString{Valid: false}
if len(zeroValueTag) > 0 {
// validate it
_, err := now.Parse(zeroValueTag)
if err == nil {
zeroValue = sql.NullString{String: zeroValueTag, Valid: true}
}
}
return zeroValue
} }
type SoftDeleteQueryClause struct { type SoftDeleteQueryClause struct {
Field *schema.Field ZeroValue sql.NullString
Field *schema.Field
} }
func (sd SoftDeleteQueryClause) Name() string { func (sd SoftDeleteQueryClause) Name() string {
@ -78,18 +95,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{ stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
}}) }})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{} stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
} }
} }
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f}} return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
} }
type SoftDeleteUpdateClause struct { type SoftDeleteUpdateClause struct {
Field *schema.Field ZeroValue sql.NullString
Field *schema.Field
} }
func (sd SoftDeleteUpdateClause) Name() string { func (sd SoftDeleteUpdateClause) Name() string {
@ -109,11 +127,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
} }
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f}} return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
} }
type SoftDeleteDeleteClause struct { type SoftDeleteDeleteClause struct {
Field *schema.Field ZeroValue sql.NullString
Field *schema.Field
} }
func (sd SoftDeleteDeleteClause) Name() string { func (sd SoftDeleteDeleteClause) Name() string {

View File

@ -7,6 +7,7 @@ import (
"regexp" "regexp"
"testing" "testing"
"github.com/jinzhu/now"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -98,3 +99,69 @@ func TestDeletedAtOneOr(t *testing.T) {
t.Fatalf("invalid sql generated, got %v", actualSQL) t.Fatalf("invalid sql generated, got %v", actualSQL)
} }
} }
type Book struct {
ID uint
Name string
Pages uint
DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"`
}
func TestSoftDeleteZeroValue(t *testing.T) {
book := Book{Name: "jinzhu", Pages: 10}
DB.Save(&book)
var count int64
if DB.Model(&Book{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
}
var pages uint
if DB.Model(&Book{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages)
}
if err := DB.Delete(&book).Error; err != nil {
t.Fatalf("No error should happen when soft delete user, but got %v", err)
}
zeroTime, _ := now.Parse("1970-01-01 00:00:01")
if book.DeletedAt.Time.Equal(zeroTime) {
t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt)
}
if DB.First(&Book{}, "name = ?", book.Name).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
count = 0
if DB.Model(&Book{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 {
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
}
pages = 0
if err := DB.Model(&Book{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 {
t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err)
}
if err := DB.Unscoped().First(&Book{}, "name = ?", book.Name).Error; err != nil {
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
}
count = 0
if DB.Unscoped().Model(&Book{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
}
pages = 0
if DB.Unscoped().Model(&Book{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages)
}
DB.Unscoped().Delete(&book)
if err := DB.Unscoped().First(&Book{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
t.Errorf("Can't find permanently deleted record")
}
}

View File

@ -100,7 +100,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
func RunMigrations() { func RunMigrations() {
var err error var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Book{}}
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })