fix: prepare deadlock

This commit is contained in:
a631807682 2022-08-12 17:32:43 +08:00
parent 9916a99d54
commit 7461e04e22
No known key found for this signature in database
GPG Key ID: 137D1D75522168AB
3 changed files with 67 additions and 5 deletions

View File

@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
preparedStmt := &PreparedStmtDB{ preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool, ConnPool: db.ConnPool,
Stmts: map[string]Stmt{}, Stmts: map[string](*Stmt){},
Mux: &sync.RWMutex{}, Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100), PreparedSQL: make([]string, 0, 100),
} }

View File

@ -9,10 +9,12 @@ import (
type Stmt struct { type Stmt struct {
*sql.Stmt *sql.Stmt
Transaction bool Transaction bool
prepared chan struct{}
prepareErr error
} }
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts map[string]Stmt Stmts map[string]*Stmt
PreparedSQL []string PreparedSQL []string
Mux *sync.RWMutex Mux *sync.RWMutex
ConnPool ConnPool
@ -46,18 +48,40 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock() db.Mux.RUnlock()
return stmt, nil // wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
} }
db.Mux.RUnlock() db.Mux.RUnlock()
// cache preparing stmt first
db.Mux.Lock()
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext (suppose the maxopen is 1).
// 1. g1 begin tx, now `db.ConnPool` db.numOpen == 1
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query) stmt, err := conn.PrepareContext(ctx, query)
if err != nil { if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err return Stmt{}, err
} }
cacheStmt := Stmt{Stmt: stmt, Transaction: isTransaction}
db.Mux.Lock() db.Mux.Lock()
db.Stmts[query] = cacheStmt cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query) db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock() db.Mux.Unlock()

View File

@ -112,4 +112,42 @@ func TestPreparedStmtDeadlock(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
} }