Add stmt_store

This commit is contained in:
Jinzhu 2025-04-24 17:42:38 +08:00
parent dfa1b81f65
commit 886a406556
2 changed files with 136 additions and 170 deletions

View File

@ -0,0 +1,106 @@
package stmt_store
import (
"database/sql"
"fmt"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func NewStmt(isTransaction bool) *Stmt {
return &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
}
func (stmt *Stmt) Done() {
close(stmt.prepared)
}
func (stmt *Stmt) AddError(err error) {
stmt.prepareErr = err
}
func (stmt *Stmt) Error() error {
<-stmt.prepared
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
type Store interface {
Get(key string) (*Stmt, bool)
Set(key string, value *Stmt)
Delete(key string)
AllMap() map[string]*Stmt
}
type StmtStore struct {
lru *lru.LRU[string, *Stmt]
}
const (
defaultMaxSize = (1 << 63) - 1
defaultTTL = time.Hour * 24
)
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Print("close stmt err panic ")
}
}()
err := v.Close()
if err != nil {
fmt.Print("close stmt err: ", err.Error())
}
}()
}
}
return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
func (s *StmtStore) AllMap() map[string]*Stmt {
return s.lru.KeyValues()
}
func (s *StmtStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
return stmt, ok
}
func (s *StmtStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *StmtStore) Delete(key string) {
s.lru.Remove(key)
}

View File

@ -5,66 +5,23 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"gorm.io/gorm/internal/lru"
"reflect"
"sync"
"time"
"gorm.io/gorm/internal/stmt_store"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
type PreparedStmtDB struct {
Stmts StmtStore
Stmts stmt_store.Store
Mux *sync.RWMutex
ConnPool
}
const default_max_size = (1 << 63) - 1
const default_ttl = time.Hour * 24
// newPrepareStmtCache creates a new statement cache with the specified maximum size and time-to-live (TTL).
// Parameters:
// - PrepareStmtMaxSize: An integer specifying the maximum number of prepared statements to cache.
// If this value is less than or equal to 0, the function will panic.
// - PrepareStmtTTL: A time.Duration specifying the TTL for cached statements.
// If this value differs from the default TTL, it will be used instead.
//
// Returns:
// - A pointer to a store.StmtStore instance configured with the provided parameters.
//
// The function initializes an LRU (Least Recently Used) cache for prepared statements,
// using either the provided size and TTL or default values
func newPrepareStmtCache(PrepareStmtMaxSize int,
PrepareStmtTTL time.Duration) *StmtStore {
var lru_size = default_max_size
var lru_ttl = default_ttl
var stmts StmtStore
if PrepareStmtMaxSize < 0 {
panic("PrepareStmtMaxSize must > 0")
}
if PrepareStmtMaxSize != 0 {
lru_size = PrepareStmtMaxSize
}
if PrepareStmtTTL != default_ttl {
lru_ttl = PrepareStmtTTL
}
lru := &LruStmtStore{}
lru.newLru(lru_size, lru_ttl)
stmts = lru
return &stmts
}
func NewPreparedStmtDB(connPool ConnPool, PrepareStmtMaxSize int,
PrepareStmtTTL time.Duration) *PreparedStmtDB {
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: *newPrepareStmtCache(PrepareStmtMaxSize,
PrepareStmtTTL),
Stmts: stmt_store.New(maxSize, ttl),
Mux: &sync.RWMutex{},
}
}
@ -89,13 +46,7 @@ func (db *PreparedStmtDB) Close() {
}
for _, stmt := range db.Stmts.AllMap() {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
go stmt.Close()
}
// setting db.Stmts to nil to avoid further using
db.Stmts = nil
@ -108,28 +59,20 @@ func (sdb *PreparedStmtDB) Reset() {
return
}
for _, stmt := range sdb.Stmts.AllMap() {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
}
//Migrator
defaultStmt := newPrepareStmtCache(0, 0)
sdb.Stmts = *defaultStmt
go stmt.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
// Migrator
sdb.Stmts = stmt_store.New(0, 0)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) {
db.Mux.RLock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
if err := stmt.Error(); err != nil {
return stmt_store.Stmt{}, err
}
return *stmt, nil
@ -140,12 +83,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.Lock()
if db.Stmts != nil {
// double check
if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
if err := stmt.Error(); err != nil {
return stmt_store.Stmt{}, err
}
return *stmt, nil
@ -155,15 +96,15 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
db.Mux.Unlock()
return Stmt{}, ErrInvalidDB
return stmt_store.Stmt{}, ErrInvalidDB
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts.set(query, &cacheStmt)
cacheStmt := stmt_store.NewStmt(isTransaction)
db.Stmts.Set(query, cacheStmt)
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
defer cacheStmt.Done()
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
@ -172,19 +113,18 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
cacheStmt.AddError(err)
db.Mux.Lock()
db.Stmts.delete(query)
//delete(db.Stmts.AllMap(), query)
db.Stmts.Delete(query)
db.Mux.Unlock()
return Stmt{}, err
return stmt_store.Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.Mux.Unlock()
return cacheStmt, nil
return *cacheStmt, nil
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -216,8 +156,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
db.Stmts.delete(query)
//delete(db.Stmts.AllMap(), query)
db.Stmts.Delete(query)
}
}
return result, err
@ -232,8 +171,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
defer db.Mux.Unlock()
go stmt.Close()
db.Stmts.delete(query)
//delete(db.Stmts.AllMap(), query)
db.Stmts.Delete(query)
}
}
return rows, err
@ -287,8 +225,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
tx.PreparedStmtDB.Stmts.delete(query)
//delete(tx.PreparedStmtDB.Stmts.AllMap(), query)
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return result, err
@ -303,8 +240,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
tx.PreparedStmtDB.Stmts.delete(query)
//delete(tx.PreparedStmtDB.Stmts.AllMap(), query)
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return rows, err
@ -325,79 +261,3 @@ func (tx *PreparedStmtTX) Ping() error {
}
return conn.Ping()
}
type StmtStore interface {
get(key string) (*Stmt, bool)
set(key string, value *Stmt)
delete(key string)
AllMap() map[string]*Stmt
}
/*
type DefaultStmtStore struct {
defaultStmt map[string]*Stmt
}
func (s *DefaultStmtStore) Init() *DefaultStmtStore {
s.defaultStmt = make(map[string]*Stmt)
return s
}
func (s *DefaultStmtStore) AllMap() map[string]*Stmt {
return s.defaultStmt
}
func (s *DefaultStmtStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.defaultStmt[key]
return stmt, ok
}
func (s *DefaultStmtStore) Set(key string, value *Stmt) {
s.defaultStmt[key] = value
}
func (s *DefaultStmtStore) Delete(key string) {
delete(s.defaultStmt, key)
}
*/
type LruStmtStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *LruStmtStore) newLru(size int, ttl time.Duration) {
onEvicted := func(k string, v *Stmt) {
if v != nil {
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Print("close stmt err panic ")
}
}()
if v != nil {
err := v.Close()
if err != nil {
//
fmt.Print("close stmt err: ", err.Error())
}
}
}()
}
}
s.lru = lru.NewLRU[string, *Stmt](size, onEvicted, ttl)
}
func (s *LruStmtStore) AllMap() map[string]*Stmt {
return s.lru.KeyValues()
}
func (s *LruStmtStore) get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
return stmt, ok
}
func (s *LruStmtStore) set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *LruStmtStore) delete(key string) {
s.lru.Remove(key)
}