Refactor Tx interface
This commit is contained in:
		
							parent
							
								
									996b96e812
								
							
						
					
					
						commit
						4e523499d1
					
				| @ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | ||||
| 		opt = opts[0] | ||||
| 	} | ||||
| 
 | ||||
| 	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { | ||||
| 	switch beginner := tx.Statement.ConnPool.(type) { | ||||
| 	case TxBeginner: | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { | ||||
| 	case ConnPoolBeginner: | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else { | ||||
| 	default: | ||||
| 		err = ErrInvalidTransaction | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -50,11 +50,6 @@ type ConnPoolBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | ||||
| } | ||||
| 
 | ||||
| // TxConnPoolBeginner tx conn pool beginner
 | ||||
| type TxConnPoolBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) | ||||
| } | ||||
| 
 | ||||
| // TxCommitter tx committer
 | ||||
| type TxCommitter interface { | ||||
| 	Commit() error | ||||
| @ -64,8 +59,7 @@ type TxCommitter interface { | ||||
| // Tx sql.Tx interface
 | ||||
| type Tx interface { | ||||
| 	ConnPool | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
| 	TxCommitter | ||||
| 	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn | ||||
| 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | ||||
| 		tx, err := beginner.BeginTx(ctx, opt) | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||
| 	} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { | ||||
| 		tx, err := beginner.BeginTx(ctx, opt) | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||
| 	} | ||||
| 	return nil, ErrInvalidTransaction | ||||
| } | ||||
|  | ||||
| @ -3,15 +3,12 @@ package tests_test | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/driver/mysql" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| @ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error { | ||||
| //	 return c.db.BeginTx(ctx, opts)
 | ||||
| // }
 | ||||
| // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | ||||
| func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { | ||||
| func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { | ||||
| 	tx, err := c.db.BeginTx(ctx, opts) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ | ||||
| 		SlowThreshold:             200 * time.Millisecond, | ||||
| 		LogLevel:                  logger.Info, | ||||
| 		IgnoreRecordNotFoundError: false, | ||||
| 		Colorful:                  true, | ||||
| 	}) | ||||
| 
 | ||||
| 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) | ||||
| 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Should open db success, but got %v", err) | ||||
| 	} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu