Merge 33558ed56e9d6b9d2ea808bc46c0b77e7952d813 into 46bce170cae701615e2b2f8b2448b54524be9648

This commit is contained in:
sunfuze 2022-07-07 10:34:56 +08:00 committed by GitHub
commit d1874b0c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 11 deletions

View File

@ -79,19 +79,19 @@ func (cs *callbacks) Transaction() *processor {
}
func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
var err error
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
// call scopes
for len(tx.Statement.scopes) > 0 {
scopes := tx.Statement.scopes
tx.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
}
}
if err != nil {
_ = tx.AddError(err)
tx.InstanceSet("gorm:transaction_options", opt)
for _, f := range p.fns {
f(tx)
}
return tx

View File

@ -80,4 +80,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
transactionCallback := db.Callback().Transaction()
_ = transactionCallback.Register("gorm:begin", Begin)
}

View File

@ -1,6 +1,8 @@
package callbacks
import (
"database/sql"
"gorm.io/gorm"
)
@ -30,3 +32,33 @@ func CommitOrRollbackTransaction(db *gorm.DB) {
}
}
}
func Begin(tx *gorm.DB) {
err := tx.Error
if err != nil {
return
}
var opt *sql.TxOptions
if v, ok := tx.InstanceGet("gorm:transaction_options"); ok {
if txOpts, ok := v.(*sql.TxOptions); ok {
opt = txOpts
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case gorm.TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case gorm.ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = gorm.ErrInvalidTransaction
}
if err != nil {
_ = tx.AddError(err)
}
}