Add TxConnPoolBeginner and Tx interface
This commit is contained in:
		
							parent
							
								
									43a72b369e
								
							
						
					
					
						commit
						649061adea
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -3,3 +3,4 @@ documents
 | 
				
			|||||||
coverage.txt
 | 
					coverage.txt
 | 
				
			||||||
_book
 | 
					_book
 | 
				
			||||||
.idea
 | 
					.idea
 | 
				
			||||||
 | 
					vendor
 | 
				
			||||||
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | 
					// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | 
				
			||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
					func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			||||||
	queryTx := db.Limit(1).Order(clause.OrderByColumn{
 | 
						queryTx := db.Limit(1).Order(clause.OrderByColumn{
 | 
				
			||||||
@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
 | 
				
			|||||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
 | 
						} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
 | 
				
			||||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
 | 
						} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok {
 | 
				
			||||||
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		err = ErrInvalidTransaction
 | 
							err = ErrInvalidTransaction
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -50,12 +50,25 @@ type ConnPoolBeginner interface {
 | 
				
			|||||||
	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
 | 
						BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TxConnPoolBeginner tx conn pool beginner
 | 
				
			||||||
 | 
					type TxConnPoolBeginner interface {
 | 
				
			||||||
 | 
						BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TxCommitter tx committer
 | 
					// TxCommitter tx committer
 | 
				
			||||||
type TxCommitter interface {
 | 
					type TxCommitter interface {
 | 
				
			||||||
	Commit() error
 | 
						Commit() error
 | 
				
			||||||
	Rollback() error
 | 
						Rollback() error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tx sql.Tx interface
 | 
				
			||||||
 | 
					type Tx interface {
 | 
				
			||||||
 | 
						ConnPool
 | 
				
			||||||
 | 
						Commit() error
 | 
				
			||||||
 | 
						Rollback() error
 | 
				
			||||||
 | 
						StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Valuer gorm valuer interface
 | 
					// Valuer gorm valuer interface
 | 
				
			||||||
type Valuer interface {
 | 
					type Valuer interface {
 | 
				
			||||||
	GormValue(context.Context, *DB) clause.Expr
 | 
						GormValue(context.Context, *DB) clause.Expr
 | 
				
			||||||
 | 
				
			|||||||
@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
 | 
				
			|||||||
	if beginner, ok := db.ConnPool.(TxBeginner); ok {
 | 
						if beginner, ok := db.ConnPool.(TxBeginner); ok {
 | 
				
			||||||
		tx, err := beginner.BeginTx(ctx, opt)
 | 
							tx, err := beginner.BeginTx(ctx, opt)
 | 
				
			||||||
		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
							return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
				
			||||||
 | 
						} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok {
 | 
				
			||||||
 | 
							tx, err := beginner.BeginTx(ctx, opt)
 | 
				
			||||||
 | 
							return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, ErrInvalidTransaction
 | 
						return nil, ErrInvalidTransaction
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type PreparedStmtTX struct {
 | 
					type PreparedStmtTX struct {
 | 
				
			||||||
	*sql.Tx
 | 
						Tx
 | 
				
			||||||
	PreparedStmtDB *PreparedStmtDB
 | 
						PreparedStmtDB *PreparedStmtDB
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
				
			|||||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
					func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
				
			||||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
 | 
							rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
								tx.PreparedStmtDB.Mux.Lock()
 | 
				
			||||||
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
								defer tx.PreparedStmtDB.Mux.Unlock()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,181 @@
 | 
				
			|||||||
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"gorm.io/driver/mysql"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						"gorm.io/gorm/logger"
 | 
				
			||||||
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type wrapperTx struct {
 | 
				
			||||||
 | 
						*sql.Tx
 | 
				
			||||||
 | 
						conn *wrapperConnPool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.PrepareContext(ctx, query)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.ExecContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.QueryContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.QueryRowContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type wrapperConnPool struct {
 | 
				
			||||||
 | 
						db     *sql.DB
 | 
				
			||||||
 | 
						got    []string
 | 
				
			||||||
 | 
						expect []string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) Ping() error {
 | 
				
			||||||
 | 
						return c.db.Ping()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
 | 
				
			||||||
 | 
					// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
 | 
				
			||||||
 | 
					//	 return c.db.BeginTx(ctx, opts)
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) {
 | 
				
			||||||
 | 
						tx, err := c.db.BeginTx(ctx, opts)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &wrapperTx{Tx: tx, conn: c}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.PrepareContext(ctx, query)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.ExecContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.QueryContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.QueryRowContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestConnPoolWrapper(t *testing.T) {
 | 
				
			||||||
 | 
						dialect := os.Getenv("GORM_DIALECT")
 | 
				
			||||||
 | 
						if dialect != "mysql" {
 | 
				
			||||||
 | 
							t.SkipNow()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dbDSN := os.Getenv("GORM_DSN")
 | 
				
			||||||
 | 
						if dbDSN == "" {
 | 
				
			||||||
 | 
							dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						nativeDB, err := sql.Open("mysql", dbDSN)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should open db success, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn := &wrapperConnPool{
 | 
				
			||||||
 | 
							db: nativeDB,
 | 
				
			||||||
 | 
							expect: []string{
 | 
				
			||||||
 | 
								"SELECT VERSION()",
 | 
				
			||||||
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							if !reflect.DeepEqual(conn.got, conn.expect) {
 | 
				
			||||||
 | 
								t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
 | 
				
			||||||
 | 
							SlowThreshold:             200 * time.Millisecond,
 | 
				
			||||||
 | 
							LogLevel:                  logger.Info,
 | 
				
			||||||
 | 
							IgnoreRecordNotFoundError: false,
 | 
				
			||||||
 | 
							Colorful:                  true,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should open db success, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx := db.Begin()
 | 
				
			||||||
 | 
						user := *GetUser("transaction", Config{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.Save(&user).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("No error should raise, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should find saved record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						user1 := *GetUser("transaction1-1", Config{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.Save(&user1).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("No error should raise, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should find saved record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should return the underlying sql.Tx")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx.Rollback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should not find record after rollback, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						txDB := db.Where("fake_name = ?", "fake_name")
 | 
				
			||||||
 | 
						tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
 | 
				
			||||||
 | 
						user2 := *GetUser("transaction-2", Config{})
 | 
				
			||||||
 | 
						if err = tx2.Save(&user2).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("No error should raise, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should find saved record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx2.Commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should be able to find committed record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -8,7 +8,7 @@ require (
 | 
				
			|||||||
	github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
						github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
				
			||||||
	github.com/jinzhu/now v1.1.4
 | 
						github.com/jinzhu/now v1.1.4
 | 
				
			||||||
	github.com/lib/pq v1.10.4
 | 
						github.com/lib/pq v1.10.4
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
						github.com/mattn/go-sqlite3 v1.14.12 // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
						golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
				
			||||||
	gorm.io/driver/mysql v1.3.2
 | 
						gorm.io/driver/mysql v1.3.2
 | 
				
			||||||
	gorm.io/driver/postgres v1.3.1
 | 
						gorm.io/driver/postgres v1.3.1
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user