From 30a4f9bcaa41a200ee8047ce72014b61fef68d89 Mon Sep 17 00:00:00 2001 From: Gerhard Gruber Date: Thu, 29 Jun 2023 10:07:04 +0200 Subject: [PATCH] Added AfterCommit callbacks (#25) Added AfterCommit callbacks --- main.go | 95 ++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 63 insertions(+), 32 deletions(-) diff --git a/main.go b/main.go index d07b7978..b3cc0bb5 100644 --- a/main.go +++ b/main.go @@ -28,19 +28,23 @@ type DB struct { callbacks *Callback dialect Dialect singularTable bool + + afterCommitCallbacks []func(db *DB) } // Open initialize a new db connection, need to import driver first, e.g: // -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } +// import _ "github.com/go-sql-driver/mysql" +// func main() { +// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") +// } +// // GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" +// +// import _ "github.com/jinzhu/gorm/dialects/mysql" +// // import _ "github.com/jinzhu/gorm/dialects/postgres" +// // import _ "github.com/jinzhu/gorm/dialects/sqlite" +// // import _ "github.com/jinzhu/gorm/dialects/mssql" func Open(dialect string, args ...interface{}) (db *DB, err error) { if len(args) == 0 { err = errors.New("invalid database source") @@ -121,7 +125,9 @@ func (s *DB) Dialect() Dialect { } // Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) +// +// db.Callback().Create().Register("update_created_at", updateCreated) +// // Refer https://jinzhu.github.io/gorm/development.html#callbacks func (s *DB) Callback() *Callback { s.parent.callbacks = s.parent.callbacks.clone() @@ -224,9 +230,10 @@ func (s *DB) Offset(offset interface{}) *DB { } // Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +// +// db.Order("name DESC") +// db.Order("name DESC", true) // reorder +// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression func (s *DB) Order(value interface{}, reorder ...bool) *DB { return s.clone().search.Order(value, reorder...).db } @@ -253,23 +260,26 @@ func (s *DB) Having(query interface{}, values ...interface{}) *DB { } // Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (s *DB) Joins(query interface{}, args ...interface{}) *DB { return s.clone().search.Joins(query, args...).db } // Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // Refer https://jinzhu.github.io/gorm/crud.html#scopes func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { for _, f := range funcs { @@ -356,8 +366,9 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { } // Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) +// +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) func (s *DB) Pluck(column string, value interface{}) *DB { return s.NewScope(s.Value).pluck(column, value).db } @@ -454,7 +465,8 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB { } // Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) +// +// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) func (s *DB) Raw(sql string, values ...interface{}) *DB { return s.clone().search.Raw(true).Where(sql, values...).db } @@ -469,10 +481,11 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB { } // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") func (s *DB) Model(value interface{}) *DB { c := s.clone() c.Value = value @@ -525,6 +538,16 @@ func (s *DB) Rollback() *DB { return s } +// AfterCommit invoke callback after transaction committed +// If no transaction is currently running, it will be invoked immediately +func (s *DB) AfterCommit(f func(db *DB)) { + if db, ok := s.db.(sqlTx); ok && db != nil { + s.afterCommitCallbacks = append(s.afterCommitCallbacks, f) + } else { + f(s) + } +} + // WrapInTx wraps a method in a transaction func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { if _, ok := s.db.(*sql.Tx); ok { @@ -552,6 +575,11 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { err = f(tx) if err == nil { err = tx.Commit().Error + if err == nil { + for i := len(tx.afterCommitCallbacks) - 1; i >= 0; i-- { + tx.afterCommitCallbacks[i](s) + } + } } panicked = false return @@ -674,7 +702,8 @@ func (s *DB) RemoveIndex(indexName string) *DB { } // AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") +// +// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) @@ -682,7 +711,8 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate } // RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +// +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") func (s *DB) RemoveForeignKey(field string, dest string) *DB { scope := s.clone().NewScope(s.Value) scope.removeForeignKey(field, dest) @@ -712,7 +742,8 @@ func (s *DB) Association(column string) *Association { } // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (s *DB) Preload(column string, conditions ...interface{}) *DB { return s.clone().search.Preload(column, conditions...).db }