test: adds AfterError tests
This commit is contained in:
parent
2e00b2bd7d
commit
352d9a9abb
@ -24,6 +24,7 @@ type Product struct {
|
|||||||
AfterSaveCallTimes int64
|
AfterSaveCallTimes int64
|
||||||
BeforeDeleteCallTimes int64
|
BeforeDeleteCallTimes int64
|
||||||
AfterDeleteCallTimes int64
|
AfterDeleteCallTimes int64
|
||||||
|
AfterErrorCallTimes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeCreate(tx *gorm.DB) (err error) {
|
func (s *Product) BeforeCreate(tx *gorm.DB) (err error) {
|
||||||
@ -88,8 +89,16 @@ func (s *Product) AfterDelete(tx *gorm.DB) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterError(tx *gorm.DB) (err error) {
|
||||||
|
if s.Code == "after_error_error" {
|
||||||
|
err = errors.New("can't handle this error")
|
||||||
|
}
|
||||||
|
s.AfterErrorCallTimes = s.AfterErrorCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Product) GetCallTimes() []int64 {
|
func (s *Product) GetCallTimes() []int64 {
|
||||||
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
|
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes, s.AfterErrorCallTimes}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunCallbacks(t *testing.T) {
|
func TestRunCallbacks(t *testing.T) {
|
||||||
@ -99,18 +108,18 @@ func TestRunCallbacks(t *testing.T) {
|
|||||||
p := Product{Code: "unique_code", Price: 100}
|
p := Product{Code: "unique_code", Price: 100}
|
||||||
DB.Save(&p)
|
DB.Save(&p)
|
||||||
|
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0, 0}) {
|
||||||
t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("Code = ?", "unique_code").First(&p)
|
DB.Where("Code = ?", "unique_code").First(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1, 0}) {
|
||||||
t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes())
|
t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Price = 200
|
p.Price = 200
|
||||||
DB.Save(&p)
|
DB.Save(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1, 0}) {
|
||||||
t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,12 +130,12 @@ func TestRunCallbacks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("Code = ?", "unique_code").First(&p)
|
DB.Where("Code = ?", "unique_code").First(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2, 0}) {
|
||||||
t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Delete(&p)
|
DB.Delete(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2, 0}) {
|
||||||
t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,6 +143,10 @@ func TestRunCallbacks(t *testing.T) {
|
|||||||
t.Fatalf("Can't find a deleted record")
|
t.Fatalf("Can't find a deleted record")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2, 1}) {
|
||||||
|
t.Fatalf("AfterError should be called because First raises error when doesn't fint, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
beforeCallTimes := p.AfterFindCallTimes
|
beforeCallTimes := p.AfterFindCallTimes
|
||||||
if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil {
|
if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil {
|
||||||
t.Fatalf("Find don't raise error when record not found")
|
t.Fatalf("Find don't raise error when record not found")
|
||||||
@ -142,6 +155,12 @@ func TestRunCallbacks(t *testing.T) {
|
|||||||
if p.AfterFindCallTimes != beforeCallTimes {
|
if p.AfterFindCallTimes != beforeCallTimes {
|
||||||
t.Fatalf("AfterFind should not be called")
|
t.Fatalf("AfterFind should not be called")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Product{})
|
||||||
|
DB.Create(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{2, 3, 1, 1, 0, 0, 1, 1, 2, 2}) {
|
||||||
|
t.Fatalf("should call BeforeCreate, BeforeSave and AfterError, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCallbacksWithErrors(t *testing.T) {
|
func TestCallbacksWithErrors(t *testing.T) {
|
||||||
@ -208,6 +227,14 @@ func TestCallbacksWithErrors(t *testing.T) {
|
|||||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||||
t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback")
|
t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Product{})
|
||||||
|
err := DB.Create(&Product{
|
||||||
|
Code: "after_error_error",
|
||||||
|
}).Error
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "can't handle this error") {
|
||||||
|
t.Fatalf("error on AfterError should be appended to the previous error, but got %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Product2 struct {
|
type Product2 struct {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user