Add PrepareStmt support
This commit is contained in:
		
							parent
							
								
									9934207c42
								
							
						
					
					
						commit
						c8e7878b3e
					
				| @ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er | ||||
| } | ||||
| 
 | ||||
| // Begin begins a transaction
 | ||||
| func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { | ||||
| 		var opt *sql.TxOptions | ||||
| 		var err error | ||||
| 		if len(opts) > 0 { | ||||
| 			opt = opts[0] | ||||
| 		} | ||||
| func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | ||||
| 	var ( | ||||
| 		tx  = db.getInstance() | ||||
| 		opt *sql.TxOptions | ||||
| 		err error | ||||
| 	) | ||||
| 
 | ||||
| 		if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { | ||||
| 			tx.AddError(err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		tx.AddError(ErrInvalidTransaction) | ||||
| 	if len(opts) > 0 { | ||||
| 		opt = opts[0] | ||||
| 	} | ||||
| 	return | ||||
| 
 | ||||
| 	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 	} else { | ||||
| 		err = ErrInvalidTransaction | ||||
| 	} | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		tx.AddError(err) | ||||
| 	} | ||||
| 
 | ||||
| 	return tx | ||||
| } | ||||
| 
 | ||||
| // Commit commit a transaction
 | ||||
| func (db *DB) Commit() *DB { | ||||
| 	if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { | ||||
| 		db.AddError(comminter.Commit()) | ||||
| 	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { | ||||
| 		db.AddError(committer.Commit()) | ||||
| 	} else { | ||||
| 		db.AddError(ErrInvalidTransaction) | ||||
| 	} | ||||
| @ -340,8 +348,8 @@ func (db *DB) Commit() *DB { | ||||
| 
 | ||||
| // Rollback rollback a transaction
 | ||||
| func (db *DB) Rollback() *DB { | ||||
| 	if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { | ||||
| 		db.AddError(comminter.Rollback()) | ||||
| 	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { | ||||
| 		db.AddError(committer.Rollback()) | ||||
| 	} else { | ||||
| 		db.AddError(ErrInvalidTransaction) | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										49
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								gorm.go
									
									
									
									
									
								
							| @ -2,6 +2,7 @@ package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @ -25,6 +26,9 @@ type Config struct { | ||||
| 	// DryRun generate sql without execute
 | ||||
| 	DryRun bool | ||||
| 
 | ||||
| 	// PrepareStmt executes the given query in cached statement
 | ||||
| 	PrepareStmt bool | ||||
| 
 | ||||
| 	// ClauseBuilders clause builder
 | ||||
| 	ClauseBuilders map[string]clause.ClauseBuilder | ||||
| 	// ConnPool db conn pool
 | ||||
| @ -48,6 +52,7 @@ type DB struct { | ||||
| // Session session config when create session with Session() method
 | ||||
| type Session struct { | ||||
| 	DryRun         bool | ||||
| 	PrepareStmt    bool | ||||
| 	WithConditions bool | ||||
| 	Context        context.Context | ||||
| 	Logger         logger.Interface | ||||
| @ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | ||||
| 		err = dialector.Initialize(db) | ||||
| 	} | ||||
| 
 | ||||
| 	if config.PrepareStmt { | ||||
| 		db.ConnPool = &PreparedStmtDB{ | ||||
| 			ConnPool: db.ConnPool, | ||||
| 			stmts:    map[string]*sql.Stmt{}, | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if db.Statement == nil { | ||||
| 		db.Statement = &Statement{ | ||||
| 			DB:       db, | ||||
| 			ConnPool: db.ConnPool, | ||||
| 			Context:  context.Background(), | ||||
| 			Clauses:  map[string]clause.Clause{}, | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { | ||||
| 			err = pinger.Ping() | ||||
| @ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB { | ||||
| 		tx.Statement.Context = config.Context | ||||
| 	} | ||||
| 
 | ||||
| 	if config.PrepareStmt { | ||||
| 		tx.Statement.ConnPool = &PreparedStmtDB{ | ||||
| 			ConnPool: db.Config.ConnPool, | ||||
| 			stmts:    map[string]*sql.Stmt{}, | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if config.WithConditions { | ||||
| 		tx.clone = 3 | ||||
| 	} | ||||
| @ -256,6 +284,12 @@ func (db *DB) getInstance() *DB { | ||||
| 
 | ||||
| 		switch db.clone { | ||||
| 		case 1: // clone with new statement
 | ||||
| 			tx.Statement = &Statement{ | ||||
| 				DB:       tx, | ||||
| 				ConnPool: db.Statement.ConnPool, | ||||
| 				Context:  db.Statement.Context, | ||||
| 				Clauses:  map[string]clause.Clause{}, | ||||
| 			} | ||||
| 		case 2: // with old statement, generate new statement for future call, used to pass to callbacks
 | ||||
| 			db.clone = 1 | ||||
| 			tx.Statement = db.Statement | ||||
| @ -266,21 +300,6 @@ func (db *DB) getInstance() *DB { | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if tx.Statement == nil { | ||||
| 			tx.Statement = &Statement{ | ||||
| 				DB:      tx, | ||||
| 				Clauses: map[string]clause.Clause{}, | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if db.Statement != nil { | ||||
| 			tx.Statement.Context = db.Statement.Context | ||||
| 			tx.Statement.ConnPool = db.Statement.ConnPool | ||||
| 		} else { | ||||
| 			tx.Statement.Context = context.Background() | ||||
| 			tx.Statement.ConnPool = db.ConnPool | ||||
| 		} | ||||
| 
 | ||||
| 		return tx | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -21,8 +21,8 @@ type Dialector interface { | ||||
| 
 | ||||
| // ConnPool db conns pool interface
 | ||||
| type ConnPool interface { | ||||
| 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||
| 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||
| 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||
| 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||
| } | ||||
| @ -31,7 +31,11 @@ type TxBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||
| } | ||||
| 
 | ||||
| type TxCommiter interface { | ||||
| type ConnPoolBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) | ||||
| } | ||||
| 
 | ||||
| type TxCommitter interface { | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
| } | ||||
|  | ||||
							
								
								
									
										92
									
								
								prepare_stmt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								prepare_stmt.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,92 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"sync" | ||||
| ) | ||||
| 
 | ||||
| type PreparedStmtDB struct { | ||||
| 	stmts map[string]*sql.Stmt | ||||
| 	mux   sync.RWMutex | ||||
| 	ConnPool | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { | ||||
| 	db.mux.RLock() | ||||
| 	if stmt, ok := db.stmts[query]; ok { | ||||
| 		db.mux.RUnlock() | ||||
| 		return stmt, nil | ||||
| 	} | ||||
| 	db.mux.RUnlock() | ||||
| 
 | ||||
| 	db.mux.Lock() | ||||
| 	stmt, err := db.ConnPool.PrepareContext(context.Background(), query) | ||||
| 	if err == nil { | ||||
| 		db.stmts[query] = stmt | ||||
| 	} | ||||
| 	db.mux.Unlock() | ||||
| 
 | ||||
| 	return stmt, err | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { | ||||
| 	if beginner, ok := db.ConnPool.(TxBeginner); ok { | ||||
| 		tx, err := beginner.BeginTx(ctx, opt) | ||||
| 		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err | ||||
| 	} | ||||
| 	return nil, ErrInvalidTransaction | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	stmt, err := db.prepare(query) | ||||
| 	if err == nil { | ||||
| 		return stmt.ExecContext(ctx, args...) | ||||
| 	} | ||||
| 	return nil, err | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	stmt, err := db.prepare(query) | ||||
| 	if err == nil { | ||||
| 		return stmt.QueryContext(ctx, args...) | ||||
| 	} | ||||
| 	return nil, err | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	stmt, err := db.prepare(query) | ||||
| 	if err == nil { | ||||
| 		return stmt.QueryRowContext(ctx, args...) | ||||
| 	} | ||||
| 	return &sql.Row{} | ||||
| } | ||||
| 
 | ||||
| type PreparedStmtTX struct { | ||||
| 	*sql.Tx | ||||
| 	PreparedStmtDB *PreparedStmtDB | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(query) | ||||
| 	if err == nil { | ||||
| 		return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) | ||||
| 	} | ||||
| 	return nil, err | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(query) | ||||
| 	if err == nil { | ||||
| 		return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) | ||||
| 	} | ||||
| 	return nil, err | ||||
| } | ||||
| 
 | ||||
| func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(query) | ||||
| 	if err == nil { | ||||
| 		return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) | ||||
| 	} | ||||
| 	return &sql.Row{} | ||||
| } | ||||
| @ -1,7 +1,6 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 
 | ||||
| @ -21,7 +20,7 @@ func TestTransaction(t *testing.T) { | ||||
| 		t.Fatalf("Should find saved record, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { | ||||
| 	if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { | ||||
| 		t.Fatalf("Should return the underlying sql.Tx") | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu