From 4777b42b5b83d6c3c9682cfe2b77820ff13da9d1 Mon Sep 17 00:00:00 2001 From: MhmdGol Date: Sat, 15 Feb 2025 01:06:02 +0330 Subject: [PATCH] test: add tests for BeforeFind hook --- callbacks/query.go | 2 +- schema/callbacks_test.go | 2 +- tests/hooks_test.go | 93 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index c70058ec..a3654516 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -12,7 +12,7 @@ import ( ) func BeforeQuery(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(BeforeFindInterface); ok { db.AddError(i.BeforeFind(tx)) diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index 4583a207..34d43a9e 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -31,7 +31,7 @@ func TestCallback(t *testing.T) { } } - for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { + for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "BeforeFind", "AfterFind"} { if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { t.Errorf("%v should be false", str) } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 04f62bde..637504a2 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -609,3 +609,96 @@ func TestPropagateUnscoped(t *testing.T) { t.Fatalf("unscoped did not propagate") } } + +type Product7 struct { + gorm.Model + Code string + Price float64 + BeforeFindCallTimes int64 `gorm:"-"` +} + +func (s *Product7) BeforeFind(tx *gorm.DB) error { + s.BeforeFindCallTimes++ + return nil +} + +// Modifies transient field +func TestBeforeFindHookCallCount(t *testing.T) { + DB.Migrator().DropTable(&Product7{}) + DB.AutoMigrate(&Product7{}) + + p := Product7{Code: "before_find_count", Price: 100} + DB.Save(&p) + + var result Product7 + + DB.First(&result, "code = ?", "before_find_count") + if result.BeforeFindCallTimes != 1 { + t.Errorf("Expected 1, got %d", result.BeforeFindCallTimes) + } + + DB.First(&result, "code = ?", "before_find_count") + if result.BeforeFindCallTimes != 2 { + t.Errorf("Expected 2, got %d", result.BeforeFindCallTimes) + } +} + +type Product8 struct { + gorm.Model + Code string + Price float64 +} + +func (s *Product8) BeforeFind(tx *gorm.DB) error { + tx.Statement.Where("price > ?", 50) + + return nil +} + +// Fails for postgres +// ERROR: invalid input syntax for type bigint: "t" (SQLSTATE 22P02) +func TestBeforeFindModifiesQuery(t *testing.T) { + DB.Migrator().DropTable(&Product8{}) + DB.AutoMigrate(&Product8{}) + + products := []Product8{ + {Code: "A", Price: 100}, + {Code: "B", Price: 30}, + } + DB.Create(&products) + + var results []Product8 + + // Without condition, hooks will be skipped + DB.Find(&results, true) + + if len(results) != 1 || results[0].Code != "A" { + t.Errorf("BeforeFind should filter results, got %v", results) + } +} + +type Product9 struct { + gorm.Model + Code string + Price float64 +} + +func (s *Product9) BeforeFind(tx *gorm.DB) error { + s.Price = 200 + return nil +} + +func TestDatabaseOverwritesBeforeFindChanges(t *testing.T) { + DB.Migrator().DropTable(&Product9{}) + DB.AutoMigrate(&Product9{}) + + p := Product9{Code: "price_overwrite", Price: 100} + DB.Save(&p) + + var result Product9 + DB.First(&result, "code = ?", "price_overwrite") + + if result.Price != 100 { + t.Errorf("Price should be loaded from database, got %f", result.Price) + } +}