支持lru淘汰preparestmt cache
This commit is contained in:
parent
5225c20309
commit
3ae5fdee0c
@ -24,20 +24,25 @@ type PreparedStmtDB struct {
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: func() StmtStore {
|
||||
func newPrepareStmtCache(prepareStmtLruConfig *PrepareStmtLruConfig) *StmtStore {
|
||||
var stmts StmtStore
|
||||
if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open {
|
||||
if prepareStmtLruConfig.Size <= 0 {
|
||||
panic("LRU prepareStmtLruConfig.Size must > 0")
|
||||
}
|
||||
lru := &LruStmtStore{}
|
||||
lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL)
|
||||
stmts = lru
|
||||
} else {
|
||||
stmts = &DefaultStmtStore{}
|
||||
defaultStmtStore := &DefaultStmtStore{}
|
||||
stmts = defaultStmtStore.init()
|
||||
}
|
||||
return stmts
|
||||
}(),
|
||||
return &stmts
|
||||
}
|
||||
func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: *newPrepareStmtCache(prepareStmtLruConfig),
|
||||
Mux: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
@ -57,6 +62,9 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
if db.Stmts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, stmt := range db.Stmts.AllMap() {
|
||||
go func(s *Stmt) {
|
||||
@ -74,7 +82,9 @@ func (db *PreparedStmtDB) Close() {
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
|
||||
if sdb.Stmts == nil {
|
||||
return
|
||||
}
|
||||
for _, stmt := range sdb.Stmts.AllMap() {
|
||||
go func(s *Stmt) {
|
||||
// make sure the stmt must finish preparation first
|
||||
@ -84,11 +94,14 @@ func (sdb *PreparedStmtDB) Reset() {
|
||||
}
|
||||
}(stmt)
|
||||
}
|
||||
sdb.Stmts = &DefaultStmtStore{}
|
||||
defaultStmt := &DefaultStmtStore{}
|
||||
defaultStmt.init()
|
||||
sdb.Stmts = defaultStmt
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
// wait for other goroutines prepared
|
||||
@ -99,9 +112,11 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
db.Mux.Lock()
|
||||
if db.Stmts != nil {
|
||||
// double check
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
@ -113,6 +128,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
}
|
||||
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
|
||||
// which cause by calling Close and executing SQL concurrently
|
||||
if db.Stmts == nil {
|
||||
@ -295,11 +311,15 @@ type StmtStore interface {
|
||||
AllMap() map[string]*Stmt
|
||||
}
|
||||
|
||||
// 默认的 map 实现
|
||||
type DefaultStmtStore struct {
|
||||
defaultStmt map[string]*Stmt
|
||||
}
|
||||
|
||||
func (s *DefaultStmtStore) init() *DefaultStmtStore {
|
||||
s.defaultStmt = make(map[string]*Stmt)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DefaultStmtStore) AllMap() map[string]*Stmt {
|
||||
return s.defaultStmt
|
||||
}
|
||||
|
@ -92,6 +92,48 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
||||
tx2.Commit()
|
||||
}
|
||||
|
||||
func TestPreparedStmtLruFromTransaction(t *testing.T) {
|
||||
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtLruConfig: &gorm.PrepareStmtLruConfig{10, 20 * time.Second, true}})
|
||||
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
if err := tx.Error; err != nil {
|
||||
t.Errorf("Failed to start transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
t.Errorf("Failed to commit transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
|
||||
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
|
||||
tx2 := db.Begin()
|
||||
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
|
||||
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
tx2.Commit()
|
||||
time.Sleep(time.Second * 40)
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts.AllMap()), 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
@ -117,8 +159,8 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 2)
|
||||
for _, stmt := range conn.Stmts {
|
||||
AssertEqual(t, len(conn.Stmts.AllMap()), 2)
|
||||
for _, stmt := range conn.Stmts.AllMap() {
|
||||
if stmt == nil {
|
||||
t.Fatalf("stmt cannot bee nil")
|
||||
}
|
||||
@ -155,7 +197,7 @@ func TestPreparedStmtReset(t *testing.T) {
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
if len(pdb.Stmts.AllMap()) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
@ -164,7 +206,7 @@ func TestPreparedStmtReset(t *testing.T) {
|
||||
pdb.Reset()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
if len(pdb.Stmts.AllMap()) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user