Preparestmt use LRU Map instead default map (#7435)
* 支持lru淘汰preparestmt cache * 支持lru淘汰preparestmt cache * 支持lru淘汰preparestmt cache * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * 只使用lru * change const export * Add stmt_store * refact prepare stmt store * Rename lru store * change const export * ADD UT * format code and add session level prepare stmt config * code format according to golinter ci * ADD UT --------- Co-authored-by: xiezhaodong <xiezhaodong@bytedance.com> Co-authored-by: Jinzhu <wosmvp@gmail.com>
This commit is contained in:
parent
489a563293
commit
a827495be1
11
gorm.go
11
gorm.go
@ -34,6 +34,11 @@ type Config struct {
|
||||
DryRun bool
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// PrepareStmt cache support LRU expired,
|
||||
// default maxsize=int64 Max value and ttl=1h
|
||||
PrepareStmtMaxSize int
|
||||
PrepareStmtTTL time.Duration
|
||||
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
@ -105,6 +110,8 @@ type DB struct {
|
||||
type Session struct {
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
PrepareStmtMaxSize int
|
||||
PrepareStmtTTL time.Duration
|
||||
NewDB bool
|
||||
Initialized bool
|
||||
SkipHooks bool
|
||||
@ -197,7 +204,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool)
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
@ -268,7 +275,7 @@ func (db *DB) Session(config *Session) *DB {
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool)
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
|
493
internal/lru/lru.go
Normal file
493
internal/lru/lru.go
Normal file
@ -0,0 +1,493 @@
|
||||
package lru
|
||||
|
||||
// golang -lru
|
||||
// https://github.com/hashicorp/golang-lru
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EvictCallback is used to get a callback when a cache entry is evicted
|
||||
type EvictCallback[K comparable, V any] func(key K, value V)
|
||||
|
||||
// LRU implements a thread-safe LRU with expirable entries.
|
||||
type LRU[K comparable, V any] struct {
|
||||
size int
|
||||
evictList *LruList[K, V]
|
||||
items map[K]*Entry[K, V]
|
||||
onEvict EvictCallback[K, V]
|
||||
|
||||
// expirable options
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
done chan struct{}
|
||||
|
||||
// buckets for expiration
|
||||
buckets []bucket[K, V]
|
||||
// uint8 because it's number between 0 and numBuckets
|
||||
nextCleanupBucket uint8
|
||||
}
|
||||
|
||||
// bucket is a container for holding entries to be expired
|
||||
type bucket[K comparable, V any] struct {
|
||||
entries map[K]*Entry[K, V]
|
||||
newestEntry time.Time
|
||||
}
|
||||
|
||||
// noEvictionTTL - very long ttl to prevent eviction
|
||||
const noEvictionTTL = time.Hour * 24 * 365 * 10
|
||||
|
||||
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
|
||||
// casting it as uint8 explicitly requires type conversions in multiple places
|
||||
const numBuckets = 100
|
||||
|
||||
// NewLRU returns a new thread-safe cache with expirable entries.
|
||||
//
|
||||
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
|
||||
//
|
||||
// Providing 0 TTL turns expiring off.
|
||||
//
|
||||
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
|
||||
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
|
||||
if size < 0 {
|
||||
size = 0
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = noEvictionTTL
|
||||
}
|
||||
|
||||
res := LRU[K, V]{
|
||||
ttl: ttl,
|
||||
size: size,
|
||||
evictList: NewList[K, V](),
|
||||
items: make(map[K]*Entry[K, V]),
|
||||
onEvict: onEvict,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// initialize the buckets
|
||||
res.buckets = make([]bucket[K, V], numBuckets)
|
||||
for i := 0; i < numBuckets; i++ {
|
||||
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
|
||||
}
|
||||
|
||||
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
|
||||
//
|
||||
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
|
||||
// it's decided to add functionality to close it in the version later than v2.
|
||||
if res.ttl != noEvictionTTL {
|
||||
go func(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(res.ttl / numBuckets)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
res.deleteExpired()
|
||||
}
|
||||
}
|
||||
}(res.done)
|
||||
}
|
||||
return &res
|
||||
}
|
||||
|
||||
// Purge clears the cache completely.
|
||||
// onEvict is called for each evicted key.
|
||||
func (c *LRU[K, V]) Purge() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for k, v := range c.items {
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(k, v.Value)
|
||||
}
|
||||
delete(c.items, k)
|
||||
}
|
||||
for _, b := range c.buckets {
|
||||
for _, ent := range b.entries {
|
||||
delete(b.entries, ent.Key)
|
||||
}
|
||||
}
|
||||
c.evictList.Init()
|
||||
}
|
||||
|
||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||
// Returns false if there was no eviction: the item was already in the cache,
|
||||
// or the size was not exceeded.
|
||||
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
// Check for existing item
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.evictList.MoveToFront(ent)
|
||||
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
|
||||
ent.Value = value
|
||||
ent.ExpiresAt = now.Add(c.ttl)
|
||||
c.addToBucket(ent)
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new item
|
||||
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
|
||||
c.items[key] = ent
|
||||
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
|
||||
|
||||
evict := c.size > 0 && c.evictList.Length() > c.size
|
||||
// Verify size not exceeded
|
||||
if evict {
|
||||
c.removeOldest()
|
||||
}
|
||||
return evict
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
c.evictList.MoveToFront(ent)
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Contains checks if a key is in the cache, without updating the recent-ness
|
||||
// or deleting it for being stale.
|
||||
func (c *LRU[K, V]) Contains(key K) (ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, ok = c.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Peek returns the key value (or undefined if not found) without updating
|
||||
// the "recently used"-ness of the key.
|
||||
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache, returning if the
|
||||
// key was contained.
|
||||
func (c *LRU[K, V]) Remove(key K) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.removeElement(ent)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveOldest removes the oldest item from the cache.
|
||||
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOldest returns the oldest entry
|
||||
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *LRU[K, V]) KeyValues() map[K]V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
maps := make(map[K]V)
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
maps[ent.Key] = ent.Value
|
||||
// keys = append(keys, ent.Key)
|
||||
}
|
||||
return maps
|
||||
}
|
||||
|
||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Keys() []K {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
keys := make([]K, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, ent.Key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Values returns a slice of the values in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Values() []V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
values := make([]V, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
values = append(values, ent.Value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Len returns the number of items in the cache.
|
||||
func (c *LRU[K, V]) Len() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.evictList.Length()
|
||||
}
|
||||
|
||||
// Resize changes the cache size. Size of 0 means unlimited.
|
||||
func (c *LRU[K, V]) Resize(size int) (evicted int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if size <= 0 {
|
||||
c.size = 0
|
||||
return 0
|
||||
}
|
||||
diff := c.evictList.Length() - size
|
||||
if diff < 0 {
|
||||
diff = 0
|
||||
}
|
||||
for i := 0; i < diff; i++ {
|
||||
c.removeOldest()
|
||||
}
|
||||
c.size = size
|
||||
return diff
|
||||
}
|
||||
|
||||
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
|
||||
// func (c *LRU[K, V]) Close() {
|
||||
// c.mu.Lock()
|
||||
// defer c.mu.Unlock()
|
||||
// select {
|
||||
// case <-c.done:
|
||||
// return
|
||||
// default:
|
||||
// }
|
||||
// close(c.done)
|
||||
// }
|
||||
|
||||
// removeOldest removes the oldest item from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeOldest() {
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
}
|
||||
|
||||
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
|
||||
c.evictList.Remove(e)
|
||||
delete(c.items, e.Key)
|
||||
c.removeFromBucket(e)
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(e.Key, e.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
|
||||
// in it to expire first.
|
||||
func (c *LRU[K, V]) deleteExpired() {
|
||||
c.mu.Lock()
|
||||
bucketIdx := c.nextCleanupBucket
|
||||
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
|
||||
// wait for newest entry to expire before cleanup without holding lock
|
||||
if timeToExpire > 0 {
|
||||
c.mu.Unlock()
|
||||
time.Sleep(timeToExpire)
|
||||
c.mu.Lock()
|
||||
}
|
||||
for _, ent := range c.buckets[bucketIdx].entries {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
|
||||
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
|
||||
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
|
||||
e.ExpireBucket = bucketID
|
||||
c.buckets[bucketID].entries[e.Key] = e
|
||||
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
|
||||
c.buckets[bucketID].newestEntry = e.ExpiresAt
|
||||
}
|
||||
}
|
||||
|
||||
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
|
||||
delete(c.buckets[e.ExpireBucket].entries, e.Key)
|
||||
}
|
||||
|
||||
// Cap returns the capacity of the cache
|
||||
func (c *LRU[K, V]) Cap() int {
|
||||
return c.size
|
||||
}
|
||||
|
||||
// Entry is an LRU Entry
|
||||
type Entry[K comparable, V any] struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *Entry[K, V]
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *LruList[K, V]
|
||||
|
||||
// The LRU Key of this element.
|
||||
Key K
|
||||
|
||||
// The Value stored with this element.
|
||||
Value V
|
||||
|
||||
// The time this element would be cleaned up, optional
|
||||
ExpiresAt time.Time
|
||||
|
||||
// The expiry bucket item was put in, optional
|
||||
ExpireBucket uint8
|
||||
}
|
||||
|
||||
// PrevEntry returns the previous list element or nil.
|
||||
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LruList represents a doubly linked list.
|
||||
// The zero Value for LruList is an empty list ready to use.
|
||||
type LruList[K comparable, V any] struct {
|
||||
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list Length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *LruList[K, V]) Init() *LruList[K, V] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// NewList returns an initialized list.
|
||||
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
|
||||
|
||||
// Length returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *LruList[K, V]) Length() int { return l.len }
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *LruList[K, V]) Back() *Entry[K, V] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List Value.
|
||||
func (l *LruList[K, V]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
|
||||
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
|
||||
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
|
||||
}
|
||||
|
||||
// Remove removes e from its list, decrements l.len
|
||||
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// move moves e to next to at.
|
||||
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, time.Time{}, &l.root)
|
||||
}
|
||||
|
||||
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, expiresAt, &l.root)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, &l.root)
|
||||
}
|
182
internal/stmt_store/stmt_store.go
Normal file
182
internal/stmt_store/stmt_store.go
Normal file
@ -0,0 +1,182 @@
|
||||
package stmt_store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/lru"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Error() error {
|
||||
return stmt.prepareErr
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Close() error {
|
||||
<-stmt.prepared
|
||||
|
||||
if stmt.Stmt != nil {
|
||||
return stmt.Stmt.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
|
||||
// This interface provides methods for creating new statements, retrieving all cache keys,
|
||||
// getting cached statements, setting cached statements, and deleting cached statements.
|
||||
type Store interface {
|
||||
// New creates a new Stmt object and caches it.
|
||||
// Parameters:
|
||||
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
|
||||
// connPool: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
// Returns:
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
|
||||
|
||||
// Keys returns a slice of all cache keys in the store.
|
||||
Keys() []string
|
||||
|
||||
// Get retrieves a Stmt object from the store based on the given key.
|
||||
// Parameters:
|
||||
// key: The key used to look up the Stmt object.
|
||||
// Returns:
|
||||
// *Stmt: The found Stmt object, or nil if not found.
|
||||
// bool: Indicates whether the corresponding Stmt object was successfully found.
|
||||
Get(key string) (*Stmt, bool)
|
||||
|
||||
// Set stores the given Stmt object in the store and associates it with the specified key.
|
||||
// Parameters:
|
||||
// key: The key used to associate the Stmt object.
|
||||
// value: The Stmt object to be stored.
|
||||
Set(key string, value *Stmt)
|
||||
|
||||
// Delete removes the Stmt object corresponding to the specified key from the store.
|
||||
// Parameters:
|
||||
// key: The key associated with the Stmt object to be deleted.
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// defaultMaxSize defines the default maximum capacity of the cache.
|
||||
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
|
||||
// the cache can theoretically store as many elements as possible.
|
||||
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
|
||||
const (
|
||||
defaultMaxSize = (1 << 63) - 1
|
||||
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
|
||||
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
|
||||
defaultTTL = time.Hour * 24
|
||||
)
|
||||
|
||||
// New creates and returns a new Store instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
|
||||
// it defaults to defaultMaxSize.
|
||||
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
|
||||
// it defaults to defaultTTL.
|
||||
//
|
||||
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
|
||||
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
|
||||
// to release associated resources.
|
||||
//
|
||||
// Returns:
|
||||
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
|
||||
// eviction callback, and TTL.
|
||||
func New(size int, ttl time.Duration) Store {
|
||||
if size <= 0 {
|
||||
size = defaultMaxSize
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = defaultTTL
|
||||
}
|
||||
|
||||
onEvicted := func(k string, v *Stmt) {
|
||||
if v != nil {
|
||||
go v.Close()
|
||||
}
|
||||
}
|
||||
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||
}
|
||||
|
||||
type lruStore struct {
|
||||
lru *lru.LRU[string, *Stmt]
|
||||
}
|
||||
|
||||
func (s *lruStore) Keys() []string {
|
||||
return s.lru.Keys()
|
||||
}
|
||||
|
||||
func (s *lruStore) Get(key string) (*Stmt, bool) {
|
||||
stmt, ok := s.lru.Get(key)
|
||||
if ok && stmt != nil {
|
||||
<-stmt.prepared
|
||||
}
|
||||
return stmt, ok
|
||||
}
|
||||
|
||||
func (s *lruStore) Set(key string, value *Stmt) {
|
||||
s.lru.Add(key, value)
|
||||
}
|
||||
|
||||
func (s *lruStore) Delete(key string) {
|
||||
s.lru.Remove(key)
|
||||
}
|
||||
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// New creates a new Stmt object for executing SQL queries.
|
||||
// It caches the Stmt object for future use and handles preparation and error states.
|
||||
// Parameters:
|
||||
//
|
||||
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
|
||||
// conn: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
|
||||
// Create a Stmt object and set its Transaction property.
|
||||
// The prepared channel is used to synchronize the statement preparation state.
|
||||
cacheStmt := &Stmt{
|
||||
Transaction: isTransaction,
|
||||
prepared: make(chan struct{}),
|
||||
}
|
||||
// Cache the Stmt object with the associated key.
|
||||
s.Set(key, cacheStmt)
|
||||
// Unlock after completing initialization to prevent deadlocks.
|
||||
locker.Unlock()
|
||||
|
||||
// Ensure the prepared channel is closed after the function execution completes.
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Prepare the SQL statement using the provided connection.
|
||||
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
|
||||
if err != nil {
|
||||
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
|
||||
cacheStmt.prepareErr = err
|
||||
s.Delete(key)
|
||||
return &Stmt{}, err
|
||||
}
|
||||
|
||||
// Return the successfully prepared Stmt object.
|
||||
return cacheStmt, nil
|
||||
}
|
144
prepare_stmt.go
144
prepare_stmt.go
@ -7,29 +7,35 @@ import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/stmt_store"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts map[string]*Stmt
|
||||
Stmts stmt_store.Store
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
|
||||
//
|
||||
// Parameters:
|
||||
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
|
||||
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
|
||||
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
|
||||
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
|
||||
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
|
||||
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
|
||||
}
|
||||
}
|
||||
|
||||
// GetDBConn returns the underlying *sql.DB connection
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
@ -42,98 +48,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
// Close closes all prepared statements in the store
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
for _, stmt := range db.Stmts {
|
||||
go func(s *Stmt) {
|
||||
// make sure the stmt must finish preparation first
|
||||
<-s.prepared
|
||||
if s.Stmt != nil {
|
||||
_ = s.Close()
|
||||
}
|
||||
}(stmt)
|
||||
for _, key := range db.Stmts.Keys() {
|
||||
db.Stmts.Delete(key)
|
||||
}
|
||||
// setting db.Stmts to nil to avoid further using
|
||||
db.Stmts = nil
|
||||
}
|
||||
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
|
||||
for _, stmt := range sdb.Stmts {
|
||||
go func(s *Stmt) {
|
||||
// make sure the stmt must finish preparation first
|
||||
<-s.prepared
|
||||
if s.Stmt != nil {
|
||||
_ = s.Close()
|
||||
}
|
||||
}(stmt)
|
||||
}
|
||||
sdb.Stmts = make(map[string]*Stmt)
|
||||
// Reset Deprecated use Close instead
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
// retry
|
||||
db.Mux.Lock()
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
|
||||
// which cause by calling Close and executing SQL concurrently
|
||||
if db.Stmts == nil {
|
||||
db.Mux.Unlock()
|
||||
return Stmt{}, ErrInvalidDB
|
||||
}
|
||||
// cache preparing stmt first
|
||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||
db.Stmts[query] = &cacheStmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
// prepare completed
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Reason why cannot lock conn.PrepareContext
|
||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
cacheStmt.prepareErr = err
|
||||
db.Mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.Mux.Unlock()
|
||||
return Stmt{}, err
|
||||
}
|
||||
|
||||
db.Mux.Lock()
|
||||
cacheStmt.Stmt = stmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
return cacheStmt, nil
|
||||
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
@ -162,10 +111,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -176,11 +122,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
@ -230,11 +172,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -245,11 +183,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
|
18
tests/go.mod
18
tests/go.mod
@ -1,15 +1,17 @@
|
||||
module gorm.io/gorm/tests
|
||||
|
||||
go 1.18
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.5.10
|
||||
gorm.io/driver/sqlite v1.5.6
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/driver/sqlserver v1.5.4
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
@ -17,7 +19,7 @@ require (
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.2 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@ -25,12 +27,12 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.7.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.7.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/crypto v0.29.0 // indirect
|
||||
golang.org/x/text v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
|
561
tests/lru_test.go
Normal file
561
tests/lru_test.go
Normal file
@ -0,0 +1,561 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"gorm.io/gorm/internal/lru"
|
||||
"math"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](10, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key1", 2)
|
||||
|
||||
if value, ok := lru.Get("key1"); !ok || value != 2 {
|
||||
t.Errorf("Expected value to be updated to 2, got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_NewKey_AddsEntry(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](10, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
|
||||
if value, ok := lru.Get("key1"); !ok || value != 1 {
|
||||
t.Errorf("Expected key1 to be added with value 1, got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](2, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key2", 2)
|
||||
lru.Add("key3", 3)
|
||||
|
||||
if _, ok := lru.Get("key1"); ok {
|
||||
t.Errorf("Expected key1 to be removed, but it still exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](0, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key2", 2)
|
||||
lru.Add("key3", 3)
|
||||
|
||||
if _, ok := lru.Get("key1"); !ok {
|
||||
t.Errorf("Expected key1 to exist, but it was evicted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_Eviction(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](0, nil, time.Second*2)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key2", 2)
|
||||
lru.Add("key3", 3)
|
||||
time.Sleep(time.Second * 3)
|
||||
if lru.Cap() != 0 {
|
||||
t.Errorf("Expected lru to be empty, but it was not")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Rand_NoExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, 0)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
var hit, miss int
|
||||
for i := 0; i < 2*b.N; i++ {
|
||||
if i%2 == 0 {
|
||||
l.Add(trace[i], trace[i])
|
||||
} else {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Freq_NoExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, 0)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
if i%2 == 0 {
|
||||
trace[i] = getRand(b) % 16384
|
||||
} else {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Add(trace[i], trace[i])
|
||||
}
|
||||
var hit, miss int
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Rand_WithExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
var hit, miss int
|
||||
for i := 0; i < 2*b.N; i++ {
|
||||
if i%2 == 0 {
|
||||
l.Add(trace[i], trace[i])
|
||||
} else {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Freq_WithExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
if i%2 == 0 {
|
||||
trace[i] = getRand(b) % 16384
|
||||
} else {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Add(trace[i], trace[i])
|
||||
}
|
||||
var hit, miss int
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func TestLRUNoPurge(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](10, nil, 0)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok := lc.Peek("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
if !lc.Contains("key1") {
|
||||
t.Fatalf("should contain key1")
|
||||
}
|
||||
if lc.Contains("key2") {
|
||||
t.Fatalf("should not contain key2")
|
||||
}
|
||||
|
||||
v, ok = lc.Peek("key2")
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
if lc.Resize(0) != 0 {
|
||||
t.Fatalf("evicted count differs from expected")
|
||||
}
|
||||
if lc.Resize(2) != 0 {
|
||||
t.Fatalf("evicted count differs from expected")
|
||||
}
|
||||
lc.Add("key2", "val2")
|
||||
if lc.Resize(1) != 1 {
|
||||
t.Fatalf("evicted count differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUEdgeCases(t *testing.T) {
|
||||
lc := lru.NewLRU[string, *string](2, nil, 0)
|
||||
|
||||
// Adding a nil value
|
||||
lc.Add("key1", nil)
|
||||
|
||||
value, exists := lc.Get("key1")
|
||||
if value != nil || !exists {
|
||||
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
|
||||
}
|
||||
|
||||
// Adding an entry with the same key but different value
|
||||
newVal := "val1"
|
||||
lc.Add("key1", &newVal)
|
||||
|
||||
value, exists = lc.Get("key1")
|
||||
if value != &newVal || !exists {
|
||||
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Values(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](3, nil, 0)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
lc.Add("key2", "val2")
|
||||
lc.Add("key3", "val3")
|
||||
|
||||
values := lc.Values()
|
||||
if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) {
|
||||
t.Fatalf("values differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestExpirableMultipleClose(_ *testing.T) {
|
||||
// lc :=lru.NewLRU[string, string](10, nil, 0)
|
||||
// lc.Close()
|
||||
// // should not panic
|
||||
// lc.Close()
|
||||
// }
|
||||
|
||||
func TestLRUWithPurge(t *testing.T) {
|
||||
var evicted []string
|
||||
lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond)
|
||||
|
||||
k, v, ok := lc.GetOldest()
|
||||
if k != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // not enough to expire
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond) // expire
|
||||
v, ok = lc.Get("key1")
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
if v != "" {
|
||||
t.Fatalf("should be nil")
|
||||
}
|
||||
|
||||
if lc.Len() != 0 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
// add new entry
|
||||
lc.Add("key2", "val2")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
k, v, ok = lc.GetOldest()
|
||||
if k != "key2" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if v != "val2" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestLRUWithPurgeEnforcedBySize(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](10, nil, time.Hour)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
i := i
|
||||
lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i))
|
||||
v, ok := lc.Get(fmt.Sprintf("key%d", i))
|
||||
if v != fmt.Sprintf("val%d", i) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
if lc.Len() > 20 {
|
||||
t.Fatalf("length should be less than 20")
|
||||
}
|
||||
}
|
||||
|
||||
if lc.Len() != 10 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUConcurrency(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](0, nil, 0)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
go func(i int) {
|
||||
lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10))
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
if lc.Len() != 100 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUInvalidateAndEvict(t *testing.T) {
|
||||
var evicted int
|
||||
lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
lc.Add("key2", "val2")
|
||||
|
||||
val, ok := lc.Get("key1")
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
if val != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if evicted != 0 {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
lc.Remove("key1")
|
||||
if evicted != 1 {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
val, ok = lc.Get("key1")
|
||||
if val != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingExpired(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok := lc.Peek("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
for {
|
||||
result, ok := lc.Get("key1")
|
||||
if ok && result == "" {
|
||||
t.Fatalf("ok should return a result")
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 100) // wait for expiration reaper
|
||||
if lc.Len() != 0 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok = lc.Peek("key1")
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRURemoveOldest(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](2, nil, 0)
|
||||
|
||||
if lc.Cap() != 2 {
|
||||
t.Fatalf("expect cap is 2")
|
||||
}
|
||||
|
||||
k, v, ok := lc.RemoveOldest()
|
||||
if k != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
ok = lc.Remove("non_existent")
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
lc.Add("key2", "val2")
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if lc.Len() != 2 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
k, v, ok = lc.RemoveOldest()
|
||||
if k != "key1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleLRU() {
|
||||
// make cache with 10ms TTL and 5 max keys
|
||||
cache := lru.NewLRU[string, string](5, nil, time.Millisecond*10)
|
||||
|
||||
// set value under key1.
|
||||
cache.Add("key1", "val1")
|
||||
|
||||
// get value under key1
|
||||
r, ok := cache.Get("key1")
|
||||
|
||||
// check for OK value
|
||||
if ok {
|
||||
fmt.Printf("value before expiration is found: %v, value: %q\n", ok, r)
|
||||
}
|
||||
|
||||
// wait for cache to expire
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
// get value under key1 after key expiration
|
||||
r, ok = cache.Get("key1")
|
||||
fmt.Printf("value after expiration is found: %v, value: %q\n", ok, r)
|
||||
|
||||
// set value under key2, would evict old entry because it is already expired.
|
||||
cache.Add("key2", "val2")
|
||||
|
||||
fmt.Printf("Cache len: %d\n", cache.Len())
|
||||
// Output:
|
||||
// value before expiration is found: true, value: "val1"
|
||||
// value after expiration is found: false, value: ""
|
||||
// Cache len: 1
|
||||
}
|
||||
|
||||
func getRand(tb testing.TB) int64 {
|
||||
out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return out.Int64()
|
||||
}
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -92,6 +91,65 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
||||
tx2.Commit()
|
||||
}
|
||||
|
||||
func TestPreparedStmtLruFromTransaction(t *testing.T) {
|
||||
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second})
|
||||
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
if err := tx.Error; err != nil {
|
||||
t.Errorf("Failed to start transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
t.Errorf("Failed to commit transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
|
||||
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
|
||||
tx2 := db.Begin()
|
||||
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
|
||||
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
|
||||
tx2.Commit()
|
||||
// Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type.
|
||||
// If the conversion is successful, ok will be true and conn will be the converted object;
|
||||
// otherwise, ok will be false and conn will be nil.
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
// Get the number of statement keys stored in the PreparedStmtDB.
|
||||
lens := len(conn.Stmts.Keys())
|
||||
// Check if the number of stored statement keys is 0.
|
||||
if lens == 0 {
|
||||
// If the number is 0, it means there are no statements stored in the LRU cache.
|
||||
// The test fails and an error message is output.
|
||||
t.Fatalf("lru should not be empty")
|
||||
}
|
||||
// Wait for 40 seconds to give the statements in the cache enough time to expire.
|
||||
time.Sleep(time.Second * 40)
|
||||
// Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type.
|
||||
AssertEqual(t, ok, true)
|
||||
// Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds.
|
||||
// If it is not 0, it means the statements in the cache have not expired as expected.
|
||||
AssertEqual(t, len(conn.Stmts.Keys()), 0)
|
||||
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
@ -117,9 +175,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 2)
|
||||
for _, stmt := range conn.Stmts {
|
||||
if stmt == nil {
|
||||
AssertEqual(t, len(conn.Stmts.Keys()), 2)
|
||||
for _, stmt := range conn.Stmts.Keys() {
|
||||
if stmt == "" {
|
||||
t.Fatalf("stmt cannot bee nil")
|
||||
}
|
||||
}
|
||||
@ -143,10 +201,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
func TestPreparedStmtClose(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
user := *GetUser("prepared_stmt_close", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
@ -155,16 +213,16 @@ func TestPreparedStmtReset(t *testing.T) {
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
if len(pdb.Stmts.Keys()) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Close()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
if len(pdb.Stmts.Keys()) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
@ -174,10 +232,10 @@ func isUsingClosedConnError(err error) bool {
|
||||
return err.Error() == "sql: statement is closed"
|
||||
}
|
||||
|
||||
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
|
||||
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
|
||||
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
|
||||
func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_reset"
|
||||
func TestPreparedStmtConcurrentClose(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_close"
|
||||
user := *GetUser(name, Config{})
|
||||
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||
if createTx.Error != nil {
|
||||
@ -220,7 +278,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-writerFinish
|
||||
pdb.Reset()
|
||||
pdb.Close()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
@ -229,88 +287,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||
t.Fatalf("should is a unexpected error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
|
||||
// for example: one goroutine found error and just close the database, and others are executing SQL
|
||||
// this test making sure that the gorm would not get a Segmentation Fault,
|
||||
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
|
||||
// and all of the goroutine must got gorm.ErrInvalidDB after database close
|
||||
func TestPreparedStmtConcurrentClose(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_close"
|
||||
user := *GetUser(name, Config{})
|
||||
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||
if createTx.Error != nil {
|
||||
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
|
||||
}
|
||||
|
||||
// create a new connection to keep away from other tests
|
||||
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test connection due to %s", err)
|
||||
}
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
loopCount := 100
|
||||
var wg sync.WaitGroup
|
||||
var lastErr error
|
||||
closeValid := make(chan struct{}, loopCount)
|
||||
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
|
||||
var lastRunIndex int
|
||||
var closeFinishedAt int64
|
||||
|
||||
wg.Add(1)
|
||||
go func(id uint) {
|
||||
defer wg.Done()
|
||||
defer close(closeValid)
|
||||
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
|
||||
if lastRunIndex == closeStartIdx {
|
||||
closeValid <- struct{}{}
|
||||
}
|
||||
var tmp User
|
||||
now := time.Now().UnixNano()
|
||||
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
|
||||
if err == nil {
|
||||
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
|
||||
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
|
||||
lastErr = errors.New("must got error after database closed")
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastErr = err
|
||||
break
|
||||
}
|
||||
}(user.ID)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range closeValid {
|
||||
for i := 0; i < loopCount; i++ {
|
||||
pdb.Close() // the Close method must can be call multiple times
|
||||
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
var tmp User
|
||||
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
|
||||
if err != gorm.ErrInvalidDB {
|
||||
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
|
||||
}
|
||||
|
||||
// must be error
|
||||
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
|
||||
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
|
||||
}
|
||||
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
|
||||
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
|
||||
}
|
||||
if pdb.Stmts != nil {
|
||||
t.Fatalf("stmts must be nil")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user