diff --git a/main.go b/main.go index b3cc0bb5..1fe33b00 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,9 @@ type DB struct { dialect Dialect singularTable bool - afterCommitCallbacks []func(db *DB) + afterCommitCallbacks []func(db *DB) + afterRollbackCallbacks []func(db *DB) + afterTransactionCallbacks []func(db *DB) } // Open initialize a new db connection, need to import driver first, e.g: @@ -538,9 +540,11 @@ func (s *DB) Rollback() *DB { return s } -// AfterCommit invoke callback after transaction committed +// AppendAfterCommitCallback invoke callback after transaction committed // If no transaction is currently running, it will be invoked immediately -func (s *DB) AfterCommit(f func(db *DB)) { +// This method appends the given callback function to the register callback functions. +// This means, that the callback will be invoked after all previously registered callbacks. +func (s *DB) AppendAfterCommitCallback(f func(db *DB)) { if db, ok := s.db.(sqlTx); ok && db != nil { s.afterCommitCallbacks = append(s.afterCommitCallbacks, f) } else { @@ -548,6 +552,62 @@ func (s *DB) AfterCommit(f func(db *DB)) { } } +// PrependAfterCommitCallback invoke callback after transaction committed +// If no transaction is currently running, it will be invoked immediately +// This method prepends the given callback function to the register callback functions. +// This means, that the callback will be invoked before all previously registered callbacks. +func (s *DB) PrependAfterCommitCallback(f func(db *DB)) { + if db, ok := s.db.(sqlTx); ok && db != nil { + s.afterCommitCallbacks = append([]func(db *DB){f}, s.afterCommitCallbacks...) + } else { + f(s) + } +} + +// AppendAfterRollbackCallback invoke callback after transaction rolled back +// If no transaction is currently running, it won't be called +// This method appends the given callback function to the register callback functions. +// This means, that the callback will be invoked after all previously registered callbacks. +func (s *DB) AppendAfterRollbackCallback(f func(db *DB)) { + if db, ok := s.db.(sqlTx); ok && db != nil { + s.afterRollbackCallbacks = append(s.afterRollbackCallbacks, f) + } +} + +// PrependAfterRollbackCallback invoke callback after transaction rolled back. +// If no transaction is currently running, it won't be called +// This method prepends the given callback function to the register callback functions. +// This means, that the callback will be invoked before all previously registered callbacks. +func (s *DB) PrependAfterRollbackCallback(f func(db *DB)) { + if db, ok := s.db.(sqlTx); ok && db != nil { + s.afterRollbackCallbacks = append([]func(db *DB){f}, s.afterRollbackCallbacks...) + } +} + +// AppendAfterTransactionCallback invoke callback after transaction finished (committed or rolled back). +// If no transaction is currently running, it will be invoked immediately +// This method appends the given callback function to the register callback functions. +// This means, that the callback will be invoked after all previously registered callbacks. +func (s *DB) AppendAfterTransactionCallback(f func(db *DB)) { + if db, ok := s.db.(sqlTx); ok && db != nil { + s.afterTransactionCallbacks = append(s.afterTransactionCallbacks, f) + } else { + f(s) + } +} + +// PrependAfterTransactionCallback invoke callback after transaction finished (committed or rolled back). +// If no transaction is currently running, it will be invoked immediately +// This method prepends the given callback function to the register callback functions. +// This means, that the callback will be invoked before all previously registered callbacks. +func (s *DB) PrependAfterTransactionCallback(f func(db *DB)) { + if db, ok := s.db.(sqlTx); ok && db != nil { + s.afterTransactionCallbacks = append([]func(db *DB){f}, s.afterTransactionCallbacks...) + } else { + f(s) + } +} + // WrapInTx wraps a method in a transaction func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { if _, ok := s.db.(*sql.Tx); ok { @@ -570,14 +630,26 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { err = fmt.Errorf("Transacton code and rollback failed: %s; %s", err, rollbackErr) } } + + for _, callback := range tx.afterRollbackCallbacks { + callback(s) + } + + for _, callback := range tx.afterTransactionCallbacks { + callback(s) + } } }() err = f(tx) if err == nil { err = tx.Commit().Error if err == nil { - for i := len(tx.afterCommitCallbacks) - 1; i >= 0; i-- { - tx.afterCommitCallbacks[i](s) + for _, callback := range tx.afterCommitCallbacks { + callback(s) + } + + for _, callback := range tx.afterTransactionCallbacks { + callback(s) } } }