add RollbackUnlessComitted, to make it easy to defer a Rollback.

This commit is contained in:
Ruben de Vries 2020-12-21 14:08:06 +01:00
parent 77bf4aecc6
commit ed83e82acf
No known key found for this signature in database
GPG Key ID: AA1D39B3B776AA3C
2 changed files with 37 additions and 0 deletions

View File

@ -584,6 +584,21 @@ func (db *DB) Rollback() *DB {
return db
}
// RollbackUnlessComitted rollback a transaction unless it has already been comitted
func (db *DB) RollbackUnlessComitted() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
err := committer.Rollback()
if err != nil && err != sql.ErrTxDone {
db.AddError(err)
}
}
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
db.AddError(savePointer.SavePoint(db, name))

View File

@ -366,3 +366,25 @@ func TestTransactionOnClosedConn(t *testing.T) {
t.Errorf("should returns error when commit with closed conn, got error %v", err)
}
}
func TestTransactionRollbackUnlessComitted(t *testing.T) {
{
tx := DB.Begin()
tx.Commit()
tx.Rollback()
if tx.Error == nil {
t.Fatalf("Expected error")
}
}
{
tx := DB.Begin()
tx.Commit()
tx.RollbackUnlessComitted()
if tx.Error != nil {
t.Fatalf("Did not expect error, got: %v", tx.Error)
}
}
}