diff --git a/callbacks.go b/callbacks.go index 195d1720..aab3d26a 100644 --- a/callbacks.go +++ b/callbacks.go @@ -15,12 +15,15 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": {db: db}, - "query": {db: db}, - "update": {db: db}, - "delete": {db: db}, - "row": {db: db}, - "raw": {db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + "begin": {db: db}, + "rollback": {db: db}, + "commit": {db: db}, }, } } @@ -48,6 +51,18 @@ type callback struct { processor *processor } +func (cs *callbacks) Begin() *processor { + return cs.processors["begin"] +} + +func (cs *callbacks) Rollback() *processor { + return cs.processors["rollback"] +} + +func (cs *callbacks) Commit() *processor { + return cs.processors["commit"] +} + func (cs *callbacks) Create() *processor { return cs.processors["create"] } diff --git a/finisher_api.go b/finisher_api.go index f80aa6c0..4450c36b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -670,8 +670,10 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { switch beginner := tx.Statement.ConnPool.(type) { case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.callbacks.Begin().Execute(tx) case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.callbacks.Begin().Execute(tx) default: err = ErrInvalidTransaction } @@ -687,6 +689,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) + db.callbacks.Commit().Execute(db) } else { db.AddError(ErrInvalidTransaction) } @@ -698,6 +701,7 @@ func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Rollback()) + db.callbacks.Rollback().Execute(db) } } else { db.AddError(ErrInvalidTransaction)