fix memory leaks in PrepareStatementDB (#7142)
* fix memory leaks in PrepareStatementDB * Fix CR: 1) Fix potential Segmentation Fault in Reset function 2) Setting db.Stmts to nil map when Close to avoid further using * Add Test: 1) TestPreparedStmtConcurrentReset 2) TestPreparedStmtConcurrentClose * Fix test, create new connection to keep away from other tests --------- Co-authored-by: Zehui Chen <zehui@ssc-hn.com>
This commit is contained in:
		
							parent
							
								
									4a50b36f63
								
							
						
					
					
						commit
						0dbfda5d7e
					
				| @ -17,18 +17,16 @@ type Stmt struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type PreparedStmtDB struct { | type PreparedStmtDB struct { | ||||||
| 	Stmts       map[string]*Stmt | 	Stmts map[string]*Stmt | ||||||
| 	PreparedSQL []string | 	Mux   *sync.RWMutex | ||||||
| 	Mux         *sync.RWMutex |  | ||||||
| 	ConnPool | 	ConnPool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { | func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { | ||||||
| 	return &PreparedStmtDB{ | 	return &PreparedStmtDB{ | ||||||
| 		ConnPool:    connPool, | 		ConnPool: connPool, | ||||||
| 		Stmts:       make(map[string]*Stmt), | 		Stmts:    make(map[string]*Stmt), | ||||||
| 		Mux:         &sync.RWMutex{}, | 		Mux:      &sync.RWMutex{}, | ||||||
| 		PreparedSQL: make([]string, 0, 100), |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -48,12 +46,17 @@ func (db *PreparedStmtDB) Close() { | |||||||
| 	db.Mux.Lock() | 	db.Mux.Lock() | ||||||
| 	defer db.Mux.Unlock() | 	defer db.Mux.Unlock() | ||||||
| 
 | 
 | ||||||
| 	for _, query := range db.PreparedSQL { | 	for _, stmt := range db.Stmts { | ||||||
| 		if stmt, ok := db.Stmts[query]; ok { | 		go func(s *Stmt) { | ||||||
| 			delete(db.Stmts, query) | 			// make sure the stmt must finish preparation first
 | ||||||
| 			go stmt.Close() | 			<-s.prepared | ||||||
| 		} | 			if s.Stmt != nil { | ||||||
|  | 				_ = s.Close() | ||||||
|  | 			} | ||||||
|  | 		}(stmt) | ||||||
| 	} | 	} | ||||||
|  | 	// setting db.Stmts to nil to avoid further using
 | ||||||
|  | 	db.Stmts = nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (sdb *PreparedStmtDB) Reset() { | func (sdb *PreparedStmtDB) Reset() { | ||||||
| @ -61,9 +64,14 @@ func (sdb *PreparedStmtDB) Reset() { | |||||||
| 	defer sdb.Mux.Unlock() | 	defer sdb.Mux.Unlock() | ||||||
| 
 | 
 | ||||||
| 	for _, stmt := range sdb.Stmts { | 	for _, stmt := range sdb.Stmts { | ||||||
| 		go stmt.Close() | 		go func(s *Stmt) { | ||||||
|  | 			// make sure the stmt must finish preparation first
 | ||||||
|  | 			<-s.prepared | ||||||
|  | 			if s.Stmt != nil { | ||||||
|  | 				_ = s.Close() | ||||||
|  | 			} | ||||||
|  | 		}(stmt) | ||||||
| 	} | 	} | ||||||
| 	sdb.PreparedSQL = make([]string, 0, 100) |  | ||||||
| 	sdb.Stmts = make(map[string]*Stmt) | 	sdb.Stmts = make(map[string]*Stmt) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | |||||||
| 
 | 
 | ||||||
| 		return *stmt, nil | 		return *stmt, nil | ||||||
| 	} | 	} | ||||||
| 
 | 	// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
 | ||||||
|  | 	// which cause by calling Close and executing SQL concurrently
 | ||||||
|  | 	if db.Stmts == nil { | ||||||
|  | 		db.Mux.Unlock() | ||||||
|  | 		return Stmt{}, ErrInvalidDB | ||||||
|  | 	} | ||||||
| 	// cache preparing stmt first
 | 	// cache preparing stmt first
 | ||||||
| 	cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} | 	cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} | ||||||
| 	db.Stmts[query] = &cacheStmt | 	db.Stmts[query] = &cacheStmt | ||||||
| @ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | |||||||
| 
 | 
 | ||||||
| 	db.Mux.Lock() | 	db.Mux.Lock() | ||||||
| 	cacheStmt.Stmt = stmt | 	cacheStmt.Stmt = stmt | ||||||
| 	db.PreparedSQL = append(db.PreparedSQL, query) |  | ||||||
| 	db.Mux.Unlock() | 	db.Mux.Unlock() | ||||||
| 
 | 
 | ||||||
| 	return cacheStmt, nil | 	return cacheStmt, nil | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"sync" | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| @ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) { | |||||||
| 		t.Fatalf("prepared stmt should be empty") | 		t.Fatalf("prepared stmt should be empty") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func isUsingClosedConnError(err error) bool { | ||||||
|  | 	// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
 | ||||||
|  | 	return err.Error() == "sql: statement is closed" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
 | ||||||
|  | // this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
 | ||||||
|  | func TestPreparedStmtConcurrentReset(t *testing.T) { | ||||||
|  | 	name := "prepared_stmt_concurrent_reset" | ||||||
|  | 	user := *GetUser(name, Config{}) | ||||||
|  | 	createTx := DB.Session(&gorm.Session{}).Create(&user) | ||||||
|  | 	if createTx.Error != nil { | ||||||
|  | 		t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// create a new connection to keep away from other tests
 | ||||||
|  | 	tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to open test connection due to %s", err) | ||||||
|  | 	} | ||||||
|  | 	pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	loopCount := 100 | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  | 	var unexpectedError bool | ||||||
|  | 	writerFinish := make(chan struct{}) | ||||||
|  | 
 | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	go func(id uint) { | ||||||
|  | 		defer wg.Done() | ||||||
|  | 		defer close(writerFinish) | ||||||
|  | 
 | ||||||
|  | 		for j := 0; j < loopCount; j++ { | ||||||
|  | 			var tmp User | ||||||
|  | 			err := tx.Session(&gorm.Session{}).First(&tmp, id).Error | ||||||
|  | 			if err == nil || isUsingClosedConnError(err) { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err) | ||||||
|  | 			unexpectedError = true | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	}(user.ID) | ||||||
|  | 
 | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	go func() { | ||||||
|  | 		defer wg.Done() | ||||||
|  | 		<-writerFinish | ||||||
|  | 		pdb.Reset() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	wg.Wait() | ||||||
|  | 
 | ||||||
|  | 	if unexpectedError { | ||||||
|  | 		t.Fatalf("should is a unexpected error") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
 | ||||||
|  | // for example: one goroutine found error and just close the database, and others are executing SQL
 | ||||||
|  | // this test making sure that the gorm would not get a Segmentation Fault,
 | ||||||
|  | // and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
 | ||||||
|  | // and all of the goroutine must got gorm.ErrInvalidDB after database close
 | ||||||
|  | func TestPreparedStmtConcurrentClose(t *testing.T) { | ||||||
|  | 	name := "prepared_stmt_concurrent_close" | ||||||
|  | 	user := *GetUser(name, Config{}) | ||||||
|  | 	createTx := DB.Session(&gorm.Session{}).Create(&user) | ||||||
|  | 	if createTx.Error != nil { | ||||||
|  | 		t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// create a new connection to keep away from other tests
 | ||||||
|  | 	tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to open test connection due to %s", err) | ||||||
|  | 	} | ||||||
|  | 	pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	loopCount := 100 | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  | 	var lastErr error | ||||||
|  | 	closeValid := make(chan struct{}, loopCount) | ||||||
|  | 	closeStartIdx := loopCount / 2 // close the database at the middle of the execution
 | ||||||
|  | 	var lastRunIndex int | ||||||
|  | 	var closeFinishedAt int64 | ||||||
|  | 
 | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	go func(id uint) { | ||||||
|  | 		defer wg.Done() | ||||||
|  | 		defer close(closeValid) | ||||||
|  | 		for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { | ||||||
|  | 			if lastRunIndex == closeStartIdx { | ||||||
|  | 				closeValid <- struct{}{} | ||||||
|  | 			} | ||||||
|  | 			var tmp User | ||||||
|  | 			now := time.Now().UnixNano() | ||||||
|  | 			err := tx.Session(&gorm.Session{}).First(&tmp, id).Error | ||||||
|  | 			if err == nil { | ||||||
|  | 				closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) | ||||||
|  | 				if (closeFinishedAt != 0) && (now > closeFinishedAt) { | ||||||
|  | 					lastErr = errors.New("must got error after database closed") | ||||||
|  | 					break | ||||||
|  | 				} | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			lastErr = err | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	}(user.ID) | ||||||
|  | 
 | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	go func() { | ||||||
|  | 		defer wg.Done() | ||||||
|  | 		for range closeValid { | ||||||
|  | 			for i := 0; i < loopCount; i++ { | ||||||
|  | 				pdb.Close() // the Close method must can be call multiple times
 | ||||||
|  | 				atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	wg.Wait() | ||||||
|  | 	var tmp User | ||||||
|  | 	err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error | ||||||
|  | 	if err != gorm.ErrInvalidDB { | ||||||
|  | 		t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// must be error
 | ||||||
|  | 	if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { | ||||||
|  | 		t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) | ||||||
|  | 	} | ||||||
|  | 	if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { | ||||||
|  | 		t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) | ||||||
|  | 	} | ||||||
|  | 	if pdb.Stmts != nil { | ||||||
|  | 		t.Fatalf("stmts must be nil") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 ivila
						ivila