Added more possibilities for transaction hooks

This commit is contained in:
Gerhard Gruber 2023-07-17 15:07:40 +02:00 committed by lukasbash
parent 246897fa79
commit d89c6ace73

80
main.go
View File

@ -30,6 +30,8 @@ type DB struct {
singularTable bool
afterCommitCallbacks []func(db *DB)
afterRollbackCallbacks []func(db *DB)
afterTransactionCallbacks []func(db *DB)
}
// Open initialize a new db connection, need to import driver first, e.g:
@ -538,9 +540,11 @@ func (s *DB) Rollback() *DB {
return s
}
// AfterCommit invoke callback after transaction committed
// AppendAfterCommitCallback invoke callback after transaction committed
// If no transaction is currently running, it will be invoked immediately
func (s *DB) AfterCommit(f func(db *DB)) {
// This method appends the given callback function to the register callback functions.
// This means, that the callback will be invoked after all previously registered callbacks.
func (s *DB) AppendAfterCommitCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterCommitCallbacks = append(s.afterCommitCallbacks, f)
} else {
@ -548,6 +552,62 @@ func (s *DB) AfterCommit(f func(db *DB)) {
}
}
// PrependAfterCommitCallback invoke callback after transaction committed
// If no transaction is currently running, it will be invoked immediately
// This method prepends the given callback function to the register callback functions.
// This means, that the callback will be invoked before all previously registered callbacks.
func (s *DB) PrependAfterCommitCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterCommitCallbacks = append([]func(db *DB){f}, s.afterCommitCallbacks...)
} else {
f(s)
}
}
// AppendAfterRollbackCallback invoke callback after transaction rolled back
// If no transaction is currently running, it won't be called
// This method appends the given callback function to the register callback functions.
// This means, that the callback will be invoked after all previously registered callbacks.
func (s *DB) AppendAfterRollbackCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterRollbackCallbacks = append(s.afterRollbackCallbacks, f)
}
}
// PrependAfterRollbackCallback invoke callback after transaction rolled back.
// If no transaction is currently running, it won't be called
// This method prepends the given callback function to the register callback functions.
// This means, that the callback will be invoked before all previously registered callbacks.
func (s *DB) PrependAfterRollbackCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterRollbackCallbacks = append([]func(db *DB){f}, s.afterRollbackCallbacks...)
}
}
// AppendAfterTransactionCallback invoke callback after transaction finished (committed or rolled back).
// If no transaction is currently running, it will be invoked immediately
// This method appends the given callback function to the register callback functions.
// This means, that the callback will be invoked after all previously registered callbacks.
func (s *DB) AppendAfterTransactionCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterTransactionCallbacks = append(s.afterTransactionCallbacks, f)
} else {
f(s)
}
}
// PrependAfterTransactionCallback invoke callback after transaction finished (committed or rolled back).
// If no transaction is currently running, it will be invoked immediately
// This method prepends the given callback function to the register callback functions.
// This means, that the callback will be invoked before all previously registered callbacks.
func (s *DB) PrependAfterTransactionCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterTransactionCallbacks = append([]func(db *DB){f}, s.afterTransactionCallbacks...)
} else {
f(s)
}
}
// WrapInTx wraps a method in a transaction
func (s *DB) WrapInTx(f func(tx *DB) error) (err error) {
if _, ok := s.db.(*sql.Tx); ok {
@ -570,14 +630,26 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) {
err = fmt.Errorf("Transacton code and rollback failed: %s; %s", err, rollbackErr)
}
}
for _, callback := range tx.afterRollbackCallbacks {
callback(s)
}
for _, callback := range tx.afterTransactionCallbacks {
callback(s)
}
}
}()
err = f(tx)
if err == nil {
err = tx.Commit().Error
if err == nil {
for i := len(tx.afterCommitCallbacks) - 1; i >= 0; i-- {
tx.afterCommitCallbacks[i](s)
for _, callback := range tx.afterCommitCallbacks {
callback(s)
}
for _, callback := range tx.afterTransactionCallbacks {
callback(s)
}
}
}