optimize skiphook code.

This commit is contained in:
wangdong 2021-08-03 21:14:42 +08:00
parent f21ce964fc
commit 100d74a00c
8 changed files with 98 additions and 38 deletions

View File

@ -127,9 +127,10 @@ func (p *processor) Execute(db *DB) *DB {
} }
for _, c := range p.callbacks { for _, c := range p.callbacks {
if !stmt.ShouldSkipHook(c) { if stmt.CanSkip(c) && stmt.ShouldSkip(c) {
c.handler(db) continue
} }
c.handler(db)
} }
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {

View File

@ -20,64 +20,107 @@ type Config struct {
DeleteClauses []string DeleteClauses []string
} }
var (
//transaction callback names
BeforeTransactionCk = "gorm:begin_transaction"
CommitOrRollbackCk = "gorm:commit_or_rollback_transaction"
// create callback names
BeforeCreateCk = "gorm:before_create"
SaveBeforeAssociationsCk = "gorm:save_before_associations"
CreateCk = "gorm:create"
SaveAfterAssociationsCk = "gorm:save_after_associations"
AfterCreateCk = "gorm:after_create"
// query callback names
QueryCk = "gorm:query"
PreloadCk = "gorm:preload"
AfterQueryCk = "gorm:after_query"
// delete callback names
BeforeDeleteCk = "gorm:before_delete"
DeleteBeforeAssociationsCk = "gorm:delete_before_associations"
DeleteCk = "gorm:delete"
AfterDeleteCk = "gorm:after_delete"
// update callback names
SetUpReflectValueCk = "gorm:setup_reflect_value"
BeforeUpdateCk = "gorm:before_update"
UpdateCk = "gorm:update"
AfterUpdateCk = "gorm:after_update"
// row callback names
RowCk = "gorm:row"
// raw callback names
RawCk = "gorm:raw"
CoreCallbackNames = [...]string{BeforeTransactionCk, CommitOrRollbackCk,
SaveBeforeAssociationsCk, SaveAfterAssociationsCk,
CreateCk, QueryCk, PreloadCk,
DeleteBeforeAssociationsCk, DeleteCk,
SetUpReflectValueCk, UpdateCk,
RowCk, RawCk}
)
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
enableTransaction := func(db *gorm.DB) bool { enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction return !db.SkipDefaultTransaction
} }
createCallback := db.Callback().Create() createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register(BeforeCreateCk, BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config)) createCallback.Register(CreateCk, Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register(AfterCreateCk, AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 { if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses config.CreateClauses = createClauses
} }
createCallback.Clauses = config.CreateClauses createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query() queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query) queryCallback.Register(QueryCk, Query)
queryCallback.Register("gorm:preload", Preload) queryCallback.Register(PreloadCk, Preload)
queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Register(AfterQueryCk, AfterQuery)
if len(config.QueryClauses) == 0 { if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses config.QueryClauses = queryClauses
} }
queryCallback.Clauses = config.QueryClauses queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete() deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register(BeforeDeleteCk, BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register(DeleteBeforeAssociationsCk, DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register(DeleteCk, Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Register(AfterDeleteCk, AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 { if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses config.DeleteClauses = deleteClauses
} }
deleteCallback.Clauses = config.DeleteClauses deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update() updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register(SetUpReflectValueCk, SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register(BeforeUpdateCk, BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update) updateCallback.Register(UpdateCk, Update)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Register(AfterUpdateCk, AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 { if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses config.UpdateClauses = updateClauses
} }
updateCallback.Clauses = config.UpdateClauses updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row() rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery) rowCallback.Register(RowCk, RowQuery)
rowCallback.Clauses = config.QueryClauses rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw() rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec) rawCallback.Register(RawCk, RawExec)
rawCallback.Clauses = config.QueryClauses rawCallback.Clauses = config.QueryClauses
} }

View File

@ -10,7 +10,7 @@ import (
) )
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok { if i, ok := value.(BeforeSaveInterface); ok {
@ -205,7 +205,7 @@ func CreateWithReturning(db *gorm.DB) {
} }
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {

View File

@ -10,7 +10,7 @@ import (
) )
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeDeleteInterface); ok { if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx)) db.AddError(i.BeforeDelete(tx))
@ -156,7 +156,7 @@ func Delete(db *gorm.DB) {
} }
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterDeleteInterface); ok { if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx)) db.AddError(i.AfterDelete(tx))

View File

@ -216,7 +216,7 @@ func Preload(db *gorm.DB) {
} }
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok { if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx)) db.AddError(i.AfterFind(tx))

View File

@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
} }
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok { if i, ok := value.(BeforeSaveInterface); ok {
@ -87,7 +87,7 @@ func Update(db *gorm.DB) {
} }
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
callbacks2 "gorm.io/gorm/callbacks"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -673,8 +674,10 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
return results, !notRestricted && len(stmt.Selects) > 0 return results, !notRestricted && len(stmt.Selects) > 0
} }
// determine // determine weather the hook should be skipped or not
func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) { // return true if should skip
func (stmt *Statement) ShouldSkip(c *callback) (skip bool) {
skip = false
if stmt.SkipHooks { if stmt.SkipHooks {
// skip all // skip all
skip = true skip = true
@ -691,3 +694,15 @@ func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) {
} }
return return
} }
// to avoid skipping core hook.
func (stmt *Statement) CanSkip(c *callback) (canSkip bool) {
ckName := c.name
canSkip = true
for _, name := range callbacks2.CoreCallbackNames {
if ckName == name {
canSkip = false
}
}
return
}

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"errors" "errors"
"gorm.io/gorm/callbacks"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -498,7 +499,7 @@ func TestSkipHookByName(t *testing.T) {
product := Product3{Name: "Product", Price: 0} product := Product3{Name: "Product", Price: 0}
DB.AutoMigrate(&Product3{}) DB.AutoMigrate(&Product3{})
// expect price = 0 // expect price = 0
DB.SkipHookByName("gorm:before_create").Create(&product) DB.SkipHookByName(callbacks.BeforeCreateCk).Create(&product)
product2 := Product3{Name: "Product", Price: 0} product2 := Product3{Name: "Product", Price: 0}
// expect price = 100 // expect price = 100
DB.Create(&product2) DB.Create(&product2)