Add stmt_store
This commit is contained in:
parent
dfa1b81f65
commit
886a406556
106
internal/stmt_store/stmt_store.go
Normal file
106
internal/stmt_store/stmt_store.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package stmt_store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm/internal/lru"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Stmt struct {
|
||||||
|
*sql.Stmt
|
||||||
|
Transaction bool
|
||||||
|
prepared chan struct{}
|
||||||
|
prepareErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStmt(isTransaction bool) *Stmt {
|
||||||
|
return &Stmt{
|
||||||
|
Transaction: isTransaction,
|
||||||
|
prepared: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Stmt) Done() {
|
||||||
|
close(stmt.prepared)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Stmt) AddError(err error) {
|
||||||
|
stmt.prepareErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Stmt) Error() error {
|
||||||
|
<-stmt.prepared
|
||||||
|
|
||||||
|
return stmt.prepareErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Stmt) Close() error {
|
||||||
|
<-stmt.prepared
|
||||||
|
|
||||||
|
if stmt.Stmt != nil {
|
||||||
|
return stmt.Stmt.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store interface {
|
||||||
|
Get(key string) (*Stmt, bool)
|
||||||
|
Set(key string, value *Stmt)
|
||||||
|
Delete(key string)
|
||||||
|
AllMap() map[string]*Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
type StmtStore struct {
|
||||||
|
lru *lru.LRU[string, *Stmt]
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMaxSize = (1 << 63) - 1
|
||||||
|
defaultTTL = time.Hour * 24
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(size int, ttl time.Duration) Store {
|
||||||
|
if size <= 0 {
|
||||||
|
size = defaultMaxSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = defaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
onEvicted := func(k string, v *Stmt) {
|
||||||
|
if v != nil {
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
fmt.Print("close stmt err panic ")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err := v.Close()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Print("close stmt err: ", err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StmtStore) AllMap() map[string]*Stmt {
|
||||||
|
return s.lru.KeyValues()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StmtStore) Get(key string) (*Stmt, bool) {
|
||||||
|
stmt, ok := s.lru.Get(key)
|
||||||
|
return stmt, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StmtStore) Set(key string, value *Stmt) {
|
||||||
|
s.lru.Add(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StmtStore) Delete(key string) {
|
||||||
|
s.lru.Remove(key)
|
||||||
|
}
|
200
prepare_stmt.go
200
prepare_stmt.go
@ -5,67 +5,24 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"gorm.io/gorm/internal/lru"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm/internal/stmt_store"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Stmt struct {
|
|
||||||
*sql.Stmt
|
|
||||||
Transaction bool
|
|
||||||
prepared chan struct{}
|
|
||||||
prepareErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
type PreparedStmtDB struct {
|
type PreparedStmtDB struct {
|
||||||
Stmts StmtStore
|
Stmts stmt_store.Store
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
ConnPool
|
ConnPool
|
||||||
}
|
}
|
||||||
|
|
||||||
const default_max_size = (1 << 63) - 1
|
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
|
||||||
const default_ttl = time.Hour * 24
|
|
||||||
|
|
||||||
// newPrepareStmtCache creates a new statement cache with the specified maximum size and time-to-live (TTL).
|
|
||||||
// Parameters:
|
|
||||||
// - PrepareStmtMaxSize: An integer specifying the maximum number of prepared statements to cache.
|
|
||||||
// If this value is less than or equal to 0, the function will panic.
|
|
||||||
// - PrepareStmtTTL: A time.Duration specifying the TTL for cached statements.
|
|
||||||
// If this value differs from the default TTL, it will be used instead.
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
// - A pointer to a store.StmtStore instance configured with the provided parameters.
|
|
||||||
//
|
|
||||||
// The function initializes an LRU (Least Recently Used) cache for prepared statements,
|
|
||||||
// using either the provided size and TTL or default values
|
|
||||||
func newPrepareStmtCache(PrepareStmtMaxSize int,
|
|
||||||
PrepareStmtTTL time.Duration) *StmtStore {
|
|
||||||
var lru_size = default_max_size
|
|
||||||
var lru_ttl = default_ttl
|
|
||||||
var stmts StmtStore
|
|
||||||
if PrepareStmtMaxSize < 0 {
|
|
||||||
panic("PrepareStmtMaxSize must > 0")
|
|
||||||
}
|
|
||||||
if PrepareStmtMaxSize != 0 {
|
|
||||||
lru_size = PrepareStmtMaxSize
|
|
||||||
}
|
|
||||||
if PrepareStmtTTL != default_ttl {
|
|
||||||
lru_ttl = PrepareStmtTTL
|
|
||||||
}
|
|
||||||
lru := &LruStmtStore{}
|
|
||||||
lru.newLru(lru_size, lru_ttl)
|
|
||||||
stmts = lru
|
|
||||||
return &stmts
|
|
||||||
}
|
|
||||||
func NewPreparedStmtDB(connPool ConnPool, PrepareStmtMaxSize int,
|
|
||||||
PrepareStmtTTL time.Duration) *PreparedStmtDB {
|
|
||||||
return &PreparedStmtDB{
|
return &PreparedStmtDB{
|
||||||
ConnPool: connPool,
|
ConnPool: connPool,
|
||||||
Stmts: *newPrepareStmtCache(PrepareStmtMaxSize,
|
Stmts: stmt_store.New(maxSize, ttl),
|
||||||
PrepareStmtTTL),
|
Mux: &sync.RWMutex{},
|
||||||
Mux: &sync.RWMutex{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,13 +46,7 @@ func (db *PreparedStmtDB) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, stmt := range db.Stmts.AllMap() {
|
for _, stmt := range db.Stmts.AllMap() {
|
||||||
go func(s *Stmt) {
|
go stmt.Close()
|
||||||
// make sure the stmt must finish preparation first
|
|
||||||
<-s.prepared
|
|
||||||
if s.Stmt != nil {
|
|
||||||
_ = s.Close()
|
|
||||||
}
|
|
||||||
}(stmt)
|
|
||||||
}
|
}
|
||||||
// setting db.Stmts to nil to avoid further using
|
// setting db.Stmts to nil to avoid further using
|
||||||
db.Stmts = nil
|
db.Stmts = nil
|
||||||
@ -108,28 +59,20 @@ func (sdb *PreparedStmtDB) Reset() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, stmt := range sdb.Stmts.AllMap() {
|
for _, stmt := range sdb.Stmts.AllMap() {
|
||||||
go func(s *Stmt) {
|
go stmt.Close()
|
||||||
// make sure the stmt must finish preparation first
|
|
||||||
<-s.prepared
|
|
||||||
if s.Stmt != nil {
|
|
||||||
_ = s.Close()
|
|
||||||
}
|
|
||||||
}(stmt)
|
|
||||||
}
|
}
|
||||||
//Migrator
|
|
||||||
defaultStmt := newPrepareStmtCache(0, 0)
|
// Migrator
|
||||||
sdb.Stmts = *defaultStmt
|
sdb.Stmts = stmt_store.New(0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) {
|
||||||
db.Mux.RLock()
|
db.Mux.RLock()
|
||||||
if db.Stmts != nil {
|
if db.Stmts != nil {
|
||||||
if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) {
|
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
// wait for other goroutines prepared
|
if err := stmt.Error(); err != nil {
|
||||||
<-stmt.prepared
|
return stmt_store.Stmt{}, err
|
||||||
if stmt.prepareErr != nil {
|
|
||||||
return Stmt{}, stmt.prepareErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return *stmt, nil
|
return *stmt, nil
|
||||||
@ -140,12 +83,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
|||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
if db.Stmts != nil {
|
if db.Stmts != nil {
|
||||||
// double check
|
// double check
|
||||||
if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) {
|
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
// wait for other goroutines prepared
|
if err := stmt.Error(); err != nil {
|
||||||
<-stmt.prepared
|
return stmt_store.Stmt{}, err
|
||||||
if stmt.prepareErr != nil {
|
|
||||||
return Stmt{}, stmt.prepareErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return *stmt, nil
|
return *stmt, nil
|
||||||
@ -155,15 +96,15 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
|||||||
// which cause by calling Close and executing SQL concurrently
|
// which cause by calling Close and executing SQL concurrently
|
||||||
if db.Stmts == nil {
|
if db.Stmts == nil {
|
||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
return Stmt{}, ErrInvalidDB
|
return stmt_store.Stmt{}, ErrInvalidDB
|
||||||
}
|
}
|
||||||
// cache preparing stmt first
|
// cache preparing stmt first
|
||||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
cacheStmt := stmt_store.NewStmt(isTransaction)
|
||||||
db.Stmts.set(query, &cacheStmt)
|
db.Stmts.Set(query, cacheStmt)
|
||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
|
|
||||||
// prepare completed
|
// prepare completed
|
||||||
defer close(cacheStmt.prepared)
|
defer cacheStmt.Done()
|
||||||
|
|
||||||
// Reason why cannot lock conn.PrepareContext
|
// Reason why cannot lock conn.PrepareContext
|
||||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||||
@ -172,19 +113,18 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
|||||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||||
stmt, err := conn.PrepareContext(ctx, query)
|
stmt, err := conn.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cacheStmt.prepareErr = err
|
cacheStmt.AddError(err)
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
db.Stmts.delete(query)
|
db.Stmts.Delete(query)
|
||||||
//delete(db.Stmts.AllMap(), query)
|
|
||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
return Stmt{}, err
|
return stmt_store.Stmt{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
cacheStmt.Stmt = stmt
|
cacheStmt.Stmt = stmt
|
||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
|
|
||||||
return cacheStmt, nil
|
return *cacheStmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||||
@ -216,8 +156,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
|||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer db.Mux.Unlock()
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
db.Stmts.delete(query)
|
db.Stmts.Delete(query)
|
||||||
//delete(db.Stmts.AllMap(), query)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
@ -232,8 +171,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
|||||||
defer db.Mux.Unlock()
|
defer db.Mux.Unlock()
|
||||||
|
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
db.Stmts.delete(query)
|
db.Stmts.Delete(query)
|
||||||
//delete(db.Stmts.AllMap(), query)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rows, err
|
return rows, err
|
||||||
@ -287,8 +225,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
tx.PreparedStmtDB.Stmts.delete(query)
|
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||||
//delete(tx.PreparedStmtDB.Stmts.AllMap(), query)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
@ -303,8 +240,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
|||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
tx.PreparedStmtDB.Stmts.delete(query)
|
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||||
//delete(tx.PreparedStmtDB.Stmts.AllMap(), query)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rows, err
|
return rows, err
|
||||||
@ -325,79 +261,3 @@ func (tx *PreparedStmtTX) Ping() error {
|
|||||||
}
|
}
|
||||||
return conn.Ping()
|
return conn.Ping()
|
||||||
}
|
}
|
||||||
|
|
||||||
type StmtStore interface {
|
|
||||||
get(key string) (*Stmt, bool)
|
|
||||||
set(key string, value *Stmt)
|
|
||||||
delete(key string)
|
|
||||||
AllMap() map[string]*Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
type DefaultStmtStore struct {
|
|
||||||
defaultStmt map[string]*Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultStmtStore) Init() *DefaultStmtStore {
|
|
||||||
s.defaultStmt = make(map[string]*Stmt)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultStmtStore) AllMap() map[string]*Stmt {
|
|
||||||
return s.defaultStmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultStmtStore) Get(key string) (*Stmt, bool) {
|
|
||||||
stmt, ok := s.defaultStmt[key]
|
|
||||||
return stmt, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultStmtStore) Set(key string, value *Stmt) {
|
|
||||||
s.defaultStmt[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultStmtStore) Delete(key string) {
|
|
||||||
delete(s.defaultStmt, key)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
type LruStmtStore struct {
|
|
||||||
lru *lru.LRU[string, *Stmt]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LruStmtStore) newLru(size int, ttl time.Duration) {
|
|
||||||
onEvicted := func(k string, v *Stmt) {
|
|
||||||
if v != nil {
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
fmt.Print("close stmt err panic ")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if v != nil {
|
|
||||||
err := v.Close()
|
|
||||||
if err != nil {
|
|
||||||
//
|
|
||||||
fmt.Print("close stmt err: ", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.lru = lru.NewLRU[string, *Stmt](size, onEvicted, ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LruStmtStore) AllMap() map[string]*Stmt {
|
|
||||||
return s.lru.KeyValues()
|
|
||||||
}
|
|
||||||
func (s *LruStmtStore) get(key string) (*Stmt, bool) {
|
|
||||||
stmt, ok := s.lru.Get(key)
|
|
||||||
return stmt, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LruStmtStore) set(key string, value *Stmt) {
|
|
||||||
s.lru.Add(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LruStmtStore) delete(key string) {
|
|
||||||
s.lru.Remove(key)
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user