use callbacks.Begin as a hook to handle begin operation

make transaction support hooks
This commit is contained in:
Joe 2022-07-06 16:43:59 +08:00
parent 46bce170ca
commit 3b79a192cd
3 changed files with 38 additions and 12 deletions

View File

@ -79,19 +79,10 @@ func (cs *callbacks) Transaction() *processor {
} }
func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
var err error tx.InstanceSet("gorm:transaction_options", opt)
switch beginner := tx.Statement.ConnPool.(type) { for _, f := range p.fns {
case TxBeginner: f(tx)
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}
if err != nil {
_ = tx.AddError(err)
} }
return tx return tx

View File

@ -80,4 +80,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
rawCallback := db.Callback().Raw() rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec) rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses rawCallback.Clauses = config.QueryClauses
transactionCallback := db.Callback().Transaction()
transactionCallback.Register("gorm:begin", Begin)
} }

View File

@ -1,6 +1,8 @@
package callbacks package callbacks
import ( import (
"database/sql"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -30,3 +32,33 @@ func CommitOrRollbackTransaction(db *gorm.DB) {
} }
} }
} }
func Begin(tx *gorm.DB) {
err := tx.Error
if err != nil {
return
}
var opt *sql.TxOptions
if v, ok := tx.InstanceGet("gorm:transaction_options"); ok {
if txOpts, ok := v.(*sql.TxOptions); ok {
opt = txOpts
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case gorm.TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case gorm.ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = gorm.ErrInvalidTransaction
}
if err != nil {
_ = tx.AddError(err)
}
}