From b73096d61d6201c65c88c6f82a07135e86b3d52a Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Tue, 2 Aug 2022 14:23:56 +0800 Subject: [PATCH] fix: prepare deadlock --- prepare_stmt.go | 23 +++++++++-------------- tests/prepared_stmt_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index b062b0d6..ff26bfd8 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -50,23 +50,18 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact } db.Mux.RUnlock() - db.Mux.Lock() - defer db.Mux.Unlock() - - // double check - if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - return stmt, nil - } else if ok { - go stmt.Close() - } - stmt, err := conn.PrepareContext(ctx, query) - if err == nil { - db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} - db.PreparedSQL = append(db.PreparedSQL, query) + if err != nil { + return Stmt{}, err } - return db.Stmts[query], err + cacheStmt := Stmt{Stmt: stmt, Transaction: isTransaction} + db.Mux.Lock() + db.Stmts[query] = cacheStmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, err } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 8730e547..3b426377 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "sync" "testing" "time" @@ -88,3 +89,27 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } tx2.Commit() } + +func TestPreparedStmtDeadlock(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 2; i++ { + wg.Add(1) + go func(j int) { + user := User{Name: "jinzhu"} + tx.Create(&user) + + var result User + tx.First(&result) + wg.Done() + }(i) + } + wg.Wait() +}