Nested transaction support via savepoint
This commit is contained in:
parent
01b6601142
commit
93ed4a2d09
18
main.go
18
main.go
@ -25,6 +25,7 @@ type DB struct {
|
||||
logger logger
|
||||
search *search
|
||||
values sync.Map
|
||||
savepoint string
|
||||
|
||||
// global db
|
||||
parent *DB
|
||||
@ -538,7 +539,8 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
||||
c.dialect.SetDB(c.db)
|
||||
c.AddError(err)
|
||||
} else {
|
||||
c.AddError(ErrCantStartTransaction)
|
||||
c.savepoint = randName()
|
||||
c.AddError(c.Exec(sqlSavepoint(c.Dialect().GetName(), c.savepoint)).Error)
|
||||
}
|
||||
return c
|
||||
}
|
||||
@ -547,7 +549,11 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
||||
func (s *DB) Commit() *DB {
|
||||
var emptySQLTx *sql.Tx
|
||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||
if savepoint := s.savepoint; savepoint == "" {
|
||||
s.AddError(db.Commit())
|
||||
} else {
|
||||
s.savepoint = ""
|
||||
}
|
||||
} else {
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
@ -558,9 +564,14 @@ func (s *DB) Commit() *DB {
|
||||
func (s *DB) Rollback() *DB {
|
||||
var emptySQLTx *sql.Tx
|
||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||
if savepoint := s.savepoint; savepoint == "" {
|
||||
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
s.AddError(err)
|
||||
}
|
||||
} else {
|
||||
s.savepoint = ""
|
||||
s.AddError(s.Exec(sqlRollback(s.Dialect().GetName(), savepoint)).Error)
|
||||
}
|
||||
} else {
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
@ -572,12 +583,16 @@ func (s *DB) Rollback() *DB {
|
||||
func (s *DB) RollbackUnlessCommitted() *DB {
|
||||
var emptySQLTx *sql.Tx
|
||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||
if savepoint := s.savepoint; savepoint == "" {
|
||||
err := db.Rollback()
|
||||
// Ignore the error indicating that the transaction has already
|
||||
// been committed.
|
||||
if err != sql.ErrTxDone {
|
||||
s.AddError(err)
|
||||
}
|
||||
} else {
|
||||
s.Exec(sqlRollback(s.Dialect().GetName(), savepoint))
|
||||
}
|
||||
} else {
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
@ -815,6 +830,7 @@ func (s *DB) clone() *DB {
|
||||
parent: s.parent,
|
||||
logger: s.logger,
|
||||
logMode: s.logMode,
|
||||
savepoint: s.savepoint,
|
||||
Value: s.Value,
|
||||
Error: s.Error,
|
||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||
|
104
main_test.go
104
main_test.go
@ -519,6 +519,110 @@ func TestTransactionReadonly(t *testing.T) {
|
||||
tx.Rollback()
|
||||
}
|
||||
|
||||
func TestTransactionNestedCallback(t *testing.T) {
|
||||
name := "TestTransactionNestedCallback"
|
||||
type TransactionNestedCallback struct {
|
||||
Id int64
|
||||
Age int64
|
||||
Name string `sql:"size:255"`
|
||||
}
|
||||
DB.AutoMigrate(&TransactionNestedCallback{})
|
||||
defer func() {
|
||||
DB.Error = nil
|
||||
DB.DropTable(&TransactionNestedCallback{})
|
||||
}()
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatal(tx.Error)
|
||||
}
|
||||
u := &TransactionNestedCallback{Name: name, Age: 1}
|
||||
if err := tx.Save(&u).Error; err != nil {
|
||||
t.Fatal("No error should raise")
|
||||
}
|
||||
id := u.Id
|
||||
if id == 0 {
|
||||
t.Fatal()
|
||||
}
|
||||
tx = tx.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatal(tx.Error)
|
||||
}
|
||||
u.Age = 2
|
||||
if err := tx.Save(u).Error; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
|
||||
t.Fatal("Should find saved record")
|
||||
}
|
||||
tx = tx.Rollback()
|
||||
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
|
||||
t.Fatal("Should not find rollbacked record")
|
||||
}
|
||||
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err != nil {
|
||||
t.Fatal("Should find saved record")
|
||||
}
|
||||
tx = tx.Rollback()
|
||||
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err == nil {
|
||||
t.Fatal("Should not find rollbacked record")
|
||||
}
|
||||
// Test rollback outside transaction error
|
||||
if err := DB.Rollback().Error; err == nil {
|
||||
t.Fatal("Rollback outside transaction should report error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionNestedCommit(t *testing.T) {
|
||||
type TransactionNestedCommit struct {
|
||||
Id int64
|
||||
Age int64
|
||||
Name string `sql:"size:255"`
|
||||
}
|
||||
DB.AutoMigrate(&TransactionNestedCommit{})
|
||||
name := "TestTransactionNestedCommit"
|
||||
defer func() {
|
||||
DB.Error = nil
|
||||
DB.DropTable(&TransactionNestedCommit{})
|
||||
}()
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatal(tx.Error)
|
||||
}
|
||||
u := TransactionNestedCommit{Name: name, Age: 1}
|
||||
if err := tx.Save(&u).Error; err != nil {
|
||||
t.Fatal("No error should raise")
|
||||
}
|
||||
id := u.Id
|
||||
if id == 0 {
|
||||
t.Fatal()
|
||||
}
|
||||
tx = tx.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatal(tx.Error)
|
||||
}
|
||||
u.Age = 2
|
||||
if err := tx.Save(u).Error; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
|
||||
t.Fatal("Should find saved record")
|
||||
}
|
||||
tx = tx.Commit()
|
||||
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
|
||||
t.Fatal("Should find the commited record")
|
||||
}
|
||||
|
||||
tx = tx.Commit()
|
||||
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
|
||||
t.Fatal("Should find the commited record")
|
||||
}
|
||||
// Test commit outside transaction error
|
||||
if err := DB.Commit().Error; err == nil {
|
||||
t.Fatal("Commit outside transaction should report error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRow(t *testing.T) {
|
||||
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||
|
24
utils.go
24
utils.go
@ -2,7 +2,9 @@ package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
@ -34,6 +36,7 @@ func init() {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
type safeMap struct {
|
||||
@ -224,3 +227,24 @@ func addExtraSpaceIfExist(str string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func randName() string {
|
||||
data := make([]byte, 7)
|
||||
rand.Read(data)
|
||||
|
||||
return "n" + hex.EncodeToString(data)
|
||||
}
|
||||
|
||||
func sqlSavepoint(dialect string, savepoint string) string {
|
||||
if dialect == "mssql" {
|
||||
return fmt.Sprintf("SAVE TRAN %s;", savepoint)
|
||||
}
|
||||
return fmt.Sprintf("SAVEPOINT %s;", savepoint)
|
||||
}
|
||||
|
||||
func sqlRollback(dialect string, savepoint string) string {
|
||||
if dialect == "mssql" {
|
||||
return fmt.Sprintf("ROLLBACK TRAN %s;", savepoint)
|
||||
}
|
||||
return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s;", savepoint)
|
||||
}
|
||||
|
56
utils_test.go
Normal file
56
utils_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package gorm
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRandName(t *testing.T) {
|
||||
names := map[string]bool{}
|
||||
max := 1000
|
||||
for i := 0; i < max; i++ {
|
||||
name := randName()
|
||||
if l := len(name); l >= 16 {
|
||||
t.Fatal(l, name)
|
||||
}
|
||||
names[name] = true
|
||||
}
|
||||
if len(names) != max {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlSavepoint(t *testing.T) {
|
||||
tcs := []struct {
|
||||
dialect string
|
||||
sql string
|
||||
}{
|
||||
{"mssql", "SAVE TRAN nced368066575bc;"},
|
||||
{"sqlite3", "SAVEPOINT nced368066575bc;"},
|
||||
{"mysql", "SAVEPOINT nced368066575bc;"},
|
||||
{"postgres", "SAVEPOINT nced368066575bc;"},
|
||||
}
|
||||
savepoint := "nced368066575bc"
|
||||
for _, tc := range tcs {
|
||||
sql := sqlSavepoint(tc.dialect, savepoint)
|
||||
if sql != tc.sql {
|
||||
t.Fatal(sql, tc.sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlRollback(t *testing.T) {
|
||||
tcs := []struct {
|
||||
dialect string
|
||||
sql string
|
||||
}{
|
||||
{"mssql", "ROLLBACK TRAN nced368066575bc;"},
|
||||
{"sqlite3", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
|
||||
{"mysql", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
|
||||
{"postgres", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
|
||||
}
|
||||
savepoint := "nced368066575bc"
|
||||
for _, tc := range tcs {
|
||||
sql := sqlRollback(tc.dialect, savepoint)
|
||||
if sql != tc.sql {
|
||||
t.Fatal(sql, tc.sql)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user