diff --git a/internal/stmt_store/stmt_store.go b/internal/stmt_store/stmt_store.go index ae6a79b6..606ff761 100644 --- a/internal/stmt_store/stmt_store.go +++ b/internal/stmt_store/stmt_store.go @@ -1,8 +1,9 @@ package stmt_store import ( + "context" "database/sql" - "fmt" + "sync" "time" "gorm.io/gorm/internal/lru" @@ -15,24 +16,7 @@ type Stmt struct { prepareErr error } -func NewStmt(isTransaction bool) *Stmt { - return &Stmt{ - Transaction: isTransaction, - prepared: make(chan struct{}), - } -} - -func (stmt *Stmt) Done() { - close(stmt.prepared) -} - -func (stmt *Stmt) AddError(err error) { - stmt.prepareErr = err -} - func (stmt *Stmt) Error() error { - <-stmt.prepared - return stmt.prepareErr } @@ -46,13 +30,14 @@ func (stmt *Stmt) Close() error { } type Store interface { + New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error) + Keys() []string Get(key string) (*Stmt, bool) Set(key string, value *Stmt) Delete(key string) - AllMap() map[string]*Stmt } -type StmtStore struct { +type LRUStore struct { lru *lru.LRU[string, *Stmt] } @@ -72,35 +57,52 @@ func New(size int, ttl time.Duration) Store { onEvicted := func(k string, v *Stmt) { if v != nil { - go func() { - defer func() { - if r := recover(); r != nil { - fmt.Print("close stmt err panic ") - } - }() - err := v.Close() - if err != nil { - fmt.Print("close stmt err: ", err.Error()) - } - }() + go v.Close() } } - return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} + return &LRUStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} } -func (s *StmtStore) AllMap() map[string]*Stmt { - return s.lru.KeyValues() +func (s *LRUStore) Keys() []string { + return s.lru.Keys() } -func (s *StmtStore) Get(key string) (*Stmt, bool) { +func (s *LRUStore) Get(key string) (*Stmt, bool) { stmt, ok := s.lru.Get(key) + if ok && stmt != nil { + <-stmt.prepared + } return stmt, ok } -func (s *StmtStore) Set(key string, value *Stmt) { +func (s *LRUStore) Set(key string, value *Stmt) { s.lru.Add(key, value) } -func (s *StmtStore) Delete(key string) { +func (s *LRUStore) Delete(key string) { s.lru.Remove(key) } + +type ConnPool interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +func (s *LRUStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) { + cacheStmt := &Stmt{ + Transaction: isTransaction, + prepared: make(chan struct{}), + } + s.Set(key, cacheStmt) + locker.Unlock() + + defer close(cacheStmt.prepared) + + cacheStmt.Stmt, err = conn.PrepareContext(ctx, key) + if err != nil { + cacheStmt.prepareErr = err + s.Delete(key) + return &Stmt{}, err + } + + return cacheStmt, nil +} diff --git a/prepare_stmt.go b/prepare_stmt.go index 5b5bf68b..68c7ba69 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -18,6 +18,7 @@ type PreparedStmtDB struct { ConnPool } +// NewPreparedStmtDB creates a new PreparedStmtDB instance func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { return &PreparedStmtDB{ ConnPool: connPool, @@ -26,6 +27,7 @@ func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *Prepa } } +// GetDBConn returns the underlying *sql.DB connection func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil @@ -38,93 +40,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { return nil, ErrInvalidDB } +// Close closes all prepared statements in the store func (db *PreparedStmtDB) Close() { db.Mux.Lock() defer db.Mux.Unlock() - if db.Stmts == nil { - return - } - for _, stmt := range db.Stmts.AllMap() { - go stmt.Close() + for _, key := range db.Stmts.Keys() { + db.Stmts.Delete(key) } - // setting db.Stmts to nil to avoid further using - db.Stmts = nil } -func (sdb *PreparedStmtDB) Reset() { - sdb.Mux.Lock() - defer sdb.Mux.Unlock() - if sdb.Stmts == nil { - return - } - for _, stmt := range sdb.Stmts.AllMap() { - go stmt.Close() - } - - // Migrator - sdb.Stmts = stmt_store.New(0, 0) +// Reset Deprecated use Close instead +func (db *PreparedStmtDB) Reset() { + db.Close() } -func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) { db.Mux.RLock() if db.Stmts != nil { if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - if err := stmt.Error(); err != nil { - return stmt_store.Stmt{}, err - } - - return *stmt, nil + return stmt, stmt.Error() } } db.Mux.RUnlock() + // retry db.Mux.Lock() if db.Stmts != nil { - // double check if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.Unlock() - if err := stmt.Error(); err != nil { - return stmt_store.Stmt{}, err - } - - return *stmt, nil + return stmt, stmt.Error() } } - // check db.Stmts first to avoid Segmentation Fault(setting value to nil map) - // which cause by calling Close and executing SQL concurrently - if db.Stmts == nil { - db.Mux.Unlock() - return stmt_store.Stmt{}, ErrInvalidDB - } - // cache preparing stmt first - cacheStmt := stmt_store.NewStmt(isTransaction) - db.Stmts.Set(query, cacheStmt) - db.Mux.Unlock() - // prepare completed - defer cacheStmt.Done() - - // Reason why cannot lock conn.PrepareContext - // suppose the maxopen is 1, g1 is creating record and g2 is querying record. - // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. - // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. - // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. - stmt, err := conn.PrepareContext(ctx, query) - if err != nil { - cacheStmt.AddError(err) - db.Mux.Lock() - db.Stmts.Delete(query) - db.Mux.Unlock() - return stmt_store.Stmt{}, err - } - - db.Mux.Lock() - cacheStmt.Stmt = stmt - db.Mux.Unlock() - - return *cacheStmt, nil + return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux) } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -153,9 +103,6 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = stmt.ExecContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - db.Mux.Lock() - defer db.Mux.Unlock() - go stmt.Close() db.Stmts.Delete(query) } } @@ -167,10 +114,6 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = stmt.QueryContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - db.Mux.Lock() - defer db.Mux.Unlock() - - go stmt.Close() db.Stmts.Delete(query) } } @@ -221,10 +164,6 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - tx.PreparedStmtDB.Mux.Lock() - defer tx.PreparedStmtDB.Mux.Unlock() - - go stmt.Close() tx.PreparedStmtDB.Stmts.Delete(query) } } @@ -236,10 +175,6 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - tx.PreparedStmtDB.Mux.Lock() - defer tx.PreparedStmtDB.Mux.Unlock() - - go stmt.Close() tx.PreparedStmtDB.Stmts.Delete(query) } } diff --git a/tests/go.mod b/tests/go.mod index 30143433..778e3bca 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,15 +1,17 @@ module gorm.io/gorm/tests -go 1.18 +go 1.23.0 + +toolchain go1.24.2 require ( github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 gorm.io/driver/mysql v1.5.7 - gorm.io/driver/postgres v1.5.10 - gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlserver v1.5.4 gorm.io/gorm v1.25.12 ) @@ -17,7 +19,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-sql-driver/mysql v1.9.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -25,12 +27,12 @@ require ( github.com/jackc/pgx/v5 v5.7.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect - github.com/mattn/go-sqlite3 v1.14.24 // indirect + github.com/mattn/go-sqlite3 v1.14.28 // indirect github.com/microsoft/go-mssqldb v1.7.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.29.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/text v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c6597594..3f2b1608 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "sync" - "sync/atomic" "testing" "time" @@ -130,13 +129,13 @@ func TestPreparedStmtLruFromTransaction(t *testing.T) { tx2.Commit() conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) - lens := len(conn.Stmts.AllMap()) + lens := len(conn.Stmts.Keys()) if lens == 0 { t.Fatalf("lru should not be empty") } time.Sleep(time.Second * 40) AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts.AllMap()), 0) + AssertEqual(t, len(conn.Stmts.Keys()), 0) } func TestPreparedStmtDeadlock(t *testing.T) { @@ -164,9 +163,9 @@ func TestPreparedStmtDeadlock(t *testing.T) { conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts.AllMap()), 2) - for _, stmt := range conn.Stmts.AllMap() { - if stmt == nil { + AssertEqual(t, len(conn.Stmts.Keys()), 2) + for _, stmt := range conn.Stmts.Keys() { + if stmt == "" { t.Fatalf("stmt cannot bee nil") } } @@ -190,10 +189,10 @@ func TestPreparedStmtInTransaction(t *testing.T) { } } -func TestPreparedStmtReset(t *testing.T) { +func TestPreparedStmtClose(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) - user := *GetUser("prepared_stmt_reset", Config{}) + user := *GetUser("prepared_stmt_close", Config{}) tx = tx.Create(&user) pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) @@ -202,16 +201,16 @@ func TestPreparedStmtReset(t *testing.T) { } pdb.Mux.Lock() - if len(pdb.Stmts.AllMap()) == 0 { + if len(pdb.Stmts.Keys()) == 0 { pdb.Mux.Unlock() t.Fatalf("prepared stmt can not be empty") } pdb.Mux.Unlock() - pdb.Reset() + pdb.Close() pdb.Mux.Lock() defer pdb.Mux.Unlock() - if len(pdb.Stmts.AllMap()) != 0 { + if len(pdb.Stmts.Keys()) != 0 { t.Fatalf("prepared stmt should be empty") } } @@ -221,10 +220,10 @@ func isUsingClosedConnError(err error) bool { return err.Error() == "sql: statement is closed" } -// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently +// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently // this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt -func TestPreparedStmtConcurrentReset(t *testing.T) { - name := "prepared_stmt_concurrent_reset" +func TestPreparedStmtConcurrentClose(t *testing.T) { + name := "prepared_stmt_concurrent_close" user := *GetUser(name, Config{}) createTx := DB.Session(&gorm.Session{}).Create(&user) if createTx.Error != nil { @@ -267,7 +266,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) { go func() { defer wg.Done() <-writerFinish - pdb.Reset() + pdb.Close() }() wg.Wait() @@ -276,88 +275,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) { t.Fatalf("should is a unexpected error") } } - -// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently -// for example: one goroutine found error and just close the database, and others are executing SQL -// this test making sure that the gorm would not get a Segmentation Fault, -// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB -// and all of the goroutine must got gorm.ErrInvalidDB after database close -func TestPreparedStmtConcurrentClose(t *testing.T) { - name := "prepared_stmt_concurrent_close" - user := *GetUser(name, Config{}) - createTx := DB.Session(&gorm.Session{}).Create(&user) - if createTx.Error != nil { - t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) - } - - // create a new connection to keep away from other tests - tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) - if err != nil { - t.Fatalf("failed to open test connection due to %s", err) - } - pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) - if !ok { - t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") - } - - loopCount := 100 - var wg sync.WaitGroup - var lastErr error - closeValid := make(chan struct{}, loopCount) - closeStartIdx := loopCount / 2 // close the database at the middle of the execution - var lastRunIndex int - var closeFinishedAt int64 - - wg.Add(1) - go func(id uint) { - defer wg.Done() - defer close(closeValid) - for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { - if lastRunIndex == closeStartIdx { - closeValid <- struct{}{} - } - var tmp User - now := time.Now().UnixNano() - err := tx.Session(&gorm.Session{}).First(&tmp, id).Error - if err == nil { - closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) - if (closeFinishedAt != 0) && (now > closeFinishedAt) { - lastErr = errors.New("must got error after database closed") - break - } - continue - } - lastErr = err - break - } - }(user.ID) - - wg.Add(1) - go func() { - defer wg.Done() - for range closeValid { - for i := 0; i < loopCount; i++ { - pdb.Close() // the Close method must can be call multiple times - atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) - } - } - }() - - wg.Wait() - var tmp User - err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error - if err != gorm.ErrInvalidDB { - t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) - } - - // must be error - if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { - t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) - } - if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { - t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) - } - if pdb.Stmts != nil { - t.Fatalf("stmts must be nil") - } -}