test: add tests for BeforeFind hook
This commit is contained in:
parent
9cdaf44650
commit
4777b42b5b
@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func BeforeQuery(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(BeforeFindInterface); ok {
|
||||
db.AddError(i.BeforeFind(tx))
|
||||
|
@ -31,7 +31,7 @@ func TestCallback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
|
||||
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "BeforeFind", "AfterFind"} {
|
||||
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
|
||||
t.Errorf("%v should be false", str)
|
||||
}
|
||||
|
@ -609,3 +609,96 @@ func TestPropagateUnscoped(t *testing.T) {
|
||||
t.Fatalf("unscoped did not propagate")
|
||||
}
|
||||
}
|
||||
|
||||
type Product7 struct {
|
||||
gorm.Model
|
||||
Code string
|
||||
Price float64
|
||||
BeforeFindCallTimes int64 `gorm:"-"`
|
||||
}
|
||||
|
||||
func (s *Product7) BeforeFind(tx *gorm.DB) error {
|
||||
s.BeforeFindCallTimes++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modifies transient field
|
||||
func TestBeforeFindHookCallCount(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product7{})
|
||||
DB.AutoMigrate(&Product7{})
|
||||
|
||||
p := Product7{Code: "before_find_count", Price: 100}
|
||||
DB.Save(&p)
|
||||
|
||||
var result Product7
|
||||
|
||||
DB.First(&result, "code = ?", "before_find_count")
|
||||
if result.BeforeFindCallTimes != 1 {
|
||||
t.Errorf("Expected 1, got %d", result.BeforeFindCallTimes)
|
||||
}
|
||||
|
||||
DB.First(&result, "code = ?", "before_find_count")
|
||||
if result.BeforeFindCallTimes != 2 {
|
||||
t.Errorf("Expected 2, got %d", result.BeforeFindCallTimes)
|
||||
}
|
||||
}
|
||||
|
||||
type Product8 struct {
|
||||
gorm.Model
|
||||
Code string
|
||||
Price float64
|
||||
}
|
||||
|
||||
func (s *Product8) BeforeFind(tx *gorm.DB) error {
|
||||
tx.Statement.Where("price > ?", 50)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fails for postgres
|
||||
// ERROR: invalid input syntax for type bigint: "t" (SQLSTATE 22P02)
|
||||
func TestBeforeFindModifiesQuery(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product8{})
|
||||
DB.AutoMigrate(&Product8{})
|
||||
|
||||
products := []Product8{
|
||||
{Code: "A", Price: 100},
|
||||
{Code: "B", Price: 30},
|
||||
}
|
||||
DB.Create(&products)
|
||||
|
||||
var results []Product8
|
||||
|
||||
// Without condition, hooks will be skipped
|
||||
DB.Find(&results, true)
|
||||
|
||||
if len(results) != 1 || results[0].Code != "A" {
|
||||
t.Errorf("BeforeFind should filter results, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
type Product9 struct {
|
||||
gorm.Model
|
||||
Code string
|
||||
Price float64
|
||||
}
|
||||
|
||||
func (s *Product9) BeforeFind(tx *gorm.DB) error {
|
||||
s.Price = 200
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDatabaseOverwritesBeforeFindChanges(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product9{})
|
||||
DB.AutoMigrate(&Product9{})
|
||||
|
||||
p := Product9{Code: "price_overwrite", Price: 100}
|
||||
DB.Save(&p)
|
||||
|
||||
var result Product9
|
||||
DB.First(&result, "code = ?", "price_overwrite")
|
||||
|
||||
if result.Price != 100 {
|
||||
t.Errorf("Price should be loaded from database, got %f", result.Price)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user