Add stmt_store
This commit is contained in:
		
							parent
							
								
									dfa1b81f65
								
							
						
					
					
						commit
						886a406556
					
				
							
								
								
									
										106
									
								
								internal/stmt_store/stmt_store.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								internal/stmt_store/stmt_store.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,106 @@ | ||||
| package stmt_store | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/internal/lru" | ||||
| ) | ||||
| 
 | ||||
| type Stmt struct { | ||||
| 	*sql.Stmt | ||||
| 	Transaction bool | ||||
| 	prepared    chan struct{} | ||||
| 	prepareErr  error | ||||
| } | ||||
| 
 | ||||
| func NewStmt(isTransaction bool) *Stmt { | ||||
| 	return &Stmt{ | ||||
| 		Transaction: isTransaction, | ||||
| 		prepared:    make(chan struct{}), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) Done() { | ||||
| 	close(stmt.prepared) | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) AddError(err error) { | ||||
| 	stmt.prepareErr = err | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) Error() error { | ||||
| 	<-stmt.prepared | ||||
| 
 | ||||
| 	return stmt.prepareErr | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) Close() error { | ||||
| 	<-stmt.prepared | ||||
| 
 | ||||
| 	if stmt.Stmt != nil { | ||||
| 		return stmt.Stmt.Close() | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type Store interface { | ||||
| 	Get(key string) (*Stmt, bool) | ||||
| 	Set(key string, value *Stmt) | ||||
| 	Delete(key string) | ||||
| 	AllMap() map[string]*Stmt | ||||
| } | ||||
| 
 | ||||
| type StmtStore struct { | ||||
| 	lru *lru.LRU[string, *Stmt] | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
| 	defaultMaxSize = (1 << 63) - 1 | ||||
| 	defaultTTL     = time.Hour * 24 | ||||
| ) | ||||
| 
 | ||||
| func New(size int, ttl time.Duration) Store { | ||||
| 	if size <= 0 { | ||||
| 		size = defaultMaxSize | ||||
| 	} | ||||
| 
 | ||||
| 	if ttl <= 0 { | ||||
| 		ttl = defaultTTL | ||||
| 	} | ||||
| 
 | ||||
| 	onEvicted := func(k string, v *Stmt) { | ||||
| 		if v != nil { | ||||
| 			go func() { | ||||
| 				defer func() { | ||||
| 					if r := recover(); r != nil { | ||||
| 						fmt.Print("close stmt err panic ") | ||||
| 					} | ||||
| 				}() | ||||
| 				err := v.Close() | ||||
| 				if err != nil { | ||||
| 					fmt.Print("close stmt err: ", err.Error()) | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
| 	} | ||||
| 	return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) AllMap() map[string]*Stmt { | ||||
| 	return s.lru.KeyValues() | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) Get(key string) (*Stmt, bool) { | ||||
| 	stmt, ok := s.lru.Get(key) | ||||
| 	return stmt, ok | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) Set(key string, value *Stmt) { | ||||
| 	s.lru.Add(key, value) | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) Delete(key string) { | ||||
| 	s.lru.Remove(key) | ||||
| } | ||||
							
								
								
									
										200
									
								
								prepare_stmt.go
									
									
									
									
									
								
							
							
						
						
									
										200
									
								
								prepare_stmt.go
									
									
									
									
									
								
							| @ -5,67 +5,24 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm/internal/lru" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/internal/stmt_store" | ||||
| ) | ||||
| 
 | ||||
| type Stmt struct { | ||||
| 	*sql.Stmt | ||||
| 	Transaction bool | ||||
| 	prepared    chan struct{} | ||||
| 	prepareErr  error | ||||
| } | ||||
| 
 | ||||
| type PreparedStmtDB struct { | ||||
| 	Stmts StmtStore | ||||
| 	Stmts stmt_store.Store | ||||
| 	Mux   *sync.RWMutex | ||||
| 	ConnPool | ||||
| } | ||||
| 
 | ||||
| const default_max_size = (1 << 63) - 1 | ||||
| const default_ttl = time.Hour * 24 | ||||
| 
 | ||||
| // newPrepareStmtCache creates a new statement cache with the specified maximum size and time-to-live (TTL).
 | ||||
| // Parameters:
 | ||||
| //   - PrepareStmtMaxSize: An integer specifying the maximum number of prepared statements to cache.
 | ||||
| //     If this value is less than or equal to 0, the function will panic.
 | ||||
| //   - PrepareStmtTTL: A time.Duration specifying the TTL for cached statements.
 | ||||
| //     If this value differs from the default TTL, it will be used instead.
 | ||||
| //
 | ||||
| // Returns:
 | ||||
| //   - A pointer to a store.StmtStore instance configured with the provided parameters.
 | ||||
| //
 | ||||
| // The function initializes an LRU (Least Recently Used) cache for prepared statements,
 | ||||
| // using either the provided size and TTL or default values
 | ||||
| func newPrepareStmtCache(PrepareStmtMaxSize int, | ||||
| 	PrepareStmtTTL time.Duration) *StmtStore { | ||||
| 	var lru_size = default_max_size | ||||
| 	var lru_ttl = default_ttl | ||||
| 	var stmts StmtStore | ||||
| 	if PrepareStmtMaxSize < 0 { | ||||
| 		panic("PrepareStmtMaxSize must > 0") | ||||
| 	} | ||||
| 	if PrepareStmtMaxSize != 0 { | ||||
| 		lru_size = PrepareStmtMaxSize | ||||
| 	} | ||||
| 	if PrepareStmtTTL != default_ttl { | ||||
| 		lru_ttl = PrepareStmtTTL | ||||
| 	} | ||||
| 	lru := &LruStmtStore{} | ||||
| 	lru.newLru(lru_size, lru_ttl) | ||||
| 	stmts = lru | ||||
| 	return &stmts | ||||
| } | ||||
| func NewPreparedStmtDB(connPool ConnPool, PrepareStmtMaxSize int, | ||||
| 	PrepareStmtTTL time.Duration) *PreparedStmtDB { | ||||
| func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { | ||||
| 	return &PreparedStmtDB{ | ||||
| 		ConnPool: connPool, | ||||
| 		Stmts: *newPrepareStmtCache(PrepareStmtMaxSize, | ||||
| 			PrepareStmtTTL), | ||||
| 		Mux: &sync.RWMutex{}, | ||||
| 		Stmts:    stmt_store.New(maxSize, ttl), | ||||
| 		Mux:      &sync.RWMutex{}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -89,13 +46,7 @@ func (db *PreparedStmtDB) Close() { | ||||
| 	} | ||||
| 
 | ||||
| 	for _, stmt := range db.Stmts.AllMap() { | ||||
| 		go func(s *Stmt) { | ||||
| 			// make sure the stmt must finish preparation first
 | ||||
| 			<-s.prepared | ||||
| 			if s.Stmt != nil { | ||||
| 				_ = s.Close() | ||||
| 			} | ||||
| 		}(stmt) | ||||
| 		go stmt.Close() | ||||
| 	} | ||||
| 	// setting db.Stmts to nil to avoid further using
 | ||||
| 	db.Stmts = nil | ||||
| @ -108,28 +59,20 @@ func (sdb *PreparedStmtDB) Reset() { | ||||
| 		return | ||||
| 	} | ||||
| 	for _, stmt := range sdb.Stmts.AllMap() { | ||||
| 		go func(s *Stmt) { | ||||
| 			// make sure the stmt must finish preparation first
 | ||||
| 			<-s.prepared | ||||
| 			if s.Stmt != nil { | ||||
| 				_ = s.Close() | ||||
| 			} | ||||
| 		}(stmt) | ||||
| 		go stmt.Close() | ||||
| 	} | ||||
| 	//Migrator
 | ||||
| 	defaultStmt := newPrepareStmtCache(0, 0) | ||||
| 	sdb.Stmts = *defaultStmt | ||||
| 
 | ||||
| 	// Migrator
 | ||||
| 	sdb.Stmts = stmt_store.New(0, 0) | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) { | ||||
| 	db.Mux.RLock() | ||||
| 	if db.Stmts != nil { | ||||
| 		if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 			db.Mux.RUnlock() | ||||
| 			// wait for other goroutines prepared
 | ||||
| 			<-stmt.prepared | ||||
| 			if stmt.prepareErr != nil { | ||||
| 				return Stmt{}, stmt.prepareErr | ||||
| 			if err := stmt.Error(); err != nil { | ||||
| 				return stmt_store.Stmt{}, err | ||||
| 			} | ||||
| 
 | ||||
| 			return *stmt, nil | ||||
| @ -140,12 +83,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | ||||
| 	db.Mux.Lock() | ||||
| 	if db.Stmts != nil { | ||||
| 		// double check
 | ||||
| 		if stmt, ok := db.Stmts.get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 			db.Mux.Unlock() | ||||
| 			// wait for other goroutines prepared
 | ||||
| 			<-stmt.prepared | ||||
| 			if stmt.prepareErr != nil { | ||||
| 				return Stmt{}, stmt.prepareErr | ||||
| 			if err := stmt.Error(); err != nil { | ||||
| 				return stmt_store.Stmt{}, err | ||||
| 			} | ||||
| 
 | ||||
| 			return *stmt, nil | ||||
| @ -155,15 +96,15 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | ||||
| 	// which cause by calling Close and executing SQL concurrently
 | ||||
| 	if db.Stmts == nil { | ||||
| 		db.Mux.Unlock() | ||||
| 		return Stmt{}, ErrInvalidDB | ||||
| 		return stmt_store.Stmt{}, ErrInvalidDB | ||||
| 	} | ||||
| 	// cache preparing stmt first
 | ||||
| 	cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} | ||||
| 	db.Stmts.set(query, &cacheStmt) | ||||
| 	cacheStmt := stmt_store.NewStmt(isTransaction) | ||||
| 	db.Stmts.Set(query, cacheStmt) | ||||
| 	db.Mux.Unlock() | ||||
| 
 | ||||
| 	// prepare completed
 | ||||
| 	defer close(cacheStmt.prepared) | ||||
| 	defer cacheStmt.Done() | ||||
| 
 | ||||
| 	// Reason why cannot lock conn.PrepareContext
 | ||||
| 	// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
 | ||||
| @ -172,19 +113,18 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | ||||
| 	// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
 | ||||
| 	stmt, err := conn.PrepareContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		cacheStmt.prepareErr = err | ||||
| 		cacheStmt.AddError(err) | ||||
| 		db.Mux.Lock() | ||||
| 		db.Stmts.delete(query) | ||||
| 		//delete(db.Stmts.AllMap(), query)
 | ||||
| 		db.Stmts.Delete(query) | ||||
| 		db.Mux.Unlock() | ||||
| 		return Stmt{}, err | ||||
| 		return stmt_store.Stmt{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	db.Mux.Lock() | ||||
| 	cacheStmt.Stmt = stmt | ||||
| 	db.Mux.Unlock() | ||||
| 
 | ||||
| 	return cacheStmt, nil | ||||
| 	return *cacheStmt, nil | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { | ||||
| @ -216,8 +156,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 			go stmt.Close() | ||||
| 			db.Stmts.delete(query) | ||||
| 			//delete(db.Stmts.AllMap(), query)
 | ||||
| 			db.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| 	return result, err | ||||
| @ -232,8 +171,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . | ||||
| 			defer db.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			db.Stmts.delete(query) | ||||
| 			//delete(db.Stmts.AllMap(), query)
 | ||||
| 			db.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| 	return rows, err | ||||
| @ -287,8 +225,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			tx.PreparedStmtDB.Stmts.delete(query) | ||||
| 			//delete(tx.PreparedStmtDB.Stmts.AllMap(), query)
 | ||||
| 			tx.PreparedStmtDB.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| 	return result, err | ||||
| @ -303,8 +240,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			tx.PreparedStmtDB.Stmts.delete(query) | ||||
| 			//delete(tx.PreparedStmtDB.Stmts.AllMap(), query)
 | ||||
| 			tx.PreparedStmtDB.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| 	return rows, err | ||||
| @ -325,79 +261,3 @@ func (tx *PreparedStmtTX) Ping() error { | ||||
| 	} | ||||
| 	return conn.Ping() | ||||
| } | ||||
| 
 | ||||
| type StmtStore interface { | ||||
| 	get(key string) (*Stmt, bool) | ||||
| 	set(key string, value *Stmt) | ||||
| 	delete(key string) | ||||
| 	AllMap() map[string]*Stmt | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	type DefaultStmtStore struct { | ||||
| 		defaultStmt map[string]*Stmt | ||||
| 	} | ||||
| 
 | ||||
| 	func (s *DefaultStmtStore) Init() *DefaultStmtStore { | ||||
| 		s.defaultStmt = make(map[string]*Stmt) | ||||
| 		return s | ||||
| 	} | ||||
| 
 | ||||
| 	func (s *DefaultStmtStore) AllMap() map[string]*Stmt { | ||||
| 		return s.defaultStmt | ||||
| 	} | ||||
| 
 | ||||
| 	func (s *DefaultStmtStore) Get(key string) (*Stmt, bool) { | ||||
| 		stmt, ok := s.defaultStmt[key] | ||||
| 		return stmt, ok | ||||
| 	} | ||||
| 
 | ||||
| 	func (s *DefaultStmtStore) Set(key string, value *Stmt) { | ||||
| 		s.defaultStmt[key] = value | ||||
| 	} | ||||
| 
 | ||||
| 	func (s *DefaultStmtStore) Delete(key string) { | ||||
| 		delete(s.defaultStmt, key) | ||||
| 	} | ||||
| */ | ||||
| type LruStmtStore struct { | ||||
| 	lru *lru.LRU[string, *Stmt] | ||||
| } | ||||
| 
 | ||||
| func (s *LruStmtStore) newLru(size int, ttl time.Duration) { | ||||
| 	onEvicted := func(k string, v *Stmt) { | ||||
| 		if v != nil { | ||||
| 			go func() { | ||||
| 				defer func() { | ||||
| 					if r := recover(); r != nil { | ||||
| 						fmt.Print("close stmt err panic ") | ||||
| 					} | ||||
| 				}() | ||||
| 				if v != nil { | ||||
| 					err := v.Close() | ||||
| 					if err != nil { | ||||
| 						//
 | ||||
| 						fmt.Print("close stmt err: ", err.Error()) | ||||
| 					} | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
| 	} | ||||
| 	s.lru = lru.NewLRU[string, *Stmt](size, onEvicted, ttl) | ||||
| } | ||||
| 
 | ||||
| func (s *LruStmtStore) AllMap() map[string]*Stmt { | ||||
| 	return s.lru.KeyValues() | ||||
| } | ||||
| func (s *LruStmtStore) get(key string) (*Stmt, bool) { | ||||
| 	stmt, ok := s.lru.Get(key) | ||||
| 	return stmt, ok | ||||
| } | ||||
| 
 | ||||
| func (s *LruStmtStore) set(key string, value *Stmt) { | ||||
| 	s.lru.Add(key, value) | ||||
| } | ||||
| 
 | ||||
| func (s *LruStmtStore) delete(key string) { | ||||
| 	s.lru.Remove(key) | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu