refact prepare stmt store

This commit is contained in:
Jinzhu 2025-04-25 11:29:39 +08:00
parent 886a406556
commit 14dc8ed9e0
4 changed files with 76 additions and 223 deletions

View File

@ -1,8 +1,9 @@
package stmt_store
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
"gorm.io/gorm/internal/lru"
@ -15,24 +16,7 @@ type Stmt struct {
prepareErr error
}
func NewStmt(isTransaction bool) *Stmt {
return &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
}
func (stmt *Stmt) Done() {
close(stmt.prepared)
}
func (stmt *Stmt) AddError(err error) {
stmt.prepareErr = err
}
func (stmt *Stmt) Error() error {
<-stmt.prepared
return stmt.prepareErr
}
@ -46,13 +30,14 @@ func (stmt *Stmt) Close() error {
}
type Store interface {
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
Keys() []string
Get(key string) (*Stmt, bool)
Set(key string, value *Stmt)
Delete(key string)
AllMap() map[string]*Stmt
}
type StmtStore struct {
type LRUStore struct {
lru *lru.LRU[string, *Stmt]
}
@ -72,35 +57,52 @@ func New(size int, ttl time.Duration) Store {
onEvicted := func(k string, v *Stmt) {
if v != nil {
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Print("close stmt err panic ")
}
}()
err := v.Close()
if err != nil {
fmt.Print("close stmt err: ", err.Error())
}
}()
go v.Close()
}
}
return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
return &LRUStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
func (s *StmtStore) AllMap() map[string]*Stmt {
return s.lru.KeyValues()
func (s *LRUStore) Keys() []string {
return s.lru.Keys()
}
func (s *StmtStore) Get(key string) (*Stmt, bool) {
func (s *LRUStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *StmtStore) Set(key string, value *Stmt) {
func (s *LRUStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *StmtStore) Delete(key string) {
func (s *LRUStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
func (s *LRUStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
s.Set(key, cacheStmt)
locker.Unlock()
defer close(cacheStmt.prepared)
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
return cacheStmt, nil
}

View File

@ -18,6 +18,7 @@ type PreparedStmtDB struct {
ConnPool
}
// NewPreparedStmtDB creates a new PreparedStmtDB instance
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
@ -26,6 +27,7 @@ func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *Prepa
}
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
@ -38,93 +40,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
return nil, ErrInvalidDB
}
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
if db.Stmts == nil {
return
for _, key := range db.Stmts.Keys() {
db.Stmts.Delete(key)
}
}
for _, stmt := range db.Stmts.AllMap() {
go stmt.Close()
}
// setting db.Stmts to nil to avoid further using
db.Stmts = nil
// Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
if sdb.Stmts == nil {
return
}
for _, stmt := range sdb.Stmts.AllMap() {
go stmt.Close()
}
// Migrator
sdb.Stmts = stmt_store.New(0, 0)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) {
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
if err := stmt.Error(); err != nil {
return stmt_store.Stmt{}, err
}
return *stmt, nil
return stmt, stmt.Error()
}
}
db.Mux.RUnlock()
// retry
db.Mux.Lock()
if db.Stmts != nil {
// double check
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
if err := stmt.Error(); err != nil {
return stmt_store.Stmt{}, err
}
return *stmt, nil
return stmt, stmt.Error()
}
}
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
db.Mux.Unlock()
return stmt_store.Stmt{}, ErrInvalidDB
}
// cache preparing stmt first
cacheStmt := stmt_store.NewStmt(isTransaction)
db.Stmts.Set(query, cacheStmt)
db.Mux.Unlock()
// prepare completed
defer cacheStmt.Done()
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.AddError(err)
db.Mux.Lock()
db.Stmts.Delete(query)
db.Mux.Unlock()
return stmt_store.Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.Mux.Unlock()
return *cacheStmt, nil
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -153,9 +103,6 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
db.Stmts.Delete(query)
}
}
@ -167,10 +114,6 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
db.Stmts.Delete(query)
}
}
@ -221,10 +164,6 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
@ -236,10 +175,6 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
tx.PreparedStmtDB.Stmts.Delete(query)
}
}

View File

@ -1,15 +1,17 @@
module gorm.io/gorm/tests
go 1.18
go 1.23.0
toolchain go1.24.2
require (
github.com/google/uuid v1.6.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.9
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.10
gorm.io/driver/sqlite v1.5.6
gorm.io/driver/postgres v1.5.11
gorm.io/driver/sqlite v1.5.7
gorm.io/driver/sqlserver v1.5.4
gorm.io/gorm v1.25.12
)
@ -17,7 +19,7 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-sql-driver/mysql v1.9.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@ -25,12 +27,12 @@ require (
github.com/jackc/pgx/v5 v5.7.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/microsoft/go-mssqldb v1.7.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.29.0 // indirect
golang.org/x/text v0.20.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/text v0.24.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
@ -130,13 +129,13 @@ func TestPreparedStmtLruFromTransaction(t *testing.T) {
tx2.Commit()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
lens := len(conn.Stmts.AllMap())
lens := len(conn.Stmts.Keys())
if lens == 0 {
t.Fatalf("lru should not be empty")
}
time.Sleep(time.Second * 40)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts.AllMap()), 0)
AssertEqual(t, len(conn.Stmts.Keys()), 0)
}
func TestPreparedStmtDeadlock(t *testing.T) {
@ -164,9 +163,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts.AllMap()), 2)
for _, stmt := range conn.Stmts.AllMap() {
if stmt == nil {
AssertEqual(t, len(conn.Stmts.Keys()), 2)
for _, stmt := range conn.Stmts.Keys() {
if stmt == "" {
t.Fatalf("stmt cannot bee nil")
}
}
@ -190,10 +189,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
}
}
func TestPreparedStmtReset(t *testing.T) {
func TestPreparedStmtClose(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{})
user := *GetUser("prepared_stmt_close", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
@ -202,16 +201,16 @@ func TestPreparedStmtReset(t *testing.T) {
}
pdb.Mux.Lock()
if len(pdb.Stmts.AllMap()) == 0 {
if len(pdb.Stmts.Keys()) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Reset()
pdb.Close()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts.AllMap()) != 0 {
if len(pdb.Stmts.Keys()) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}
@ -221,10 +220,10 @@ func isUsingClosedConnError(err error) bool {
return err.Error() == "sql: statement is closed"
}
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
func TestPreparedStmtConcurrentReset(t *testing.T) {
name := "prepared_stmt_concurrent_reset"
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
@ -267,7 +266,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
go func() {
defer wg.Done()
<-writerFinish
pdb.Reset()
pdb.Close()
}()
wg.Wait()
@ -276,88 +275,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
t.Fatalf("should is a unexpected error")
}
}
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// for example: one goroutine found error and just close the database, and others are executing SQL
// this test making sure that the gorm would not get a Segmentation Fault,
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
// and all of the goroutine must got gorm.ErrInvalidDB after database close
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}
// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
loopCount := 100
var wg sync.WaitGroup
var lastErr error
closeValid := make(chan struct{}, loopCount)
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
var lastRunIndex int
var closeFinishedAt int64
wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(closeValid)
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
if lastRunIndex == closeStartIdx {
closeValid <- struct{}{}
}
var tmp User
now := time.Now().UnixNano()
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil {
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
lastErr = errors.New("must got error after database closed")
break
}
continue
}
lastErr = err
break
}
}(user.ID)
wg.Add(1)
go func() {
defer wg.Done()
for range closeValid {
for i := 0; i < loopCount; i++ {
pdb.Close() // the Close method must can be call multiple times
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
}
}
}()
wg.Wait()
var tmp User
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
if err != gorm.ErrInvalidDB {
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
}
// must be error
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
}
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
}
if pdb.Stmts != nil {
t.Fatalf("stmts must be nil")
}
}