From 638805ed62a1b4a61f6ac475dafbd242a2ce1d3f Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Wed, 19 Oct 2022 13:45:43 +0800 Subject: [PATCH] feat(PreparedStmtDB): support reset --- prepare_stmt.go | 6 ++++++ tests/prepared_stmt_test.go | 27 ++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 3934bb97..0d2ed15e 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -44,6 +44,12 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Close() + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = map[string](*Stmt){} +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c7f251f2..6c141851 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,8 +2,8 @@ package tests_test import ( "context" - "sync" "errors" + "sync" "testing" "time" @@ -168,3 +168,28 @@ func TestPreparedStmtInTransaction(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + user := *GetUser("prepared_stmt_reset", Config{}) + tx.Create(&user) + + pdb.Mux.Lock() + if len(pdb.PreparedSQL) == 0 || len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.PreparedSQL) != 0 || len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +}