Fix prepared statement in transaction mode can't be shared in normal operations, close #3927
This commit is contained in:
		
							parent
							
								
									7302c8a136
								
							
						
					
					
						commit
						fe553a7c1a
					
				
							
								
								
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							| @ -126,7 +126,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | ||||
| 
 | ||||
| 	preparedStmt := &PreparedStmtDB{ | ||||
| 		ConnPool:    db.ConnPool, | ||||
| 		Stmts:       map[string]*sql.Stmt{}, | ||||
| 		Stmts:       map[string]Stmt{}, | ||||
| 		Mux:         &sync.RWMutex{}, | ||||
| 		PreparedSQL: make([]string, 0, 100), | ||||
| 	} | ||||
|  | ||||
| @ -6,8 +6,13 @@ import ( | ||||
| 	"sync" | ||||
| ) | ||||
| 
 | ||||
| type Stmt struct { | ||||
| 	*sql.Stmt | ||||
| 	Transaction bool | ||||
| } | ||||
| 
 | ||||
| type PreparedStmtDB struct { | ||||
| 	Stmts       map[string]*sql.Stmt | ||||
| 	Stmts       map[string]Stmt | ||||
| 	PreparedSQL []string | ||||
| 	Mux         *sync.RWMutex | ||||
| 	ConnPool | ||||
| @ -25,9 +30,9 @@ func (db *PreparedStmtDB) Close() { | ||||
| 	db.Mux.Unlock() | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query string) (*sql.Stmt, error) { | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { | ||||
| 	db.Mux.RLock() | ||||
| 	if stmt, ok := db.Stmts[query]; ok { | ||||
| 	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { | ||||
| 		db.Mux.RUnlock() | ||||
| 		return stmt, nil | ||||
| 	} | ||||
| @ -35,19 +40,21 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query stri | ||||
| 
 | ||||
| 	db.Mux.Lock() | ||||
| 	// double check
 | ||||
| 	if stmt, ok := db.Stmts[query]; ok { | ||||
| 	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { | ||||
| 		db.Mux.Unlock() | ||||
| 		return stmt, nil | ||||
| 	} else if ok { | ||||
| 		stmt.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	stmt, err := conn.PrepareContext(ctx, query) | ||||
| 	if err == nil { | ||||
| 		db.Stmts[query] = stmt | ||||
| 		db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} | ||||
| 		db.PreparedSQL = append(db.PreparedSQL, query) | ||||
| 	} | ||||
| 	db.Mux.Unlock() | ||||
| 
 | ||||
| 	return stmt, err | ||||
| 	return db.Stmts[query], err | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { | ||||
| @ -59,7 +66,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, query) | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, false, query) | ||||
| 	if err == nil { | ||||
| 		result, err = stmt.ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| @ -73,7 +80,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, query) | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, false, query) | ||||
| 	if err == nil { | ||||
| 		rows, err = stmt.QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| @ -87,7 +94,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, query) | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, false, query) | ||||
| 	if err == nil { | ||||
| 		return stmt.QueryRowContext(ctx, args...) | ||||
| 	} | ||||
| @ -114,9 +121,9 @@ func (tx *PreparedStmtTX) Rollback() error { | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||
| 	if err == nil { | ||||
| 		result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) | ||||
| 		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			stmt.Close() | ||||
| @ -128,9 +135,9 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||
| 	if err == nil { | ||||
| 		rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) | ||||
| 		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			stmt.Close() | ||||
| @ -142,9 +149,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||
| 	if err == nil { | ||||
| 		return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) | ||||
| 		return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) | ||||
| 	} | ||||
| 	return &sql.Row{} | ||||
| } | ||||
|  | ||||
| @ -50,3 +50,41 @@ func TestPreparedStmt(t *testing.T) { | ||||
| 		t.Fatalf("no error should happen but got %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtFromTransaction(t *testing.T) { | ||||
| 	db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) | ||||
| 
 | ||||
| 	tx := db.Begin() | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			tx.Rollback() | ||||
| 		} | ||||
| 	}() | ||||
| 	if err := tx.Error; err != nil { | ||||
| 		t.Errorf("Failed to start transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { | ||||
| 		tx.Rollback() | ||||
| 		t.Errorf("Failed to run one transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { | ||||
| 		tx.Rollback() | ||||
| 		t.Errorf("Failed to run one transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Commit().Error; err != nil { | ||||
| 		t.Errorf("Failed to commit transaction, got error %v\n", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { | ||||
| 		t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) | ||||
| 	} | ||||
| 
 | ||||
| 	tx2 := db.Begin() | ||||
| 	if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { | ||||
| 		t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) | ||||
| 	} | ||||
| 	tx2.Commit() | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu