Add RollbackUnlessCommitted

This commit is contained in:
Denis Dorozhkin 2020-03-09 19:14:56 +02:00
parent 2a0c3e39f2
commit 8d862ce981
2 changed files with 19 additions and 4 deletions

View File

@ -214,8 +214,8 @@ func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) {
// Commit commit a transaction // Commit commit a transaction
func (db DB) Commit() DB { func (db DB) Commit() DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if commiter, ok := db.Statement.ConnPool.(TxCommiter); ok && commiter != nil {
db.AddError(comminter.Commit()) db.AddError(commiter.Commit())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }
@ -224,8 +224,22 @@ func (db DB) Commit() DB {
// Rollback rollback a transaction // Rollback rollback a transaction
func (db DB) Rollback() DB { func (db DB) Rollback() DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if commiter, ok := db.Statement.ConnPool.(TxCommiter); ok && commiter != nil {
db.AddError(comminter.Rollback()) db.AddError(commiter.Rollback())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
// RollbackUnlessCommitted rollbacks a transaction if it is not yet commited
func (db DB) RollbackUnlessCommitted() DB {
if commiter, ok := db.Statement.ConnPool.(TxCommiter); ok && commiter != nil {
err := commiter.Rollback()
if err == nil || err == sql.ErrTxDone {
return db
}
db.AddError(err)
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }

View File

@ -33,6 +33,7 @@ type TxBeginner interface {
type TxCommiter interface { type TxCommiter interface {
Commit() error Commit() error
Rollback() error Rollback() error
RollbackUnlessCommitted() error
} }
type BeforeCreateInterface interface { type BeforeCreateInterface interface {