From a336f51444af6ea727308630346efccd8b0e716b Mon Sep 17 00:00:00 2001 From: Timothy Stranex Date: Sun, 16 Mar 2014 18:24:32 +0200 Subject: [PATCH 1/2] Add DB.Tx() method to provice access to the underlying sql.Tx instance. --- main.go | 12 ++++++++++++ main_test.go | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/main.go b/main.go index ca1d24bb..f205b6c3 100644 --- a/main.go +++ b/main.go @@ -28,10 +28,22 @@ func Open(driver, source string) (DB, error) { return db, err } +// Return the underlying sql.DB instance. +// +// If called inside a transaction, it will panic. +// Use Tx() instead in this case. func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// Return the underlying sql.Tx instance. +// +// If called outside of a transaction, it will panic. +// Use DB() instead in this case. +func (s *DB) Tx() *sql.Tx { + return s.db.(*sql.Tx) +} + func (s *DB) Callback() *callback { s.parent.callback = s.parent.callback.clone() return s.parent.callback diff --git a/main_test.go b/main_test.go index 47611908..14de6701 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,6 +1542,11 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } + sql_tx := tx.Tx() // This shouldn't panic. + if sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx, but got nil") + } + tx.Rollback() if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { From 42448cb5d6f68cf2c3a61bf68b217c5bcfb5ab30 Mon Sep 17 00:00:00 2001 From: Timothy Stranex Date: Mon, 17 Mar 2014 12:08:44 +0200 Subject: [PATCH 2/2] Add DB.CommonDB() instead of DB.Tx(), as discussed in the PR thread. --- main.go | 15 +++++---------- main_test.go | 7 +++---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index f205b6c3..942880b5 100644 --- a/main.go +++ b/main.go @@ -28,20 +28,15 @@ func Open(driver, source string) (DB, error) { return db, err } -// Return the underlying sql.DB instance. -// -// If called inside a transaction, it will panic. -// Use Tx() instead in this case. func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } -// Return the underlying sql.Tx instance. -// -// If called outside of a transaction, it will panic. -// Use DB() instead in this case. -func (s *DB) Tx() *sql.Tx { - return s.db.(*sql.Tx) +// Return the underlying sql.DB or sql.Tx instance. +// Use of this method is discouraged. It's mainly intended to allow +// coexistence with legacy non-GORM code. +func (s *DB) CommonDB() sqlCommon { + return s.db } func (s *DB) Callback() *callback { diff --git a/main_test.go b/main_test.go index 14de6701..32d9ac62 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,10 +1542,9 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } - sql_tx := tx.Tx() // This shouldn't panic. - if sql_tx == nil { - t.Errorf("Should return the underlying sql.Tx, but got nil") - } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } tx.Rollback()