支持lru淘汰preparestmt cache
This commit is contained in:
parent
5225c20309
commit
3ae5fdee0c
@ -24,21 +24,26 @@ type PreparedStmtDB struct {
|
|||||||
ConnPool
|
ConnPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newPrepareStmtCache(prepareStmtLruConfig *PrepareStmtLruConfig) *StmtStore {
|
||||||
|
var stmts StmtStore
|
||||||
|
if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open {
|
||||||
|
if prepareStmtLruConfig.Size <= 0 {
|
||||||
|
panic("LRU prepareStmtLruConfig.Size must > 0")
|
||||||
|
}
|
||||||
|
lru := &LruStmtStore{}
|
||||||
|
lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL)
|
||||||
|
stmts = lru
|
||||||
|
} else {
|
||||||
|
defaultStmtStore := &DefaultStmtStore{}
|
||||||
|
stmts = defaultStmtStore.init()
|
||||||
|
}
|
||||||
|
return &stmts
|
||||||
|
}
|
||||||
func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB {
|
func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB {
|
||||||
return &PreparedStmtDB{
|
return &PreparedStmtDB{
|
||||||
ConnPool: connPool,
|
ConnPool: connPool,
|
||||||
Stmts: func() StmtStore {
|
Stmts: *newPrepareStmtCache(prepareStmtLruConfig),
|
||||||
var stmts StmtStore
|
Mux: &sync.RWMutex{},
|
||||||
if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open {
|
|
||||||
lru := &LruStmtStore{}
|
|
||||||
lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL)
|
|
||||||
stmts = lru
|
|
||||||
} else {
|
|
||||||
stmts = &DefaultStmtStore{}
|
|
||||||
}
|
|
||||||
return stmts
|
|
||||||
}(),
|
|
||||||
Mux: &sync.RWMutex{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,6 +62,9 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
|||||||
func (db *PreparedStmtDB) Close() {
|
func (db *PreparedStmtDB) Close() {
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer db.Mux.Unlock()
|
||||||
|
if db.Stmts == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, stmt := range db.Stmts.AllMap() {
|
for _, stmt := range db.Stmts.AllMap() {
|
||||||
go func(s *Stmt) {
|
go func(s *Stmt) {
|
||||||
@ -74,7 +82,9 @@ func (db *PreparedStmtDB) Close() {
|
|||||||
func (sdb *PreparedStmtDB) Reset() {
|
func (sdb *PreparedStmtDB) Reset() {
|
||||||
sdb.Mux.Lock()
|
sdb.Mux.Lock()
|
||||||
defer sdb.Mux.Unlock()
|
defer sdb.Mux.Unlock()
|
||||||
|
if sdb.Stmts == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, stmt := range sdb.Stmts.AllMap() {
|
for _, stmt := range sdb.Stmts.AllMap() {
|
||||||
go func(s *Stmt) {
|
go func(s *Stmt) {
|
||||||
// make sure the stmt must finish preparation first
|
// make sure the stmt must finish preparation first
|
||||||
@ -84,34 +94,40 @@ func (sdb *PreparedStmtDB) Reset() {
|
|||||||
}
|
}
|
||||||
}(stmt)
|
}(stmt)
|
||||||
}
|
}
|
||||||
sdb.Stmts = &DefaultStmtStore{}
|
defaultStmt := &DefaultStmtStore{}
|
||||||
|
defaultStmt.init()
|
||||||
|
sdb.Stmts = defaultStmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||||
db.Mux.RLock()
|
db.Mux.RLock()
|
||||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
if db.Stmts != nil {
|
||||||
db.Mux.RUnlock()
|
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||||
// wait for other goroutines prepared
|
db.Mux.RUnlock()
|
||||||
<-stmt.prepared
|
// wait for other goroutines prepared
|
||||||
if stmt.prepareErr != nil {
|
<-stmt.prepared
|
||||||
return Stmt{}, stmt.prepareErr
|
if stmt.prepareErr != nil {
|
||||||
}
|
return Stmt{}, stmt.prepareErr
|
||||||
|
}
|
||||||
|
|
||||||
return *stmt, nil
|
return *stmt, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
|
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
// double check
|
if db.Stmts != nil {
|
||||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
// double check
|
||||||
db.Mux.Unlock()
|
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||||
// wait for other goroutines prepared
|
db.Mux.Unlock()
|
||||||
<-stmt.prepared
|
// wait for other goroutines prepared
|
||||||
if stmt.prepareErr != nil {
|
<-stmt.prepared
|
||||||
return Stmt{}, stmt.prepareErr
|
if stmt.prepareErr != nil {
|
||||||
}
|
return Stmt{}, stmt.prepareErr
|
||||||
|
}
|
||||||
|
|
||||||
return *stmt, nil
|
return *stmt, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
|
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
|
||||||
// which cause by calling Close and executing SQL concurrently
|
// which cause by calling Close and executing SQL concurrently
|
||||||
@ -295,11 +311,15 @@ type StmtStore interface {
|
|||||||
AllMap() map[string]*Stmt
|
AllMap() map[string]*Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// 默认的 map 实现
|
|
||||||
type DefaultStmtStore struct {
|
type DefaultStmtStore struct {
|
||||||
defaultStmt map[string]*Stmt
|
defaultStmt map[string]*Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultStmtStore) init() *DefaultStmtStore {
|
||||||
|
s.defaultStmt = make(map[string]*Stmt)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultStmtStore) AllMap() map[string]*Stmt {
|
func (s *DefaultStmtStore) AllMap() map[string]*Stmt {
|
||||||
return s.defaultStmt
|
return s.defaultStmt
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,48 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
|||||||
tx2.Commit()
|
tx2.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtLruFromTransaction(t *testing.T) {
|
||||||
|
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtLruConfig: &gorm.PrepareStmtLruConfig{10, 20 * time.Second, true}})
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := tx.Error; err != nil {
|
||||||
|
t.Errorf("Failed to start transaction, got error %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
t.Errorf("Failed to commit transaction, got error %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
|
||||||
|
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2 := db.Begin()
|
||||||
|
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
|
||||||
|
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
tx2.Commit()
|
||||||
|
time.Sleep(time.Second * 40)
|
||||||
|
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
AssertEqual(t, ok, true)
|
||||||
|
AssertEqual(t, len(conn.Stmts.AllMap()), 0)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||||
tx, err := OpenTestConnection(&gorm.Config{})
|
tx, err := OpenTestConnection(&gorm.Config{})
|
||||||
AssertEqual(t, err, nil)
|
AssertEqual(t, err, nil)
|
||||||
@ -117,8 +159,8 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
|||||||
|
|
||||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
AssertEqual(t, ok, true)
|
AssertEqual(t, ok, true)
|
||||||
AssertEqual(t, len(conn.Stmts), 2)
|
AssertEqual(t, len(conn.Stmts.AllMap()), 2)
|
||||||
for _, stmt := range conn.Stmts {
|
for _, stmt := range conn.Stmts.AllMap() {
|
||||||
if stmt == nil {
|
if stmt == nil {
|
||||||
t.Fatalf("stmt cannot bee nil")
|
t.Fatalf("stmt cannot bee nil")
|
||||||
}
|
}
|
||||||
@ -155,7 +197,7 @@ func TestPreparedStmtReset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pdb.Mux.Lock()
|
pdb.Mux.Lock()
|
||||||
if len(pdb.Stmts) == 0 {
|
if len(pdb.Stmts.AllMap()) == 0 {
|
||||||
pdb.Mux.Unlock()
|
pdb.Mux.Unlock()
|
||||||
t.Fatalf("prepared stmt can not be empty")
|
t.Fatalf("prepared stmt can not be empty")
|
||||||
}
|
}
|
||||||
@ -164,7 +206,7 @@ func TestPreparedStmtReset(t *testing.T) {
|
|||||||
pdb.Reset()
|
pdb.Reset()
|
||||||
pdb.Mux.Lock()
|
pdb.Mux.Lock()
|
||||||
defer pdb.Mux.Unlock()
|
defer pdb.Mux.Unlock()
|
||||||
if len(pdb.Stmts) != 0 {
|
if len(pdb.Stmts.AllMap()) != 0 {
|
||||||
t.Fatalf("prepared stmt should be empty")
|
t.Fatalf("prepared stmt should be empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user