From 3b79a192cd2e63326f091e9f6b924fa4d51e01fc Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 6 Jul 2022 16:43:59 +0800 Subject: [PATCH] use callbacks.Begin as a hook to handle begin operation make transaction support hooks --- callbacks.go | 15 +++------------ callbacks/callbacks.go | 3 +++ callbacks/transaction.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/callbacks.go b/callbacks.go index f835e504..d027ed34 100644 --- a/callbacks.go +++ b/callbacks.go @@ -79,19 +79,10 @@ func (cs *callbacks) Transaction() *processor { } func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { - var err error + tx.InstanceSet("gorm:transaction_options", opt) - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction - } - - if err != nil { - _ = tx.AddError(err) + for _, f := range p.fns { + f(tx) } return tx diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..04b01d7a 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -80,4 +80,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { rawCallback := db.Callback().Raw() rawCallback.Register("gorm:raw", RawExec) rawCallback.Clauses = config.QueryClauses + + transactionCallback := db.Callback().Transaction() + transactionCallback.Register("gorm:begin", Begin) } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 50887ccc..3fcd0391 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -1,6 +1,8 @@ package callbacks import ( + "database/sql" + "gorm.io/gorm" ) @@ -30,3 +32,33 @@ func CommitOrRollbackTransaction(db *gorm.DB) { } } } + +func Begin(tx *gorm.DB) { + err := tx.Error + + if err != nil { + return + } + + var opt *sql.TxOptions + + if v, ok := tx.InstanceGet("gorm:transaction_options"); ok { + if txOpts, ok := v.(*sql.TxOptions); ok { + opt = txOpts + } + } + + switch beginner := tx.Statement.ConnPool.(type) { + case gorm.TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case gorm.ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = gorm.ErrInvalidTransaction + } + + if err != nil { + _ = tx.AddError(err) + } + +}