Add TxConnPoolBeginner and Tx interface
This commit is contained in:
		
							parent
							
								
									e2e802b837
								
							
						
					
					
						commit
						996b96e812
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -3,3 +3,4 @@ documents | ||||
| coverage.txt | ||||
| _book | ||||
| .idea | ||||
| vendor | ||||
| @ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | ||||
| func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | ||||
| @ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else { | ||||
| 		err = ErrInvalidTransaction | ||||
| 	} | ||||
|  | ||||
| @ -50,12 +50,25 @@ type ConnPoolBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | ||||
| } | ||||
| 
 | ||||
| // TxConnPoolBeginner tx conn pool beginner
 | ||||
| type TxConnPoolBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) | ||||
| } | ||||
| 
 | ||||
| // TxCommitter tx committer
 | ||||
| type TxCommitter interface { | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
| } | ||||
| 
 | ||||
| // Tx sql.Tx interface
 | ||||
| type Tx interface { | ||||
| 	ConnPool | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
| 	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | ||||
| } | ||||
| 
 | ||||
| // Valuer gorm valuer interface
 | ||||
| type Valuer interface { | ||||
| 	GormValue(context.Context, *DB) clause.Expr | ||||
|  | ||||
| @ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn | ||||
| 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | ||||
| 		tx, err := beginner.BeginTx(ctx, opt) | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||
| 	} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { | ||||
| 		tx, err := beginner.BeginTx(ctx, opt) | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||
| 	} | ||||
| 	return nil, ErrInvalidTransaction | ||||
| } | ||||
| @ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg | ||||
| } | ||||
| 
 | ||||
| type PreparedStmtTX struct { | ||||
| 	*sql.Tx | ||||
| 	Tx | ||||
| 	PreparedStmtDB *PreparedStmtDB | ||||
| } | ||||
| 
 | ||||
| @ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||
| 	if err == nil { | ||||
| 		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) | ||||
| 		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
|  | ||||
							
								
								
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,181 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/driver/mysql" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| type wrapperTx struct { | ||||
| 	*sql.Tx | ||||
| 	conn *wrapperConnPool | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||
| 	c.conn.got = append(c.conn.got, query) | ||||
| 	return c.Tx.PrepareContext(ctx, query) | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	c.conn.got = append(c.conn.got, query) | ||||
| 	return c.Tx.ExecContext(ctx, query, args...) | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	c.conn.got = append(c.conn.got, query) | ||||
| 	return c.Tx.QueryContext(ctx, query, args...) | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	c.conn.got = append(c.conn.got, query) | ||||
| 	return c.Tx.QueryRowContext(ctx, query, args...) | ||||
| } | ||||
| 
 | ||||
| type wrapperConnPool struct { | ||||
| 	db     *sql.DB | ||||
| 	got    []string | ||||
| 	expect []string | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperConnPool) Ping() error { | ||||
| 	return c.db.Ping() | ||||
| } | ||||
| 
 | ||||
| // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
 | ||||
| // func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
 | ||||
| //	 return c.db.BeginTx(ctx, opts)
 | ||||
| // }
 | ||||
| // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | ||||
| func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { | ||||
| 	tx, err := c.db.BeginTx(ctx, opts) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &wrapperTx{Tx: tx, conn: c}, nil | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||
| 	c.got = append(c.got, query) | ||||
| 	return c.db.PrepareContext(ctx, query) | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	c.got = append(c.got, query) | ||||
| 	return c.db.ExecContext(ctx, query, args...) | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	c.got = append(c.got, query) | ||||
| 	return c.db.QueryContext(ctx, query, args...) | ||||
| } | ||||
| 
 | ||||
| func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	c.got = append(c.got, query) | ||||
| 	return c.db.QueryRowContext(ctx, query, args...) | ||||
| } | ||||
| 
 | ||||
| func TestConnPoolWrapper(t *testing.T) { | ||||
| 	dialect := os.Getenv("GORM_DIALECT") | ||||
| 	if dialect != "mysql" { | ||||
| 		t.SkipNow() | ||||
| 	} | ||||
| 
 | ||||
| 	dbDSN := os.Getenv("GORM_DSN") | ||||
| 	if dbDSN == "" { | ||||
| 		dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" | ||||
| 	} | ||||
| 	nativeDB, err := sql.Open("mysql", dbDSN) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Should open db success, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	conn := &wrapperConnPool{ | ||||
| 		db: nativeDB, | ||||
| 		expect: []string{ | ||||
| 			"SELECT VERSION()", | ||||
| 			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", | ||||
| 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||
| 			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", | ||||
| 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||
| 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||
| 			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", | ||||
| 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||
| 			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	defer func() { | ||||
| 		if !reflect.DeepEqual(conn.got, conn.expect) { | ||||
| 			t.Errorf("expect %#v but got %#v", conn.expect, conn.got) | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ | ||||
| 		SlowThreshold:             200 * time.Millisecond, | ||||
| 		LogLevel:                  logger.Info, | ||||
| 		IgnoreRecordNotFoundError: false, | ||||
| 		Colorful:                  true, | ||||
| 	}) | ||||
| 
 | ||||
| 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Should open db success, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	tx := db.Begin() | ||||
| 	user := *GetUser("transaction", Config{}) | ||||
| 
 | ||||
| 	if err = tx.Save(&user).Error; err != nil { | ||||
| 		t.Fatalf("No error should raise, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user1 := *GetUser("transaction1-1", Config{}) | ||||
| 
 | ||||
| 	if err = tx.Save(&user1).Error; err != nil { | ||||
| 		t.Fatalf("No error should raise, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { | ||||
| 		t.Fatalf("Should return the underlying sql.Tx") | ||||
| 	} | ||||
| 
 | ||||
| 	tx.Rollback() | ||||
| 
 | ||||
| 	if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { | ||||
| 		t.Fatalf("Should not find record after rollback, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	txDB := db.Where("fake_name = ?", "fake_name") | ||||
| 	tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() | ||||
| 	user2 := *GetUser("transaction-2", Config{}) | ||||
| 	if err = tx2.Save(&user2).Error; err != nil { | ||||
| 		t.Fatalf("No error should raise, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	tx2.Commit() | ||||
| 
 | ||||
| 	if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { | ||||
| 		t.Fatalf("Should be able to find committed record, but got %v", err) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 lianghuan
						lianghuan