Improved TX callback handling for rollback callbacks

This commit is contained in:
Lukas Jorg 2023-08-02 16:04:48 +02:00 committed by Gerhard Gruber
parent d89c6ace73
commit 3ca61d4d7b

29
main.go
View File

@ -565,22 +565,26 @@ func (s *DB) PrependAfterCommitCallback(f func(db *DB)) {
} }
// AppendAfterRollbackCallback invoke callback after transaction rolled back // AppendAfterRollbackCallback invoke callback after transaction rolled back
// If no transaction is currently running, it won't be called // If no transaction is currently running, it will be invoked immediately
// This method appends the given callback function to the register callback functions. // This method appends the given callback function to the register callback functions.
// This means, that the callback will be invoked after all previously registered callbacks. // This means, that the callback will be invoked after all previously registered callbacks.
func (s *DB) AppendAfterRollbackCallback(f func(db *DB)) { func (s *DB) AppendAfterRollbackCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil { if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterRollbackCallbacks = append(s.afterRollbackCallbacks, f) s.afterRollbackCallbacks = append(s.afterRollbackCallbacks, f)
} else {
f(s)
} }
} }
// PrependAfterRollbackCallback invoke callback after transaction rolled back. // PrependAfterRollbackCallback invoke callback after transaction rolled back.
// If no transaction is currently running, it won't be called // If no transaction is currently running, it will be invoked immediately
// This method prepends the given callback function to the register callback functions. // This method prepends the given callback function to the register callback functions.
// This means, that the callback will be invoked before all previously registered callbacks. // This means, that the callback will be invoked before all previously registered callbacks.
func (s *DB) PrependAfterRollbackCallback(f func(db *DB)) { func (s *DB) PrependAfterRollbackCallback(f func(db *DB)) {
if db, ok := s.db.(sqlTx); ok && db != nil { if db, ok := s.db.(sqlTx); ok && db != nil {
s.afterRollbackCallbacks = append([]func(db *DB){f}, s.afterRollbackCallbacks...) s.afterRollbackCallbacks = append([]func(db *DB){f}, s.afterRollbackCallbacks...)
} else {
f(s)
} }
} }
@ -622,6 +626,20 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) {
panicked := true panicked := true
defer func() { defer func() {
if panicked || err != nil { if panicked || err != nil {
// Note that the callbacks are called right before the actual rollback happens
// because entities might have been updated/inserted/deleted in the transaction
// and the callbacks might need to access them. If the callbacks were called
// after the rollback, the entity changes would be gone.
// The downside of this implementation is that a possible rollback error will be hidden
// for the actual callbacks.
for _, callback := range tx.afterRollbackCallbacks {
callback(tx)
}
for _, callback := range tx.afterTransactionCallbacks {
callback(tx)
}
rollbackErr := tx.Rollback().Error rollbackErr := tx.Rollback().Error
if rollbackErr != nil { if rollbackErr != nil {
if err == nil { if err == nil {
@ -631,13 +649,6 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) {
} }
} }
for _, callback := range tx.afterRollbackCallbacks {
callback(s)
}
for _, callback := range tx.afterTransactionCallbacks {
callback(s)
}
} }
}() }()
err = f(tx) err = f(tx)