refact prepare stmt store

This commit is contained in:
Jinzhu 2025-04-25 11:29:39 +08:00
parent 886a406556
commit 14dc8ed9e0
4 changed files with 76 additions and 223 deletions

View File

@ -1,8 +1,9 @@
package stmt_store package stmt_store
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "sync"
"time" "time"
"gorm.io/gorm/internal/lru" "gorm.io/gorm/internal/lru"
@ -15,24 +16,7 @@ type Stmt struct {
prepareErr error prepareErr error
} }
func NewStmt(isTransaction bool) *Stmt {
return &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
}
func (stmt *Stmt) Done() {
close(stmt.prepared)
}
func (stmt *Stmt) AddError(err error) {
stmt.prepareErr = err
}
func (stmt *Stmt) Error() error { func (stmt *Stmt) Error() error {
<-stmt.prepared
return stmt.prepareErr return stmt.prepareErr
} }
@ -46,13 +30,14 @@ func (stmt *Stmt) Close() error {
} }
type Store interface { type Store interface {
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
Keys() []string
Get(key string) (*Stmt, bool) Get(key string) (*Stmt, bool)
Set(key string, value *Stmt) Set(key string, value *Stmt)
Delete(key string) Delete(key string)
AllMap() map[string]*Stmt
} }
type StmtStore struct { type LRUStore struct {
lru *lru.LRU[string, *Stmt] lru *lru.LRU[string, *Stmt]
} }
@ -72,35 +57,52 @@ func New(size int, ttl time.Duration) Store {
onEvicted := func(k string, v *Stmt) { onEvicted := func(k string, v *Stmt) {
if v != nil { if v != nil {
go func() { go v.Close()
defer func() {
if r := recover(); r != nil {
fmt.Print("close stmt err panic ")
}
}()
err := v.Close()
if err != nil {
fmt.Print("close stmt err: ", err.Error())
}
}()
} }
} }
return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} return &LRUStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
} }
func (s *StmtStore) AllMap() map[string]*Stmt { func (s *LRUStore) Keys() []string {
return s.lru.KeyValues() return s.lru.Keys()
} }
func (s *StmtStore) Get(key string) (*Stmt, bool) { func (s *LRUStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key) stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok return stmt, ok
} }
func (s *StmtStore) Set(key string, value *Stmt) { func (s *LRUStore) Set(key string, value *Stmt) {
s.lru.Add(key, value) s.lru.Add(key, value)
} }
func (s *StmtStore) Delete(key string) { func (s *LRUStore) Delete(key string) {
s.lru.Remove(key) s.lru.Remove(key)
} }
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
func (s *LRUStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
s.Set(key, cacheStmt)
locker.Unlock()
defer close(cacheStmt.prepared)
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
return cacheStmt, nil
}

View File

@ -18,6 +18,7 @@ type PreparedStmtDB struct {
ConnPool ConnPool
} }
// NewPreparedStmtDB creates a new PreparedStmtDB instance
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{ return &PreparedStmtDB{
ConnPool: connPool, ConnPool: connPool,
@ -26,6 +27,7 @@ func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *Prepa
} }
} }
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok { if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil return sqldb, nil
@ -38,93 +40,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
return nil, ErrInvalidDB return nil, ErrInvalidDB
} }
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock() defer db.Mux.Unlock()
if db.Stmts == nil {
return
}
for _, stmt := range db.Stmts.AllMap() { for _, key := range db.Stmts.Keys() {
go stmt.Close() db.Stmts.Delete(key)
} }
// setting db.Stmts to nil to avoid further using
db.Stmts = nil
} }
func (sdb *PreparedStmtDB) Reset() { // Reset Deprecated use Close instead
sdb.Mux.Lock() func (db *PreparedStmtDB) Reset() {
defer sdb.Mux.Unlock() db.Close()
if sdb.Stmts == nil {
return
}
for _, stmt := range sdb.Stmts.AllMap() {
go stmt.Close()
}
// Migrator
sdb.Stmts = stmt_store.New(0, 0)
} }
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock() db.Mux.RLock()
if db.Stmts != nil { if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock() db.Mux.RUnlock()
if err := stmt.Error(); err != nil { return stmt, stmt.Error()
return stmt_store.Stmt{}, err
}
return *stmt, nil
} }
} }
db.Mux.RUnlock() db.Mux.RUnlock()
// retry
db.Mux.Lock() db.Mux.Lock()
if db.Stmts != nil { if db.Stmts != nil {
// double check
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock() db.Mux.Unlock()
if err := stmt.Error(); err != nil { return stmt, stmt.Error()
return stmt_store.Stmt{}, err
}
return *stmt, nil
} }
} }
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
db.Mux.Unlock()
return stmt_store.Stmt{}, ErrInvalidDB
}
// cache preparing stmt first
cacheStmt := stmt_store.NewStmt(isTransaction)
db.Stmts.Set(query, cacheStmt)
db.Mux.Unlock()
// prepare completed return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
defer cacheStmt.Done()
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.AddError(err)
db.Mux.Lock()
db.Stmts.Delete(query)
db.Mux.Unlock()
return stmt_store.Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.Mux.Unlock()
return *cacheStmt, nil
} }
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -153,9 +103,6 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
if err == nil { if err == nil {
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
db.Stmts.Delete(query) db.Stmts.Delete(query)
} }
} }
@ -167,10 +114,6 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
if err == nil { if err == nil {
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
db.Stmts.Delete(query) db.Stmts.Delete(query)
} }
} }
@ -221,10 +164,6 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
if err == nil { if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
tx.PreparedStmtDB.Stmts.Delete(query) tx.PreparedStmtDB.Stmts.Delete(query)
} }
} }
@ -236,10 +175,6 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
if err == nil { if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
tx.PreparedStmtDB.Stmts.Delete(query) tx.PreparedStmtDB.Stmts.Delete(query)
} }
} }

View File

@ -1,15 +1,17 @@
module gorm.io/gorm/tests module gorm.io/gorm/tests
go 1.18 go 1.23.0
toolchain go1.24.2
require ( require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.10.0
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.10 gorm.io/driver/postgres v1.5.11
gorm.io/driver/sqlite v1.5.6 gorm.io/driver/sqlite v1.5.7
gorm.io/driver/sqlserver v1.5.4 gorm.io/driver/sqlserver v1.5.4
gorm.io/gorm v1.25.12 gorm.io/gorm v1.25.12
) )
@ -17,7 +19,7 @@ require (
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-sql-driver/mysql v1.9.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
@ -25,12 +27,12 @@ require (
github.com/jackc/pgx/v5 v5.7.1 // indirect github.com/jackc/pgx/v5 v5.7.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/microsoft/go-mssqldb v1.7.2 // indirect github.com/microsoft/go-mssqldb v1.7.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.29.0 // indirect golang.org/x/crypto v0.37.0 // indirect
golang.org/x/text v0.20.0 // indirect golang.org/x/text v0.24.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -130,13 +129,13 @@ func TestPreparedStmtLruFromTransaction(t *testing.T) {
tx2.Commit() tx2.Commit()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
lens := len(conn.Stmts.AllMap()) lens := len(conn.Stmts.Keys())
if lens == 0 { if lens == 0 {
t.Fatalf("lru should not be empty") t.Fatalf("lru should not be empty")
} }
time.Sleep(time.Second * 40) time.Sleep(time.Second * 40)
AssertEqual(t, ok, true) AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts.AllMap()), 0) AssertEqual(t, len(conn.Stmts.Keys()), 0)
} }
func TestPreparedStmtDeadlock(t *testing.T) { func TestPreparedStmtDeadlock(t *testing.T) {
@ -164,9 +163,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true) AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts.AllMap()), 2) AssertEqual(t, len(conn.Stmts.Keys()), 2)
for _, stmt := range conn.Stmts.AllMap() { for _, stmt := range conn.Stmts.Keys() {
if stmt == nil { if stmt == "" {
t.Fatalf("stmt cannot bee nil") t.Fatalf("stmt cannot bee nil")
} }
} }
@ -190,10 +189,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
} }
} }
func TestPreparedStmtReset(t *testing.T) { func TestPreparedStmtClose(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true}) tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{}) user := *GetUser("prepared_stmt_close", Config{})
tx = tx.Create(&user) tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
@ -202,16 +201,16 @@ func TestPreparedStmtReset(t *testing.T) {
} }
pdb.Mux.Lock() pdb.Mux.Lock()
if len(pdb.Stmts.AllMap()) == 0 { if len(pdb.Stmts.Keys()) == 0 {
pdb.Mux.Unlock() pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty") t.Fatalf("prepared stmt can not be empty")
} }
pdb.Mux.Unlock() pdb.Mux.Unlock()
pdb.Reset() pdb.Close()
pdb.Mux.Lock() pdb.Mux.Lock()
defer pdb.Mux.Unlock() defer pdb.Mux.Unlock()
if len(pdb.Stmts.AllMap()) != 0 { if len(pdb.Stmts.Keys()) != 0 {
t.Fatalf("prepared stmt should be empty") t.Fatalf("prepared stmt should be empty")
} }
} }
@ -221,10 +220,10 @@ func isUsingClosedConnError(err error) bool {
return err.Error() == "sql: statement is closed" return err.Error() == "sql: statement is closed"
} }
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently // TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt // this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
func TestPreparedStmtConcurrentReset(t *testing.T) { func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_reset" name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{}) user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user) createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil { if createTx.Error != nil {
@ -267,7 +266,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
<-writerFinish <-writerFinish
pdb.Reset() pdb.Close()
}() }()
wg.Wait() wg.Wait()
@ -276,88 +275,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
t.Fatalf("should is a unexpected error") t.Fatalf("should is a unexpected error")
} }
} }
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// for example: one goroutine found error and just close the database, and others are executing SQL
// this test making sure that the gorm would not get a Segmentation Fault,
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
// and all of the goroutine must got gorm.ErrInvalidDB after database close
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}
// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
loopCount := 100
var wg sync.WaitGroup
var lastErr error
closeValid := make(chan struct{}, loopCount)
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
var lastRunIndex int
var closeFinishedAt int64
wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(closeValid)
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
if lastRunIndex == closeStartIdx {
closeValid <- struct{}{}
}
var tmp User
now := time.Now().UnixNano()
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil {
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
lastErr = errors.New("must got error after database closed")
break
}
continue
}
lastErr = err
break
}
}(user.ID)
wg.Add(1)
go func() {
defer wg.Done()
for range closeValid {
for i := 0; i < loopCount; i++ {
pdb.Close() // the Close method must can be call multiple times
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
}
}
}()
wg.Wait()
var tmp User
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
if err != gorm.ErrInvalidDB {
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
}
// must be error
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
}
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
}
if pdb.Stmts != nil {
t.Fatalf("stmts must be nil")
}
}