Yay, callbacks works
This commit is contained in:
		
							parent
							
								
									b551fee276
								
							
						
					
					
						commit
						c5b0908b22
					
				
							
								
								
									
										10
									
								
								model.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								model.go
									
									
									
									
									
								
							| @ -122,12 +122,14 @@ func (m *Model) TableName() string { | ||||
| 	return reg.ReplaceAllString(toSnake(t.Name()), "s") | ||||
| } | ||||
| 
 | ||||
| func (model *Model) callMethod(method string) error { | ||||
| 	fm := reflect.ValueOf(model).MethodByName(method) | ||||
| func (m *Model) callMethod(method string) error { | ||||
| 	fm := reflect.ValueOf(m.Data).MethodByName(method) | ||||
| 	if fm.IsValid() { | ||||
| 		v := fm.Call([]reflect.Value{}) | ||||
| 		if verr, ok := v[0].Interface().(error); ok { | ||||
| 			return verr | ||||
| 		if len(v) > 0 { | ||||
| 			if verr, ok := v[0].Interface().(error); ok { | ||||
| 				return verr | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
|  | ||||
							
								
								
									
										4
									
								
								orm.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								orm.go
									
									
									
									
									
								
							| @ -110,9 +110,9 @@ func (s *Orm) Select(value interface{}) *Orm { | ||||
| func (s *Orm) Save(value interface{}) *Orm { | ||||
| 	s.Model(value) | ||||
| 	if s.model.PrimaryKeyIsEmpty() { | ||||
| 		s.explain(value, "Create").create(value) | ||||
| 		s.create(value) | ||||
| 	} else { | ||||
| 		s.explain(value, "Update").update(value) | ||||
| 		s.update(value) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
							
								
								
									
										69
									
								
								orm_test.go
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								orm_test.go
									
									
									
									
									
								
							| @ -1,6 +1,7 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| @ -14,6 +15,20 @@ type User struct { | ||||
| 	UpdatedAt time.Time | ||||
| } | ||||
| 
 | ||||
| type Product struct { | ||||
| 	Id                    int64 | ||||
| 	Code                  string | ||||
| 	Price                 int64 | ||||
| 	CreatedAt             time.Time | ||||
| 	UpdatedAt             time.Time | ||||
| 	BeforeCreateCallTimes int64 | ||||
| 	AfterCreateCallTimes  int64 | ||||
| 	BeforeUpdateCallTimes int64 | ||||
| 	AfterUpdateCallTimes  int64 | ||||
| 	BeforeSaveCallTimes   int64 | ||||
| 	AfterSaveCallTimes    int64 | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	db                 DB | ||||
| 	t1, t2, t3, t4, t5 time.Time | ||||
| @ -22,11 +37,13 @@ var ( | ||||
| func init() { | ||||
| 	db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable") | ||||
| 	db.Exec("drop table users;") | ||||
| 	db.Exec("drop table products;") | ||||
| 
 | ||||
| 	orm := db.CreateTable(&User{}) | ||||
| 	if orm.Error != nil { | ||||
| 		panic("No error should raise when create table") | ||||
| 	} | ||||
| 	db.CreateTable(&Product{}) | ||||
| 
 | ||||
| 	var shortForm = "2006-01-02 15:04:05" | ||||
| 	t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") | ||||
| @ -311,3 +328,55 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) { | ||||
| 		t.Errorf("Updated At should be changed after update") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeCreate() { | ||||
| 	s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeUpdate() { | ||||
| 	s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) BeforeSave() { | ||||
| 	s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterCreate() { | ||||
| 	s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterUpdate() { | ||||
| 	s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (s *Product) AfterSave() { | ||||
| 	s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 | ||||
| } | ||||
| 
 | ||||
| func (p *Product) GetCallTimes() []int64 { | ||||
| 	return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes} | ||||
| } | ||||
| 
 | ||||
| func TestRunCallbacks(t *testing.T) { | ||||
| 	p := Product{Code: "unique_code", Price: 100} | ||||
| 	db.Save(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) { | ||||
| 		t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Where("Code = ?", "unique_code").First(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) { | ||||
| 		t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	p.Price = 200 | ||||
| 	db.Save(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) { | ||||
| 		t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Where("Code = ?", "unique_code").First(&p) | ||||
| 	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) { | ||||
| 		t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										9
									
								
								sql.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								sql.go
									
									
									
									
									
								
							| @ -131,7 +131,7 @@ func (s *Orm) create(value interface{}) { | ||||
| 	var id int64 | ||||
| 	s.err(s.model.callMethod("BeforeCreate")) | ||||
| 	s.err(s.model.callMethod("BeforeSave")) | ||||
| 
 | ||||
| 	s.explain(value, "Create") | ||||
| 	if s.driver == "postgres" { | ||||
| 		s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) | ||||
| 	} else { | ||||
| @ -141,12 +141,11 @@ func (s *Orm) create(value interface{}) { | ||||
| 		id, err = s.SqlResult.LastInsertId() | ||||
| 		s.err(err) | ||||
| 	} | ||||
| 	result := reflect.ValueOf(s.model.Data).Elem() | ||||
| 	result.FieldByName(s.model.PrimaryKey()).SetInt(id) | ||||
| 
 | ||||
| 	s.err(s.model.callMethod("AfterCreate")) | ||||
| 	s.err(s.model.callMethod("AfterSave")) | ||||
| 
 | ||||
| 	result := reflect.ValueOf(s.model.Data).Elem() | ||||
| 	result.FieldByName(s.model.PrimaryKey()).SetInt(id) | ||||
| } | ||||
| 
 | ||||
| func (s *Orm) updateSql(value interface{}) { | ||||
| @ -169,7 +168,7 @@ func (s *Orm) updateSql(value interface{}) { | ||||
| func (s *Orm) update(value interface{}) { | ||||
| 	s.err(s.model.callMethod("BeforeUpdate")) | ||||
| 	s.err(s.model.callMethod("BeforeSave")) | ||||
| 	s.Exec() | ||||
| 	s.explain(value, "Update").Exec() | ||||
| 	s.err(s.model.callMethod("AfterUpdate")) | ||||
| 	s.err(s.model.callMethod("AfterSave")) | ||||
| 	return | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu