diff --git a/internal/stmt_store/stmt_store.go b/internal/stmt_store/stmt_store.go new file mode 100644 index 00000000..ae6a79b6 --- /dev/null +++ b/internal/stmt_store/stmt_store.go @@ -0,0 +1,106 @@ +package stmt_store + +import ( + "database/sql" + "fmt" + "time" + + "gorm.io/gorm/internal/lru" +) + +type Stmt struct { + *sql.Stmt + Transaction bool + prepared chan struct{} + prepareErr error +} + +func NewStmt(isTransaction bool) *Stmt { + return &Stmt{ + Transaction: isTransaction, + prepared: make(chan struct{}), + } +} + +func (stmt *Stmt) Done() { + close(stmt.prepared) +} + +func (stmt *Stmt) AddError(err error) { + stmt.prepareErr = err +} + +func (stmt *Stmt) Error() error { + <-stmt.prepared + + return stmt.prepareErr +} + +func (stmt *Stmt) Close() error { + <-stmt.prepared + + if stmt.Stmt != nil { + return stmt.Stmt.Close() + } + return nil +} + +type Store interface { + Get(key string) (*Stmt, bool) + Set(key string, value *Stmt) + Delete(key string) + AllMap() map[string]*Stmt +} + +type StmtStore struct { + lru *lru.LRU[string, *Stmt] +} + +const ( + defaultMaxSize = (1 << 63) - 1 + defaultTTL = time.Hour * 24 +) + +func New(size int, ttl time.Duration) Store { + if size <= 0 { + size = defaultMaxSize + } + + if ttl <= 0 { + ttl = defaultTTL + } + + onEvicted := func(k string, v *Stmt) { + if v != nil { + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Print("close stmt err panic ") + } + }() + err := v.Close() + if err != nil { + fmt.Print("close stmt err: ", err.Error()) + } + }() + } + } + return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} +} + +func (s *StmtStore) AllMap() map[string]*Stmt { + return s.lru.KeyValues() +} + +func (s *StmtStore) Get(key string) (*Stmt, bool) { + stmt, ok := s.lru.Get(key) + return stmt, ok +} + +func (s *StmtStore) Set(key string, value *Stmt) { + s.lru.Add(key, value) +} + +func (s *StmtStore) Delete(key string) { + s.lru.Remove(key) +} diff --git a/prepare_stmt.go b/prepare_stmt.go index f175690e..5b5bf68b 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -5,67 +5,24 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" - "gorm.io/gorm/internal/lru" "reflect" "sync" "time" + + "gorm.io/gorm/internal/stmt_store" ) -type Stmt struct { - *sql.Stmt - Transaction bool - prepared chan struct{} - prepareErr error -} - type PreparedStmtDB struct { - Stmts StmtStore + Stmts stmt_store.Store Mux *sync.RWMutex ConnPool } -const default_max_size = (1 << 63) - 1 -const default_ttl = time.Hour * 24 - -// newPrepareStmtCache creates a new statement cache with the specified maximum size and time-to-live (TTL). -// Parameters: -// - PrepareStmtMaxSize: An integer specifying the maximum number of prepared statements to cache. -// If this value is less than or equal to 0, the function will panic. -// - PrepareStmtTTL: A time.Duration specifying the TTL for cached statements. -// If this value differs from the default TTL, it will be used instead. -// -// Returns: -// - A pointer to a store.StmtStore instance configured with the provided parameters. -// -// The function initializes an LRU (Least Recently Used) cache for prepared statements, -// using either the provided size and TTL or default values -func newPrepareStmtCache(PrepareStmtMaxSize int, - PrepareStmtTTL time.Duration) *StmtStore { - var lru_size = default_max_size - var lru_ttl = default_ttl - var stmts StmtStore - if PrepareStmtMaxSize < 0 { - panic("PrepareStmtMaxSize must > 0") - } - if PrepareStmtMaxSize != 0 { - lru_size = PrepareStmtMaxSize - } - if PrepareStmtTTL != default_ttl { - lru_ttl = PrepareStmtTTL - } - lru := &LruStmtStore{} - lru.newLru(lru_size, lru_ttl) - stmts = lru - return &stmts -} -func NewPreparedStmtDB(connPool ConnPool, PrepareStmtMaxSize int, - PrepareStmtTTL time.Duration) *PreparedStmtDB { +func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { return &PreparedStmtDB{ ConnPool: connPool, - Stmts: *newPrepareStmtCache(PrepareStmtMaxSize, - PrepareStmtTTL), - Mux: &sync.RWMutex{}, + Stmts: stmt_store.New(maxSize, ttl), + Mux: &sync.RWMutex{}, } } @@ -89,13 +46,7 @@ func (db *PreparedStmtDB) Close() { } for _, stmt := range db.Stmts.AllMap() { - go func(s *Stmt) { - // make sure the stmt must finish preparation first - <-s.prepared - if s.Stmt != nil { - _ = s.Close() - } - }(stmt) + go stmt.Close() } // setting db.Stmts to nil to avoid further using db.Stmts = nil @@ -108,28 +59,20 @@ func (sdb *PreparedStmtDB) Reset() { return } for _, stmt := range sdb.Stmts.AllMap() { - go func(s *Stmt) { - // make sure the stmt must finish preparation first - <-s.prepared - if s.Stmt != nil { - _ = s.Close() - } - }(stmt) + go stmt.Close() } - //Migrator - defaultStmt := newPrepareStmtCache(0, 0) - sdb.Stmts = *defaultStmt + + // Migrator + sdb.Stmts = stmt_store.New(0, 0) } -func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) { db.Mux.RLock() if db.Stmts != nil { - if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) { + if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - // wait for other goroutines prepared - <-stmt.prepared - if stmt.prepareErr != nil { - return Stmt{}, stmt.prepareErr + if err := stmt.Error(); err != nil { + return stmt_store.Stmt{}, err } return *stmt, nil @@ -140,12 +83,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.Lock() if db.Stmts != nil { // double check - if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) { + if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.Unlock() - // wait for other goroutines prepared - <-stmt.prepared - if stmt.prepareErr != nil { - return Stmt{}, stmt.prepareErr + if err := stmt.Error(); err != nil { + return stmt_store.Stmt{}, err } return *stmt, nil @@ -155,15 +96,15 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact // which cause by calling Close and executing SQL concurrently if db.Stmts == nil { db.Mux.Unlock() - return Stmt{}, ErrInvalidDB + return stmt_store.Stmt{}, ErrInvalidDB } // cache preparing stmt first - cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} - db.Stmts.set(query, &cacheStmt) + cacheStmt := stmt_store.NewStmt(isTransaction) + db.Stmts.Set(query, cacheStmt) db.Mux.Unlock() // prepare completed - defer close(cacheStmt.prepared) + defer cacheStmt.Done() // Reason why cannot lock conn.PrepareContext // suppose the maxopen is 1, g1 is creating record and g2 is querying record. @@ -172,19 +113,18 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) if err != nil { - cacheStmt.prepareErr = err + cacheStmt.AddError(err) db.Mux.Lock() - db.Stmts.delete(query) - //delete(db.Stmts.AllMap(), query) + db.Stmts.Delete(query) db.Mux.Unlock() - return Stmt{}, err + return stmt_store.Stmt{}, err } db.Mux.Lock() cacheStmt.Stmt = stmt db.Mux.Unlock() - return cacheStmt, nil + return *cacheStmt, nil } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -216,8 +156,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() - db.Stmts.delete(query) - //delete(db.Stmts.AllMap(), query) + db.Stmts.Delete(query) } } return result, err @@ -232,8 +171,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . defer db.Mux.Unlock() go stmt.Close() - db.Stmts.delete(query) - //delete(db.Stmts.AllMap(), query) + db.Stmts.Delete(query) } } return rows, err @@ -287,8 +225,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. defer tx.PreparedStmtDB.Mux.Unlock() go stmt.Close() - tx.PreparedStmtDB.Stmts.delete(query) - //delete(tx.PreparedStmtDB.Stmts.AllMap(), query) + tx.PreparedStmtDB.Stmts.Delete(query) } } return result, err @@ -303,8 +240,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . defer tx.PreparedStmtDB.Mux.Unlock() go stmt.Close() - tx.PreparedStmtDB.Stmts.delete(query) - //delete(tx.PreparedStmtDB.Stmts.AllMap(), query) + tx.PreparedStmtDB.Stmts.Delete(query) } } return rows, err @@ -325,79 +261,3 @@ func (tx *PreparedStmtTX) Ping() error { } return conn.Ping() } - -type StmtStore interface { - get(key string) (*Stmt, bool) - set(key string, value *Stmt) - delete(key string) - AllMap() map[string]*Stmt -} - -/* - type DefaultStmtStore struct { - defaultStmt map[string]*Stmt - } - - func (s *DefaultStmtStore) Init() *DefaultStmtStore { - s.defaultStmt = make(map[string]*Stmt) - return s - } - - func (s *DefaultStmtStore) AllMap() map[string]*Stmt { - return s.defaultStmt - } - - func (s *DefaultStmtStore) Get(key string) (*Stmt, bool) { - stmt, ok := s.defaultStmt[key] - return stmt, ok - } - - func (s *DefaultStmtStore) Set(key string, value *Stmt) { - s.defaultStmt[key] = value - } - - func (s *DefaultStmtStore) Delete(key string) { - delete(s.defaultStmt, key) - } -*/ -type LruStmtStore struct { - lru *lru.LRU[string, *Stmt] -} - -func (s *LruStmtStore) newLru(size int, ttl time.Duration) { - onEvicted := func(k string, v *Stmt) { - if v != nil { - go func() { - defer func() { - if r := recover(); r != nil { - fmt.Print("close stmt err panic ") - } - }() - if v != nil { - err := v.Close() - if err != nil { - // - fmt.Print("close stmt err: ", err.Error()) - } - } - }() - } - } - s.lru = lru.NewLRU[string, *Stmt](size, onEvicted, ttl) -} - -func (s *LruStmtStore) AllMap() map[string]*Stmt { - return s.lru.KeyValues() -} -func (s *LruStmtStore) get(key string) (*Stmt, bool) { - stmt, ok := s.lru.Get(key) - return stmt, ok -} - -func (s *LruStmtStore) set(key string, value *Stmt) { - s.lru.Add(key, value) -} - -func (s *LruStmtStore) delete(key string) { - s.lru.Remove(key) -}