Nested transaction support via savepoint

This commit is contained in:
JUN JIE NAN 2019-06-21 16:44:37 +08:00
parent 01b6601142
commit 93ed4a2d09
4 changed files with 209 additions and 9 deletions

34
main.go
View File

@ -25,6 +25,7 @@ type DB struct {
logger logger logger logger
search *search search *search
values sync.Map values sync.Map
savepoint string
// global db // global db
parent *DB parent *DB
@ -538,7 +539,8 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c.dialect.SetDB(c.db) c.dialect.SetDB(c.db)
c.AddError(err) c.AddError(err)
} else { } else {
c.AddError(ErrCantStartTransaction) c.savepoint = randName()
c.AddError(c.Exec(sqlSavepoint(c.Dialect().GetName(), c.savepoint)).Error)
} }
return c return c
} }
@ -547,7 +549,11 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
func (s *DB) Commit() *DB { func (s *DB) Commit() *DB {
var emptySQLTx *sql.Tx var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
s.AddError(db.Commit()) if savepoint := s.savepoint; savepoint == "" {
s.AddError(db.Commit())
} else {
s.savepoint = ""
}
} else { } else {
s.AddError(ErrInvalidTransaction) s.AddError(ErrInvalidTransaction)
} }
@ -558,8 +564,13 @@ func (s *DB) Commit() *DB {
func (s *DB) Rollback() *DB { func (s *DB) Rollback() *DB {
var emptySQLTx *sql.Tx var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
if err := db.Rollback(); err != nil && err != sql.ErrTxDone { if savepoint := s.savepoint; savepoint == "" {
s.AddError(err) if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
s.AddError(err)
}
} else {
s.savepoint = ""
s.AddError(s.Exec(sqlRollback(s.Dialect().GetName(), savepoint)).Error)
} }
} else { } else {
s.AddError(ErrInvalidTransaction) s.AddError(ErrInvalidTransaction)
@ -572,11 +583,15 @@ func (s *DB) Rollback() *DB {
func (s *DB) RollbackUnlessCommitted() *DB { func (s *DB) RollbackUnlessCommitted() *DB {
var emptySQLTx *sql.Tx var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
err := db.Rollback() if savepoint := s.savepoint; savepoint == "" {
// Ignore the error indicating that the transaction has already err := db.Rollback()
// been committed. // Ignore the error indicating that the transaction has already
if err != sql.ErrTxDone { // been committed.
s.AddError(err) if err != sql.ErrTxDone {
s.AddError(err)
}
} else {
s.Exec(sqlRollback(s.Dialect().GetName(), savepoint))
} }
} else { } else {
s.AddError(ErrInvalidTransaction) s.AddError(ErrInvalidTransaction)
@ -815,6 +830,7 @@ func (s *DB) clone() *DB {
parent: s.parent, parent: s.parent,
logger: s.logger, logger: s.logger,
logMode: s.logMode, logMode: s.logMode,
savepoint: s.savepoint,
Value: s.Value, Value: s.Value,
Error: s.Error, Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate, blockGlobalUpdate: s.blockGlobalUpdate,

View File

@ -519,6 +519,110 @@ func TestTransactionReadonly(t *testing.T) {
tx.Rollback() tx.Rollback()
} }
func TestTransactionNestedCallback(t *testing.T) {
name := "TestTransactionNestedCallback"
type TransactionNestedCallback struct {
Id int64
Age int64
Name string `sql:"size:255"`
}
DB.AutoMigrate(&TransactionNestedCallback{})
defer func() {
DB.Error = nil
DB.DropTable(&TransactionNestedCallback{})
}()
tx := DB.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u := &TransactionNestedCallback{Name: name, Age: 1}
if err := tx.Save(&u).Error; err != nil {
t.Fatal("No error should raise")
}
id := u.Id
if id == 0 {
t.Fatal()
}
tx = tx.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u.Age = 2
if err := tx.Save(u).Error; err != nil {
t.Fatal(err)
}
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
t.Fatal("Should find saved record")
}
tx = tx.Rollback()
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
t.Fatal("Should not find rollbacked record")
}
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err != nil {
t.Fatal("Should find saved record")
}
tx = tx.Rollback()
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err == nil {
t.Fatal("Should not find rollbacked record")
}
// Test rollback outside transaction error
if err := DB.Rollback().Error; err == nil {
t.Fatal("Rollback outside transaction should report error")
}
}
func TestTransactionNestedCommit(t *testing.T) {
type TransactionNestedCommit struct {
Id int64
Age int64
Name string `sql:"size:255"`
}
DB.AutoMigrate(&TransactionNestedCommit{})
name := "TestTransactionNestedCommit"
defer func() {
DB.Error = nil
DB.DropTable(&TransactionNestedCommit{})
}()
tx := DB.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u := TransactionNestedCommit{Name: name, Age: 1}
if err := tx.Save(&u).Error; err != nil {
t.Fatal("No error should raise")
}
id := u.Id
if id == 0 {
t.Fatal()
}
tx = tx.Begin()
if tx.Error != nil {
t.Fatal(tx.Error)
}
u.Age = 2
if err := tx.Save(u).Error; err != nil {
t.Fatal(err)
}
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
t.Fatal("Should find saved record")
}
tx = tx.Commit()
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
t.Fatal("Should find the commited record")
}
tx = tx.Commit()
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
t.Fatal("Should find the commited record")
}
// Test commit outside transaction error
if err := DB.Commit().Error; err == nil {
t.Fatal("Commit outside transaction should report error")
}
}
func TestRow(t *testing.T) { func TestRow(t *testing.T) {
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}

View File

@ -2,7 +2,9 @@ package gorm
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/hex"
"fmt" "fmt"
"math/rand"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
@ -34,6 +36,7 @@ func init() {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
} }
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
rand.Seed(time.Now().UnixNano())
} }
type safeMap struct { type safeMap struct {
@ -224,3 +227,24 @@ func addExtraSpaceIfExist(str string) string {
} }
return "" return ""
} }
func randName() string {
data := make([]byte, 7)
rand.Read(data)
return "n" + hex.EncodeToString(data)
}
func sqlSavepoint(dialect string, savepoint string) string {
if dialect == "mssql" {
return fmt.Sprintf("SAVE TRAN %s;", savepoint)
}
return fmt.Sprintf("SAVEPOINT %s;", savepoint)
}
func sqlRollback(dialect string, savepoint string) string {
if dialect == "mssql" {
return fmt.Sprintf("ROLLBACK TRAN %s;", savepoint)
}
return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s;", savepoint)
}

56
utils_test.go Normal file
View File

@ -0,0 +1,56 @@
package gorm
import "testing"
func TestRandName(t *testing.T) {
names := map[string]bool{}
max := 1000
for i := 0; i < max; i++ {
name := randName()
if l := len(name); l >= 16 {
t.Fatal(l, name)
}
names[name] = true
}
if len(names) != max {
t.Fatal()
}
}
func TestSqlSavepoint(t *testing.T) {
tcs := []struct {
dialect string
sql string
}{
{"mssql", "SAVE TRAN nced368066575bc;"},
{"sqlite3", "SAVEPOINT nced368066575bc;"},
{"mysql", "SAVEPOINT nced368066575bc;"},
{"postgres", "SAVEPOINT nced368066575bc;"},
}
savepoint := "nced368066575bc"
for _, tc := range tcs {
sql := sqlSavepoint(tc.dialect, savepoint)
if sql != tc.sql {
t.Fatal(sql, tc.sql)
}
}
}
func TestSqlRollback(t *testing.T) {
tcs := []struct {
dialect string
sql string
}{
{"mssql", "ROLLBACK TRAN nced368066575bc;"},
{"sqlite3", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
{"mysql", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
{"postgres", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
}
savepoint := "nced368066575bc"
for _, tc := range tcs {
sql := sqlRollback(tc.dialect, savepoint)
if sql != tc.sql {
t.Fatal(sql, tc.sql)
}
}
}