Refactor Tx interface
This commit is contained in:
		
							parent
							
								
									996b96e812
								
							
						
					
					
						commit
						4e523499d1
					
				| @ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | |||||||
| 		opt = opts[0] | 		opt = opts[0] | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { | 	switch beginner := tx.Statement.ConnPool.(type) { | ||||||
|  | 	case TxBeginner: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { | 	case ConnPoolBeginner: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { | 	default: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) |  | ||||||
| 	} else { |  | ||||||
| 		err = ErrInvalidTransaction | 		err = ErrInvalidTransaction | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -50,11 +50,6 @@ type ConnPoolBeginner interface { | |||||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // TxConnPoolBeginner tx conn pool beginner
 |  | ||||||
| type TxConnPoolBeginner interface { |  | ||||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // TxCommitter tx committer
 | // TxCommitter tx committer
 | ||||||
| type TxCommitter interface { | type TxCommitter interface { | ||||||
| 	Commit() error | 	Commit() error | ||||||
| @ -64,8 +59,7 @@ type TxCommitter interface { | |||||||
| // Tx sql.Tx interface
 | // Tx sql.Tx interface
 | ||||||
| type Tx interface { | type Tx interface { | ||||||
| 	ConnPool | 	ConnPool | ||||||
| 	Commit() error | 	TxCommitter | ||||||
| 	Rollback() error |  | ||||||
| 	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | 	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn | |||||||
| 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | ||||||
| 		tx, err := beginner.BeginTx(ctx, opt) | 		tx, err := beginner.BeginTx(ctx, opt) | ||||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||||
| 	} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { |  | ||||||
| 		tx, err := beginner.BeginTx(ctx, opt) |  | ||||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err |  | ||||||
| 	} | 	} | ||||||
| 	return nil, ErrInvalidTransaction | 	return nil, ErrInvalidTransaction | ||||||
| } | } | ||||||
|  | |||||||
| @ -3,15 +3,12 @@ package tests_test | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"log" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" |  | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/driver/mysql" | 	"gorm.io/driver/mysql" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/logger" |  | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error { | |||||||
| //	 return c.db.BeginTx(ctx, opts)
 | //	 return c.db.BeginTx(ctx, opts)
 | ||||||
| // }
 | // }
 | ||||||
| // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | ||||||
| func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { | func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { | ||||||
| 	tx, err := c.db.BeginTx(ctx, opts) | 	tx, err := c.db.BeginTx(ctx, opts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ | 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) | ||||||
| 		SlowThreshold:             200 * time.Millisecond, |  | ||||||
| 		LogLevel:                  logger.Info, |  | ||||||
| 		IgnoreRecordNotFoundError: false, |  | ||||||
| 		Colorful:                  true, |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Should open db success, but got %v", err) | 		t.Fatalf("Should open db success, but got %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu