test: add tests for BeforeFind hook
This commit is contained in:
		
							parent
							
								
									9cdaf44650
								
							
						
					
					
						commit
						4777b42b5b
					
				| @ -12,7 +12,7 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func BeforeQuery(db *gorm.DB) { | ||||
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind { | ||||
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeFind { | ||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) bool { | ||||
| 			if i, ok := value.(BeforeFindInterface); ok { | ||||
| 				db.AddError(i.BeforeFind(tx)) | ||||
|  | ||||
| @ -31,7 +31,7 @@ func TestCallback(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { | ||||
| 	for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "BeforeFind", "AfterFind"} { | ||||
| 		if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { | ||||
| 			t.Errorf("%v should be false", str) | ||||
| 		} | ||||
|  | ||||
| @ -609,3 +609,96 @@ func TestPropagateUnscoped(t *testing.T) { | ||||
| 		t.Fatalf("unscoped did not propagate") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Product7 struct { | ||||
| 	gorm.Model | ||||
| 	Code                string | ||||
| 	Price               float64 | ||||
| 	BeforeFindCallTimes int64 `gorm:"-"` | ||||
| } | ||||
| 
 | ||||
| func (s *Product7) BeforeFind(tx *gorm.DB) error { | ||||
| 	s.BeforeFindCallTimes++ | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Modifies transient field
 | ||||
| func TestBeforeFindHookCallCount(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&Product7{}) | ||||
| 	DB.AutoMigrate(&Product7{}) | ||||
| 
 | ||||
| 	p := Product7{Code: "before_find_count", Price: 100} | ||||
| 	DB.Save(&p) | ||||
| 
 | ||||
| 	var result Product7 | ||||
| 
 | ||||
| 	DB.First(&result, "code = ?", "before_find_count") | ||||
| 	if result.BeforeFindCallTimes != 1 { | ||||
| 		t.Errorf("Expected 1, got %d", result.BeforeFindCallTimes) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.First(&result, "code = ?", "before_find_count") | ||||
| 	if result.BeforeFindCallTimes != 2 { | ||||
| 		t.Errorf("Expected 2, got %d", result.BeforeFindCallTimes) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Product8 struct { | ||||
| 	gorm.Model | ||||
| 	Code  string | ||||
| 	Price float64 | ||||
| } | ||||
| 
 | ||||
| func (s *Product8) BeforeFind(tx *gorm.DB) error { | ||||
| 	tx.Statement.Where("price > ?", 50) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Fails for postgres
 | ||||
| // ERROR: invalid input syntax for type bigint: "t" (SQLSTATE 22P02)
 | ||||
| func TestBeforeFindModifiesQuery(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&Product8{}) | ||||
| 	DB.AutoMigrate(&Product8{}) | ||||
| 
 | ||||
| 	products := []Product8{ | ||||
| 		{Code: "A", Price: 100}, | ||||
| 		{Code: "B", Price: 30}, | ||||
| 	} | ||||
| 	DB.Create(&products) | ||||
| 
 | ||||
| 	var results []Product8 | ||||
| 
 | ||||
| 	// Without condition, hooks will be skipped
 | ||||
| 	DB.Find(&results, true) | ||||
| 
 | ||||
| 	if len(results) != 1 || results[0].Code != "A" { | ||||
| 		t.Errorf("BeforeFind should filter results, got %v", results) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Product9 struct { | ||||
| 	gorm.Model | ||||
| 	Code  string | ||||
| 	Price float64 | ||||
| } | ||||
| 
 | ||||
| func (s *Product9) BeforeFind(tx *gorm.DB) error { | ||||
| 	s.Price = 200 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestDatabaseOverwritesBeforeFindChanges(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&Product9{}) | ||||
| 	DB.AutoMigrate(&Product9{}) | ||||
| 
 | ||||
| 	p := Product9{Code: "price_overwrite", Price: 100} | ||||
| 	DB.Save(&p) | ||||
| 
 | ||||
| 	var result Product9 | ||||
| 	DB.First(&result, "code = ?", "price_overwrite") | ||||
| 
 | ||||
| 	if result.Price != 100 { | ||||
| 		t.Errorf("Price should be loaded from database, got %f", result.Price) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 MhmdGol
						MhmdGol