optimize skiphook code.
This commit is contained in:
parent
f21ce964fc
commit
100d74a00c
@ -127,9 +127,10 @@ func (p *processor) Execute(db *DB) *DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range p.callbacks {
|
for _, c := range p.callbacks {
|
||||||
if !stmt.ShouldSkipHook(c) {
|
if stmt.CanSkip(c) && stmt.ShouldSkip(c) {
|
||||||
c.handler(db)
|
continue
|
||||||
}
|
}
|
||||||
|
c.handler(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||||
|
@ -20,64 +20,107 @@ type Config struct {
|
|||||||
DeleteClauses []string
|
DeleteClauses []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
//transaction callback names
|
||||||
|
BeforeTransactionCk = "gorm:begin_transaction"
|
||||||
|
CommitOrRollbackCk = "gorm:commit_or_rollback_transaction"
|
||||||
|
|
||||||
|
// create callback names
|
||||||
|
BeforeCreateCk = "gorm:before_create"
|
||||||
|
SaveBeforeAssociationsCk = "gorm:save_before_associations"
|
||||||
|
CreateCk = "gorm:create"
|
||||||
|
SaveAfterAssociationsCk = "gorm:save_after_associations"
|
||||||
|
AfterCreateCk = "gorm:after_create"
|
||||||
|
|
||||||
|
// query callback names
|
||||||
|
QueryCk = "gorm:query"
|
||||||
|
PreloadCk = "gorm:preload"
|
||||||
|
AfterQueryCk = "gorm:after_query"
|
||||||
|
|
||||||
|
// delete callback names
|
||||||
|
BeforeDeleteCk = "gorm:before_delete"
|
||||||
|
DeleteBeforeAssociationsCk = "gorm:delete_before_associations"
|
||||||
|
DeleteCk = "gorm:delete"
|
||||||
|
AfterDeleteCk = "gorm:after_delete"
|
||||||
|
|
||||||
|
// update callback names
|
||||||
|
SetUpReflectValueCk = "gorm:setup_reflect_value"
|
||||||
|
BeforeUpdateCk = "gorm:before_update"
|
||||||
|
UpdateCk = "gorm:update"
|
||||||
|
AfterUpdateCk = "gorm:after_update"
|
||||||
|
|
||||||
|
// row callback names
|
||||||
|
RowCk = "gorm:row"
|
||||||
|
|
||||||
|
// raw callback names
|
||||||
|
RawCk = "gorm:raw"
|
||||||
|
|
||||||
|
CoreCallbackNames = [...]string{BeforeTransactionCk, CommitOrRollbackCk,
|
||||||
|
SaveBeforeAssociationsCk, SaveAfterAssociationsCk,
|
||||||
|
CreateCk, QueryCk, PreloadCk,
|
||||||
|
DeleteBeforeAssociationsCk, DeleteCk,
|
||||||
|
SetUpReflectValueCk, UpdateCk,
|
||||||
|
RowCk, RawCk}
|
||||||
|
)
|
||||||
|
|
||||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
enableTransaction := func(db *gorm.DB) bool {
|
enableTransaction := func(db *gorm.DB) bool {
|
||||||
return !db.SkipDefaultTransaction
|
return !db.SkipDefaultTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
createCallback := db.Callback().Create()
|
createCallback := db.Callback().Create()
|
||||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
createCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
|
||||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
createCallback.Register(BeforeCreateCk, BeforeCreate)
|
||||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
|
createCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(true))
|
||||||
createCallback.Register("gorm:create", Create(config))
|
createCallback.Register(CreateCk, Create(config))
|
||||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
createCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(true))
|
||||||
createCallback.Register("gorm:after_create", AfterCreate)
|
createCallback.Register(AfterCreateCk, AfterCreate)
|
||||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
createCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
|
||||||
if len(config.CreateClauses) == 0 {
|
if len(config.CreateClauses) == 0 {
|
||||||
config.CreateClauses = createClauses
|
config.CreateClauses = createClauses
|
||||||
}
|
}
|
||||||
createCallback.Clauses = config.CreateClauses
|
createCallback.Clauses = config.CreateClauses
|
||||||
|
|
||||||
queryCallback := db.Callback().Query()
|
queryCallback := db.Callback().Query()
|
||||||
queryCallback.Register("gorm:query", Query)
|
queryCallback.Register(QueryCk, Query)
|
||||||
queryCallback.Register("gorm:preload", Preload)
|
queryCallback.Register(PreloadCk, Preload)
|
||||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
queryCallback.Register(AfterQueryCk, AfterQuery)
|
||||||
if len(config.QueryClauses) == 0 {
|
if len(config.QueryClauses) == 0 {
|
||||||
config.QueryClauses = queryClauses
|
config.QueryClauses = queryClauses
|
||||||
}
|
}
|
||||||
queryCallback.Clauses = config.QueryClauses
|
queryCallback.Clauses = config.QueryClauses
|
||||||
|
|
||||||
deleteCallback := db.Callback().Delete()
|
deleteCallback := db.Callback().Delete()
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
deleteCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
|
||||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
deleteCallback.Register(BeforeDeleteCk, BeforeDelete)
|
||||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
deleteCallback.Register(DeleteBeforeAssociationsCk, DeleteBeforeAssociations)
|
||||||
deleteCallback.Register("gorm:delete", Delete)
|
deleteCallback.Register(DeleteCk, Delete)
|
||||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
deleteCallback.Register(AfterDeleteCk, AfterDelete)
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
deleteCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
|
||||||
if len(config.DeleteClauses) == 0 {
|
if len(config.DeleteClauses) == 0 {
|
||||||
config.DeleteClauses = deleteClauses
|
config.DeleteClauses = deleteClauses
|
||||||
}
|
}
|
||||||
deleteCallback.Clauses = config.DeleteClauses
|
deleteCallback.Clauses = config.DeleteClauses
|
||||||
|
|
||||||
updateCallback := db.Callback().Update()
|
updateCallback := db.Callback().Update()
|
||||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
updateCallback.Match(enableTransaction).Register(BeforeTransactionCk, BeginTransaction)
|
||||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
updateCallback.Register(SetUpReflectValueCk, SetupUpdateReflectValue)
|
||||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
updateCallback.Register(BeforeUpdateCk, BeforeUpdate)
|
||||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
updateCallback.Register(SaveBeforeAssociationsCk, SaveBeforeAssociations(false))
|
||||||
updateCallback.Register("gorm:update", Update)
|
updateCallback.Register(UpdateCk, Update)
|
||||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
updateCallback.Register(SaveAfterAssociationsCk, SaveAfterAssociations(false))
|
||||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
updateCallback.Register(AfterUpdateCk, AfterUpdate)
|
||||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
updateCallback.Match(enableTransaction).Register(CommitOrRollbackCk, CommitOrRollbackTransaction)
|
||||||
if len(config.UpdateClauses) == 0 {
|
if len(config.UpdateClauses) == 0 {
|
||||||
config.UpdateClauses = updateClauses
|
config.UpdateClauses = updateClauses
|
||||||
}
|
}
|
||||||
updateCallback.Clauses = config.UpdateClauses
|
updateCallback.Clauses = config.UpdateClauses
|
||||||
|
|
||||||
rowCallback := db.Callback().Row()
|
rowCallback := db.Callback().Row()
|
||||||
rowCallback.Register("gorm:row", RowQuery)
|
rowCallback.Register(RowCk, RowQuery)
|
||||||
rowCallback.Clauses = config.QueryClauses
|
rowCallback.Clauses = config.QueryClauses
|
||||||
|
|
||||||
rawCallback := db.Callback().Raw()
|
rawCallback := db.Callback().Raw()
|
||||||
rawCallback.Register("gorm:raw", RawExec)
|
rawCallback.Register(RawCk, RawExec)
|
||||||
rawCallback.Clauses = config.QueryClauses
|
rawCallback.Clauses = config.QueryClauses
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BeforeCreate(db *gorm.DB) {
|
func BeforeCreate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(BeforeSaveInterface); ok {
|
if i, ok := value.(BeforeSaveInterface); ok {
|
||||||
@ -205,7 +205,7 @@ func CreateWithReturning(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AfterCreate(db *gorm.DB) {
|
func AfterCreate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BeforeDelete(db *gorm.DB) {
|
func BeforeDelete(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||||
db.AddError(i.BeforeDelete(tx))
|
db.AddError(i.BeforeDelete(tx))
|
||||||
@ -156,7 +156,7 @@ func Delete(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AfterDelete(db *gorm.DB) {
|
func AfterDelete(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(AfterDeleteInterface); ok {
|
if i, ok := value.(AfterDeleteInterface); ok {
|
||||||
db.AddError(i.AfterDelete(tx))
|
db.AddError(i.AfterDelete(tx))
|
||||||
|
@ -216,7 +216,7 @@ func Preload(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AfterQuery(db *gorm.DB) {
|
func AfterQuery(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(AfterFindInterface); ok {
|
if i, ok := value.(AfterFindInterface); ok {
|
||||||
db.AddError(i.AfterFind(tx))
|
db.AddError(i.AfterFind(tx))
|
||||||
|
@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BeforeUpdate(db *gorm.DB) {
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(BeforeSaveInterface); ok {
|
if i, ok := value.(BeforeSaveInterface); ok {
|
||||||
@ -87,7 +87,7 @@ func Update(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AfterUpdate(db *gorm.DB) {
|
func AfterUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
|
19
statement.go
19
statement.go
@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
callbacks2 "gorm.io/gorm/callbacks"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -673,8 +674,10 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
|||||||
return results, !notRestricted && len(stmt.Selects) > 0
|
return results, !notRestricted && len(stmt.Selects) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// determine
|
// determine weather the hook should be skipped or not
|
||||||
func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) {
|
// return true if should skip
|
||||||
|
func (stmt *Statement) ShouldSkip(c *callback) (skip bool) {
|
||||||
|
skip = false
|
||||||
if stmt.SkipHooks {
|
if stmt.SkipHooks {
|
||||||
// skip all
|
// skip all
|
||||||
skip = true
|
skip = true
|
||||||
@ -691,3 +694,15 @@ func (stmt *Statement) ShouldSkipHook(c *callback) (skip bool) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// to avoid skipping core hook.
|
||||||
|
func (stmt *Statement) CanSkip(c *callback) (canSkip bool) {
|
||||||
|
ckName := c.name
|
||||||
|
canSkip = true
|
||||||
|
for _, name := range callbacks2.CoreCallbackNames {
|
||||||
|
if ckName == name {
|
||||||
|
canSkip = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"gorm.io/gorm/callbacks"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -498,7 +499,7 @@ func TestSkipHookByName(t *testing.T) {
|
|||||||
product := Product3{Name: "Product", Price: 0}
|
product := Product3{Name: "Product", Price: 0}
|
||||||
DB.AutoMigrate(&Product3{})
|
DB.AutoMigrate(&Product3{})
|
||||||
// expect price = 0
|
// expect price = 0
|
||||||
DB.SkipHookByName("gorm:before_create").Create(&product)
|
DB.SkipHookByName(callbacks.BeforeCreateCk).Create(&product)
|
||||||
product2 := Product3{Name: "Product", Price: 0}
|
product2 := Product3{Name: "Product", Price: 0}
|
||||||
// expect price = 100
|
// expect price = 100
|
||||||
DB.Create(&product2)
|
DB.Create(&product2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user