Nested transaction support via savepoint
This commit is contained in:
parent
01b6601142
commit
93ed4a2d09
18
main.go
18
main.go
@ -25,6 +25,7 @@ type DB struct {
|
|||||||
logger logger
|
logger logger
|
||||||
search *search
|
search *search
|
||||||
values sync.Map
|
values sync.Map
|
||||||
|
savepoint string
|
||||||
|
|
||||||
// global db
|
// global db
|
||||||
parent *DB
|
parent *DB
|
||||||
@ -538,7 +539,8 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
|||||||
c.dialect.SetDB(c.db)
|
c.dialect.SetDB(c.db)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(ErrCantStartTransaction)
|
c.savepoint = randName()
|
||||||
|
c.AddError(c.Exec(sqlSavepoint(c.Dialect().GetName(), c.savepoint)).Error)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@ -547,7 +549,11 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
|||||||
func (s *DB) Commit() *DB {
|
func (s *DB) Commit() *DB {
|
||||||
var emptySQLTx *sql.Tx
|
var emptySQLTx *sql.Tx
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
|
if savepoint := s.savepoint; savepoint == "" {
|
||||||
s.AddError(db.Commit())
|
s.AddError(db.Commit())
|
||||||
|
} else {
|
||||||
|
s.savepoint = ""
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
@ -558,9 +564,14 @@ func (s *DB) Commit() *DB {
|
|||||||
func (s *DB) Rollback() *DB {
|
func (s *DB) Rollback() *DB {
|
||||||
var emptySQLTx *sql.Tx
|
var emptySQLTx *sql.Tx
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
|
if savepoint := s.savepoint; savepoint == "" {
|
||||||
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
|
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||||
s.AddError(err)
|
s.AddError(err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
s.savepoint = ""
|
||||||
|
s.AddError(s.Exec(sqlRollback(s.Dialect().GetName(), savepoint)).Error)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
@ -572,12 +583,16 @@ func (s *DB) Rollback() *DB {
|
|||||||
func (s *DB) RollbackUnlessCommitted() *DB {
|
func (s *DB) RollbackUnlessCommitted() *DB {
|
||||||
var emptySQLTx *sql.Tx
|
var emptySQLTx *sql.Tx
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
|
if savepoint := s.savepoint; savepoint == "" {
|
||||||
err := db.Rollback()
|
err := db.Rollback()
|
||||||
// Ignore the error indicating that the transaction has already
|
// Ignore the error indicating that the transaction has already
|
||||||
// been committed.
|
// been committed.
|
||||||
if err != sql.ErrTxDone {
|
if err != sql.ErrTxDone {
|
||||||
s.AddError(err)
|
s.AddError(err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
s.Exec(sqlRollback(s.Dialect().GetName(), savepoint))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
@ -815,6 +830,7 @@ func (s *DB) clone() *DB {
|
|||||||
parent: s.parent,
|
parent: s.parent,
|
||||||
logger: s.logger,
|
logger: s.logger,
|
||||||
logMode: s.logMode,
|
logMode: s.logMode,
|
||||||
|
savepoint: s.savepoint,
|
||||||
Value: s.Value,
|
Value: s.Value,
|
||||||
Error: s.Error,
|
Error: s.Error,
|
||||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||||
|
104
main_test.go
104
main_test.go
@ -519,6 +519,110 @@ func TestTransactionReadonly(t *testing.T) {
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransactionNestedCallback(t *testing.T) {
|
||||||
|
name := "TestTransactionNestedCallback"
|
||||||
|
type TransactionNestedCallback struct {
|
||||||
|
Id int64
|
||||||
|
Age int64
|
||||||
|
Name string `sql:"size:255"`
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&TransactionNestedCallback{})
|
||||||
|
defer func() {
|
||||||
|
DB.Error = nil
|
||||||
|
DB.DropTable(&TransactionNestedCallback{})
|
||||||
|
}()
|
||||||
|
tx := DB.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatal(tx.Error)
|
||||||
|
}
|
||||||
|
u := &TransactionNestedCallback{Name: name, Age: 1}
|
||||||
|
if err := tx.Save(&u).Error; err != nil {
|
||||||
|
t.Fatal("No error should raise")
|
||||||
|
}
|
||||||
|
id := u.Id
|
||||||
|
if id == 0 {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
tx = tx.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatal(tx.Error)
|
||||||
|
}
|
||||||
|
u.Age = 2
|
||||||
|
if err := tx.Save(u).Error; err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
|
||||||
|
t.Fatal("Should find saved record")
|
||||||
|
}
|
||||||
|
tx = tx.Rollback()
|
||||||
|
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
|
||||||
|
t.Fatal("Should not find rollbacked record")
|
||||||
|
}
|
||||||
|
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err != nil {
|
||||||
|
t.Fatal("Should find saved record")
|
||||||
|
}
|
||||||
|
tx = tx.Rollback()
|
||||||
|
if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err == nil {
|
||||||
|
t.Fatal("Should not find rollbacked record")
|
||||||
|
}
|
||||||
|
// Test rollback outside transaction error
|
||||||
|
if err := DB.Rollback().Error; err == nil {
|
||||||
|
t.Fatal("Rollback outside transaction should report error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionNestedCommit(t *testing.T) {
|
||||||
|
type TransactionNestedCommit struct {
|
||||||
|
Id int64
|
||||||
|
Age int64
|
||||||
|
Name string `sql:"size:255"`
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&TransactionNestedCommit{})
|
||||||
|
name := "TestTransactionNestedCommit"
|
||||||
|
defer func() {
|
||||||
|
DB.Error = nil
|
||||||
|
DB.DropTable(&TransactionNestedCommit{})
|
||||||
|
}()
|
||||||
|
tx := DB.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatal(tx.Error)
|
||||||
|
}
|
||||||
|
u := TransactionNestedCommit{Name: name, Age: 1}
|
||||||
|
if err := tx.Save(&u).Error; err != nil {
|
||||||
|
t.Fatal("No error should raise")
|
||||||
|
}
|
||||||
|
id := u.Id
|
||||||
|
if id == 0 {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
tx = tx.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatal(tx.Error)
|
||||||
|
}
|
||||||
|
u.Age = 2
|
||||||
|
if err := tx.Save(u).Error; err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
|
||||||
|
t.Fatal("Should find saved record")
|
||||||
|
}
|
||||||
|
tx = tx.Commit()
|
||||||
|
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil {
|
||||||
|
t.Fatal("Should find the commited record")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx = tx.Commit()
|
||||||
|
if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil {
|
||||||
|
t.Fatal("Should find the commited record")
|
||||||
|
}
|
||||||
|
// Test commit outside transaction error
|
||||||
|
if err := DB.Commit().Error; err == nil {
|
||||||
|
t.Fatal("Commit outside transaction should report error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRow(t *testing.T) {
|
func TestRow(t *testing.T) {
|
||||||
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
24
utils.go
24
utils.go
@ -2,7 +2,9 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -34,6 +36,7 @@ func init() {
|
|||||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||||
}
|
}
|
||||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
type safeMap struct {
|
type safeMap struct {
|
||||||
@ -224,3 +227,24 @@ func addExtraSpaceIfExist(str string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randName() string {
|
||||||
|
data := make([]byte, 7)
|
||||||
|
rand.Read(data)
|
||||||
|
|
||||||
|
return "n" + hex.EncodeToString(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqlSavepoint(dialect string, savepoint string) string {
|
||||||
|
if dialect == "mssql" {
|
||||||
|
return fmt.Sprintf("SAVE TRAN %s;", savepoint)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("SAVEPOINT %s;", savepoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqlRollback(dialect string, savepoint string) string {
|
||||||
|
if dialect == "mssql" {
|
||||||
|
return fmt.Sprintf("ROLLBACK TRAN %s;", savepoint)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s;", savepoint)
|
||||||
|
}
|
||||||
|
56
utils_test.go
Normal file
56
utils_test.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestRandName(t *testing.T) {
|
||||||
|
names := map[string]bool{}
|
||||||
|
max := 1000
|
||||||
|
for i := 0; i < max; i++ {
|
||||||
|
name := randName()
|
||||||
|
if l := len(name); l >= 16 {
|
||||||
|
t.Fatal(l, name)
|
||||||
|
}
|
||||||
|
names[name] = true
|
||||||
|
}
|
||||||
|
if len(names) != max {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlSavepoint(t *testing.T) {
|
||||||
|
tcs := []struct {
|
||||||
|
dialect string
|
||||||
|
sql string
|
||||||
|
}{
|
||||||
|
{"mssql", "SAVE TRAN nced368066575bc;"},
|
||||||
|
{"sqlite3", "SAVEPOINT nced368066575bc;"},
|
||||||
|
{"mysql", "SAVEPOINT nced368066575bc;"},
|
||||||
|
{"postgres", "SAVEPOINT nced368066575bc;"},
|
||||||
|
}
|
||||||
|
savepoint := "nced368066575bc"
|
||||||
|
for _, tc := range tcs {
|
||||||
|
sql := sqlSavepoint(tc.dialect, savepoint)
|
||||||
|
if sql != tc.sql {
|
||||||
|
t.Fatal(sql, tc.sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlRollback(t *testing.T) {
|
||||||
|
tcs := []struct {
|
||||||
|
dialect string
|
||||||
|
sql string
|
||||||
|
}{
|
||||||
|
{"mssql", "ROLLBACK TRAN nced368066575bc;"},
|
||||||
|
{"sqlite3", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
|
||||||
|
{"mysql", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
|
||||||
|
{"postgres", "ROLLBACK TO SAVEPOINT nced368066575bc;"},
|
||||||
|
}
|
||||||
|
savepoint := "nced368066575bc"
|
||||||
|
for _, tc := range tcs {
|
||||||
|
sql := sqlRollback(tc.dialect, savepoint)
|
||||||
|
if sql != tc.sql {
|
||||||
|
t.Fatal(sql, tc.sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user