diff --git a/README.md b/README.md index f0dc21f6..9f4cf14f 100644 --- a/README.md +++ b/README.md @@ -560,7 +560,7 @@ db.Table("deleted_users").Pluck("name", &names) ## Callbacks Callback is a function defined to a struct, the function would be run when reflect a struct to database. -If the function return an error, will prevent following operations. (for example, stop inserting, updating) +If a function return error, gorm will prevent future operations and do rollback Those callbacks are defined now: @@ -570,12 +570,21 @@ Those callbacks are defined now: `BeforeDelete`, `AfterDelete` ```go +// Won't update readonly user func (u *User) BeforeUpdate() (err error) { if u.readonly() { err = errors.New("Read Only User") } return } + +// If have more than 1000 users, will rollback the insertion +func (u *User) AfterCreate() (err error) { + if (u.Id > 1000) { // just an example, don't use Id to count users + err = errors.New("Only 1000 users allowed") + } + return +} ``` ## Specify Table Name diff --git a/chain.go b/chain.go index 48acc163..1df4412c 100644 --- a/chain.go +++ b/chain.go @@ -132,22 +132,16 @@ func (s *Chain) Select(value interface{}) *Chain { } func (s *Chain) Save(value interface{}) *Chain { - do := s.do(value) - tx_started := do.begin() + do := s.do(value).begin() do.save() - if tx_started { - do.commit() - } + do.commit_or_rollback() return s } func (s *Chain) Delete(value interface{}) *Chain { - do := s.do(value) - tx_started := do.begin() + do := s.do(value).begin() do.delete() - if tx_started { - do.commit() - } + do.commit_or_rollback() return s } @@ -156,12 +150,9 @@ func (s *Chain) Update(attrs ...interface{}) *Chain { } func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { - do := s.do(s.value) - tx_started := do.begin() - do.setUpdateAttrs(values, ignore_protected_attrs...).update() - if tx_started { - do.commit() - } + do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...) + do.update() + do.commit_or_rollback() return s } diff --git a/do.go b/do.go index cde02ccf..7fe13012 100644 --- a/do.go +++ b/do.go @@ -17,6 +17,7 @@ type Do struct { db sql_common guessedTableName string specifiedTableName string + startedTransaction bool model *Model value interface{} @@ -756,20 +757,26 @@ func (s *Do) autoMigrate() *Do { return s } -func (s *Do) begin() bool { +func (s *Do) begin() *Do { if db, ok := s.db.(sql_db); ok { tx, err := db.Begin() if err == nil { s.db = interface{}(tx).(sql_common) - return true + s.startedTransaction = true } } - return false + return s } -func (s *Do) commit() { - if db, ok := s.db.(sql_tx); ok { - s.err(db.Commit()) +func (s *Do) commit_or_rollback() { + if s.startedTransaction { + if db, ok := s.db.(sql_tx); ok { + if s.chain.hasError() { + db.Rollback() + } else { + db.Commit() + } + } } } diff --git a/gorm_test.go b/gorm_test.go index 5b67006c..6a445a84 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -610,8 +610,12 @@ func (s *Product) AfterUpdate() { s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 } -func (s *Product) AfterSave() { +func (s *Product) AfterSave() (err error) { + if s.Code == "after_save_error" { + err = errors.New("Can't save") + } s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return } func (s *Product) BeforeDelete() (err error) { @@ -622,8 +626,12 @@ func (s *Product) BeforeDelete() (err error) { return } -func (s *Product) AfterDelete() { +func (s *Product) AfterDelete() (err error) { + if s.Code == "after_delete_error" { + err = errors.New("Can't delete") + } s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return } func (p *Product) GetCallTimes() []int64 { return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes} @@ -703,6 +711,23 @@ func TestRunCallbacksAndGetErrors(t *testing.T) { if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil { t.Errorf("Should not delete record due to errors happened in callback") } + + p4 := Product{Code: "after_save_error", Price: 100} + db.Save(&p4) + if err := db.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Errorf("Record should be reverted if get an error after save", err) + } + + p5 := Product{Code: "after_delete_error", Price: 100} + db.Save(&p5) + if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Errorf("Record should be found", err) + } + + db.Delete(&p5) + if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Errorf("Record should be found because failed to delete", err) + } } func TestFillSmallerStructCorrectly(t *testing.T) { @@ -1365,6 +1390,13 @@ func TestTransaction(t *testing.T) { } } +func (s *CreditCard) BeforeSave() (err error) { + if s.Number == "0000" { + err = errors.New("invalid credit card") + } + return +} + func BenchmarkGorm(b *testing.B) { for x := 0; x < b.N; x++ { email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}