add RollbackUnlessCommitted
This commit is contained in:
parent
8cfa9d98f0
commit
c98a6e1847
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
@ -647,3 +648,25 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx.callbacks.Raw().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// RollbackUnlessCommitted if transaction not commited, will rollback this transaction
|
||||
func (db *DB) RollbackUnlessCommitted() *DB {
|
||||
txCommitter, ok := db.Statement.ConnPool.(TxCommitter)
|
||||
if !ok {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
return db
|
||||
}
|
||||
tx, ok := txCommitter.(*sql.Tx)
|
||||
if !ok {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
return db
|
||||
}
|
||||
type Tx struct {
|
||||
done int32
|
||||
}
|
||||
x := (*Tx)(unsafe.Pointer(tx))
|
||||
if x.done == 0 {
|
||||
db.Rollback()
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
48
tests/rollback_unless_committed_test.go
Normal file
48
tests/rollback_unless_committed_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestRollbackUnlessCommitted(t *testing.T) {
|
||||
type Product struct {
|
||||
gorm.Model
|
||||
Code string
|
||||
Price uint
|
||||
}
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
if err := DB.Migrator().AutoMigrate(&Product{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
err := DB.RollbackUnlessCommitted().Error
|
||||
if !errors.Is(err, gorm.ErrInvalidTransaction) {
|
||||
t.Fatalf("want err %v, but get err %v", gorm.ErrInvalidTransaction, err)
|
||||
}
|
||||
|
||||
tx := DB.Begin()
|
||||
tx.Create(&Product{Code: "D42", Price: 100})
|
||||
err = tx.RollbackUnlessCommitted().Error
|
||||
if err != nil {
|
||||
t.Fatalf("RollbackUnlessCommitted failed, got err %v", err)
|
||||
}
|
||||
var count int64
|
||||
DB.Model(&Product{}).Where("price = ?", 100).Count(&count)
|
||||
if count != 0 {
|
||||
t.Fatalf("count should be 0, but get %d", count)
|
||||
}
|
||||
|
||||
tx1 := DB.Begin()
|
||||
tx1.Create(&Product{Code: "D42", Price: 100})
|
||||
tx1.Commit()
|
||||
err = tx.RollbackUnlessCommitted().Error
|
||||
if err != nil {
|
||||
t.Fatalf("RollbackUnlessCommitted failed, got err %v", err)
|
||||
}
|
||||
DB.Model(&Product{}).Where("price = ?", 100).Count(&count)
|
||||
if count != 1 {
|
||||
t.Fatalf("count should be 1, but get %d", count)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user