Add tests for callbacks
This commit is contained in:
		
							parent
							
								
									c5b0908b22
								
							
						
					
					
						commit
						2600e1099e
					
				| @ -3,7 +3,6 @@ | ||||
| Yet Another ORM library for Go, aims for developer friendly | ||||
| 
 | ||||
| ## TODO | ||||
| * After/Before Save/Update/Create/Delete | ||||
| * Soft Delete | ||||
| * Better First method (First(&user, primary_key, where conditions)) | ||||
| * Even more complex where query (with map or struct) | ||||
|  | ||||
							
								
								
									
										94
									
								
								orm_test.go
									
									
									
									
									
								
							
							
						
						
									
										94
									
								
								orm_test.go
									
									
									
									
									
								
							| @ -1,6 +1,7 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| @ -27,6 +28,8 @@ type Product struct { | ||||
| 	AfterUpdateCallTimes  int64 | ||||
| 	BeforeSaveCallTimes   int64 | ||||
| 	AfterSaveCallTimes    int64 | ||||
| 	BeforeDeleteCallTimes int64 | ||||
| 	AfterDeleteCallTimes  int64 | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| @ -329,16 +332,28 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeCreate() { | ||||
| func (s *Product) BeforeCreate() (err error) { | ||||
| 	if s.Code == "Invalid" { | ||||
| 		err = errors.New("invalid product") | ||||
| 	} | ||||
| 	s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeUpdate() { | ||||
| func (s *Product) BeforeUpdate() (err error) { | ||||
| 	if s.Code == "dont_update" { | ||||
| 		err = errors.New("Can't update") | ||||
| 	} | ||||
| 	s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeSave() { | ||||
| func (s *Product) BeforeSave() (err error) { | ||||
| 	if s.Code == "dont_save" { | ||||
| 		err = errors.New("Can't save") | ||||
| 	} | ||||
| 	s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterCreate() { | ||||
| @ -353,30 +368,93 @@ func (s *Product) AfterSave() { | ||||
| 	s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeDelete() (err error) { | ||||
| 	if s.Code == "dont_delete" { | ||||
| 		err = errors.New("Can't delete") | ||||
| 	} | ||||
| 	s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterDelete() { | ||||
| 	s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 | ||||
| } | ||||
| func (p *Product) GetCallTimes() []int64 { | ||||
| 	return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes} | ||||
| 	return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes} | ||||
| } | ||||
| 
 | ||||
| func TestRunCallbacks(t *testing.T) { | ||||
| 	p := Product{Code: "unique_code", Price: 100} | ||||
| 	db.Save(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) { | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0}) { | ||||
| 		t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Where("Code = ?", "unique_code").First(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) { | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0, 0, 0}) { | ||||
| 		t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	p.Price = 200 | ||||
| 	db.Save(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) { | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1, 0, 0}) { | ||||
| 		t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Where("Code = ?", "unique_code").First(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) { | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 0, 0}) { | ||||
| 		t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Delete(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 1, 1}) { | ||||
| 		t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	if db.Where("Code = ?", "unique_code").First(&p).Error == nil { | ||||
| 		t.Errorf("Should get error when find an deleted record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRunCallbacksAndGetErrors(t *testing.T) { | ||||
| 	p := Product{Code: "Invalid", Price: 100} | ||||
| 	if db.Save(&p).Error == nil { | ||||
| 		t.Errorf("An error from create callbacks expected when create") | ||||
| 	} | ||||
| 
 | ||||
| 	if db.Where("code = ?", "Invalid").First(&Product{}).Error == nil { | ||||
| 		t.Errorf("Should not save records that have errors") | ||||
| 	} | ||||
| 
 | ||||
| 	if db.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { | ||||
| 		t.Errorf("An error from create callbacks expected when create") | ||||
| 	} | ||||
| 
 | ||||
| 	p2 := Product{Code: "update_callback", Price: 100} | ||||
| 	db.Save(&p2) | ||||
| 	p2.Code = "dont_update" | ||||
| 	if db.Save(&p2).Error == nil { | ||||
| 		t.Errorf("An error from callbacks expected when update") | ||||
| 	} | ||||
| 	if db.Where("code = ?", "update_callback").First(&Product{}).Error != nil { | ||||
| 		t.Errorf("Record Should not be updated due to errors happened in callback") | ||||
| 	} | ||||
| 	if db.Where("code = ?", "dont_update").First(&Product{}).Error == nil { | ||||
| 		t.Errorf("Record Should not be updated due to errors happened in callback") | ||||
| 	} | ||||
| 
 | ||||
| 	p2.Code = "dont_save" | ||||
| 	if db.Save(&p2).Error == nil { | ||||
| 		t.Errorf("An error from before save callbacks expected when update") | ||||
| 	} | ||||
| 
 | ||||
| 	p3 := Product{Code: "dont_delete", Price: 100} | ||||
| 	db.Save(&p3) | ||||
| 	if db.Delete(&p3).Error == nil { | ||||
| 		t.Errorf("An error from before delete callbacks expected when delete") | ||||
| 	} | ||||
| 
 | ||||
| 	if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil { | ||||
| 		t.Errorf("Should not delete record due to errors happened in callback") | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										37
									
								
								sql.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								sql.go
									
									
									
									
									
								
							| @ -132,20 +132,23 @@ func (s *Orm) create(value interface{}) { | ||||
| 	s.err(s.model.callMethod("BeforeCreate")) | ||||
| 	s.err(s.model.callMethod("BeforeSave")) | ||||
| 	s.explain(value, "Create") | ||||
| 	if s.driver == "postgres" { | ||||
| 		s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) | ||||
| 	} else { | ||||
| 		var err error | ||||
| 		s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) | ||||
| 		s.err(err) | ||||
| 		id, err = s.SqlResult.LastInsertId() | ||||
| 		s.err(err) | ||||
| 	} | ||||
| 	result := reflect.ValueOf(s.model.Data).Elem() | ||||
| 	result.FieldByName(s.model.PrimaryKey()).SetInt(id) | ||||
| 
 | ||||
| 	s.err(s.model.callMethod("AfterCreate")) | ||||
| 	s.err(s.model.callMethod("AfterSave")) | ||||
| 	if len(s.Errors) == 0 { | ||||
| 		if s.driver == "postgres" { | ||||
| 			s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) | ||||
| 		} else { | ||||
| 			var err error | ||||
| 			s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) | ||||
| 			s.err(err) | ||||
| 			id, err = s.SqlResult.LastInsertId() | ||||
| 			s.err(err) | ||||
| 		} | ||||
| 		result := reflect.ValueOf(s.model.Data).Elem() | ||||
| 		result.FieldByName(s.model.PrimaryKey()).SetInt(id) | ||||
| 
 | ||||
| 		s.err(s.model.callMethod("AfterCreate")) | ||||
| 		s.err(s.model.callMethod("AfterSave")) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Orm) updateSql(value interface{}) { | ||||
| @ -168,7 +171,9 @@ func (s *Orm) updateSql(value interface{}) { | ||||
| func (s *Orm) update(value interface{}) { | ||||
| 	s.err(s.model.callMethod("BeforeUpdate")) | ||||
| 	s.err(s.model.callMethod("BeforeSave")) | ||||
| 	s.explain(value, "Update").Exec() | ||||
| 	if len(s.Errors) == 0 { | ||||
| 		s.explain(value, "Update").Exec() | ||||
| 	} | ||||
| 	s.err(s.model.callMethod("AfterUpdate")) | ||||
| 	s.err(s.model.callMethod("AfterSave")) | ||||
| 	return | ||||
| @ -181,7 +186,9 @@ func (s *Orm) deleteSql(value interface{}) { | ||||
| 
 | ||||
| func (s *Orm) delete(value interface{}) { | ||||
| 	s.err(s.model.callMethod("BeforeDelete")) | ||||
| 	s.Exec() | ||||
| 	if len(s.Errors) == 0 { | ||||
| 		s.Exec() | ||||
| 	} | ||||
| 	s.err(s.model.callMethod("AfterDelete")) | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu