diff --git a/finisher_api.go b/finisher_api.go index 03bcd20f..e94ca4db 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -584,6 +584,21 @@ func (db *DB) Rollback() *DB { return db } +// RollbackUnlessComitted rollback a transaction unless it has already been comitted +func (db *DB) RollbackUnlessComitted() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if !reflect.ValueOf(committer).IsNil() { + err := committer.Rollback() + if err != nil && err != sql.ErrTxDone { + db.AddError(err) + } + } + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { db.AddError(savePointer.SavePoint(db, name)) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c17fea3b..b625eb8b 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -366,3 +366,25 @@ func TestTransactionOnClosedConn(t *testing.T) { t.Errorf("should returns error when commit with closed conn, got error %v", err) } } + +func TestTransactionRollbackUnlessComitted(t *testing.T) { + { + tx := DB.Begin() + tx.Commit() + + tx.Rollback() + if tx.Error == nil { + t.Fatalf("Expected error") + } + } + + { + tx := DB.Begin() + tx.Commit() + + tx.RollbackUnlessComitted() + if tx.Error != nil { + t.Fatalf("Did not expect error, got: %v", tx.Error) + } + } +}