refact prepare stmt store
This commit is contained in:
parent
886a406556
commit
14dc8ed9e0
@ -1,8 +1,9 @@
|
||||
package stmt_store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/lru"
|
||||
@ -15,24 +16,7 @@ type Stmt struct {
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
func NewStmt(isTransaction bool) *Stmt {
|
||||
return &Stmt{
|
||||
Transaction: isTransaction,
|
||||
prepared: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Done() {
|
||||
close(stmt.prepared)
|
||||
}
|
||||
|
||||
func (stmt *Stmt) AddError(err error) {
|
||||
stmt.prepareErr = err
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Error() error {
|
||||
<-stmt.prepared
|
||||
|
||||
return stmt.prepareErr
|
||||
}
|
||||
|
||||
@ -46,13 +30,14 @@ func (stmt *Stmt) Close() error {
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
|
||||
Keys() []string
|
||||
Get(key string) (*Stmt, bool)
|
||||
Set(key string, value *Stmt)
|
||||
Delete(key string)
|
||||
AllMap() map[string]*Stmt
|
||||
}
|
||||
|
||||
type StmtStore struct {
|
||||
type LRUStore struct {
|
||||
lru *lru.LRU[string, *Stmt]
|
||||
}
|
||||
|
||||
@ -72,35 +57,52 @@ func New(size int, ttl time.Duration) Store {
|
||||
|
||||
onEvicted := func(k string, v *Stmt) {
|
||||
if v != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Print("close stmt err panic ")
|
||||
}
|
||||
}()
|
||||
err := v.Close()
|
||||
if err != nil {
|
||||
fmt.Print("close stmt err: ", err.Error())
|
||||
}
|
||||
}()
|
||||
go v.Close()
|
||||
}
|
||||
}
|
||||
return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||
return &LRUStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||
}
|
||||
|
||||
func (s *StmtStore) AllMap() map[string]*Stmt {
|
||||
return s.lru.KeyValues()
|
||||
func (s *LRUStore) Keys() []string {
|
||||
return s.lru.Keys()
|
||||
}
|
||||
|
||||
func (s *StmtStore) Get(key string) (*Stmt, bool) {
|
||||
func (s *LRUStore) Get(key string) (*Stmt, bool) {
|
||||
stmt, ok := s.lru.Get(key)
|
||||
if ok && stmt != nil {
|
||||
<-stmt.prepared
|
||||
}
|
||||
return stmt, ok
|
||||
}
|
||||
|
||||
func (s *StmtStore) Set(key string, value *Stmt) {
|
||||
func (s *LRUStore) Set(key string, value *Stmt) {
|
||||
s.lru.Add(key, value)
|
||||
}
|
||||
|
||||
func (s *StmtStore) Delete(key string) {
|
||||
func (s *LRUStore) Delete(key string) {
|
||||
s.lru.Remove(key)
|
||||
}
|
||||
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
func (s *LRUStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
|
||||
cacheStmt := &Stmt{
|
||||
Transaction: isTransaction,
|
||||
prepared: make(chan struct{}),
|
||||
}
|
||||
s.Set(key, cacheStmt)
|
||||
locker.Unlock()
|
||||
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
|
||||
if err != nil {
|
||||
cacheStmt.prepareErr = err
|
||||
s.Delete(key)
|
||||
return &Stmt{}, err
|
||||
}
|
||||
|
||||
return cacheStmt, nil
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ type PreparedStmtDB struct {
|
||||
ConnPool
|
||||
}
|
||||
|
||||
// NewPreparedStmtDB creates a new PreparedStmtDB instance
|
||||
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
@ -26,6 +27,7 @@ func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *Prepa
|
||||
}
|
||||
}
|
||||
|
||||
// GetDBConn returns the underlying *sql.DB connection
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
@ -38,93 +40,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
// Close closes all prepared statements in the store
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
if db.Stmts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, stmt := range db.Stmts.AllMap() {
|
||||
go stmt.Close()
|
||||
for _, key := range db.Stmts.Keys() {
|
||||
db.Stmts.Delete(key)
|
||||
}
|
||||
// setting db.Stmts to nil to avoid further using
|
||||
db.Stmts = nil
|
||||
}
|
||||
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
if sdb.Stmts == nil {
|
||||
return
|
||||
}
|
||||
for _, stmt := range sdb.Stmts.AllMap() {
|
||||
go stmt.Close()
|
||||
}
|
||||
|
||||
// Migrator
|
||||
sdb.Stmts = stmt_store.New(0, 0)
|
||||
// Reset Deprecated use Close instead
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) {
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
|
||||
db.Mux.RLock()
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
if err := stmt.Error(); err != nil {
|
||||
return stmt_store.Stmt{}, err
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
// retry
|
||||
db.Mux.Lock()
|
||||
if db.Stmts != nil {
|
||||
// double check
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
if err := stmt.Error(); err != nil {
|
||||
return stmt_store.Stmt{}, err
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
}
|
||||
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
|
||||
// which cause by calling Close and executing SQL concurrently
|
||||
if db.Stmts == nil {
|
||||
db.Mux.Unlock()
|
||||
return stmt_store.Stmt{}, ErrInvalidDB
|
||||
}
|
||||
// cache preparing stmt first
|
||||
cacheStmt := stmt_store.NewStmt(isTransaction)
|
||||
db.Stmts.Set(query, cacheStmt)
|
||||
db.Mux.Unlock()
|
||||
|
||||
// prepare completed
|
||||
defer cacheStmt.Done()
|
||||
|
||||
// Reason why cannot lock conn.PrepareContext
|
||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
cacheStmt.AddError(err)
|
||||
db.Mux.Lock()
|
||||
db.Stmts.Delete(query)
|
||||
db.Mux.Unlock()
|
||||
return stmt_store.Stmt{}, err
|
||||
}
|
||||
|
||||
db.Mux.Lock()
|
||||
cacheStmt.Stmt = stmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
return *cacheStmt, nil
|
||||
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
@ -153,9 +103,6 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
@ -167,10 +114,6 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
@ -221,10 +164,6 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
@ -236,10 +175,6 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
|
18
tests/go.mod
18
tests/go.mod
@ -1,15 +1,17 @@
|
||||
module gorm.io/gorm/tests
|
||||
|
||||
go 1.18
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.5.10
|
||||
gorm.io/driver/sqlite v1.5.6
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/driver/sqlserver v1.5.4
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
@ -17,7 +19,7 @@ require (
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.2 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@ -25,12 +27,12 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.7.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.7.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/crypto v0.29.0 // indirect
|
||||
golang.org/x/text v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -130,13 +129,13 @@ func TestPreparedStmtLruFromTransaction(t *testing.T) {
|
||||
|
||||
tx2.Commit()
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
lens := len(conn.Stmts.AllMap())
|
||||
lens := len(conn.Stmts.Keys())
|
||||
if lens == 0 {
|
||||
t.Fatalf("lru should not be empty")
|
||||
}
|
||||
time.Sleep(time.Second * 40)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts.AllMap()), 0)
|
||||
AssertEqual(t, len(conn.Stmts.Keys()), 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
@ -164,9 +163,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts.AllMap()), 2)
|
||||
for _, stmt := range conn.Stmts.AllMap() {
|
||||
if stmt == nil {
|
||||
AssertEqual(t, len(conn.Stmts.Keys()), 2)
|
||||
for _, stmt := range conn.Stmts.Keys() {
|
||||
if stmt == "" {
|
||||
t.Fatalf("stmt cannot bee nil")
|
||||
}
|
||||
}
|
||||
@ -190,10 +189,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
func TestPreparedStmtClose(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
user := *GetUser("prepared_stmt_close", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
@ -202,16 +201,16 @@ func TestPreparedStmtReset(t *testing.T) {
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts.AllMap()) == 0 {
|
||||
if len(pdb.Stmts.Keys()) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Close()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts.AllMap()) != 0 {
|
||||
if len(pdb.Stmts.Keys()) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
@ -221,10 +220,10 @@ func isUsingClosedConnError(err error) bool {
|
||||
return err.Error() == "sql: statement is closed"
|
||||
}
|
||||
|
||||
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
|
||||
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
|
||||
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
|
||||
func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_reset"
|
||||
func TestPreparedStmtConcurrentClose(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_close"
|
||||
user := *GetUser(name, Config{})
|
||||
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||
if createTx.Error != nil {
|
||||
@ -267,7 +266,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-writerFinish
|
||||
pdb.Reset()
|
||||
pdb.Close()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
@ -276,88 +275,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||
t.Fatalf("should is a unexpected error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
|
||||
// for example: one goroutine found error and just close the database, and others are executing SQL
|
||||
// this test making sure that the gorm would not get a Segmentation Fault,
|
||||
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
|
||||
// and all of the goroutine must got gorm.ErrInvalidDB after database close
|
||||
func TestPreparedStmtConcurrentClose(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_close"
|
||||
user := *GetUser(name, Config{})
|
||||
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||
if createTx.Error != nil {
|
||||
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
|
||||
}
|
||||
|
||||
// create a new connection to keep away from other tests
|
||||
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test connection due to %s", err)
|
||||
}
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
loopCount := 100
|
||||
var wg sync.WaitGroup
|
||||
var lastErr error
|
||||
closeValid := make(chan struct{}, loopCount)
|
||||
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
|
||||
var lastRunIndex int
|
||||
var closeFinishedAt int64
|
||||
|
||||
wg.Add(1)
|
||||
go func(id uint) {
|
||||
defer wg.Done()
|
||||
defer close(closeValid)
|
||||
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
|
||||
if lastRunIndex == closeStartIdx {
|
||||
closeValid <- struct{}{}
|
||||
}
|
||||
var tmp User
|
||||
now := time.Now().UnixNano()
|
||||
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
|
||||
if err == nil {
|
||||
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
|
||||
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
|
||||
lastErr = errors.New("must got error after database closed")
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastErr = err
|
||||
break
|
||||
}
|
||||
}(user.ID)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range closeValid {
|
||||
for i := 0; i < loopCount; i++ {
|
||||
pdb.Close() // the Close method must can be call multiple times
|
||||
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
var tmp User
|
||||
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
|
||||
if err != gorm.ErrInvalidDB {
|
||||
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
|
||||
}
|
||||
|
||||
// must be error
|
||||
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
|
||||
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
|
||||
}
|
||||
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
|
||||
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
|
||||
}
|
||||
if pdb.Stmts != nil {
|
||||
t.Fatalf("stmts must be nil")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user