diff --git a/main.go b/main.go index ca1d24bb..942880b5 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,13 @@ func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// Return the underlying sql.DB or sql.Tx instance. +// Use of this method is discouraged. It's mainly intended to allow +// coexistence with legacy non-GORM code. +func (s *DB) CommonDB() sqlCommon { + return s.db +} + func (s *DB) Callback() *callback { s.parent.callback = s.parent.callback.clone() return s.parent.callback diff --git a/main_test.go b/main_test.go index 47611908..32d9ac62 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,6 +1542,10 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + tx.Rollback() if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {