diff --git a/callbacks.go b/callbacks.go index d4b61f73..297ba925 100644 --- a/callbacks.go +++ b/callbacks.go @@ -127,9 +127,10 @@ func (p *processor) Execute(db *DB) *DB { } for _, c := range p.callbacks { - if !stmt.ShouldSkipHook(c) { - c.handler(db) + if stmt.CanSkip(c) && stmt.ShouldSkip(c) { + continue } + c.handler(db) } db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d85c1928..82ef2e1a 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -20,64 +20,107 @@ type Config struct { DeleteClauses []string } +var ( + //transaction callback names + BeforeTransactionCk = "gorm:begin_transaction" + CommitOrRollbackCk = "gorm:commit_or_rollback_transaction" + + // create callback names + BeforeCreateCk = "gorm:before_create" + SaveBeforeAssociationsCk = "gorm:save_before_associations" + CreateCk = "gorm:create" + SaveAfterAssociationsCk = "gorm:save_after_associations" + AfterCreateCk = "gorm:after_create" + + // query callback names + QueryCk = "gorm:query" + PreloadCk = "gorm:preload" + AfterQueryCk = "gorm:after_query" + + // delete callback names + BeforeDeleteCk = "gorm:before_delete" + DeleteBeforeAssociationsCk = "gorm:delete_before_associations" + DeleteCk = "gorm:delete" + AfterDeleteCk = "gorm:after_delete" + + // update callback names + SetUpReflectValueCk = "gorm:setup_reflect_value" + BeforeUpdateCk = "gorm:before_update" + UpdateCk = "gorm:update" + AfterUpdateCk = "gorm:after_update" + + // row callback names + RowCk = "gorm:row" + + // raw callback names + RawCk = "gorm:raw" + + CoreCallbackNames = [...]string{BeforeTransactionCk, CommitOrRollbackCk, + SaveBeforeAssociationsCk, SaveAfterAssociationsCk, + CreateCk, QueryCk, PreloadCk, + DeleteBeforeAssociationsCk, DeleteCk, + SetUpReflectValueCk, UpdateCk, + RowCk, RawCk} +) + func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { enableTransaction := func(db *gorm.DB) bool { return !db.SkipDefaultTransaction } createCallback := db.Callback().Create() - createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) - createCallback.Register("gorm:before_create", BeforeCreate) - createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) - createCallback.Register("gorm:create", Create(config)) - createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) - createCallback.Register("gorm:after_create", AfterCreate) - createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + createCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction) + createCallback.Register(BeforeCreateCk, BeforeCreate) + createCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(true)) + createCallback.Register(CreateCk, Create(config)) + createCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(true)) + createCallback.Register(AfterCreateCk, AfterCreate) + createCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction) if len(config.CreateClauses) == 0 { config.CreateClauses = createClauses } createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() - queryCallback.Register("gorm:query", Query) - queryCallback.Register("gorm:preload", Preload) - queryCallback.Register("gorm:after_query", AfterQuery) + queryCallback.Register(QueryCk, Query) + queryCallback.Register(PreloadCk, Preload) + queryCallback.Register(AfterQueryCk, AfterQuery) if len(config.QueryClauses) == 0 { config.QueryClauses = queryClauses } queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() - deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) - deleteCallback.Register("gorm:before_delete", BeforeDelete) - deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) - deleteCallback.Register("gorm:delete", Delete) - deleteCallback.Register("gorm:after_delete", AfterDelete) - deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + deleteCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction) + deleteCallback.Register(BeforeDeleteCk, BeforeDelete) + deleteCallback.Register(DeleteBeforeAssociationsCk, DeleteBeforeAssociations) + deleteCallback.Register(DeleteCk, Delete) + deleteCallback.Register(AfterDeleteCk, AfterDelete) + deleteCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction) if len(config.DeleteClauses) == 0 { config.DeleteClauses = deleteClauses } deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() - updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) - updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) - updateCallback.Register("gorm:before_update", BeforeUpdate) - updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) - updateCallback.Register("gorm:update", Update) - updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) - updateCallback.Register("gorm:after_update", AfterUpdate) - updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + updateCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction) + updateCallback.Register(SetUpReflectValueCk, SetupUpdateReflectValue) + updateCallback.Register(BeforeUpdateCk, BeforeUpdate) + updateCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(false)) + updateCallback.Register(UpdateCk, Update) + updateCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(false)) + updateCallback.Register(AfterUpdateCk, AfterUpdate) + updateCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction) if len(config.UpdateClauses) == 0 { config.UpdateClauses = updateClauses } updateCallback.Clauses = config.UpdateClauses rowCallback := db.Callback().Row() - rowCallback.Register("gorm:row", RowQuery) + rowCallback.Register(RowCk, RowQuery) rowCallback.Clauses = config.QueryClauses rawCallback := db.Callback().Raw() - rawCallback.Register("gorm:raw", RawExec) + rawCallback.Register(RawCk, RawExec) rawCallback.Clauses = config.QueryClauses } diff --git a/callbacks/create.go b/callbacks/create.go index 8a3c593c..c867c60a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,7 +10,7 @@ import ( ) func BeforeCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -205,7 +205,7 @@ func CreateWithReturning(db *gorm.DB) { } func AfterCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { diff --git a/callbacks/delete.go b/callbacks/delete.go index 91659c51..b462ff87 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -10,7 +10,7 @@ import ( ) func BeforeDelete(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) @@ -156,7 +156,7 @@ func Delete(db *gorm.DB) { } func AfterDelete(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) diff --git a/callbacks/query.go b/callbacks/query.go index 3299d015..31c9c097 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -216,7 +216,7 @@ func Preload(db *gorm.DB) { } func AfterQuery(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) diff --git a/callbacks/update.go b/callbacks/update.go index 75bb02db..fe580c33 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { } func BeforeUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -87,7 +87,7 @@ func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { diff --git a/statement.go b/statement.go index d2755610..dc14c50d 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + callbacks2 "gorm.io/gorm/callbacks" "reflect" "sort" "strconv" @@ -673,8 +674,10 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( return results, !notRestricted && len(stmt.Selects) > 0 } -// determine -func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) { +// determine weather the hook should be skipped or not +// return true if should skip +func (stmt *Statement) ShouldSkip(c *callback) (skip bool) { + skip = false if stmt.SkipHooks { // skip all skip = true @@ -691,3 +694,15 @@ func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) { } return } + +// to avoid skipping core hook. +func (stmt *Statement) CanSkip(c *callback) (canSkip bool) { + ckName := c.name + canSkip = true + for _, name := range callbacks2.CoreCallbackNames { + if ckName == name { + canSkip = false + } + } + return +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 2fec2d4d..3d145e65 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "gorm.io/gorm/callbacks" "reflect" "strings" "testing" @@ -498,7 +499,7 @@ func TestSkipHookByName(t *testing.T) { product := Product3{Name: "Product", Price: 0} DB.AutoMigrate(&Product3{}) // expect price = 0 - DB.SkipHookByName("gorm:before_create").Create(&product) + DB.SkipHookByName(callbacks.BeforeCreateCk).Create(&product) product2 := Product3{Name: "Product", Price: 0} // expect price = 100 DB.Create(&product2)