feat: add before find hook

This commit is contained in:
MhmdGol 2025-02-13 20:22:12 +03:30
parent 9f273777f5
commit 9cdaf44650
4 changed files with 22 additions and 2 deletions

View File

@ -48,6 +48,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:before_query", BeforeQuery)
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)

View File

@ -34,6 +34,10 @@ type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type BeforeFindInterface interface {
BeforeFind(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

View File

@ -11,6 +11,18 @@ import (
"gorm.io/gorm/utils"
)
func BeforeQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeFindInterface); ok {
db.AddError(i.BeforeFind(tx))
return true
}
return false
})
}
}
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)

View File

@ -24,6 +24,7 @@ const (
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeBeforeFind callbackType = "BeforeFind"
callbackTypeAfterFind callbackType = "AfterFind"
)
@ -52,7 +53,7 @@ type Schema struct {
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
BeforeFind, AfterFind bool
err error
initialized chan struct{}
namer Namer
@ -308,7 +309,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
callbackTypeBeforeFind, callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
@ -396,6 +397,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeBeforeFind:
return modelType.MethodByName(string(callbackTypeBeforeFind))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default: