add RollbackUnlessCommitted

This commit is contained in:
labulaka521 2021-04-08 00:00:47 +08:00
parent 8cfa9d98f0
commit c98a6e1847
2 changed files with 71 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"unsafe"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
@ -647,3 +648,25 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx.callbacks.Raw().Execute(tx)
return
}
// RollbackUnlessCommitted if transaction not commited, will rollback this transaction
func (db *DB) RollbackUnlessCommitted() *DB {
txCommitter, ok := db.Statement.ConnPool.(TxCommitter)
if !ok {
db.AddError(ErrInvalidTransaction)
return db
}
tx, ok := txCommitter.(*sql.Tx)
if !ok {
db.AddError(ErrInvalidTransaction)
return db
}
type Tx struct {
done int32
}
x := (*Tx)(unsafe.Pointer(tx))
if x.done == 0 {
db.Rollback()
}
return db
}

View File

@ -0,0 +1,48 @@
package tests_test
import (
"errors"
"testing"
"gorm.io/gorm"
)
func TestRollbackUnlessCommitted(t *testing.T) {
type Product struct {
gorm.Model
Code string
Price uint
}
DB.Migrator().DropTable(&Product{})
if err := DB.Migrator().AutoMigrate(&Product{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
err := DB.RollbackUnlessCommitted().Error
if !errors.Is(err, gorm.ErrInvalidTransaction) {
t.Fatalf("want err %v, but get err %v", gorm.ErrInvalidTransaction, err)
}
tx := DB.Begin()
tx.Create(&Product{Code: "D42", Price: 100})
err = tx.RollbackUnlessCommitted().Error
if err != nil {
t.Fatalf("RollbackUnlessCommitted failed, got err %v", err)
}
var count int64
DB.Model(&Product{}).Where("price = ?", 100).Count(&count)
if count != 0 {
t.Fatalf("count should be 0, but get %d", count)
}
tx1 := DB.Begin()
tx1.Create(&Product{Code: "D42", Price: 100})
tx1.Commit()
err = tx.RollbackUnlessCommitted().Error
if err != nil {
t.Fatalf("RollbackUnlessCommitted failed, got err %v", err)
}
DB.Model(&Product{}).Where("price = ?", 100).Count(&count)
if count != 1 {
t.Fatalf("count should be 1, but get %d", count)
}
}