refact prepare stmt store
This commit is contained in:
		
							parent
							
								
									886a406556
								
							
						
					
					
						commit
						14dc8ed9e0
					
				| @ -1,8 +1,9 @@ | ||||
| package stmt_store | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/internal/lru" | ||||
| @ -15,24 +16,7 @@ type Stmt struct { | ||||
| 	prepareErr  error | ||||
| } | ||||
| 
 | ||||
| func NewStmt(isTransaction bool) *Stmt { | ||||
| 	return &Stmt{ | ||||
| 		Transaction: isTransaction, | ||||
| 		prepared:    make(chan struct{}), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) Done() { | ||||
| 	close(stmt.prepared) | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) AddError(err error) { | ||||
| 	stmt.prepareErr = err | ||||
| } | ||||
| 
 | ||||
| func (stmt *Stmt) Error() error { | ||||
| 	<-stmt.prepared | ||||
| 
 | ||||
| 	return stmt.prepareErr | ||||
| } | ||||
| 
 | ||||
| @ -46,13 +30,14 @@ func (stmt *Stmt) Close() error { | ||||
| } | ||||
| 
 | ||||
| type Store interface { | ||||
| 	New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error) | ||||
| 	Keys() []string | ||||
| 	Get(key string) (*Stmt, bool) | ||||
| 	Set(key string, value *Stmt) | ||||
| 	Delete(key string) | ||||
| 	AllMap() map[string]*Stmt | ||||
| } | ||||
| 
 | ||||
| type StmtStore struct { | ||||
| type LRUStore struct { | ||||
| 	lru *lru.LRU[string, *Stmt] | ||||
| } | ||||
| 
 | ||||
| @ -72,35 +57,52 @@ func New(size int, ttl time.Duration) Store { | ||||
| 
 | ||||
| 	onEvicted := func(k string, v *Stmt) { | ||||
| 		if v != nil { | ||||
| 			go func() { | ||||
| 				defer func() { | ||||
| 					if r := recover(); r != nil { | ||||
| 						fmt.Print("close stmt err panic ") | ||||
| 					} | ||||
| 				}() | ||||
| 				err := v.Close() | ||||
| 				if err != nil { | ||||
| 					fmt.Print("close stmt err: ", err.Error()) | ||||
| 				} | ||||
| 			}() | ||||
| 			go v.Close() | ||||
| 		} | ||||
| 	} | ||||
| 	return &StmtStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} | ||||
| 	return &LRUStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) AllMap() map[string]*Stmt { | ||||
| 	return s.lru.KeyValues() | ||||
| func (s *LRUStore) Keys() []string { | ||||
| 	return s.lru.Keys() | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) Get(key string) (*Stmt, bool) { | ||||
| func (s *LRUStore) Get(key string) (*Stmt, bool) { | ||||
| 	stmt, ok := s.lru.Get(key) | ||||
| 	if ok && stmt != nil { | ||||
| 		<-stmt.prepared | ||||
| 	} | ||||
| 	return stmt, ok | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) Set(key string, value *Stmt) { | ||||
| func (s *LRUStore) Set(key string, value *Stmt) { | ||||
| 	s.lru.Add(key, value) | ||||
| } | ||||
| 
 | ||||
| func (s *StmtStore) Delete(key string) { | ||||
| func (s *LRUStore) Delete(key string) { | ||||
| 	s.lru.Remove(key) | ||||
| } | ||||
| 
 | ||||
| type ConnPool interface { | ||||
| 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||
| } | ||||
| 
 | ||||
| func (s *LRUStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) { | ||||
| 	cacheStmt := &Stmt{ | ||||
| 		Transaction: isTransaction, | ||||
| 		prepared:    make(chan struct{}), | ||||
| 	} | ||||
| 	s.Set(key, cacheStmt) | ||||
| 	locker.Unlock() | ||||
| 
 | ||||
| 	defer close(cacheStmt.prepared) | ||||
| 
 | ||||
| 	cacheStmt.Stmt, err = conn.PrepareContext(ctx, key) | ||||
| 	if err != nil { | ||||
| 		cacheStmt.prepareErr = err | ||||
| 		s.Delete(key) | ||||
| 		return &Stmt{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	return cacheStmt, nil | ||||
| } | ||||
|  | ||||
| @ -18,6 +18,7 @@ type PreparedStmtDB struct { | ||||
| 	ConnPool | ||||
| } | ||||
| 
 | ||||
| // NewPreparedStmtDB creates a new PreparedStmtDB instance
 | ||||
| func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { | ||||
| 	return &PreparedStmtDB{ | ||||
| 		ConnPool: connPool, | ||||
| @ -26,6 +27,7 @@ func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *Prepa | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // GetDBConn returns the underlying *sql.DB connection
 | ||||
| func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { | ||||
| 	if sqldb, ok := db.ConnPool.(*sql.DB); ok { | ||||
| 		return sqldb, nil | ||||
| @ -38,93 +40,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { | ||||
| 	return nil, ErrInvalidDB | ||||
| } | ||||
| 
 | ||||
| // Close closes all prepared statements in the store
 | ||||
| func (db *PreparedStmtDB) Close() { | ||||
| 	db.Mux.Lock() | ||||
| 	defer db.Mux.Unlock() | ||||
| 	if db.Stmts == nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	for _, stmt := range db.Stmts.AllMap() { | ||||
| 		go stmt.Close() | ||||
| 	for _, key := range db.Stmts.Keys() { | ||||
| 		db.Stmts.Delete(key) | ||||
| 	} | ||||
| 	// setting db.Stmts to nil to avoid further using
 | ||||
| 	db.Stmts = nil | ||||
| } | ||||
| 
 | ||||
| func (sdb *PreparedStmtDB) Reset() { | ||||
| 	sdb.Mux.Lock() | ||||
| 	defer sdb.Mux.Unlock() | ||||
| 	if sdb.Stmts == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	for _, stmt := range sdb.Stmts.AllMap() { | ||||
| 		go stmt.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	// Migrator
 | ||||
| 	sdb.Stmts = stmt_store.New(0, 0) | ||||
| // Reset Deprecated use Close instead
 | ||||
| func (db *PreparedStmtDB) Reset() { | ||||
| 	db.Close() | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (stmt_store.Stmt, error) { | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) { | ||||
| 	db.Mux.RLock() | ||||
| 	if db.Stmts != nil { | ||||
| 		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 			db.Mux.RUnlock() | ||||
| 			if err := stmt.Error(); err != nil { | ||||
| 				return stmt_store.Stmt{}, err | ||||
| 			} | ||||
| 
 | ||||
| 			return *stmt, nil | ||||
| 			return stmt, stmt.Error() | ||||
| 		} | ||||
| 	} | ||||
| 	db.Mux.RUnlock() | ||||
| 
 | ||||
| 	// retry
 | ||||
| 	db.Mux.Lock() | ||||
| 	if db.Stmts != nil { | ||||
| 		// double check
 | ||||
| 		if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { | ||||
| 			db.Mux.Unlock() | ||||
| 			if err := stmt.Error(); err != nil { | ||||
| 				return stmt_store.Stmt{}, err | ||||
| 			} | ||||
| 
 | ||||
| 			return *stmt, nil | ||||
| 			return stmt, stmt.Error() | ||||
| 		} | ||||
| 	} | ||||
| 	// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
 | ||||
| 	// which cause by calling Close and executing SQL concurrently
 | ||||
| 	if db.Stmts == nil { | ||||
| 		db.Mux.Unlock() | ||||
| 		return stmt_store.Stmt{}, ErrInvalidDB | ||||
| 	} | ||||
| 	// cache preparing stmt first
 | ||||
| 	cacheStmt := stmt_store.NewStmt(isTransaction) | ||||
| 	db.Stmts.Set(query, cacheStmt) | ||||
| 	db.Mux.Unlock() | ||||
| 
 | ||||
| 	// prepare completed
 | ||||
| 	defer cacheStmt.Done() | ||||
| 
 | ||||
| 	// Reason why cannot lock conn.PrepareContext
 | ||||
| 	// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
 | ||||
| 	// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
 | ||||
| 	// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
 | ||||
| 	// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
 | ||||
| 	stmt, err := conn.PrepareContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		cacheStmt.AddError(err) | ||||
| 		db.Mux.Lock() | ||||
| 		db.Stmts.Delete(query) | ||||
| 		db.Mux.Unlock() | ||||
| 		return stmt_store.Stmt{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	db.Mux.Lock() | ||||
| 	cacheStmt.Stmt = stmt | ||||
| 	db.Mux.Unlock() | ||||
| 
 | ||||
| 	return *cacheStmt, nil | ||||
| 	return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux) | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { | ||||
| @ -153,9 +103,6 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. | ||||
| 	if err == nil { | ||||
| 		result, err = stmt.ExecContext(ctx, args...) | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 			go stmt.Close() | ||||
| 			db.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| @ -167,10 +114,6 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . | ||||
| 	if err == nil { | ||||
| 		rows, err = stmt.QueryContext(ctx, args...) | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			db.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| @ -221,10 +164,6 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| 	if err == nil { | ||||
| 		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			tx.PreparedStmtDB.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
| @ -236,10 +175,6 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . | ||||
| 	if err == nil { | ||||
| 		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			tx.PreparedStmtDB.Stmts.Delete(query) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										18
									
								
								tests/go.mod
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								tests/go.mod
									
									
									
									
									
								
							| @ -1,15 +1,17 @@ | ||||
| module gorm.io/gorm/tests | ||||
| 
 | ||||
| go 1.18 | ||||
| go 1.23.0 | ||||
| 
 | ||||
| toolchain go1.24.2 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/google/uuid v1.6.0 | ||||
| 	github.com/jinzhu/now v1.1.5 | ||||
| 	github.com/lib/pq v1.10.9 | ||||
| 	github.com/stretchr/testify v1.9.0 | ||||
| 	github.com/stretchr/testify v1.10.0 | ||||
| 	gorm.io/driver/mysql v1.5.7 | ||||
| 	gorm.io/driver/postgres v1.5.10 | ||||
| 	gorm.io/driver/sqlite v1.5.6 | ||||
| 	gorm.io/driver/postgres v1.5.11 | ||||
| 	gorm.io/driver/sqlite v1.5.7 | ||||
| 	gorm.io/driver/sqlserver v1.5.4 | ||||
| 	gorm.io/gorm v1.25.12 | ||||
| ) | ||||
| @ -17,7 +19,7 @@ require ( | ||||
| require ( | ||||
| 	filippo.io/edwards25519 v1.1.0 // indirect | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.8.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.9.2 // indirect | ||||
| 	github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect | ||||
| 	github.com/golang-sql/sqlexp v0.1.0 // indirect | ||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||
| @ -25,12 +27,12 @@ require ( | ||||
| 	github.com/jackc/pgx/v5 v5.7.1 // indirect | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/kr/text v0.2.0 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v1.14.24 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v1.14.28 // indirect | ||||
| 	github.com/microsoft/go-mssqldb v1.7.2 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/rogpeppe/go-internal v1.12.0 // indirect | ||||
| 	golang.org/x/crypto v0.29.0 // indirect | ||||
| 	golang.org/x/text v0.20.0 // indirect | ||||
| 	golang.org/x/crypto v0.37.0 // indirect | ||||
| 	golang.org/x/text v0.24.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
| 
 | ||||
|  | ||||
| @ -4,7 +4,6 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -130,13 +129,13 @@ func TestPreparedStmtLruFromTransaction(t *testing.T) { | ||||
| 
 | ||||
| 	tx2.Commit() | ||||
| 	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| 	lens := len(conn.Stmts.AllMap()) | ||||
| 	lens := len(conn.Stmts.Keys()) | ||||
| 	if lens == 0 { | ||||
| 		t.Fatalf("lru should not be empty") | ||||
| 	} | ||||
| 	time.Sleep(time.Second * 40) | ||||
| 	AssertEqual(t, ok, true) | ||||
| 	AssertEqual(t, len(conn.Stmts.AllMap()), 0) | ||||
| 	AssertEqual(t, len(conn.Stmts.Keys()), 0) | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtDeadlock(t *testing.T) { | ||||
| @ -164,9 +163,9 @@ func TestPreparedStmtDeadlock(t *testing.T) { | ||||
| 
 | ||||
| 	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| 	AssertEqual(t, ok, true) | ||||
| 	AssertEqual(t, len(conn.Stmts.AllMap()), 2) | ||||
| 	for _, stmt := range conn.Stmts.AllMap() { | ||||
| 		if stmt == nil { | ||||
| 	AssertEqual(t, len(conn.Stmts.Keys()), 2) | ||||
| 	for _, stmt := range conn.Stmts.Keys() { | ||||
| 		if stmt == "" { | ||||
| 			t.Fatalf("stmt cannot bee nil") | ||||
| 		} | ||||
| 	} | ||||
| @ -190,10 +189,10 @@ func TestPreparedStmtInTransaction(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtReset(t *testing.T) { | ||||
| func TestPreparedStmtClose(t *testing.T) { | ||||
| 	tx := DB.Session(&gorm.Session{PrepareStmt: true}) | ||||
| 
 | ||||
| 	user := *GetUser("prepared_stmt_reset", Config{}) | ||||
| 	user := *GetUser("prepared_stmt_close", Config{}) | ||||
| 	tx = tx.Create(&user) | ||||
| 
 | ||||
| 	pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| @ -202,16 +201,16 @@ func TestPreparedStmtReset(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	pdb.Mux.Lock() | ||||
| 	if len(pdb.Stmts.AllMap()) == 0 { | ||||
| 	if len(pdb.Stmts.Keys()) == 0 { | ||||
| 		pdb.Mux.Unlock() | ||||
| 		t.Fatalf("prepared stmt can not be empty") | ||||
| 	} | ||||
| 	pdb.Mux.Unlock() | ||||
| 
 | ||||
| 	pdb.Reset() | ||||
| 	pdb.Close() | ||||
| 	pdb.Mux.Lock() | ||||
| 	defer pdb.Mux.Unlock() | ||||
| 	if len(pdb.Stmts.AllMap()) != 0 { | ||||
| 	if len(pdb.Stmts.Keys()) != 0 { | ||||
| 		t.Fatalf("prepared stmt should be empty") | ||||
| 	} | ||||
| } | ||||
| @ -221,10 +220,10 @@ func isUsingClosedConnError(err error) bool { | ||||
| 	return err.Error() == "sql: statement is closed" | ||||
| } | ||||
| 
 | ||||
| // TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
 | ||||
| // TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
 | ||||
| // this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
 | ||||
| func TestPreparedStmtConcurrentReset(t *testing.T) { | ||||
| 	name := "prepared_stmt_concurrent_reset" | ||||
| func TestPreparedStmtConcurrentClose(t *testing.T) { | ||||
| 	name := "prepared_stmt_concurrent_close" | ||||
| 	user := *GetUser(name, Config{}) | ||||
| 	createTx := DB.Session(&gorm.Session{}).Create(&user) | ||||
| 	if createTx.Error != nil { | ||||
| @ -267,7 +266,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) { | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		<-writerFinish | ||||
| 		pdb.Reset() | ||||
| 		pdb.Close() | ||||
| 	}() | ||||
| 
 | ||||
| 	wg.Wait() | ||||
| @ -276,88 +275,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) { | ||||
| 		t.Fatalf("should is a unexpected error") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
 | ||||
| // for example: one goroutine found error and just close the database, and others are executing SQL
 | ||||
| // this test making sure that the gorm would not get a Segmentation Fault,
 | ||||
| // and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
 | ||||
| // and all of the goroutine must got gorm.ErrInvalidDB after database close
 | ||||
| func TestPreparedStmtConcurrentClose(t *testing.T) { | ||||
| 	name := "prepared_stmt_concurrent_close" | ||||
| 	user := *GetUser(name, Config{}) | ||||
| 	createTx := DB.Session(&gorm.Session{}).Create(&user) | ||||
| 	if createTx.Error != nil { | ||||
| 		t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) | ||||
| 	} | ||||
| 
 | ||||
| 	// create a new connection to keep away from other tests
 | ||||
| 	tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to open test connection due to %s", err) | ||||
| 	} | ||||
| 	pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| 	if !ok { | ||||
| 		t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") | ||||
| 	} | ||||
| 
 | ||||
| 	loopCount := 100 | ||||
| 	var wg sync.WaitGroup | ||||
| 	var lastErr error | ||||
| 	closeValid := make(chan struct{}, loopCount) | ||||
| 	closeStartIdx := loopCount / 2 // close the database at the middle of the execution
 | ||||
| 	var lastRunIndex int | ||||
| 	var closeFinishedAt int64 | ||||
| 
 | ||||
| 	wg.Add(1) | ||||
| 	go func(id uint) { | ||||
| 		defer wg.Done() | ||||
| 		defer close(closeValid) | ||||
| 		for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { | ||||
| 			if lastRunIndex == closeStartIdx { | ||||
| 				closeValid <- struct{}{} | ||||
| 			} | ||||
| 			var tmp User | ||||
| 			now := time.Now().UnixNano() | ||||
| 			err := tx.Session(&gorm.Session{}).First(&tmp, id).Error | ||||
| 			if err == nil { | ||||
| 				closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) | ||||
| 				if (closeFinishedAt != 0) && (now > closeFinishedAt) { | ||||
| 					lastErr = errors.New("must got error after database closed") | ||||
| 					break | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| 			lastErr = err | ||||
| 			break | ||||
| 		} | ||||
| 	}(user.ID) | ||||
| 
 | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		for range closeValid { | ||||
| 			for i := 0; i < loopCount; i++ { | ||||
| 				pdb.Close() // the Close method must can be call multiple times
 | ||||
| 				atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	wg.Wait() | ||||
| 	var tmp User | ||||
| 	err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error | ||||
| 	if err != gorm.ErrInvalidDB { | ||||
| 		t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// must be error
 | ||||
| 	if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { | ||||
| 		t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) | ||||
| 	} | ||||
| 	if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { | ||||
| 		t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) | ||||
| 	} | ||||
| 	if pdb.Stmts != nil { | ||||
| 		t.Fatalf("stmts must be nil") | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu