test: add tests for BeforeFind hook
This commit is contained in:
parent
9cdaf44650
commit
4777b42b5b
@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BeforeQuery(db *gorm.DB) {
|
func BeforeQuery(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(BeforeFindInterface); ok {
|
if i, ok := value.(BeforeFindInterface); ok {
|
||||||
db.AddError(i.BeforeFind(tx))
|
db.AddError(i.BeforeFind(tx))
|
||||||
|
@ -31,7 +31,7 @@ func TestCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
|
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "BeforeFind", "AfterFind"} {
|
||||||
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
|
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
|
||||||
t.Errorf("%v should be false", str)
|
t.Errorf("%v should be false", str)
|
||||||
}
|
}
|
||||||
|
@ -609,3 +609,96 @@ func TestPropagateUnscoped(t *testing.T) {
|
|||||||
t.Fatalf("unscoped did not propagate")
|
t.Fatalf("unscoped did not propagate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Product7 struct {
|
||||||
|
gorm.Model
|
||||||
|
Code string
|
||||||
|
Price float64
|
||||||
|
BeforeFindCallTimes int64 `gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product7) BeforeFind(tx *gorm.DB) error {
|
||||||
|
s.BeforeFindCallTimes++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modifies transient field
|
||||||
|
func TestBeforeFindHookCallCount(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&Product7{})
|
||||||
|
DB.AutoMigrate(&Product7{})
|
||||||
|
|
||||||
|
p := Product7{Code: "before_find_count", Price: 100}
|
||||||
|
DB.Save(&p)
|
||||||
|
|
||||||
|
var result Product7
|
||||||
|
|
||||||
|
DB.First(&result, "code = ?", "before_find_count")
|
||||||
|
if result.BeforeFindCallTimes != 1 {
|
||||||
|
t.Errorf("Expected 1, got %d", result.BeforeFindCallTimes)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&result, "code = ?", "before_find_count")
|
||||||
|
if result.BeforeFindCallTimes != 2 {
|
||||||
|
t.Errorf("Expected 2, got %d", result.BeforeFindCallTimes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Product8 struct {
|
||||||
|
gorm.Model
|
||||||
|
Code string
|
||||||
|
Price float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product8) BeforeFind(tx *gorm.DB) error {
|
||||||
|
tx.Statement.Where("price > ?", 50)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fails for postgres
|
||||||
|
// ERROR: invalid input syntax for type bigint: "t" (SQLSTATE 22P02)
|
||||||
|
func TestBeforeFindModifiesQuery(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&Product8{})
|
||||||
|
DB.AutoMigrate(&Product8{})
|
||||||
|
|
||||||
|
products := []Product8{
|
||||||
|
{Code: "A", Price: 100},
|
||||||
|
{Code: "B", Price: 30},
|
||||||
|
}
|
||||||
|
DB.Create(&products)
|
||||||
|
|
||||||
|
var results []Product8
|
||||||
|
|
||||||
|
// Without condition, hooks will be skipped
|
||||||
|
DB.Find(&results, true)
|
||||||
|
|
||||||
|
if len(results) != 1 || results[0].Code != "A" {
|
||||||
|
t.Errorf("BeforeFind should filter results, got %v", results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Product9 struct {
|
||||||
|
gorm.Model
|
||||||
|
Code string
|
||||||
|
Price float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product9) BeforeFind(tx *gorm.DB) error {
|
||||||
|
s.Price = 200
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseOverwritesBeforeFindChanges(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&Product9{})
|
||||||
|
DB.AutoMigrate(&Product9{})
|
||||||
|
|
||||||
|
p := Product9{Code: "price_overwrite", Price: 100}
|
||||||
|
DB.Save(&p)
|
||||||
|
|
||||||
|
var result Product9
|
||||||
|
DB.First(&result, "code = ?", "price_overwrite")
|
||||||
|
|
||||||
|
if result.Price != 100 {
|
||||||
|
t.Errorf("Price should be loaded from database, got %f", result.Price)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user