Merge 33558ed56e9d6b9d2ea808bc46c0b77e7952d813 into 46bce170cae701615e2b2f8b2448b54524be9648

This commit is contained in:
sunfuze 2022-07-07 10:34:56 +08:00 committed by GitHub
commit d1874b0c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 11 deletions

View File

@ -79,19 +79,19 @@ func (cs *callbacks) Transaction() *processor {
} }
func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
var err error // call scopes
for len(tx.Statement.scopes) > 0 {
switch beginner := tx.Statement.ConnPool.(type) { scopes := tx.Statement.scopes
case TxBeginner: tx.Statement.scopes = nil
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) for _, scope := range scopes {
case ConnPoolBeginner: tx = scope(tx)
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) }
default:
err = ErrInvalidTransaction
} }
if err != nil { tx.InstanceSet("gorm:transaction_options", opt)
_ = tx.AddError(err)
for _, f := range p.fns {
f(tx)
} }
return tx return tx

View File

@ -80,4 +80,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
rawCallback := db.Callback().Raw() rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec) rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses rawCallback.Clauses = config.QueryClauses
transactionCallback := db.Callback().Transaction()
_ = transactionCallback.Register("gorm:begin", Begin)
} }

View File

@ -1,6 +1,8 @@
package callbacks package callbacks
import ( import (
"database/sql"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -30,3 +32,33 @@ func CommitOrRollbackTransaction(db *gorm.DB) {
} }
} }
} }
func Begin(tx *gorm.DB) {
err := tx.Error
if err != nil {
return
}
var opt *sql.TxOptions
if v, ok := tx.InstanceGet("gorm:transaction_options"); ok {
if txOpts, ok := v.(*sql.TxOptions); ok {
opt = txOpts
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case gorm.TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case gorm.ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = gorm.ErrInvalidTransaction
}
if err != nil {
_ = tx.AddError(err)
}
}