From ac3ec858a6c375a466f613c86b053726abbe3755 Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 26 Jul 2018 19:35:53 -0400 Subject: [PATCH] Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions (#1939) * Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions. * Adds a test case for tables creations and autoMigrate in the same transaction. --- main.go | 5 ++++- migration_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 25c3a06b..3a5d6b0c 100644 --- a/main.go +++ b/main.go @@ -119,7 +119,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -484,6 +484,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -748,6 +750,7 @@ func (s *DB) clone() *DB { Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } for key, value := range s.values { diff --git a/migration_test.go b/migration_test.go index 7c694485..78555dcc 100644 --- a/migration_test.go +++ b/migration_test.go @@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/scope.go b/scope.go index 397ccf0b..5eb98963 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database