支持lru淘汰preparestmt cache

This commit is contained in:
xiezhaodong 2025-04-15 11:42:00 +08:00
parent 5225c20309
commit 3ae5fdee0c
2 changed files with 98 additions and 36 deletions

View File

@ -24,20 +24,25 @@ type PreparedStmtDB struct {
ConnPool
}
func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: func() StmtStore {
func newPrepareStmtCache(prepareStmtLruConfig *PrepareStmtLruConfig) *StmtStore {
var stmts StmtStore
if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open {
if prepareStmtLruConfig.Size <= 0 {
panic("LRU prepareStmtLruConfig.Size must > 0")
}
lru := &LruStmtStore{}
lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL)
stmts = lru
} else {
stmts = &DefaultStmtStore{}
defaultStmtStore := &DefaultStmtStore{}
stmts = defaultStmtStore.init()
}
return stmts
}(),
return &stmts
}
func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: *newPrepareStmtCache(prepareStmtLruConfig),
Mux: &sync.RWMutex{},
}
}
@ -57,6 +62,9 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
if db.Stmts == nil {
return
}
for _, stmt := range db.Stmts.AllMap() {
go func(s *Stmt) {
@ -74,7 +82,9 @@ func (db *PreparedStmtDB) Close() {
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
if sdb.Stmts == nil {
return
}
for _, stmt := range sdb.Stmts.AllMap() {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
@ -84,11 +94,14 @@ func (sdb *PreparedStmtDB) Reset() {
}
}(stmt)
}
sdb.Stmts = &DefaultStmtStore{}
defaultStmt := &DefaultStmtStore{}
defaultStmt.init()
sdb.Stmts = defaultStmt
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// wait for other goroutines prepared
@ -99,9 +112,11 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
return *stmt, nil
}
}
db.Mux.RUnlock()
db.Mux.Lock()
if db.Stmts != nil {
// double check
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
@ -113,6 +128,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
return *stmt, nil
}
}
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
@ -295,11 +311,15 @@ type StmtStore interface {
AllMap() map[string]*Stmt
}
// 默认的 map 实现
type DefaultStmtStore struct {
defaultStmt map[string]*Stmt
}
func (s *DefaultStmtStore) init() *DefaultStmtStore {
s.defaultStmt = make(map[string]*Stmt)
return s
}
func (s *DefaultStmtStore) AllMap() map[string]*Stmt {
return s.defaultStmt
}

View File

@ -92,6 +92,48 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
tx2.Commit()
}
func TestPreparedStmtLruFromTransaction(t *testing.T) {
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtLruConfig: &gorm.PrepareStmtLruConfig{10, 20 * time.Second, true}})
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if err := tx.Error; err != nil {
t.Errorf("Failed to start transaction, got error %v\n", err)
}
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
tx.Rollback()
t.Errorf("Failed to run one transaction, got error %v\n", err)
}
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
tx.Rollback()
t.Errorf("Failed to run one transaction, got error %v\n", err)
}
if err := tx.Commit().Error; err != nil {
t.Errorf("Failed to commit transaction, got error %v\n", err)
}
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
}
tx2 := db.Begin()
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
}
tx2.Commit()
time.Sleep(time.Second * 40)
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts.AllMap()), 0)
}
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
@ -117,8 +159,8 @@ func TestPreparedStmtDeadlock(t *testing.T) {
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
AssertEqual(t, len(conn.Stmts.AllMap()), 2)
for _, stmt := range conn.Stmts.AllMap() {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
@ -155,7 +197,7 @@ func TestPreparedStmtReset(t *testing.T) {
}
pdb.Mux.Lock()
if len(pdb.Stmts) == 0 {
if len(pdb.Stmts.AllMap()) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
@ -164,7 +206,7 @@ func TestPreparedStmtReset(t *testing.T) {
pdb.Reset()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts) != 0 {
if len(pdb.Stmts.AllMap()) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}