diff --git a/gorm.go b/gorm.go index 84d4b433..51791caf 100644 --- a/gorm.go +++ b/gorm.go @@ -375,6 +375,10 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool + if connector, ok := connPool.(SQLConnector); ok && connector != nil { + return connector.GetSQLConn(db) + } + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } diff --git a/interfaces.go b/interfaces.go index 3bcc3d57..ce62508a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,6 +77,12 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } +// SQLConnector represents SQL db connector which takes into account the current +// database context +type SQLConnector interface { + GetSQLConn(db *DB) (*sql.DB, error) +} + // Rows rows interface type Rows interface { Columns() ([]string, error)