Nested transaction support via savepoint

This commit is contained in:
JUN JIE NAN 2019-06-21 16:44:37 +08:00
parent 01b6601142
commit 93ed4a2d09
4 changed files with 209 additions and 9 deletions

18
main.go
View File

@ -25,6 +25,7 @@ type DB struct {
logger logger
search *search
values sync.Map
savepoint string
// global db
parent *DB
@ -538,7 +539,8 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c.dialect.SetDB(c.db)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
c.savepoint = randName()
c.AddError(c.Exec(sqlSavepoint(c.Dialect().GetName(), c.savepoint)).Error)
}
return c
}
@ -547,7 +549,11 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
func (s *DB) Commit() *DB {
var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
if savepoint := s.savepoint; savepoint == "" {
s.AddError(db.Commit())
} else {
s.savepoint = ""
}
} else {
s.AddError(ErrInvalidTransaction)
}
@ -558,9 +564,14 @@ func (s *DB) Commit() *DB {
func (s *DB) Rollback() *DB {
var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
if savepoint := s.savepoint; savepoint == "" {
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
s.AddError(err)
}
} else {
s.savepoint = ""
s.AddError(s.Exec(sqlRollback(s.Dialect().GetName(), savepoint)).Error)
}
} else {
s.AddError(ErrInvalidTransaction)
}
@ -572,12 +583,16 @@ func (s *DB) Rollback() *DB {
func (s *DB) RollbackUnlessCommitted() *DB {
var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
if savepoint := s.savepoint; savepoint == "" {
err := db.Rollback()
// Ignore the error indicating that the transaction has already
// been committed.
if err != sql.ErrTxDone {
s.AddError(err)
}
} else {
s.Exec(sqlRollback(s.Dialect().GetName(), savepoint))
}
} else {
s.AddError(ErrInvalidTransaction)
}
@ -815,6 +830,7 @@ func (s *DB) clone() *DB {
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
savepoint: s.savepoint,
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,

View File

@ -519,6 +519,110 @@ func TestTransactionReadonly(t *testing.T) {
tx.Rollback()
}
func TestTransactionNestedCallback(t *testing.T) {
name := "TestTransactionNestedCallback"
type TransactionNestedCallback struct {
Id int64
Age int64
Name string `sql:"size:255"`
}
DB.AutoMigrate(&TransactionNestedCallback{})
defer func() {
DB.Error = nil
DB.DropTable(&TransactionNestedCallback{})
}()
tx := DB.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u := &TransactionNestedCallback{Name: name, Age: 1}
if err := tx.Save(&u).Error; err != nil {
t.Fatal("No error should raise")
}
id := u.Id
if id == 0 {
t.Fatal()
}
tx = tx.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u.Age = 2
if err := tx.Save(u).Error; err != nil {
t.Fatal(err)
}
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
t.Fatal("Should find saved record")
}
tx = tx.Rollback()
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
t.Fatal("Should not find rollbacked record")
}
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err != nil {
t.Fatal("Should find saved record")
}
tx = tx.Rollback()
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err == nil {
t.Fatal("Should not find rollbacked record")
}
// Test rollback outside transaction error
if err := DB.Rollback().Error; err == nil {
t.Fatal("Rollback outside transaction should report error")
}
}
func TestTransactionNestedCommit(t *testing.T) {
type TransactionNestedCommit struct {
Id int64
Age int64
Name string `sql:"size:255"`
}
DB.AutoMigrate(&TransactionNestedCommit{})
name := "TestTransactionNestedCommit"
defer func() {
DB.Error = nil
DB.DropTable(&TransactionNestedCommit{})
}()
tx := DB.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u := TransactionNestedCommit{Name: name, Age: 1}
if err := tx.Save(&u).Error; err != nil {
t.Fatal("No error should raise")
}
id := u.Id
if id == 0 {
t.Fatal()
}
tx = tx.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u.Age = 2
if err := tx.Save(u).Error; err != nil {
t.Fatal(err)
}
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
t.Fatal("Should find saved record")
}
tx = tx.Commit()
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
t.Fatal("Should find the commited record")
}
tx = tx.Commit()
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
t.Fatal("Should find the commited record")
}
// Test commit outside transaction error
if err := DB.Commit().Error; err == nil {
t.Fatal("Commit outside transaction should report error")
}
}
func TestRow(t *testing.T) {
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}

View File

@ -2,7 +2,9 @@ package gorm
import (
"database/sql/driver"
"encoding/hex"
"fmt"
"math/rand"
"reflect"
"regexp"
"runtime"
@ -34,6 +36,7 @@ func init() {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
rand.Seed(time.Now().UnixNano())
}
type safeMap struct {
@ -224,3 +227,24 @@ func addExtraSpaceIfExist(str string) string {
}
return ""
}
func randName() string {
data := make([]byte, 7)
rand.Read(data)
return "n" + hex.EncodeToString(data)
}
func sqlSavepoint(dialect string, savepoint string) string {
if dialect == "mssql" {
return fmt.Sprintf("SAVE TRAN %s;", savepoint)
}
return fmt.Sprintf("SAVEPOINT %s;", savepoint)
}
func sqlRollback(dialect string, savepoint string) string {
if dialect == "mssql" {
return fmt.Sprintf("ROLLBACK TRAN %s;", savepoint)
}
return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s;", savepoint)
}

56
utils_test.go Normal file
View File

@ -0,0 +1,56 @@
package gorm
import "testing"
func TestRandName(t *testing.T) {
names := map[string]bool{}
max := 1000
for i := 0; i < max; i++ {
name := randName()
if l := len(name); l >= 16 {
t.Fatal(l, name)
}
names[name] = true
}
if len(names) != max {
t.Fatal()
}
}
func TestSqlSavepoint(t *testing.T) {
tcs := []struct {
dialect string
sql string
}{
{"mssql", "SAVE TRAN nced368066575bc;"},
{"sqlite3", "SAVEPOINT nced368066575bc;"},
{"mysql", "SAVEPOINT nced368066575bc;"},
{"postgres", "SAVEPOINT nced368066575bc;"},
}
savepoint := "nced368066575bc"
for _, tc := range tcs {
sql := sqlSavepoint(tc.dialect, savepoint)
if sql != tc.sql {
t.Fatal(sql, tc.sql)
}
}
}
func TestSqlRollback(t *testing.T) {
tcs := []struct {
dialect string
sql string
}{
{"mssql", "ROLLBACK TRAN nced368066575bc;"},
{"sqlite3", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
{"mysql", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
{"postgres", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
}
savepoint := "nced368066575bc"
for _, tc := range tcs {
sql := sqlRollback(tc.dialect, savepoint)
if sql != tc.sql {
t.Fatal(sql, tc.sql)
}
}
}