feat(callback): add callback support for begin, commit, rollback

This commit is contained in:
Loona 2023-07-07 11:15:35 +07:00
parent 2066138684
commit 01b00f4cbd
2 changed files with 25 additions and 6 deletions

View File

@ -15,12 +15,15 @@ import (
func initializeCallbacks(db *DB) *callbacks { func initializeCallbacks(db *DB) *callbacks {
return &callbacks{ return &callbacks{
processors: map[string]*processor{ processors: map[string]*processor{
"create": {db: db}, "create": {db: db},
"query": {db: db}, "query": {db: db},
"update": {db: db}, "update": {db: db},
"delete": {db: db}, "delete": {db: db},
"row": {db: db}, "row": {db: db},
"raw": {db: db}, "raw": {db: db},
"begin": {db: db},
"rollback": {db: db},
"commit": {db: db},
}, },
} }
} }
@ -48,6 +51,18 @@ type callback struct {
processor *processor processor *processor
} }
func (cs *callbacks) Begin() *processor {
return cs.processors["begin"]
}
func (cs *callbacks) Rollback() *processor {
return cs.processors["rollback"]
}
func (cs *callbacks) Commit() *processor {
return cs.processors["commit"]
}
func (cs *callbacks) Create() *processor { func (cs *callbacks) Create() *processor {
return cs.processors["create"] return cs.processors["create"]
} }

View File

@ -670,8 +670,10 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
switch beginner := tx.Statement.ConnPool.(type) { switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner: case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
tx.callbacks.Begin().Execute(tx)
case ConnPoolBeginner: case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
tx.callbacks.Begin().Execute(tx)
default: default:
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }
@ -687,6 +689,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit()) db.AddError(committer.Commit())
db.callbacks.Commit().Execute(db)
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }
@ -698,6 +701,7 @@ func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() { if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback()) db.AddError(committer.Rollback())
db.callbacks.Rollback().Execute(db)
} }
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)