diff --git a/finisher_api.go b/finisher_api.go index b5cbfaa6..f6aa42e0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "unsafe" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -647,3 +648,25 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx.callbacks.Raw().Execute(tx) return } + +// RollbackUnlessCommitted if transaction not commited, will rollback this transaction +func (db *DB) RollbackUnlessCommitted() *DB { + txCommitter, ok := db.Statement.ConnPool.(TxCommitter) + if !ok { + db.AddError(ErrInvalidTransaction) + return db + } + tx, ok := txCommitter.(*sql.Tx) + if !ok { + db.AddError(ErrInvalidTransaction) + return db + } + type Tx struct { + done int32 + } + x := (*Tx)(unsafe.Pointer(tx)) + if x.done == 0 { + db.Rollback() + } + return db +} diff --git a/tests/rollback_unless_committed_test.go b/tests/rollback_unless_committed_test.go new file mode 100644 index 00000000..88aa0200 --- /dev/null +++ b/tests/rollback_unless_committed_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" +) + +func TestRollbackUnlessCommitted(t *testing.T) { + type Product struct { + gorm.Model + Code string + Price uint + } + DB.Migrator().DropTable(&Product{}) + if err := DB.Migrator().AutoMigrate(&Product{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + err := DB.RollbackUnlessCommitted().Error + if !errors.Is(err, gorm.ErrInvalidTransaction) { + t.Fatalf("want err %v, but get err %v", gorm.ErrInvalidTransaction, err) + } + + tx := DB.Begin() + tx.Create(&Product{Code: "D42", Price: 100}) + err = tx.RollbackUnlessCommitted().Error + if err != nil { + t.Fatalf("RollbackUnlessCommitted failed, got err %v", err) + } + var count int64 + DB.Model(&Product{}).Where("price = ?", 100).Count(&count) + if count != 0 { + t.Fatalf("count should be 0, but get %d", count) + } + + tx1 := DB.Begin() + tx1.Create(&Product{Code: "D42", Price: 100}) + tx1.Commit() + err = tx.RollbackUnlessCommitted().Error + if err != nil { + t.Fatalf("RollbackUnlessCommitted failed, got err %v", err) + } + DB.Model(&Product{}).Where("price = ?", 100).Count(&count) + if count != 1 { + t.Fatalf("count should be 1, but get %d", count) + } +}