diff --git a/main.go b/main.go index e24638a6..c37190a1 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ type DB struct { logger logger search *search values sync.Map + savepoint string // global db parent *DB @@ -538,7 +539,8 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c.dialect.SetDB(c.db) c.AddError(err) } else { - c.AddError(ErrCantStartTransaction) + c.savepoint = randName() + c.AddError(c.Exec(sqlSavepoint(c.Dialect().GetName(), c.savepoint)).Error) } return c } @@ -547,7 +549,11 @@ func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { func (s *DB) Commit() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Commit()) + if savepoint := s.savepoint; savepoint == "" { + s.AddError(db.Commit()) + } else { + s.savepoint = "" + } } else { s.AddError(ErrInvalidTransaction) } @@ -558,8 +564,13 @@ func (s *DB) Commit() *DB { func (s *DB) Rollback() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - if err := db.Rollback(); err != nil && err != sql.ErrTxDone { - s.AddError(err) + if savepoint := s.savepoint; savepoint == "" { + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.savepoint = "" + s.AddError(s.Exec(sqlRollback(s.Dialect().GetName(), savepoint)).Error) } } else { s.AddError(ErrInvalidTransaction) @@ -572,11 +583,15 @@ func (s *DB) Rollback() *DB { func (s *DB) RollbackUnlessCommitted() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - err := db.Rollback() - // Ignore the error indicating that the transaction has already - // been committed. - if err != sql.ErrTxDone { - s.AddError(err) + if savepoint := s.savepoint; savepoint == "" { + err := db.Rollback() + // Ignore the error indicating that the transaction has already + // been committed. + if err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.Exec(sqlRollback(s.Dialect().GetName(), savepoint)) } } else { s.AddError(ErrInvalidTransaction) @@ -815,6 +830,7 @@ func (s *DB) clone() *DB { parent: s.parent, logger: s.logger, logMode: s.logMode, + savepoint: s.savepoint, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, diff --git a/main_test.go b/main_test.go index 68bf7419..9fe57b09 100644 --- a/main_test.go +++ b/main_test.go @@ -519,6 +519,110 @@ func TestTransactionReadonly(t *testing.T) { tx.Rollback() } +func TestTransactionNestedCallback(t *testing.T) { + name := "TestTransactionNestedCallback" + type TransactionNestedCallback struct { + Id int64 + Age int64 + Name string `sql:"size:255"` + } + DB.AutoMigrate(&TransactionNestedCallback{}) + defer func() { + DB.Error = nil + DB.DropTable(&TransactionNestedCallback{}) + }() + tx := DB.Begin() + if tx.Error != nil { + t.Fatal(tx.Error) + } + u := &TransactionNestedCallback{Name: name, Age: 1} + if err := tx.Save(&u).Error; err != nil { + t.Fatal("No error should raise") + } + id := u.Id + if id == 0 { + t.Fatal() + } + tx = tx.Begin() + if tx.Error != nil { + t.Fatal(tx.Error) + } + u.Age = 2 + if err := tx.Save(u).Error; err != nil { + t.Fatal(err) + } + + if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil { + t.Fatal("Should find saved record") + } + tx = tx.Rollback() + if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil { + t.Fatal("Should not find rollbacked record") + } + if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err != nil { + t.Fatal("Should find saved record") + } + tx = tx.Rollback() + if err := tx.First(&TransactionNestedCallback{Id: id}, "name = ? and age = ?", name, 1).Error; err == nil { + t.Fatal("Should not find rollbacked record") + } + // Test rollback outside transaction error + if err := DB.Rollback().Error; err == nil { + t.Fatal("Rollback outside transaction should report error") + } +} + +func TestTransactionNestedCommit(t *testing.T) { + type TransactionNestedCommit struct { + Id int64 + Age int64 + Name string `sql:"size:255"` + } + DB.AutoMigrate(&TransactionNestedCommit{}) + name := "TestTransactionNestedCommit" + defer func() { + DB.Error = nil + DB.DropTable(&TransactionNestedCommit{}) + }() + tx := DB.Begin() + if tx.Error != nil { + t.Fatal(tx.Error) + } + u := TransactionNestedCommit{Name: name, Age: 1} + if err := tx.Save(&u).Error; err != nil { + t.Fatal("No error should raise") + } + id := u.Id + if id == 0 { + t.Fatal() + } + tx = tx.Begin() + if tx.Error != nil { + t.Fatal(tx.Error) + } + u.Age = 2 + if err := tx.Save(u).Error; err != nil { + t.Fatal(err) + } + + if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil { + t.Fatal("Should find saved record") + } + tx = tx.Commit() + if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err != nil { + t.Fatal("Should find the commited record") + } + + tx = tx.Commit() + if err := tx.First(&TransactionNestedCommit{Id: id}, "name = ? and age = ?", name, 2).Error; err == nil { + t.Fatal("Should find the commited record") + } + // Test commit outside transaction error + if err := DB.Commit().Error; err == nil { + t.Fatal("Commit outside transaction should report error") + } +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} diff --git a/utils.go b/utils.go index e58e57a5..78d28e37 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,9 @@ package gorm import ( "database/sql/driver" + "encoding/hex" "fmt" + "math/rand" "reflect" "regexp" "runtime" @@ -34,6 +36,7 @@ func init() { commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) + rand.Seed(time.Now().UnixNano()) } type safeMap struct { @@ -224,3 +227,24 @@ func addExtraSpaceIfExist(str string) string { } return "" } + +func randName() string { + data := make([]byte, 7) + rand.Read(data) + + return "n" + hex.EncodeToString(data) +} + +func sqlSavepoint(dialect string, savepoint string) string { + if dialect == "mssql" { + return fmt.Sprintf("SAVE TRAN %s;", savepoint) + } + return fmt.Sprintf("SAVEPOINT %s;", savepoint) +} + +func sqlRollback(dialect string, savepoint string) string { + if dialect == "mssql" { + return fmt.Sprintf("ROLLBACK TRAN %s;", savepoint) + } + return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s;", savepoint) +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..205ab681 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,56 @@ +package gorm + +import "testing" + +func TestRandName(t *testing.T) { + names := map[string]bool{} + max := 1000 + for i := 0; i < max; i++ { + name := randName() + if l := len(name); l >= 16 { + t.Fatal(l, name) + } + names[name] = true + } + if len(names) != max { + t.Fatal() + } +} + +func TestSqlSavepoint(t *testing.T) { + tcs := []struct { + dialect string + sql string + }{ + {"mssql", "SAVE TRAN nced368066575bc;"}, + {"sqlite3", "SAVEPOINT nced368066575bc;"}, + {"mysql", "SAVEPOINT nced368066575bc;"}, + {"postgres", "SAVEPOINT nced368066575bc;"}, + } + savepoint := "nced368066575bc" + for _, tc := range tcs { + sql := sqlSavepoint(tc.dialect, savepoint) + if sql != tc.sql { + t.Fatal(sql, tc.sql) + } + } +} + +func TestSqlRollback(t *testing.T) { + tcs := []struct { + dialect string + sql string + }{ + {"mssql", "ROLLBACK TRAN nced368066575bc;"}, + {"sqlite3", "ROLLBACK TO SAVEPOINT nced368066575bc;"}, + {"mysql", "ROLLBACK TO SAVEPOINT nced368066575bc;"}, + {"postgres", "ROLLBACK TO SAVEPOINT nced368066575bc;"}, + } + savepoint := "nced368066575bc" + for _, tc := range tcs { + sql := sqlRollback(tc.dialect, savepoint) + if sql != tc.sql { + t.Fatal(sql, tc.sql) + } + } +}