From 3f922a78c8f2d608eff504b544fd18bb56c1c0a3 Mon Sep 17 00:00:00 2001 From: xiezhaodong Date: Thu, 24 Apr 2025 16:00:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AA=E4=BD=BF=E7=94=A8lru?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gorm.go | 8 +- lru.go => internal/lru/lru.go | 2 +- internal/store/stmt_store.go | 84 +++++++++++++++++++++ prepare_stmt.go | 134 ++++++++++------------------------ 4 files changed, 130 insertions(+), 98 deletions(-) rename lru.go => internal/lru/lru.go (99%) create mode 100644 internal/store/stmt_store.go diff --git a/gorm.go b/gorm.go index 9a100169..1a93b83a 100644 --- a/gorm.go +++ b/gorm.go @@ -35,7 +35,9 @@ type Config struct { // PrepareStmt executes the given query in cached statement PrepareStmt bool // PrepareStmt cache support LRU expired - PrepareStmtLruConfig *PrepareStmtLruConfig + PrepareStmtMaxSize int + PrepareStmtTTL time.Duration + // DisableAutomaticPing DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating @@ -204,7 +206,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.PrepareStmt { - preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtLruConfig) + preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -275,7 +277,7 @@ func (db *DB) Session(config *Session) *DB { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt = v.(*PreparedStmtDB) } else { - preparedStmt = NewPreparedStmtDB(db.ConnPool, db.Config.PrepareStmtLruConfig) + preparedStmt = NewPreparedStmtDB(db.ConnPool, db.Config.PrepareStmtMaxSize, db.Config.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } diff --git a/lru.go b/internal/lru/lru.go similarity index 99% rename from lru.go rename to internal/lru/lru.go index 6dcd1e62..fbdffbba 100644 --- a/lru.go +++ b/internal/lru/lru.go @@ -1,4 +1,4 @@ -package gorm +package lru // golang -lru //https://github.com/hashicorp/golang-lru diff --git a/internal/store/stmt_store.go b/internal/store/stmt_store.go new file mode 100644 index 00000000..94d5a7b2 --- /dev/null +++ b/internal/store/stmt_store.go @@ -0,0 +1,84 @@ +package store + +import ( + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/internal/lru" + "time" +) + +type StmtStore interface { + Get(key string) (*gorm.Stmt, bool) + Set(key string, value *gorm.Stmt) + Delete(key string) + AllMap() map[string]*gorm.Stmt +} + +/* + type DefaultStmtStore struct { + defaultStmt map[string]*gorm.Stmt + } + + func (s *DefaultStmtStore) Init() *DefaultStmtStore { + s.defaultStmt = make(map[string]*gorm.Stmt) + return s + } + + func (s *DefaultStmtStore) AllMap() map[string]*gorm.Stmt { + return s.defaultStmt + } + + func (s *DefaultStmtStore) Get(key string) (*gorm.Stmt, bool) { + stmt, ok := s.defaultStmt[key] + return stmt, ok + } + + func (s *DefaultStmtStore) Set(key string, value *gorm.Stmt) { + s.defaultStmt[key] = value + } + + func (s *DefaultStmtStore) Delete(key string) { + delete(s.defaultStmt, key) + } +*/ +type LruStmtStore struct { + lru *lru.LRU[string, *gorm.Stmt] +} + +func (s *LruStmtStore) NewLru(size int, ttl time.Duration) { + onEvicted := func(k string, v *gorm.Stmt) { + if v != nil { + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Print("close stmt err panic ") + } + }() + if v != nil { + err := v.Close() + if err != nil { + // + fmt.Print("close stmt err: ", err.Error()) + } + } + }() + } + } + s.lru = lru.NewLRU[string, *gorm.Stmt](size, onEvicted, ttl) +} + +func (s *LruStmtStore) AllMap() map[string]*gorm.Stmt { + return s.lru.KeyValues() +} +func (s *LruStmtStore) Get(key string) (*gorm.Stmt, bool) { + stmt, ok := s.lru.Get(key) + return stmt, ok +} + +func (s *LruStmtStore) Set(key string, value *gorm.Stmt) { + s.lru.Add(key, value) +} + +func (s *LruStmtStore) Delete(key string) { + s.lru.Remove(key) +} diff --git a/prepare_stmt.go b/prepare_stmt.go index 2836ed83..c5c59626 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -5,7 +5,7 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" + "gorm.io/gorm/internal/store" "reflect" "sync" "time" @@ -19,31 +19,52 @@ type Stmt struct { } type PreparedStmtDB struct { - Stmts StmtStore + Stmts store.StmtStore Mux *sync.RWMutex ConnPool } -func newPrepareStmtCache(prepareStmtLruConfig *PrepareStmtLruConfig) *StmtStore { - var stmts StmtStore - if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open { - if prepareStmtLruConfig.Size <= 0 { - panic("LRU prepareStmtLruConfig.Size must > 0") - } - lru := &LruStmtStore{} - lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL) - stmts = lru - } else { - defaultStmtStore := &DefaultStmtStore{} - stmts = defaultStmtStore.init() +const DEFAULT_MAX_SIZE = (1 << 63) - 1 +const DEFAULT_TTL = time.Hour * 24 + +// newPrepareStmtCache creates a new statement cache with the specified maximum size and time-to-live (TTL). +// Parameters: +// - PrepareStmtMaxSize: An integer specifying the maximum number of prepared statements to cache. +// If this value is less than or equal to 0, the function will panic. +// - PrepareStmtTTL: A time.Duration specifying the TTL for cached statements. +// If this value differs from the default TTL, it will be used instead. +// +// Returns: +// - A pointer to a store.StmtStore instance configured with the provided parameters. +// +// The function initializes an LRU (Least Recently Used) cache for prepared statements, +// using either the provided size and TTL or default values +func newPrepareStmtCache(PrepareStmtMaxSize int, + PrepareStmtTTL time.Duration) *store.StmtStore { + var lru_size = DEFAULT_MAX_SIZE + var lru_ttl = DEFAULT_TTL + var stmts store.StmtStore + if PrepareStmtMaxSize <= 0 { + panic("PrepareStmtMaxSize must > 0") } + if PrepareStmtMaxSize != 0 { + lru_size = PrepareStmtMaxSize + } + if PrepareStmtTTL != DEFAULT_TTL { + lru_ttl = PrepareStmtTTL + } + lru := &store.LruStmtStore{} + lru.NewLru(lru_size, lru_ttl) + stmts = lru return &stmts } -func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB { +func NewPreparedStmtDB(connPool ConnPool, PrepareStmtMaxSize int, + PrepareStmtTTL time.Duration) *PreparedStmtDB { return &PreparedStmtDB{ ConnPool: connPool, - Stmts: *newPrepareStmtCache(prepareStmtLruConfig), - Mux: &sync.RWMutex{}, + Stmts: *newPrepareStmtCache(PrepareStmtMaxSize, + PrepareStmtTTL), + Mux: &sync.RWMutex{}, } } @@ -94,9 +115,8 @@ func (sdb *PreparedStmtDB) Reset() { } }(stmt) } - defaultStmt := &DefaultStmtStore{} - defaultStmt.init() - sdb.Stmts = defaultStmt + defaultStmt := newPrepareStmtCache(0, 0) + sdb.Stmts = *defaultStmt } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -303,77 +323,3 @@ func (tx *PreparedStmtTX) Ping() error { } return conn.Ping() } - -type StmtStore interface { - Get(key string) (*Stmt, bool) - Set(key string, value *Stmt) - Delete(key string) - AllMap() map[string]*Stmt -} - -type DefaultStmtStore struct { - defaultStmt map[string]*Stmt -} - -func (s *DefaultStmtStore) init() *DefaultStmtStore { - s.defaultStmt = make(map[string]*Stmt) - return s -} - -func (s *DefaultStmtStore) AllMap() map[string]*Stmt { - return s.defaultStmt -} -func (s *DefaultStmtStore) Get(key string) (*Stmt, bool) { - stmt, ok := s.defaultStmt[key] - return stmt, ok -} - -func (s *DefaultStmtStore) Set(key string, value *Stmt) { - s.defaultStmt[key] = value -} - -func (s *DefaultStmtStore) Delete(key string) { - delete(s.defaultStmt, key) -} - -type LruStmtStore struct { - lru *LRU[string, *Stmt] -} - -func (s *LruStmtStore) NewLru(size int, ttl time.Duration) { - onEvicted := func(k string, v *Stmt) { - if v != nil { - go func() { - defer func() { - if r := recover(); r != nil { - fmt.Print("close stmt err panic ") - } - }() - if v != nil { - err := v.Close() - if err != nil { - // - fmt.Print("close stmt err: ", err.Error()) - } - } - }() - } - } - s.lru = NewLRU[string, *Stmt](size, onEvicted, ttl) -} - -func (s *LruStmtStore) AllMap() map[string]*Stmt { - return s.lru.KeyValues() -} -func (s *LruStmtStore) Get(key string) (*Stmt, bool) { - stmt, ok := s.lru.Get(key) - return stmt, ok -} - -func (s *LruStmtStore) Set(key string, value *Stmt) { - s.lru.Add(key, value) -} - -func (s *LruStmtStore) Delete(key string) { - s.lru.Remove(key) -}