diff --git a/main.go b/main.go index 1fe33b00..2d7b4f95 100644 --- a/main.go +++ b/main.go @@ -565,22 +565,26 @@ func (s *DB) PrependAfterCommitCallback(f func(db *DB)) { } // AppendAfterRollbackCallback invoke callback after transaction rolled back -// If no transaction is currently running, it won't be called +// If no transaction is currently running, it will be invoked immediately // This method appends the given callback function to the register callback functions. // This means, that the callback will be invoked after all previously registered callbacks. func (s *DB) AppendAfterRollbackCallback(f func(db *DB)) { if db, ok := s.db.(sqlTx); ok && db != nil { s.afterRollbackCallbacks = append(s.afterRollbackCallbacks, f) + } else { + f(s) } } // PrependAfterRollbackCallback invoke callback after transaction rolled back. -// If no transaction is currently running, it won't be called +// If no transaction is currently running, it will be invoked immediately // This method prepends the given callback function to the register callback functions. // This means, that the callback will be invoked before all previously registered callbacks. func (s *DB) PrependAfterRollbackCallback(f func(db *DB)) { if db, ok := s.db.(sqlTx); ok && db != nil { s.afterRollbackCallbacks = append([]func(db *DB){f}, s.afterRollbackCallbacks...) + } else { + f(s) } } @@ -622,6 +626,20 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { panicked := true defer func() { if panicked || err != nil { + // Note that the callbacks are called right before the actual rollback happens + // because entities might have been updated/inserted/deleted in the transaction + // and the callbacks might need to access them. If the callbacks were called + // after the rollback, the entity changes would be gone. + // The downside of this implementation is that a possible rollback error will be hidden + // for the actual callbacks. + for _, callback := range tx.afterRollbackCallbacks { + callback(tx) + } + + for _, callback := range tx.afterTransactionCallbacks { + callback(tx) + } + rollbackErr := tx.Rollback().Error if rollbackErr != nil { if err == nil { @@ -631,13 +649,6 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { } } - for _, callback := range tx.afterRollbackCallbacks { - callback(s) - } - - for _, callback := range tx.afterTransactionCallbacks { - callback(s) - } } }() err = f(tx)