diff --git a/finisher_api.go b/finisher_api.go index 4b3829a2..c04adbf7 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -214,8 +214,8 @@ func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { // Commit commit a transaction func (db DB) Commit() DB { - if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { - db.AddError(comminter.Commit()) + if commiter, ok := db.Statement.ConnPool.(TxCommiter); ok && commiter != nil { + db.AddError(commiter.Commit()) } else { db.AddError(ErrInvalidTransaction) } @@ -224,8 +224,22 @@ func (db DB) Commit() DB { // Rollback rollback a transaction func (db DB) Rollback() DB { - if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { - db.AddError(comminter.Rollback()) + if commiter, ok := db.Statement.ConnPool.(TxCommiter); ok && commiter != nil { + db.AddError(commiter.Rollback()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +// RollbackUnlessCommitted rollbacks a transaction if it is not yet commited +func (db DB) RollbackUnlessCommitted() DB { + if commiter, ok := db.Statement.ConnPool.(TxCommiter); ok && commiter != nil { + err := commiter.Rollback() + if err == nil || err == sql.ErrTxDone { + return db + } + db.AddError(err) } else { db.AddError(ErrInvalidTransaction) } diff --git a/interfaces.go b/interfaces.go index 9dd00c15..d32e26a9 100644 --- a/interfaces.go +++ b/interfaces.go @@ -33,6 +33,7 @@ type TxBeginner interface { type TxCommiter interface { Commit() error Rollback() error + RollbackUnlessCommitted() error } type BeforeCreateInterface interface {