支持lru淘汰preparestmt cache
This commit is contained in:
		
							parent
							
								
									5225c20309
								
							
						
					
					
						commit
						3ae5fdee0c
					
				| @ -24,21 +24,26 @@ type PreparedStmtDB struct { | ||||
| 	ConnPool | ||||
| } | ||||
| 
 | ||||
| func newPrepareStmtCache(prepareStmtLruConfig *PrepareStmtLruConfig) *StmtStore { | ||||
| 	var stmts StmtStore | ||||
| 	if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open { | ||||
| 		if prepareStmtLruConfig.Size <= 0 { | ||||
| 			panic("LRU prepareStmtLruConfig.Size must > 0") | ||||
| 		} | ||||
| 		lru := &LruStmtStore{} | ||||
| 		lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL) | ||||
| 		stmts = lru | ||||
| 	} else { | ||||
| 		defaultStmtStore := &DefaultStmtStore{} | ||||
| 		stmts = defaultStmtStore.init() | ||||
| 	} | ||||
| 	return &stmts | ||||
| } | ||||
| func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB { | ||||
| 	return &PreparedStmtDB{ | ||||
| 		ConnPool: connPool, | ||||
| 		Stmts: func() StmtStore { | ||||
| 			var stmts StmtStore | ||||
| 			if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open { | ||||
| 				lru := &LruStmtStore{} | ||||
| 				lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL) | ||||
| 				stmts = lru | ||||
| 			} else { | ||||
| 				stmts = &DefaultStmtStore{} | ||||
| 			} | ||||
| 			return stmts | ||||
| 		}(), | ||||
| 		Mux: &sync.RWMutex{}, | ||||
| 		Stmts:    *newPrepareStmtCache(prepareStmtLruConfig), | ||||
| 		Mux:      &sync.RWMutex{}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -57,6 +62,9 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { | ||||
| func (db *PreparedStmtDB) Close() { | ||||
| 	db.Mux.Lock() | ||||
| 	defer db.Mux.Unlock() | ||||
| 	if db.Stmts == nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	for _, stmt := range db.Stmts.AllMap() { | ||||
| 		go func(s *Stmt) { | ||||
| @ -74,7 +82,9 @@ func (db *PreparedStmtDB) Close() { | ||||
| func (sdb *PreparedStmtDB) Reset() { | ||||
| 	sdb.Mux.Lock() | ||||
| 	defer sdb.Mux.Unlock() | ||||
| 
 | ||||
| 	if sdb.Stmts == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	for _, stmt := range sdb.Stmts.AllMap() { | ||||
| 		go func(s *Stmt) { | ||||
| 			// make sure the stmt must finish preparation first
 | ||||
| @ -84,34 +94,40 @@ func (sdb *PreparedStmtDB) Reset() { | ||||
| 			} | ||||
| 		}(stmt) | ||||
| 	} | ||||
| 	sdb.Stmts = &DefaultStmtStore{} | ||||
| 	defaultStmt := &DefaultStmtStore{} | ||||
| 	defaultStmt.init() | ||||
| 	sdb.Stmts = defaultStmt | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { | ||||
| 	db.Mux.RLock() | ||||
| 	if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 		db.Mux.RUnlock() | ||||
| 		// wait for other goroutines prepared
 | ||||
| 		<-stmt.prepared | ||||
| 		if stmt.prepareErr != nil { | ||||
| 			return Stmt{}, stmt.prepareErr | ||||
| 		} | ||||
| 	if db.Stmts != nil { | ||||
| 		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 			db.Mux.RUnlock() | ||||
| 			// wait for other goroutines prepared
 | ||||
| 			<-stmt.prepared | ||||
| 			if stmt.prepareErr != nil { | ||||
| 				return Stmt{}, stmt.prepareErr | ||||
| 			} | ||||
| 
 | ||||
| 		return *stmt, nil | ||||
| 			return *stmt, nil | ||||
| 		} | ||||
| 	} | ||||
| 	db.Mux.RUnlock() | ||||
| 
 | ||||
| 	db.Mux.Lock() | ||||
| 	// double check
 | ||||
| 	if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 		db.Mux.Unlock() | ||||
| 		// wait for other goroutines prepared
 | ||||
| 		<-stmt.prepared | ||||
| 		if stmt.prepareErr != nil { | ||||
| 			return Stmt{}, stmt.prepareErr | ||||
| 		} | ||||
| 	if db.Stmts != nil { | ||||
| 		// double check
 | ||||
| 		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 			db.Mux.Unlock() | ||||
| 			// wait for other goroutines prepared
 | ||||
| 			<-stmt.prepared | ||||
| 			if stmt.prepareErr != nil { | ||||
| 				return Stmt{}, stmt.prepareErr | ||||
| 			} | ||||
| 
 | ||||
| 		return *stmt, nil | ||||
| 			return *stmt, nil | ||||
| 		} | ||||
| 	} | ||||
| 	// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
 | ||||
| 	// which cause by calling Close and executing SQL concurrently
 | ||||
| @ -295,11 +311,15 @@ type StmtStore interface { | ||||
| 	AllMap() map[string]*Stmt | ||||
| } | ||||
| 
 | ||||
| // 默认的 map 实现
 | ||||
| type DefaultStmtStore struct { | ||||
| 	defaultStmt map[string]*Stmt | ||||
| } | ||||
| 
 | ||||
| func (s *DefaultStmtStore) init() *DefaultStmtStore { | ||||
| 	s.defaultStmt = make(map[string]*Stmt) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| func (s *DefaultStmtStore) AllMap() map[string]*Stmt { | ||||
| 	return s.defaultStmt | ||||
| } | ||||
|  | ||||
| @ -92,6 +92,48 @@ func TestPreparedStmtFromTransaction(t *testing.T) { | ||||
| 	tx2.Commit() | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtLruFromTransaction(t *testing.T) { | ||||
| 	db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtLruConfig: &gorm.PrepareStmtLruConfig{10, 20 * time.Second, true}}) | ||||
| 
 | ||||
| 	tx := db.Begin() | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			tx.Rollback() | ||||
| 		} | ||||
| 	}() | ||||
| 	if err := tx.Error; err != nil { | ||||
| 		t.Errorf("Failed to start transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { | ||||
| 		tx.Rollback() | ||||
| 		t.Errorf("Failed to run one transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { | ||||
| 		tx.Rollback() | ||||
| 		t.Errorf("Failed to run one transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Commit().Error; err != nil { | ||||
| 		t.Errorf("Failed to commit transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { | ||||
| 		t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) | ||||
| 	} | ||||
| 
 | ||||
| 	tx2 := db.Begin() | ||||
| 	if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { | ||||
| 		t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) | ||||
| 	} | ||||
| 	tx2.Commit() | ||||
| 	time.Sleep(time.Second * 40) | ||||
| 	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| 	AssertEqual(t, ok, true) | ||||
| 	AssertEqual(t, len(conn.Stmts.AllMap()), 0) | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtDeadlock(t *testing.T) { | ||||
| 	tx, err := OpenTestConnection(&gorm.Config{}) | ||||
| 	AssertEqual(t, err, nil) | ||||
| @ -117,8 +159,8 @@ func TestPreparedStmtDeadlock(t *testing.T) { | ||||
| 
 | ||||
| 	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| 	AssertEqual(t, ok, true) | ||||
| 	AssertEqual(t, len(conn.Stmts), 2) | ||||
| 	for _, stmt := range conn.Stmts { | ||||
| 	AssertEqual(t, len(conn.Stmts.AllMap()), 2) | ||||
| 	for _, stmt := range conn.Stmts.AllMap() { | ||||
| 		if stmt == nil { | ||||
| 			t.Fatalf("stmt cannot bee nil") | ||||
| 		} | ||||
| @ -155,7 +197,7 @@ func TestPreparedStmtReset(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	pdb.Mux.Lock() | ||||
| 	if len(pdb.Stmts) == 0 { | ||||
| 	if len(pdb.Stmts.AllMap()) == 0 { | ||||
| 		pdb.Mux.Unlock() | ||||
| 		t.Fatalf("prepared stmt can not be empty") | ||||
| 	} | ||||
| @ -164,7 +206,7 @@ func TestPreparedStmtReset(t *testing.T) { | ||||
| 	pdb.Reset() | ||||
| 	pdb.Mux.Lock() | ||||
| 	defer pdb.Mux.Unlock() | ||||
| 	if len(pdb.Stmts) != 0 { | ||||
| 	if len(pdb.Stmts.AllMap()) != 0 { | ||||
| 		t.Fatalf("prepared stmt should be empty") | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 xiezhaodong
						xiezhaodong