Support GetDBConnWithContext PreparedStmtDB

This commit is contained in:
Jinzhu 2023-08-10 13:30:48 +08:00
parent 3c34bc2f59
commit 15162afaf2

View File

@ -30,15 +30,19 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
} }
} }
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
if sqldb, ok := db.ConnPool.(*sql.DB); ok { if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil return sqldb, nil
} }
if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(gormdb)
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB return nil, ErrInvalidDB
} }
@ -54,15 +58,15 @@ func (db *PreparedStmtDB) Close() {
} }
} }
func (db *PreparedStmtDB) Reset() { func (sdb *PreparedStmtDB) Reset() {
db.Mux.Lock() sdb.Mux.Lock()
defer db.Mux.Unlock() defer sdb.Mux.Unlock()
for _, stmt := range db.Stmts { for _, stmt := range sdb.Stmts {
go stmt.Close() go stmt.Close()
} }
db.PreparedSQL = make([]string, 0, 100) sdb.PreparedSQL = make([]string, 0, 100)
db.Stmts = make(map[string]*Stmt) sdb.Stmts = make(map[string]*Stmt)
} }
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {