From ed83e82acff39bfcc8c6dc9923d2ac3f29007693 Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Mon, 21 Dec 2020 14:08:06 +0100 Subject: [PATCH] add RollbackUnlessComitted, to make it easy to defer a Rollback. --- finisher_api.go | 15 +++++++++++++++ tests/transaction_test.go | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index 03bcd20f..e94ca4db 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -584,6 +584,21 @@ func (db *DB) Rollback() *DB { return db } +// RollbackUnlessComitted rollback a transaction unless it has already been comitted +func (db *DB) RollbackUnlessComitted() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if !reflect.ValueOf(committer).IsNil() { + err := committer.Rollback() + if err != nil && err != sql.ErrTxDone { + db.AddError(err) + } + } + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { db.AddError(savePointer.SavePoint(db, name)) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c17fea3b..b625eb8b 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -366,3 +366,25 @@ func TestTransactionOnClosedConn(t *testing.T) { t.Errorf("should returns error when commit with closed conn, got error %v", err) } } + +func TestTransactionRollbackUnlessComitted(t *testing.T) { + { + tx := DB.Begin() + tx.Commit() + + tx.Rollback() + if tx.Error == nil { + t.Fatalf("Expected error") + } + } + + { + tx := DB.Begin() + tx.Commit() + + tx.RollbackUnlessComitted() + if tx.Error != nil { + t.Fatalf("Did not expect error, got: %v", tx.Error) + } + } +}