optimize skiphook code.
This commit is contained in:
parent
f21ce964fc
commit
100d74a00c
@ -127,9 +127,10 @@ func (p *processor) Execute(db *DB) *DB {
|
||||
}
|
||||
|
||||
for _, c := range p.callbacks {
|
||||
if !stmt.ShouldSkipHook(c) {
|
||||
c.handler(db)
|
||||
if stmt.CanSkip(c) && stmt.ShouldSkip(c) {
|
||||
continue
|
||||
}
|
||||
c.handler(db)
|
||||
}
|
||||
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
|
@ -20,64 +20,107 @@ type Config struct {
|
||||
DeleteClauses []string
|
||||
}
|
||||
|
||||
var (
|
||||
//transaction callback names
|
||||
BeforeTransactionCk = "gorm:begin_transaction"
|
||||
CommitOrRollbackCk = "gorm:commit_or_rollback_transaction"
|
||||
|
||||
// create callback names
|
||||
BeforeCreateCk = "gorm:before_create"
|
||||
SaveBeforeAssociationsCk = "gorm:save_before_associations"
|
||||
CreateCk = "gorm:create"
|
||||
SaveAfterAssociationsCk = "gorm:save_after_associations"
|
||||
AfterCreateCk = "gorm:after_create"
|
||||
|
||||
// query callback names
|
||||
QueryCk = "gorm:query"
|
||||
PreloadCk = "gorm:preload"
|
||||
AfterQueryCk = "gorm:after_query"
|
||||
|
||||
// delete callback names
|
||||
BeforeDeleteCk = "gorm:before_delete"
|
||||
DeleteBeforeAssociationsCk = "gorm:delete_before_associations"
|
||||
DeleteCk = "gorm:delete"
|
||||
AfterDeleteCk = "gorm:after_delete"
|
||||
|
||||
// update callback names
|
||||
SetUpReflectValueCk = "gorm:setup_reflect_value"
|
||||
BeforeUpdateCk = "gorm:before_update"
|
||||
UpdateCk = "gorm:update"
|
||||
AfterUpdateCk = "gorm:after_update"
|
||||
|
||||
// row callback names
|
||||
RowCk = "gorm:row"
|
||||
|
||||
// raw callback names
|
||||
RawCk = "gorm:raw"
|
||||
|
||||
CoreCallbackNames = [...]string{BeforeTransactionCk, CommitOrRollbackCk,
|
||||
SaveBeforeAssociationsCk, SaveAfterAssociationsCk,
|
||||
CreateCk, QueryCk, PreloadCk,
|
||||
DeleteBeforeAssociationsCk, DeleteCk,
|
||||
SetUpReflectValueCk, UpdateCk,
|
||||
RowCk, RawCk}
|
||||
)
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
enableTransaction := func(db *gorm.DB) bool {
|
||||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
||||
createCallback := db.Callback().Create()
|
||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
|
||||
createCallback.Register("gorm:create", Create(config))
|
||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
createCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
|
||||
createCallback.Register(BeforeCreateCk, BeforeCreate)
|
||||
createCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(true))
|
||||
createCallback.Register(CreateCk, Create(config))
|
||||
createCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(true))
|
||||
createCallback.Register(AfterCreateCk, AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
createCallback.Clauses = config.CreateClauses
|
||||
|
||||
queryCallback := db.Callback().Query()
|
||||
queryCallback.Register("gorm:query", Query)
|
||||
queryCallback.Register("gorm:preload", Preload)
|
||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||
queryCallback.Register(QueryCk, Query)
|
||||
queryCallback.Register(PreloadCk, Preload)
|
||||
queryCallback.Register(AfterQueryCk, AfterQuery)
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
queryCallback.Clauses = config.QueryClauses
|
||||
|
||||
deleteCallback := db.Callback().Delete()
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||
deleteCallback.Register("gorm:delete", Delete)
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
deleteCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
|
||||
deleteCallback.Register(BeforeDeleteCk, BeforeDelete)
|
||||
deleteCallback.Register(DeleteBeforeAssociationsCk, DeleteBeforeAssociations)
|
||||
deleteCallback.Register(DeleteCk, Delete)
|
||||
deleteCallback.Register(AfterDeleteCk, AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
deleteCallback.Clauses = config.DeleteClauses
|
||||
|
||||
updateCallback := db.Callback().Update()
|
||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||
updateCallback.Register("gorm:update", Update)
|
||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
updateCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
|
||||
updateCallback.Register(SetUpReflectValueCk, SetupUpdateReflectValue)
|
||||
updateCallback.Register(BeforeUpdateCk, BeforeUpdate)
|
||||
updateCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(false))
|
||||
updateCallback.Register(UpdateCk, Update)
|
||||
updateCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(false))
|
||||
updateCallback.Register(AfterUpdateCk, AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
updateCallback.Clauses = config.UpdateClauses
|
||||
|
||||
rowCallback := db.Callback().Row()
|
||||
rowCallback.Register("gorm:row", RowQuery)
|
||||
rowCallback.Register(RowCk, RowQuery)
|
||||
rowCallback.Clauses = config.QueryClauses
|
||||
|
||||
rawCallback := db.Callback().Raw()
|
||||
rawCallback.Register("gorm:raw", RawExec)
|
||||
rawCallback.Register(RawCk, RawExec)
|
||||
rawCallback.Clauses = config.QueryClauses
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
@ -205,7 +205,7 @@ func CreateWithReturning(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
@ -156,7 +156,7 @@ func Delete(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
|
@ -216,7 +216,7 @@ func Preload(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
|
@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
@ -87,7 +87,7 @@ func Update(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
|
19
statement.go
19
statement.go
@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
callbacks2 "gorm.io/gorm/callbacks"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
@ -673,8 +674,10 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
||||
return results, !notRestricted && len(stmt.Selects) > 0
|
||||
}
|
||||
|
||||
// determine
|
||||
func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) {
|
||||
// determine weather the hook should be skipped or not
|
||||
// return true if should skip
|
||||
func (stmt *Statement) ShouldSkip(c *callback) (skip bool) {
|
||||
skip = false
|
||||
if stmt.SkipHooks {
|
||||
// skip all
|
||||
skip = true
|
||||
@ -691,3 +694,15 @@ func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// to avoid skipping core hook.
|
||||
func (stmt *Statement) CanSkip(c *callback) (canSkip bool) {
|
||||
ckName := c.name
|
||||
canSkip = true
|
||||
for _, name := range callbacks2.CoreCallbackNames {
|
||||
if ckName == name {
|
||||
canSkip = false
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -498,7 +499,7 @@ func TestSkipHookByName(t *testing.T) {
|
||||
product := Product3{Name: "Product", Price: 0}
|
||||
DB.AutoMigrate(&Product3{})
|
||||
// expect price = 0
|
||||
DB.SkipHookByName("gorm:before_create").Create(&product)
|
||||
DB.SkipHookByName(callbacks.BeforeCreateCk).Create(&product)
|
||||
product2 := Product3{Name: "Product", Price: 0}
|
||||
// expect price = 100
|
||||
DB.Create(&product2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user