From a827495be126f896c0f744043410642d70bbac1b Mon Sep 17 00:00:00 2001 From: Zhaodong Xie <837199685@qq.com> Date: Fri, 25 Apr 2025 16:22:26 +0800 Subject: [PATCH] Preparestmt use LRU Map instead default map (#7435) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 支持lru淘汰preparestmt cache * 支持lru淘汰preparestmt cache * 支持lru淘汰preparestmt cache * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * change const export * Add stmt_store * refact prepare stmt store * Rename lru store * change const export * ADD UT * format code and add session level prepare stmt config * code format according to golinter ci * ADD UT --------- Co-authored-by: xiezhaodong Co-authored-by: Jinzhu --- gorm.go | 11 +- internal/lru/lru.go | 493 ++++++++++++++++++++++++++ internal/stmt_store/stmt_store.go | 182 ++++++++++ prepare_stmt.go | 144 +++----- tests/go.mod | 18 +- tests/lru_test.go | 561 ++++++++++++++++++++++++++++++ tests/prepared_stmt_test.go | 169 ++++----- 7 files changed, 1365 insertions(+), 213 deletions(-) create mode 100644 internal/lru/lru.go create mode 100644 internal/stmt_store/stmt_store.go create mode 100644 tests/lru_test.go diff --git a/gorm.go b/gorm.go index bc6d6db3..d253736d 100644 --- a/gorm.go +++ b/gorm.go @@ -34,6 +34,11 @@ type Config struct { DryRun bool // PrepareStmt executes the given query in cached statement PrepareStmt bool + // PrepareStmt cache support LRU expired, + // default maxsize=int64 Max value and ttl=1h + PrepareStmtMaxSize int + PrepareStmtTTL time.Duration + // DisableAutomaticPing DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating @@ -105,6 +110,8 @@ type DB struct { type Session struct { DryRun bool PrepareStmt bool + PrepareStmtMaxSize int + PrepareStmtTTL time.Duration NewDB bool Initialized bool SkipHooks bool @@ -197,7 +204,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.PrepareStmt { - preparedStmt := NewPreparedStmtDB(db.ConnPool) + preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -268,7 +275,7 @@ func (db *DB) Session(config *Session) *DB { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt = v.(*PreparedStmtDB) } else { - preparedStmt = NewPreparedStmtDB(db.ConnPool) + preparedStmt = NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } diff --git a/internal/lru/lru.go b/internal/lru/lru.go new file mode 100644 index 00000000..4f21589a --- /dev/null +++ b/internal/lru/lru.go @@ -0,0 +1,493 @@ +package lru + +// golang -lru +// https://github.com/hashicorp/golang-lru +import ( + "sync" + "time" +) + +// EvictCallback is used to get a callback when a cache entry is evicted +type EvictCallback[K comparable, V any] func(key K, value V) + +// LRU implements a thread-safe LRU with expirable entries. +type LRU[K comparable, V any] struct { + size int + evictList *LruList[K, V] + items map[K]*Entry[K, V] + onEvict EvictCallback[K, V] + + // expirable options + mu sync.Mutex + ttl time.Duration + done chan struct{} + + // buckets for expiration + buckets []bucket[K, V] + // uint8 because it's number between 0 and numBuckets + nextCleanupBucket uint8 +} + +// bucket is a container for holding entries to be expired +type bucket[K comparable, V any] struct { + entries map[K]*Entry[K, V] + newestEntry time.Time +} + +// noEvictionTTL - very long ttl to prevent eviction +const noEvictionTTL = time.Hour * 24 * 365 * 10 + +// because of uint8 usage for nextCleanupBucket, should not exceed 256. +// casting it as uint8 explicitly requires type conversions in multiple places +const numBuckets = 100 + +// NewLRU returns a new thread-safe cache with expirable entries. +// +// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off. +// +// Providing 0 TTL turns expiring off. +// +// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely. +func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] { + if size < 0 { + size = 0 + } + if ttl <= 0 { + ttl = noEvictionTTL + } + + res := LRU[K, V]{ + ttl: ttl, + size: size, + evictList: NewList[K, V](), + items: make(map[K]*Entry[K, V]), + onEvict: onEvict, + done: make(chan struct{}), + } + + // initialize the buckets + res.buckets = make([]bucket[K, V], numBuckets) + for i := 0; i < numBuckets; i++ { + res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])} + } + + // enable deleteExpired() running in separate goroutine for cache with non-zero TTL + // + // Important: done channel is never closed, so deleteExpired() goroutine will never exit, + // it's decided to add functionality to close it in the version later than v2. + if res.ttl != noEvictionTTL { + go func(done <-chan struct{}) { + ticker := time.NewTicker(res.ttl / numBuckets) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + res.deleteExpired() + } + } + }(res.done) + } + return &res +} + +// Purge clears the cache completely. +// onEvict is called for each evicted key. +func (c *LRU[K, V]) Purge() { + c.mu.Lock() + defer c.mu.Unlock() + for k, v := range c.items { + if c.onEvict != nil { + c.onEvict(k, v.Value) + } + delete(c.items, k) + } + for _, b := range c.buckets { + for _, ent := range b.entries { + delete(b.entries, ent.Key) + } + } + c.evictList.Init() +} + +// Add adds a value to the cache. Returns true if an eviction occurred. +// Returns false if there was no eviction: the item was already in the cache, +// or the size was not exceeded. +func (c *LRU[K, V]) Add(key K, value V) (evicted bool) { + c.mu.Lock() + defer c.mu.Unlock() + now := time.Now() + + // Check for existing item + if ent, ok := c.items[key]; ok { + c.evictList.MoveToFront(ent) + c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed + ent.Value = value + ent.ExpiresAt = now.Add(c.ttl) + c.addToBucket(ent) + return false + } + + // Add new item + ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl)) + c.items[key] = ent + c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket + + evict := c.size > 0 && c.evictList.Length() > c.size + // Verify size not exceeded + if evict { + c.removeOldest() + } + return evict +} + +// Get looks up a key's value from the cache. +func (c *LRU[K, V]) Get(key K) (value V, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + var ent *Entry[K, V] + if ent, ok = c.items[key]; ok { + // Expired item check + if time.Now().After(ent.ExpiresAt) { + return value, false + } + c.evictList.MoveToFront(ent) + return ent.Value, true + } + return +} + +// Contains checks if a key is in the cache, without updating the recent-ness +// or deleting it for being stale. +func (c *LRU[K, V]) Contains(key K) (ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + _, ok = c.items[key] + return ok +} + +// Peek returns the key value (or undefined if not found) without updating +// the "recently used"-ness of the key. +func (c *LRU[K, V]) Peek(key K) (value V, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + var ent *Entry[K, V] + if ent, ok = c.items[key]; ok { + // Expired item check + if time.Now().After(ent.ExpiresAt) { + return value, false + } + return ent.Value, true + } + return +} + +// Remove removes the provided key from the cache, returning if the +// key was contained. +func (c *LRU[K, V]) Remove(key K) bool { + c.mu.Lock() + defer c.mu.Unlock() + if ent, ok := c.items[key]; ok { + c.removeElement(ent) + return true + } + return false +} + +// RemoveOldest removes the oldest item from the cache. +func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ent := c.evictList.Back(); ent != nil { + c.removeElement(ent) + return ent.Key, ent.Value, true + } + return +} + +// GetOldest returns the oldest entry +func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ent := c.evictList.Back(); ent != nil { + return ent.Key, ent.Value, true + } + return +} + +func (c *LRU[K, V]) KeyValues() map[K]V { + c.mu.Lock() + defer c.mu.Unlock() + maps := make(map[K]V) + now := time.Now() + for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() { + if now.After(ent.ExpiresAt) { + continue + } + maps[ent.Key] = ent.Value + // keys = append(keys, ent.Key) + } + return maps +} + +// Keys returns a slice of the keys in the cache, from oldest to newest. +// Expired entries are filtered out. +func (c *LRU[K, V]) Keys() []K { + c.mu.Lock() + defer c.mu.Unlock() + keys := make([]K, 0, len(c.items)) + now := time.Now() + for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() { + if now.After(ent.ExpiresAt) { + continue + } + keys = append(keys, ent.Key) + } + return keys +} + +// Values returns a slice of the values in the cache, from oldest to newest. +// Expired entries are filtered out. +func (c *LRU[K, V]) Values() []V { + c.mu.Lock() + defer c.mu.Unlock() + values := make([]V, 0, len(c.items)) + now := time.Now() + for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() { + if now.After(ent.ExpiresAt) { + continue + } + values = append(values, ent.Value) + } + return values +} + +// Len returns the number of items in the cache. +func (c *LRU[K, V]) Len() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.evictList.Length() +} + +// Resize changes the cache size. Size of 0 means unlimited. +func (c *LRU[K, V]) Resize(size int) (evicted int) { + c.mu.Lock() + defer c.mu.Unlock() + if size <= 0 { + c.size = 0 + return 0 + } + diff := c.evictList.Length() - size + if diff < 0 { + diff = 0 + } + for i := 0; i < diff; i++ { + c.removeOldest() + } + c.size = size + return diff +} + +// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close(). +// func (c *LRU[K, V]) Close() { +// c.mu.Lock() +// defer c.mu.Unlock() +// select { +// case <-c.done: +// return +// default: +// } +// close(c.done) +// } + +// removeOldest removes the oldest item from the cache. Has to be called with lock! +func (c *LRU[K, V]) removeOldest() { + if ent := c.evictList.Back(); ent != nil { + c.removeElement(ent) + } +} + +// removeElement is used to remove a given list element from the cache. Has to be called with lock! +func (c *LRU[K, V]) removeElement(e *Entry[K, V]) { + c.evictList.Remove(e) + delete(c.items, e.Key) + c.removeFromBucket(e) + if c.onEvict != nil { + c.onEvict(e.Key, e.Value) + } +} + +// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry +// in it to expire first. +func (c *LRU[K, V]) deleteExpired() { + c.mu.Lock() + bucketIdx := c.nextCleanupBucket + timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry) + // wait for newest entry to expire before cleanup without holding lock + if timeToExpire > 0 { + c.mu.Unlock() + time.Sleep(timeToExpire) + c.mu.Lock() + } + for _, ent := range c.buckets[bucketIdx].entries { + c.removeElement(ent) + } + c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets + c.mu.Unlock() +} + +// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock! +func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) { + bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets + e.ExpireBucket = bucketID + c.buckets[bucketID].entries[e.Key] = e + if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) { + c.buckets[bucketID].newestEntry = e.ExpiresAt + } +} + +// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock! +func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) { + delete(c.buckets[e.ExpireBucket].entries, e.Key) +} + +// Cap returns the capacity of the cache +func (c *LRU[K, V]) Cap() int { + return c.size +} + +// Entry is an LRU Entry +type Entry[K comparable, V any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Entry[K, V] + + // The list to which this element belongs. + list *LruList[K, V] + + // The LRU Key of this element. + Key K + + // The Value stored with this element. + Value V + + // The time this element would be cleaned up, optional + ExpiresAt time.Time + + // The expiry bucket item was put in, optional + ExpireBucket uint8 +} + +// PrevEntry returns the previous list element or nil. +func (e *Entry[K, V]) PrevEntry() *Entry[K, V] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// LruList represents a doubly linked list. +// The zero Value for LruList is an empty list ready to use. +type LruList[K comparable, V any] struct { + root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list Length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *LruList[K, V]) Init() *LruList[K, V] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewList returns an initialized list. +func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() } + +// Length returns the number of elements of list l. +// The complexity is O(1). +func (l *LruList[K, V]) Length() int { return l.len } + +// Back returns the last element of list l or nil if the list is empty. +func (l *LruList[K, V]) Back() *Entry[K, V] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List Value. +func (l *LruList[K, V]) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at). +func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] { + return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at) +} + +// Remove removes e from its list, decrements l.len +func (l *LruList[K, V]) Remove(e *Entry[K, V]) V { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + + return e.Value +} + +// move moves e to next to at. +func (l *LruList[K, V]) move(e, at *Entry[K, V]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] { + l.lazyInit() + return l.insertValue(k, v, time.Time{}, &l.root) +} + +// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e. +func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] { + l.lazyInit() + return l.insertValue(k, v, expiresAt, &l.root) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) +} diff --git a/internal/stmt_store/stmt_store.go b/internal/stmt_store/stmt_store.go new file mode 100644 index 00000000..7068419d --- /dev/null +++ b/internal/stmt_store/stmt_store.go @@ -0,0 +1,182 @@ +package stmt_store + +import ( + "context" + "database/sql" + "sync" + "time" + + "gorm.io/gorm/internal/lru" +) + +type Stmt struct { + *sql.Stmt + Transaction bool + prepared chan struct{} + prepareErr error +} + +func (stmt *Stmt) Error() error { + return stmt.prepareErr +} + +func (stmt *Stmt) Close() error { + <-stmt.prepared + + if stmt.Stmt != nil { + return stmt.Stmt.Close() + } + return nil +} + +// Store defines an interface for managing the caching operations of SQL statements (Stmt). +// This interface provides methods for creating new statements, retrieving all cache keys, +// getting cached statements, setting cached statements, and deleting cached statements. +type Store interface { + // New creates a new Stmt object and caches it. + // Parameters: + // ctx: The context for the request, which can carry deadlines, cancellation signals, etc. + // key: The key representing the SQL query, used for caching and preparing the statement. + // isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy. + // connPool: A connection pool that provides database connections. + // locker: A synchronization lock that is unlocked after initialization to avoid deadlocks. + // Returns: + // *Stmt: A newly created statement object for executing SQL operations. + // error: An error if the statement preparation fails. + New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error) + + // Keys returns a slice of all cache keys in the store. + Keys() []string + + // Get retrieves a Stmt object from the store based on the given key. + // Parameters: + // key: The key used to look up the Stmt object. + // Returns: + // *Stmt: The found Stmt object, or nil if not found. + // bool: Indicates whether the corresponding Stmt object was successfully found. + Get(key string) (*Stmt, bool) + + // Set stores the given Stmt object in the store and associates it with the specified key. + // Parameters: + // key: The key used to associate the Stmt object. + // value: The Stmt object to be stored. + Set(key string, value *Stmt) + + // Delete removes the Stmt object corresponding to the specified key from the store. + // Parameters: + // key: The key associated with the Stmt object to be deleted. + Delete(key string) +} + +// defaultMaxSize defines the default maximum capacity of the cache. +// Its value is the maximum value of the int64 type, which means that when the cache size is not specified, +// the cache can theoretically store as many elements as possible. +// (1 << 63) - 1 is the maximum value that an int64 type can represent. +const ( + defaultMaxSize = (1 << 63) - 1 + // defaultTTL defines the default time-to-live (TTL) for each cache entry. + // When the TTL for cache entries is not specified, each cache entry will expire after 24 hours. + defaultTTL = time.Hour * 24 +) + +// New creates and returns a new Store instance. +// +// Parameters: +// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0, +// it defaults to defaultMaxSize. +// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0, +// it defaults to defaultTTL. +// +// This function defines an onEvicted callback that is invoked when a cache entry is evicted. +// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously +// to release associated resources. +// +// Returns: +// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size, +// eviction callback, and TTL. +func New(size int, ttl time.Duration) Store { + if size <= 0 { + size = defaultMaxSize + } + + if ttl <= 0 { + ttl = defaultTTL + } + + onEvicted := func(k string, v *Stmt) { + if v != nil { + go v.Close() + } + } + return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} +} + +type lruStore struct { + lru *lru.LRU[string, *Stmt] +} + +func (s *lruStore) Keys() []string { + return s.lru.Keys() +} + +func (s *lruStore) Get(key string) (*Stmt, bool) { + stmt, ok := s.lru.Get(key) + if ok && stmt != nil { + <-stmt.prepared + } + return stmt, ok +} + +func (s *lruStore) Set(key string, value *Stmt) { + s.lru.Add(key, value) +} + +func (s *lruStore) Delete(key string) { + s.lru.Remove(key) +} + +type ConnPool interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// New creates a new Stmt object for executing SQL queries. +// It caches the Stmt object for future use and handles preparation and error states. +// Parameters: +// +// ctx: Context for the request, used to carry deadlines, cancellation signals, etc. +// key: The key representing the SQL query, used for caching and preparing the statement. +// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy. +// conn: A connection pool that provides database connections. +// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks. +// +// Returns: +// +// *Stmt: A newly created statement object for executing SQL operations. +// error: An error if the statement preparation fails. +func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) { + // Create a Stmt object and set its Transaction property. + // The prepared channel is used to synchronize the statement preparation state. + cacheStmt := &Stmt{ + Transaction: isTransaction, + prepared: make(chan struct{}), + } + // Cache the Stmt object with the associated key. + s.Set(key, cacheStmt) + // Unlock after completing initialization to prevent deadlocks. + locker.Unlock() + + // Ensure the prepared channel is closed after the function execution completes. + defer close(cacheStmt.prepared) + + // Prepare the SQL statement using the provided connection. + cacheStmt.Stmt, err = conn.PrepareContext(ctx, key) + if err != nil { + // If statement preparation fails, record the error and remove the invalid Stmt object from the cache. + cacheStmt.prepareErr = err + s.Delete(key) + return &Stmt{}, err + } + + // Return the successfully prepared Stmt object. + return cacheStmt, nil +} diff --git a/prepare_stmt.go b/prepare_stmt.go index 094bb477..799df5bc 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -7,29 +7,35 @@ import ( "errors" "reflect" "sync" + "time" + + "gorm.io/gorm/internal/stmt_store" ) -type Stmt struct { - *sql.Stmt - Transaction bool - prepared chan struct{} - prepareErr error -} - type PreparedStmtDB struct { - Stmts map[string]*Stmt + Stmts stmt_store.Store Mux *sync.RWMutex ConnPool } -func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { +// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB. +// +// Parameters: +// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections. +// - maxSize: The maximum number of prepared statements that can be stored in the statement store. +// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed. +// +// Returns: +// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration. +func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { return &PreparedStmtDB{ - ConnPool: connPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, + ConnPool: connPool, // Assigns the provided connection pool to manage database connections. + Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL. + Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store. } } +// GetDBConn returns the underlying *sql.DB connection func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil @@ -42,98 +48,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { return nil, ErrInvalidDB } +// Close closes all prepared statements in the store func (db *PreparedStmtDB) Close() { db.Mux.Lock() defer db.Mux.Unlock() - for _, stmt := range db.Stmts { - go func(s *Stmt) { - // make sure the stmt must finish preparation first - <-s.prepared - if s.Stmt != nil { - _ = s.Close() - } - }(stmt) + for _, key := range db.Stmts.Keys() { + db.Stmts.Delete(key) } - // setting db.Stmts to nil to avoid further using - db.Stmts = nil } -func (sdb *PreparedStmtDB) Reset() { - sdb.Mux.Lock() - defer sdb.Mux.Unlock() - - for _, stmt := range sdb.Stmts { - go func(s *Stmt) { - // make sure the stmt must finish preparation first - <-s.prepared - if s.Stmt != nil { - _ = s.Close() - } - }(stmt) - } - sdb.Stmts = make(map[string]*Stmt) +// Reset Deprecated use Close instead +func (db *PreparedStmtDB) Reset() { + db.Close() } -func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) { db.Mux.RLock() - if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - db.Mux.RUnlock() - // wait for other goroutines prepared - <-stmt.prepared - if stmt.prepareErr != nil { - return Stmt{}, stmt.prepareErr + if db.Stmts != nil { + if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { + db.Mux.RUnlock() + return stmt, stmt.Error() } - - return *stmt, nil } db.Mux.RUnlock() + // retry db.Mux.Lock() - // double check - if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - db.Mux.Unlock() - // wait for other goroutines prepared - <-stmt.prepared - if stmt.prepareErr != nil { - return Stmt{}, stmt.prepareErr + if db.Stmts != nil { + if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { + db.Mux.Unlock() + return stmt, stmt.Error() } - - return *stmt, nil - } - // check db.Stmts first to avoid Segmentation Fault(setting value to nil map) - // which cause by calling Close and executing SQL concurrently - if db.Stmts == nil { - db.Mux.Unlock() - return Stmt{}, ErrInvalidDB - } - // cache preparing stmt first - cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} - db.Stmts[query] = &cacheStmt - db.Mux.Unlock() - - // prepare completed - defer close(cacheStmt.prepared) - - // Reason why cannot lock conn.PrepareContext - // suppose the maxopen is 1, g1 is creating record and g2 is querying record. - // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. - // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. - // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. - stmt, err := conn.PrepareContext(ctx, query) - if err != nil { - cacheStmt.prepareErr = err - db.Mux.Lock() - delete(db.Stmts, query) - db.Mux.Unlock() - return Stmt{}, err } - db.Mux.Lock() - cacheStmt.Stmt = stmt - db.Mux.Unlock() - - return cacheStmt, nil + return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux) } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -162,10 +111,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = stmt.ExecContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - db.Mux.Lock() - defer db.Mux.Unlock() - go stmt.Close() - delete(db.Stmts, query) + db.Stmts.Delete(query) } } return result, err @@ -176,11 +122,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = stmt.QueryContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - db.Mux.Lock() - defer db.Mux.Unlock() - - go stmt.Close() - delete(db.Stmts, query) + db.Stmts.Delete(query) } } return rows, err @@ -230,11 +172,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - tx.PreparedStmtDB.Mux.Lock() - defer tx.PreparedStmtDB.Mux.Unlock() - - go stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.Stmts.Delete(query) } } return result, err @@ -245,11 +183,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { - tx.PreparedStmtDB.Mux.Lock() - defer tx.PreparedStmtDB.Mux.Unlock() - - go stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.Stmts.Delete(query) } } return rows, err diff --git a/tests/go.mod b/tests/go.mod index 30143433..778e3bca 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,15 +1,17 @@ module gorm.io/gorm/tests -go 1.18 +go 1.23.0 + +toolchain go1.24.2 require ( github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 gorm.io/driver/mysql v1.5.7 - gorm.io/driver/postgres v1.5.10 - gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlserver v1.5.4 gorm.io/gorm v1.25.12 ) @@ -17,7 +19,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-sql-driver/mysql v1.9.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -25,12 +27,12 @@ require ( github.com/jackc/pgx/v5 v5.7.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect - github.com/mattn/go-sqlite3 v1.14.24 // indirect + github.com/mattn/go-sqlite3 v1.14.28 // indirect github.com/microsoft/go-mssqldb v1.7.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.29.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/text v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/lru_test.go b/tests/lru_test.go new file mode 100644 index 00000000..3eaef5de --- /dev/null +++ b/tests/lru_test.go @@ -0,0 +1,561 @@ +package tests_test + +import ( + "crypto/rand" + "fmt" + "gorm.io/gorm/internal/lru" + "math" + "math/big" + "reflect" + "sync" + "testing" + "time" +) + +func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) { + lru := lru.NewLRU[string, int](10, nil, time.Hour) + lru.Add("key1", 1) + lru.Add("key1", 2) + + if value, ok := lru.Get("key1"); !ok || value != 2 { + t.Errorf("Expected value to be updated to 2, got %v", value) + } +} + +func TestLRU_Add_NewKey_AddsEntry(t *testing.T) { + lru := lru.NewLRU[string, int](10, nil, time.Hour) + lru.Add("key1", 1) + + if value, ok := lru.Get("key1"); !ok || value != 1 { + t.Errorf("Expected key1 to be added with value 1, got %v", value) + } +} + +func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) { + lru := lru.NewLRU[string, int](2, nil, time.Hour) + lru.Add("key1", 1) + lru.Add("key2", 2) + lru.Add("key3", 3) + + if _, ok := lru.Get("key1"); ok { + t.Errorf("Expected key1 to be removed, but it still exists") + } +} + +func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) { + lru := lru.NewLRU[string, int](0, nil, time.Hour) + lru.Add("key1", 1) + lru.Add("key2", 2) + lru.Add("key3", 3) + + if _, ok := lru.Get("key1"); !ok { + t.Errorf("Expected key1 to exist, but it was evicted") + } +} + +func TestLRU_Add_Eviction(t *testing.T) { + lru := lru.NewLRU[string, int](0, nil, time.Second*2) + lru.Add("key1", 1) + lru.Add("key2", 2) + lru.Add("key3", 3) + time.Sleep(time.Second * 3) + if lru.Cap() != 0 { + t.Errorf("Expected lru to be empty, but it was not") + } + +} + +func BenchmarkLRU_Rand_NoExpire(b *testing.B) { + l := lru.NewLRU[int64, int64](8192, nil, 0) + + trace := make([]int64, b.N*2) + for i := 0; i < b.N*2; i++ { + trace[i] = getRand(b) % 32768 + } + + b.ResetTimer() + + var hit, miss int + for i := 0; i < 2*b.N; i++ { + if i%2 == 0 { + l.Add(trace[i], trace[i]) + } else { + if _, ok := l.Get(trace[i]); ok { + hit++ + } else { + miss++ + } + } + } + b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) +} + +func BenchmarkLRU_Freq_NoExpire(b *testing.B) { + l := lru.NewLRU[int64, int64](8192, nil, 0) + + trace := make([]int64, b.N*2) + for i := 0; i < b.N*2; i++ { + if i%2 == 0 { + trace[i] = getRand(b) % 16384 + } else { + trace[i] = getRand(b) % 32768 + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + l.Add(trace[i], trace[i]) + } + var hit, miss int + for i := 0; i < b.N; i++ { + if _, ok := l.Get(trace[i]); ok { + hit++ + } else { + miss++ + } + } + b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) +} + +func BenchmarkLRU_Rand_WithExpire(b *testing.B) { + l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10) + + trace := make([]int64, b.N*2) + for i := 0; i < b.N*2; i++ { + trace[i] = getRand(b) % 32768 + } + + b.ResetTimer() + + var hit, miss int + for i := 0; i < 2*b.N; i++ { + if i%2 == 0 { + l.Add(trace[i], trace[i]) + } else { + if _, ok := l.Get(trace[i]); ok { + hit++ + } else { + miss++ + } + } + } + b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) +} + +func BenchmarkLRU_Freq_WithExpire(b *testing.B) { + l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10) + + trace := make([]int64, b.N*2) + for i := 0; i < b.N*2; i++ { + if i%2 == 0 { + trace[i] = getRand(b) % 16384 + } else { + trace[i] = getRand(b) % 32768 + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + l.Add(trace[i], trace[i]) + } + var hit, miss int + for i := 0; i < b.N; i++ { + if _, ok := l.Get(trace[i]); ok { + hit++ + } else { + miss++ + } + } + b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) +} + +func TestLRUNoPurge(t *testing.T) { + lc := lru.NewLRU[string, string](10, nil, 0) + + lc.Add("key1", "val1") + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } + + v, ok := lc.Peek("key1") + if v != "val1" { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + + if !lc.Contains("key1") { + t.Fatalf("should contain key1") + } + if lc.Contains("key2") { + t.Fatalf("should not contain key2") + } + + v, ok = lc.Peek("key2") + if v != "" { + t.Fatalf("should be empty") + } + if ok { + t.Fatalf("should be false") + } + + if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) { + t.Fatalf("value differs from expected") + } + + if lc.Resize(0) != 0 { + t.Fatalf("evicted count differs from expected") + } + if lc.Resize(2) != 0 { + t.Fatalf("evicted count differs from expected") + } + lc.Add("key2", "val2") + if lc.Resize(1) != 1 { + t.Fatalf("evicted count differs from expected") + } +} + +func TestLRUEdgeCases(t *testing.T) { + lc := lru.NewLRU[string, *string](2, nil, 0) + + // Adding a nil value + lc.Add("key1", nil) + + value, exists := lc.Get("key1") + if value != nil || !exists { + t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists) + } + + // Adding an entry with the same key but different value + newVal := "val1" + lc.Add("key1", &newVal) + + value, exists = lc.Get("key1") + if value != &newVal || !exists { + t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists) + } +} + +func TestLRU_Values(t *testing.T) { + lc := lru.NewLRU[string, string](3, nil, 0) + + lc.Add("key1", "val1") + lc.Add("key2", "val2") + lc.Add("key3", "val3") + + values := lc.Values() + if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) { + t.Fatalf("values differs from expected") + } +} + +// func TestExpirableMultipleClose(_ *testing.T) { +// lc :=lru.NewLRU[string, string](10, nil, 0) +// lc.Close() +// // should not panic +// lc.Close() +// } + +func TestLRUWithPurge(t *testing.T) { + var evicted []string + lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond) + + k, v, ok := lc.GetOldest() + if k != "" { + t.Fatalf("should be empty") + } + if v != "" { + t.Fatalf("should be empty") + } + if ok { + t.Fatalf("should be false") + } + + lc.Add("key1", "val1") + + time.Sleep(100 * time.Millisecond) // not enough to expire + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } + + v, ok = lc.Get("key1") + if v != "val1" { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + + time.Sleep(200 * time.Millisecond) // expire + v, ok = lc.Get("key1") + if ok { + t.Fatalf("should be false") + } + if v != "" { + t.Fatalf("should be nil") + } + + if lc.Len() != 0 { + t.Fatalf("length differs from expected") + } + if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) { + t.Fatalf("value differs from expected") + } + + // add new entry + lc.Add("key2", "val2") + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } + + k, v, ok = lc.GetOldest() + if k != "key2" { + t.Fatalf("value differs from expected") + } + if v != "val2" { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + +} + +func TestLRUWithPurgeEnforcedBySize(t *testing.T) { + lc := lru.NewLRU[string, string](10, nil, time.Hour) + + for i := 0; i < 100; i++ { + i := i + lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i)) + v, ok := lc.Get(fmt.Sprintf("key%d", i)) + if v != fmt.Sprintf("val%d", i) { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + if lc.Len() > 20 { + t.Fatalf("length should be less than 20") + } + } + + if lc.Len() != 10 { + t.Fatalf("length differs from expected") + } +} + +func TestLRUConcurrency(t *testing.T) { + lc := lru.NewLRU[string, string](0, nil, 0) + wg := sync.WaitGroup{} + wg.Add(1000) + for i := 0; i < 1000; i++ { + go func(i int) { + lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10)) + wg.Done() + }(i) + } + wg.Wait() + if lc.Len() != 100 { + t.Fatalf("length differs from expected") + } +} + +func TestLRUInvalidateAndEvict(t *testing.T) { + var evicted int + lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0) + + lc.Add("key1", "val1") + lc.Add("key2", "val2") + + val, ok := lc.Get("key1") + if !ok { + t.Fatalf("should be true") + } + if val != "val1" { + t.Fatalf("value differs from expected") + } + if evicted != 0 { + t.Fatalf("value differs from expected") + } + + lc.Remove("key1") + if evicted != 1 { + t.Fatalf("value differs from expected") + } + val, ok = lc.Get("key1") + if val != "" { + t.Fatalf("should be empty") + } + if ok { + t.Fatalf("should be false") + } +} + +func TestLoadingExpired(t *testing.T) { + lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5) + + lc.Add("key1", "val1") + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } + + v, ok := lc.Peek("key1") + if v != "val1" { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + + v, ok = lc.Get("key1") + if v != "val1" { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + + for { + result, ok := lc.Get("key1") + if ok && result == "" { + t.Fatalf("ok should return a result") + } + if !ok { + break + } + } + + time.Sleep(time.Millisecond * 100) // wait for expiration reaper + if lc.Len() != 0 { + t.Fatalf("length differs from expected") + } + + v, ok = lc.Peek("key1") + if v != "" { + t.Fatalf("should be empty") + } + if ok { + t.Fatalf("should be false") + } + + v, ok = lc.Get("key1") + if v != "" { + t.Fatalf("should be empty") + } + if ok { + t.Fatalf("should be false") + } +} + +func TestLRURemoveOldest(t *testing.T) { + lc := lru.NewLRU[string, string](2, nil, 0) + + if lc.Cap() != 2 { + t.Fatalf("expect cap is 2") + } + + k, v, ok := lc.RemoveOldest() + if k != "" { + t.Fatalf("should be empty") + } + if v != "" { + t.Fatalf("should be empty") + } + if ok { + t.Fatalf("should be false") + } + + ok = lc.Remove("non_existent") + if ok { + t.Fatalf("should be false") + } + + lc.Add("key1", "val1") + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } + + v, ok = lc.Get("key1") + if !ok { + t.Fatalf("should be true") + } + if v != "val1" { + t.Fatalf("value differs from expected") + } + + if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) { + t.Fatalf("value differs from expected") + } + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } + + lc.Add("key2", "val2") + if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) { + t.Fatalf("value differs from expected") + } + if lc.Len() != 2 { + t.Fatalf("length differs from expected") + } + + k, v, ok = lc.RemoveOldest() + if k != "key1" { + t.Fatalf("value differs from expected") + } + if v != "val1" { + t.Fatalf("value differs from expected") + } + if !ok { + t.Fatalf("should be true") + } + + if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) { + t.Fatalf("value differs from expected") + } + if lc.Len() != 1 { + t.Fatalf("length differs from expected") + } +} + +func ExampleLRU() { + // make cache with 10ms TTL and 5 max keys + cache := lru.NewLRU[string, string](5, nil, time.Millisecond*10) + + // set value under key1. + cache.Add("key1", "val1") + + // get value under key1 + r, ok := cache.Get("key1") + + // check for OK value + if ok { + fmt.Printf("value before expiration is found: %v, value: %q\n", ok, r) + } + + // wait for cache to expire + time.Sleep(time.Millisecond * 100) + + // get value under key1 after key expiration + r, ok = cache.Get("key1") + fmt.Printf("value after expiration is found: %v, value: %q\n", ok, r) + + // set value under key2, would evict old entry because it is already expired. + cache.Add("key2", "val2") + + fmt.Printf("Cache len: %d\n", cache.Len()) + // Output: + // value before expiration is found: true, value: "val1" + // value after expiration is found: false, value: "" + // Cache len: 1 +} + +func getRand(tb testing.TB) int64 { + out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + tb.Fatal(err) + } + return out.Int64() +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 20a4f730..16a29108 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "sync" - "sync/atomic" "testing" "time" @@ -92,6 +91,65 @@ func TestPreparedStmtFromTransaction(t *testing.T) { tx2.Commit() } +func TestPreparedStmtLruFromTransaction(t *testing.T) { + db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2.Commit() + // Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type. + // If the conversion is successful, ok will be true and conn will be the converted object; + // otherwise, ok will be false and conn will be nil. + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + // Get the number of statement keys stored in the PreparedStmtDB. + lens := len(conn.Stmts.Keys()) + // Check if the number of stored statement keys is 0. + if lens == 0 { + // If the number is 0, it means there are no statements stored in the LRU cache. + // The test fails and an error message is output. + t.Fatalf("lru should not be empty") + } + // Wait for 40 seconds to give the statements in the cache enough time to expire. + time.Sleep(time.Second * 40) + // Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type. + AssertEqual(t, ok, true) + // Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds. + // If it is not 0, it means the statements in the cache have not expired as expected. + AssertEqual(t, len(conn.Stmts.Keys()), 0) + +} + func TestPreparedStmtDeadlock(t *testing.T) { tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) @@ -117,9 +175,9 @@ func TestPreparedStmtDeadlock(t *testing.T) { conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts), 2) - for _, stmt := range conn.Stmts { - if stmt == nil { + AssertEqual(t, len(conn.Stmts.Keys()), 2) + for _, stmt := range conn.Stmts.Keys() { + if stmt == "" { t.Fatalf("stmt cannot bee nil") } } @@ -143,10 +201,10 @@ func TestPreparedStmtInTransaction(t *testing.T) { } } -func TestPreparedStmtReset(t *testing.T) { +func TestPreparedStmtClose(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) - user := *GetUser("prepared_stmt_reset", Config{}) + user := *GetUser("prepared_stmt_close", Config{}) tx = tx.Create(&user) pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) @@ -155,16 +213,16 @@ func TestPreparedStmtReset(t *testing.T) { } pdb.Mux.Lock() - if len(pdb.Stmts) == 0 { + if len(pdb.Stmts.Keys()) == 0 { pdb.Mux.Unlock() t.Fatalf("prepared stmt can not be empty") } pdb.Mux.Unlock() - pdb.Reset() + pdb.Close() pdb.Mux.Lock() defer pdb.Mux.Unlock() - if len(pdb.Stmts) != 0 { + if len(pdb.Stmts.Keys()) != 0 { t.Fatalf("prepared stmt should be empty") } } @@ -174,10 +232,10 @@ func isUsingClosedConnError(err error) bool { return err.Error() == "sql: statement is closed" } -// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently +// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently // this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt -func TestPreparedStmtConcurrentReset(t *testing.T) { - name := "prepared_stmt_concurrent_reset" +func TestPreparedStmtConcurrentClose(t *testing.T) { + name := "prepared_stmt_concurrent_close" user := *GetUser(name, Config{}) createTx := DB.Session(&gorm.Session{}).Create(&user) if createTx.Error != nil { @@ -220,7 +278,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) { go func() { defer wg.Done() <-writerFinish - pdb.Reset() + pdb.Close() }() wg.Wait() @@ -229,88 +287,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) { t.Fatalf("should is a unexpected error") } } - -// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently -// for example: one goroutine found error and just close the database, and others are executing SQL -// this test making sure that the gorm would not get a Segmentation Fault, -// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB -// and all of the goroutine must got gorm.ErrInvalidDB after database close -func TestPreparedStmtConcurrentClose(t *testing.T) { - name := "prepared_stmt_concurrent_close" - user := *GetUser(name, Config{}) - createTx := DB.Session(&gorm.Session{}).Create(&user) - if createTx.Error != nil { - t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) - } - - // create a new connection to keep away from other tests - tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) - if err != nil { - t.Fatalf("failed to open test connection due to %s", err) - } - pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) - if !ok { - t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") - } - - loopCount := 100 - var wg sync.WaitGroup - var lastErr error - closeValid := make(chan struct{}, loopCount) - closeStartIdx := loopCount / 2 // close the database at the middle of the execution - var lastRunIndex int - var closeFinishedAt int64 - - wg.Add(1) - go func(id uint) { - defer wg.Done() - defer close(closeValid) - for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { - if lastRunIndex == closeStartIdx { - closeValid <- struct{}{} - } - var tmp User - now := time.Now().UnixNano() - err := tx.Session(&gorm.Session{}).First(&tmp, id).Error - if err == nil { - closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) - if (closeFinishedAt != 0) && (now > closeFinishedAt) { - lastErr = errors.New("must got error after database closed") - break - } - continue - } - lastErr = err - break - } - }(user.ID) - - wg.Add(1) - go func() { - defer wg.Done() - for range closeValid { - for i := 0; i < loopCount; i++ { - pdb.Close() // the Close method must can be call multiple times - atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) - } - } - }() - - wg.Wait() - var tmp User - err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error - if err != gorm.ErrInvalidDB { - t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) - } - - // must be error - if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { - t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) - } - if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { - t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) - } - if pdb.Stmts != nil { - t.Fatalf("stmts must be nil") - } -}