diff --git a/callbacks.go b/callbacks.go index f835e504..47145ed9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -79,19 +79,19 @@ func (cs *callbacks) Transaction() *processor { } func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { - var err error - - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction + // call scopes + for len(tx.Statement.scopes) > 0 { + scopes := tx.Statement.scopes + tx.Statement.scopes = nil + for _, scope := range scopes { + tx = scope(tx) + } } - if err != nil { - _ = tx.AddError(err) + tx.InstanceSet("gorm:transaction_options", opt) + + for _, f := range p.fns { + f(tx) } return tx diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..6a7ca1c3 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -80,4 +80,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { rawCallback := db.Callback().Raw() rawCallback.Register("gorm:raw", RawExec) rawCallback.Clauses = config.QueryClauses + + transactionCallback := db.Callback().Transaction() + _ = transactionCallback.Register("gorm:begin", Begin) } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 50887ccc..3fcd0391 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -1,6 +1,8 @@ package callbacks import ( + "database/sql" + "gorm.io/gorm" ) @@ -30,3 +32,33 @@ func CommitOrRollbackTransaction(db *gorm.DB) { } } } + +func Begin(tx *gorm.DB) { + err := tx.Error + + if err != nil { + return + } + + var opt *sql.TxOptions + + if v, ok := tx.InstanceGet("gorm:transaction_options"); ok { + if txOpts, ok := v.(*sql.TxOptions); ok { + opt = txOpts + } + } + + switch beginner := tx.Statement.ConnPool.(type) { + case gorm.TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case gorm.ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = gorm.ErrInvalidTransaction + } + + if err != nil { + _ = tx.AddError(err) + } + +}