add RollbackUnlessComitted, to make it easy to defer a Rollback.
This commit is contained in:
parent
77bf4aecc6
commit
ed83e82acf
@ -584,6 +584,21 @@ func (db *DB) Rollback() *DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// RollbackUnlessComitted rollback a transaction unless it has already been comitted
|
||||
func (db *DB) RollbackUnlessComitted() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
if !reflect.ValueOf(committer).IsNil() {
|
||||
err := committer.Rollback()
|
||||
if err != nil && err != sql.ErrTxDone {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
|
@ -366,3 +366,25 @@ func TestTransactionOnClosedConn(t *testing.T) {
|
||||
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionRollbackUnlessComitted(t *testing.T) {
|
||||
{
|
||||
tx := DB.Begin()
|
||||
tx.Commit()
|
||||
|
||||
tx.Rollback()
|
||||
if tx.Error == nil {
|
||||
t.Fatalf("Expected error")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tx := DB.Begin()
|
||||
tx.Commit()
|
||||
|
||||
tx.RollbackUnlessComitted()
|
||||
if tx.Error != nil {
|
||||
t.Fatalf("Did not expect error, got: %v", tx.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user