diff --git a/gorm.go b/gorm.go index c852e60c..419f9155 100644 --- a/gorm.go +++ b/gorm.go @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]Stmt{}, + Stmts: map[string](*Stmt){}, Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index d16ae6c3..360b06a8 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,10 +9,12 @@ import ( type Stmt struct { *sql.Stmt Transaction bool + prepared chan struct{} + prepareErr error } type PreparedStmtDB struct { - Stmts map[string]Stmt + Stmts map[string]*Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool @@ -46,18 +48,40 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - return stmt, nil + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } db.Mux.RUnlock() + // cache preparing stmt first + db.Mux.Lock() + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext (suppose the maxopen is 1). + // 1. g1 begin tx, now `db.ConnPool` db.numOpen == 1 + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() return Stmt{}, err } - cacheStmt := Stmt{Stmt: stmt, Transaction: isTransaction} db.Mux.Lock() - db.Stmts[query] = cacheStmt + cacheStmt.Stmt = stmt db.PreparedSQL = append(db.PreparedSQL, query) db.Mux.Unlock() diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 637d8d1c..918bd2b7 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -112,4 +112,42 @@ func TestPreparedStmtDeadlock(t *testing.T) { }() } wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 2) + for _, stmt := range conn.Stmts { + if stmt == nil { + t.Fatalf("stmt cannot bee nil") + } + } + + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtError(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + // err prepare + tag := Tag{Locale: "zh"} + tx.Table("users").Find(&tag) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 0) + AssertEqual(t, sqlDB.Stats().InUse, 0) }