diff --git a/callbacks.go b/callbacks.go index 297ba925..f2d86c56 100644 --- a/callbacks.go +++ b/callbacks.go @@ -12,6 +12,49 @@ import ( "gorm.io/gorm/utils" ) +var ( + //transaction callback names + BeforeTransactionCk = "gorm:begin_transaction" + CommitOrRollbackCk = "gorm:commit_or_rollback_transaction" + + // create callback names + BeforeCreateCk = "gorm:before_create" + SaveBeforeAssociationsCk = "gorm:save_before_associations" + CreateCk = "gorm:create" + SaveAfterAssociationsCk = "gorm:save_after_associations" + AfterCreateCk = "gorm:after_create" + + // query callback names + QueryCk = "gorm:query" + PreloadCk = "gorm:preload" + AfterQueryCk = "gorm:after_query" + + // delete callback names + BeforeDeleteCk = "gorm:before_delete" + DeleteBeforeAssociationsCk = "gorm:delete_before_associations" + DeleteCk = "gorm:delete" + AfterDeleteCk = "gorm:after_delete" + + // update callback names + SetUpReflectValueCk = "gorm:setup_reflect_value" + BeforeUpdateCk = "gorm:before_update" + UpdateCk = "gorm:update" + AfterUpdateCk = "gorm:after_update" + + // row callback names + RowCk = "gorm:row" + + // raw callback names + RawCk = "gorm:raw" + + CoreCallbackNames = [...]string{BeforeTransactionCk, CommitOrRollbackCk, + SaveBeforeAssociationsCk, SaveAfterAssociationsCk, + CreateCk, QueryCk, PreloadCk, + DeleteBeforeAssociationsCk, DeleteCk, + SetUpReflectValueCk, UpdateCk, + RowCk, RawCk} +) + func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 82ef2e1a..4dab7ab8 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -20,107 +20,64 @@ type Config struct { DeleteClauses []string } -var ( - //transaction callback names - BeforeTransactionCk = "gorm:begin_transaction" - CommitOrRollbackCk = "gorm:commit_or_rollback_transaction" - - // create callback names - BeforeCreateCk = "gorm:before_create" - SaveBeforeAssociationsCk = "gorm:save_before_associations" - CreateCk = "gorm:create" - SaveAfterAssociationsCk = "gorm:save_after_associations" - AfterCreateCk = "gorm:after_create" - - // query callback names - QueryCk = "gorm:query" - PreloadCk = "gorm:preload" - AfterQueryCk = "gorm:after_query" - - // delete callback names - BeforeDeleteCk = "gorm:before_delete" - DeleteBeforeAssociationsCk = "gorm:delete_before_associations" - DeleteCk = "gorm:delete" - AfterDeleteCk = "gorm:after_delete" - - // update callback names - SetUpReflectValueCk = "gorm:setup_reflect_value" - BeforeUpdateCk = "gorm:before_update" - UpdateCk = "gorm:update" - AfterUpdateCk = "gorm:after_update" - - // row callback names - RowCk = "gorm:row" - - // raw callback names - RawCk = "gorm:raw" - - CoreCallbackNames = [...]string{BeforeTransactionCk, CommitOrRollbackCk, - SaveBeforeAssociationsCk, SaveAfterAssociationsCk, - CreateCk, QueryCk, PreloadCk, - DeleteBeforeAssociationsCk, DeleteCk, - SetUpReflectValueCk, UpdateCk, - RowCk, RawCk} -) - func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { enableTransaction := func(db *gorm.DB) bool { return !db.SkipDefaultTransaction } createCallback := db.Callback().Create() - createCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction) - createCallback.Register(BeforeCreateCk, BeforeCreate) - createCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(true)) - createCallback.Register(CreateCk, Create(config)) - createCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(true)) - createCallback.Register(AfterCreateCk, AfterCreate) - createCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction) + createCallback.Match(enableTransaction).Register(gorm.BeforeTransactionCk, BeginTransaction) + createCallback.Register(gorm.BeforeCreateCk, BeforeCreate) + createCallback.Register(gorm.SaveBeforeAssociationsCk, SaveBeforeAssociations(true)) + createCallback.Register(gorm.CreateCk, Create(config)) + createCallback.Register(gorm.SaveAfterAssociationsCk, SaveAfterAssociations(true)) + createCallback.Register(gorm.AfterCreateCk, AfterCreate) + createCallback.Match(enableTransaction).Register(gorm.CommitOrRollbackCk, CommitOrRollbackTransaction) if len(config.CreateClauses) == 0 { config.CreateClauses = createClauses } createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() - queryCallback.Register(QueryCk, Query) - queryCallback.Register(PreloadCk, Preload) - queryCallback.Register(AfterQueryCk, AfterQuery) + queryCallback.Register(gorm.QueryCk, Query) + queryCallback.Register(gorm.PreloadCk, Preload) + queryCallback.Register(gorm.AfterQueryCk, AfterQuery) if len(config.QueryClauses) == 0 { config.QueryClauses = queryClauses } queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() - deleteCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction) - deleteCallback.Register(BeforeDeleteCk, BeforeDelete) - deleteCallback.Register(DeleteBeforeAssociationsCk, DeleteBeforeAssociations) - deleteCallback.Register(DeleteCk, Delete) - deleteCallback.Register(AfterDeleteCk, AfterDelete) - deleteCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction) + deleteCallback.Match(enableTransaction).Register(gorm.BeforeTransactionCk, BeginTransaction) + deleteCallback.Register(gorm.BeforeDeleteCk, BeforeDelete) + deleteCallback.Register(gorm.DeleteBeforeAssociationsCk, DeleteBeforeAssociations) + deleteCallback.Register(gorm.DeleteCk, Delete) + deleteCallback.Register(gorm.AfterDeleteCk, AfterDelete) + deleteCallback.Match(enableTransaction).Register(gorm.CommitOrRollbackCk, CommitOrRollbackTransaction) if len(config.DeleteClauses) == 0 { config.DeleteClauses = deleteClauses } deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() - updateCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction) - updateCallback.Register(SetUpReflectValueCk, SetupUpdateReflectValue) - updateCallback.Register(BeforeUpdateCk, BeforeUpdate) - updateCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(false)) - updateCallback.Register(UpdateCk, Update) - updateCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(false)) - updateCallback.Register(AfterUpdateCk, AfterUpdate) - updateCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction) + updateCallback.Match(enableTransaction).Register(gorm.BeforeTransactionCk, BeginTransaction) + updateCallback.Register(gorm.SetUpReflectValueCk, SetupUpdateReflectValue) + updateCallback.Register(gorm.BeforeUpdateCk, BeforeUpdate) + updateCallback.Register(gorm.SaveBeforeAssociationsCk, SaveBeforeAssociations(false)) + updateCallback.Register(gorm.UpdateCk, Update) + updateCallback.Register(gorm.SaveAfterAssociationsCk, SaveAfterAssociations(false)) + updateCallback.Register(gorm.AfterUpdateCk, AfterUpdate) + updateCallback.Match(enableTransaction).Register(gorm.CommitOrRollbackCk, CommitOrRollbackTransaction) if len(config.UpdateClauses) == 0 { config.UpdateClauses = updateClauses } updateCallback.Clauses = config.UpdateClauses rowCallback := db.Callback().Row() - rowCallback.Register(RowCk, RowQuery) + rowCallback.Register(gorm.RowCk, RowQuery) rowCallback.Clauses = config.QueryClauses rawCallback := db.Callback().Raw() - rawCallback.Register(RawCk, RawExec) + rawCallback.Register(gorm.RawCk, RawExec) rawCallback.Clauses = config.QueryClauses } diff --git a/statement.go b/statement.go index dc14c50d..18545a73 100644 --- a/statement.go +++ b/statement.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - callbacks2 "gorm.io/gorm/callbacks" "reflect" "sort" "strconv" @@ -699,7 +698,7 @@ func (stmt *Statement) ShouldSkip(c *callback) (skip bool) { func (stmt *Statement) CanSkip(c *callback) (canSkip bool) { ckName := c.name canSkip = true - for _, name := range callbacks2.CoreCallbackNames { + for _, name := range CoreCallbackNames { if ckName == name { canSkip = false } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 3d145e65..bbb7d499 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -2,7 +2,6 @@ package tests_test import ( "errors" - "gorm.io/gorm/callbacks" "reflect" "strings" "testing" @@ -499,12 +498,14 @@ func TestSkipHookByName(t *testing.T) { product := Product3{Name: "Product", Price: 0} DB.AutoMigrate(&Product3{}) // expect price = 0 - DB.SkipHookByName(callbacks.BeforeCreateCk).Create(&product) + DB.SkipHookByName(gorm.BeforeCreateCk).Create(&product) product2 := Product3{Name: "Product", Price: 0} // expect price = 100 DB.Create(&product2) // expect code = code1 , price = 100 + 20(add in before update) + 30(add in before update) DB.Model(&product2).Update("code", "code1") // expect code = code2 , price not change - DB.Model(&product).SkipHookByName("gorm:before_update").Update("code", "code2") + DB.Model(&product).SkipHookByName(gorm.BeforeUpdateCk).Update("code", "code2") + // cant skip 'update',because update is core hook + DB.Model(&product).SkipHookByName(gorm.UpdateCk).Update("code", "code3") }