Add TxConnPoolBeginner and Tx interface
This commit is contained in:
		
							parent
							
								
									e2e802b837
								
							
						
					
					
						commit
						996b96e812
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -3,3 +3,4 @@ documents | |||||||
| coverage.txt | coverage.txt | ||||||
| _book | _book | ||||||
| .idea | .idea | ||||||
|  | vendor | ||||||
| @ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
| // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | ||||||
| func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | ||||||
| @ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | |||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { | 	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||||
|  | 	} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { | ||||||
|  | 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||||
| 	} else { | 	} else { | ||||||
| 		err = ErrInvalidTransaction | 		err = ErrInvalidTransaction | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -50,12 +50,25 @@ type ConnPoolBeginner interface { | |||||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // TxConnPoolBeginner tx conn pool beginner
 | ||||||
|  | type TxConnPoolBeginner interface { | ||||||
|  | 	BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // TxCommitter tx committer
 | // TxCommitter tx committer
 | ||||||
| type TxCommitter interface { | type TxCommitter interface { | ||||||
| 	Commit() error | 	Commit() error | ||||||
| 	Rollback() error | 	Rollback() error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Tx sql.Tx interface
 | ||||||
|  | type Tx interface { | ||||||
|  | 	ConnPool | ||||||
|  | 	Commit() error | ||||||
|  | 	Rollback() error | ||||||
|  | 	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Valuer gorm valuer interface
 | // Valuer gorm valuer interface
 | ||||||
| type Valuer interface { | type Valuer interface { | ||||||
| 	GormValue(context.Context, *DB) clause.Expr | 	GormValue(context.Context, *DB) clause.Expr | ||||||
|  | |||||||
| @ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn | |||||||
| 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | ||||||
| 		tx, err := beginner.BeginTx(ctx, opt) | 		tx, err := beginner.BeginTx(ctx, opt) | ||||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||||
|  | 	} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { | ||||||
|  | 		tx, err := beginner.BeginTx(ctx, opt) | ||||||
|  | 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||||
| 	} | 	} | ||||||
| 	return nil, ErrInvalidTransaction | 	return nil, ErrInvalidTransaction | ||||||
| } | } | ||||||
| @ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type PreparedStmtTX struct { | type PreparedStmtTX struct { | ||||||
| 	*sql.Tx | 	Tx | ||||||
| 	PreparedStmtDB *PreparedStmtDB | 	PreparedStmtDB *PreparedStmtDB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | |||||||
| func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { | func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { | ||||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) | 		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			tx.PreparedStmtDB.Mux.Lock() | 			tx.PreparedStmtDB.Mux.Lock() | ||||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||||
|  | |||||||
							
								
								
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,181 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"log" | ||||||
|  | 	"os" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"gorm.io/driver/mysql" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/logger" | ||||||
|  | 	. "gorm.io/gorm/utils/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type wrapperTx struct { | ||||||
|  | 	*sql.Tx | ||||||
|  | 	conn *wrapperConnPool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||||
|  | 	c.conn.got = append(c.conn.got, query) | ||||||
|  | 	return c.Tx.PrepareContext(ctx, query) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	c.conn.got = append(c.conn.got, query) | ||||||
|  | 	return c.Tx.ExecContext(ctx, query, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||||
|  | 	c.conn.got = append(c.conn.got, query) | ||||||
|  | 	return c.Tx.QueryContext(ctx, query, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||||
|  | 	c.conn.got = append(c.conn.got, query) | ||||||
|  | 	return c.Tx.QueryRowContext(ctx, query, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type wrapperConnPool struct { | ||||||
|  | 	db     *sql.DB | ||||||
|  | 	got    []string | ||||||
|  | 	expect []string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperConnPool) Ping() error { | ||||||
|  | 	return c.db.Ping() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
 | ||||||
|  | // func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
 | ||||||
|  | //	 return c.db.BeginTx(ctx, opts)
 | ||||||
|  | // }
 | ||||||
|  | // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | ||||||
|  | func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { | ||||||
|  | 	tx, err := c.db.BeginTx(ctx, opts) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &wrapperTx{Tx: tx, conn: c}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||||
|  | 	c.got = append(c.got, query) | ||||||
|  | 	return c.db.PrepareContext(ctx, query) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	c.got = append(c.got, query) | ||||||
|  | 	return c.db.ExecContext(ctx, query, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||||
|  | 	c.got = append(c.got, query) | ||||||
|  | 	return c.db.QueryContext(ctx, query, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||||
|  | 	c.got = append(c.got, query) | ||||||
|  | 	return c.db.QueryRowContext(ctx, query, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestConnPoolWrapper(t *testing.T) { | ||||||
|  | 	dialect := os.Getenv("GORM_DIALECT") | ||||||
|  | 	if dialect != "mysql" { | ||||||
|  | 		t.SkipNow() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	dbDSN := os.Getenv("GORM_DSN") | ||||||
|  | 	if dbDSN == "" { | ||||||
|  | 		dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" | ||||||
|  | 	} | ||||||
|  | 	nativeDB, err := sql.Open("mysql", dbDSN) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Should open db success, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	conn := &wrapperConnPool{ | ||||||
|  | 		db: nativeDB, | ||||||
|  | 		expect: []string{ | ||||||
|  | 			"SELECT VERSION()", | ||||||
|  | 			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", | ||||||
|  | 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||||
|  | 			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", | ||||||
|  | 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||||
|  | 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||||
|  | 			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", | ||||||
|  | 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||||
|  | 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	defer func() { | ||||||
|  | 		if !reflect.DeepEqual(conn.got, conn.expect) { | ||||||
|  | 			t.Errorf("expect %#v but got %#v", conn.expect, conn.got) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ | ||||||
|  | 		SlowThreshold:             200 * time.Millisecond, | ||||||
|  | 		LogLevel:                  logger.Info, | ||||||
|  | 		IgnoreRecordNotFoundError: false, | ||||||
|  | 		Colorful:                  true, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Should open db success, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tx := db.Begin() | ||||||
|  | 	user := *GetUser("transaction", Config{}) | ||||||
|  | 
 | ||||||
|  | 	if err = tx.Save(&user).Error; err != nil { | ||||||
|  | 		t.Fatalf("No error should raise, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { | ||||||
|  | 		t.Fatalf("Should find saved record, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user1 := *GetUser("transaction1-1", Config{}) | ||||||
|  | 
 | ||||||
|  | 	if err = tx.Save(&user1).Error; err != nil { | ||||||
|  | 		t.Fatalf("No error should raise, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { | ||||||
|  | 		t.Fatalf("Should find saved record, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { | ||||||
|  | 		t.Fatalf("Should return the underlying sql.Tx") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tx.Rollback() | ||||||
|  | 
 | ||||||
|  | 	if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { | ||||||
|  | 		t.Fatalf("Should not find record after rollback, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	txDB := db.Where("fake_name = ?", "fake_name") | ||||||
|  | 	tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() | ||||||
|  | 	user2 := *GetUser("transaction-2", Config{}) | ||||||
|  | 	if err = tx2.Save(&user2).Error; err != nil { | ||||||
|  | 		t.Fatalf("No error should raise, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { | ||||||
|  | 		t.Fatalf("Should find saved record, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tx2.Commit() | ||||||
|  | 
 | ||||||
|  | 	if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { | ||||||
|  | 		t.Fatalf("Should be able to find committed record, but got %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 lianghuan
						lianghuan