diff --git a/main.go b/main.go index 14a76b19..7744b04c 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,7 @@ type DB struct { values map[string]interface{} } -func Open(dialect string, drivesources ...string) (DB, error) { +func makeDB(dialect string, drivesources ...string) (*DB, string, string, error) { var db DB var err error var driver = dialect @@ -54,10 +54,26 @@ func Open(dialect string, drivesources ...string) (DB, error) { db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}} - db.db, err = sql.Open(driver, source) db.parent = &db } - return db, err + + return &db, driver, source, err +} + +func Open(dialect string, drivesources ...string) (DB, error) { + db, driver, source, err := makeDB(dialect, drivesources...) + if err == nil { + db.db, err = sql.Open(driver, source) + } + return *db, err +} + +func Wrap(wrapDb sqlCommon, dialect string, drivesources ...string) (DB, error) { + db, _, _, err := makeDB(dialect, drivesources...) + if err == nil { + db.db = wrapDb + } + return *db, err } func (s *DB) Close() error {