From 9cdaf446508837e14d469ab45ce6437893c81c06 Mon Sep 17 00:00:00 2001 From: MhmdGol Date: Thu, 13 Feb 2025 20:22:12 +0330 Subject: [PATCH] feat: add before find hook --- callbacks/callbacks.go | 1 + callbacks/interfaces.go | 4 ++++ callbacks/query.go | 12 ++++++++++++ schema/schema.go | 7 +++++-- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..db5b865e 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -48,6 +48,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() + queryCallback.Register("gorm:before_query", BeforeQuery) queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go index 2302470f..825519c5 100644 --- a/callbacks/interfaces.go +++ b/callbacks/interfaces.go @@ -34,6 +34,10 @@ type AfterDeleteInterface interface { AfterDelete(*gorm.DB) error } +type BeforeFindInterface interface { + BeforeFind(*gorm.DB) error +} + type AfterFindInterface interface { AfterFind(*gorm.DB) error } diff --git a/callbacks/query.go b/callbacks/query.go index bbf238a9..c70058ec 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -11,6 +11,18 @@ import ( "gorm.io/gorm/utils" ) +func BeforeQuery(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(BeforeFindInterface); ok { + db.AddError(i.BeforeFind(tx)) + return true + } + return false + }) + } +} + func Query(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) diff --git a/schema/schema.go b/schema/schema.go index db236797..5a3ebedc 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -24,6 +24,7 @@ const ( callbackTypeAfterSave callbackType = "AfterSave" callbackTypeBeforeDelete callbackType = "BeforeDelete" callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeBeforeFind callbackType = "BeforeFind" callbackTypeAfterFind callbackType = "AfterFind" ) @@ -52,7 +53,7 @@ type Schema struct { BeforeUpdate, AfterUpdate bool BeforeDelete, AfterDelete bool BeforeSave, AfterSave bool - AfterFind bool + BeforeFind, AfterFind bool err error initialized chan struct{} namer Namer @@ -308,7 +309,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam callbackTypeBeforeUpdate, callbackTypeAfterUpdate, callbackTypeBeforeSave, callbackTypeAfterSave, callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, + callbackTypeBeforeFind, callbackTypeAfterFind, } for _, cbName := range callbackTypes { if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { @@ -396,6 +397,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect return modelType.MethodByName(string(callbackTypeBeforeDelete)) case callbackTypeAfterDelete: return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeBeforeFind: + return modelType.MethodByName(string(callbackTypeBeforeFind)) case callbackTypeAfterFind: return modelType.MethodByName(string(callbackTypeAfterFind)) default: