Add TxConnPoolBeginner and Tx interface
This commit is contained in:
		
							parent
							
								
									43a72b369e
								
							
						
					
					
						commit
						649061adea
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -3,3 +3,4 @@ documents
 | 
			
		||||
coverage.txt
 | 
			
		||||
_book
 | 
			
		||||
.idea
 | 
			
		||||
vendor
 | 
			
		||||
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | 
			
		||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
			
		||||
	queryTx := db.Limit(1).Order(clause.OrderByColumn{
 | 
			
		||||
@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
 | 
			
		||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
			
		||||
	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
 | 
			
		||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
			
		||||
	} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok {
 | 
			
		||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = ErrInvalidTransaction
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -50,12 +50,25 @@ type ConnPoolBeginner interface {
 | 
			
		||||
	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TxConnPoolBeginner tx conn pool beginner
 | 
			
		||||
type TxConnPoolBeginner interface {
 | 
			
		||||
	BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TxCommitter tx committer
 | 
			
		||||
type TxCommitter interface {
 | 
			
		||||
	Commit() error
 | 
			
		||||
	Rollback() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tx sql.Tx interface
 | 
			
		||||
type Tx interface {
 | 
			
		||||
	ConnPool
 | 
			
		||||
	Commit() error
 | 
			
		||||
	Rollback() error
 | 
			
		||||
	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Valuer gorm valuer interface
 | 
			
		||||
type Valuer interface {
 | 
			
		||||
	GormValue(context.Context, *DB) clause.Expr
 | 
			
		||||
 | 
			
		||||
@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
 | 
			
		||||
	if beginner, ok := db.ConnPool.(TxBeginner); ok {
 | 
			
		||||
		tx, err := beginner.BeginTx(ctx, opt)
 | 
			
		||||
		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
			
		||||
	} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok {
 | 
			
		||||
		tx, err := beginner.BeginTx(ctx, opt)
 | 
			
		||||
		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
			
		||||
	}
 | 
			
		||||
	return nil, ErrInvalidTransaction
 | 
			
		||||
}
 | 
			
		||||
@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PreparedStmtTX struct {
 | 
			
		||||
	*sql.Tx
 | 
			
		||||
	Tx
 | 
			
		||||
	PreparedStmtDB *PreparedStmtDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
			
		||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
			
		||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
 | 
			
		||||
		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
			
		||||
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,181 @@
 | 
			
		||||
package tests_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/driver/mysql"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	. "gorm.io/gorm/utils/tests"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type wrapperTx struct {
 | 
			
		||||
	*sql.Tx
 | 
			
		||||
	conn *wrapperConnPool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.PrepareContext(ctx, query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.ExecContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.QueryContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.QueryRowContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type wrapperConnPool struct {
 | 
			
		||||
	db     *sql.DB
 | 
			
		||||
	got    []string
 | 
			
		||||
	expect []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) Ping() error {
 | 
			
		||||
	return c.db.Ping()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
 | 
			
		||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
 | 
			
		||||
//	 return c.db.BeginTx(ctx, opts)
 | 
			
		||||
// }
 | 
			
		||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | 
			
		||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) {
 | 
			
		||||
	tx, err := c.db.BeginTx(ctx, opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &wrapperTx{Tx: tx, conn: c}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.PrepareContext(ctx, query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.ExecContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.QueryContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.QueryRowContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConnPoolWrapper(t *testing.T) {
 | 
			
		||||
	dialect := os.Getenv("GORM_DIALECT")
 | 
			
		||||
	if dialect != "mysql" {
 | 
			
		||||
		t.SkipNow()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbDSN := os.Getenv("GORM_DSN")
 | 
			
		||||
	if dbDSN == "" {
 | 
			
		||||
		dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
 | 
			
		||||
	}
 | 
			
		||||
	nativeDB, err := sql.Open("mysql", dbDSN)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Should open db success, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn := &wrapperConnPool{
 | 
			
		||||
		db: nativeDB,
 | 
			
		||||
		expect: []string{
 | 
			
		||||
			"SELECT VERSION()",
 | 
			
		||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if !reflect.DeepEqual(conn.got, conn.expect) {
 | 
			
		||||
			t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
 | 
			
		||||
		SlowThreshold:             200 * time.Millisecond,
 | 
			
		||||
		LogLevel:                  logger.Info,
 | 
			
		||||
		IgnoreRecordNotFoundError: false,
 | 
			
		||||
		Colorful:                  true,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Should open db success, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := db.Begin()
 | 
			
		||||
	user := *GetUser("transaction", Config{})
 | 
			
		||||
 | 
			
		||||
	if err = tx.Save(&user).Error; err != nil {
 | 
			
		||||
		t.Fatalf("No error should raise, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should find saved record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user1 := *GetUser("transaction1-1", Config{})
 | 
			
		||||
 | 
			
		||||
	if err = tx.Save(&user1).Error; err != nil {
 | 
			
		||||
		t.Fatalf("No error should raise, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should find saved record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
 | 
			
		||||
		t.Fatalf("Should return the underlying sql.Tx")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
 | 
			
		||||
		t.Fatalf("Should not find record after rollback, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	txDB := db.Where("fake_name = ?", "fake_name")
 | 
			
		||||
	tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
 | 
			
		||||
	user2 := *GetUser("transaction-2", Config{})
 | 
			
		||||
	if err = tx2.Save(&user2).Error; err != nil {
 | 
			
		||||
		t.Fatalf("No error should raise, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should find saved record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx2.Commit()
 | 
			
		||||
 | 
			
		||||
	if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should be able to find committed record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -8,7 +8,7 @@ require (
 | 
			
		||||
	github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
			
		||||
	github.com/jinzhu/now v1.1.4
 | 
			
		||||
	github.com/lib/pq v1.10.4
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.12 // indirect
 | 
			
		||||
	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
			
		||||
	gorm.io/driver/mysql v1.3.2
 | 
			
		||||
	gorm.io/driver/postgres v1.3.1
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user