Add Update test
This commit is contained in:
		
							parent
							
								
									0c34123796
								
							
						
					
					
						commit
						cbd55dbcd5
					
				| @ -44,13 +44,14 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) ( | |||||||
| 	sort.Strings(keys) | 	sort.Strings(keys) | ||||||
| 
 | 
 | ||||||
| 	for _, k := range keys { | 	for _, k := range keys { | ||||||
|  | 		value := mapValue[k] | ||||||
| 		if field := stmt.Schema.LookUpField(k); field != nil { | 		if field := stmt.Schema.LookUpField(k); field != nil { | ||||||
| 			k = field.DBName | 			k = field.DBName | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | 		if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||||
| 			columns = append(columns, k) | 			columns = append(columns, k) | ||||||
| 			values.Values[0] = append(values.Values[0], mapValue[k]) | 			values.Values[0] = append(values.Values[0], value) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
|  | |||||||
| @ -2,8 +2,10 @@ package callbacks | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"sort" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/clause" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func BeforeUpdate(db *gorm.DB) { | func BeforeUpdate(db *gorm.DB) { | ||||||
| @ -40,6 +42,17 @@ func BeforeUpdate(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Update(db *gorm.DB) { | func Update(db *gorm.DB) { | ||||||
|  | 	db.Statement.AddClauseIfNotExists(clause.Update{}) | ||||||
|  | 	db.Statement.AddClause(ConvertToAssignments(db.Statement)) | ||||||
|  | 	db.Statement.Build("UPDATE", "SET", "WHERE") | ||||||
|  | 
 | ||||||
|  | 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 
 | ||||||
|  | 	if err == nil { | ||||||
|  | 		db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 	} else { | ||||||
|  | 		db.AddError(err) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AfterUpdate(db *gorm.DB) { | func AfterUpdate(db *gorm.DB) { | ||||||
| @ -74,3 +87,48 @@ func AfterUpdate(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ConvertToAssignments convert to update assignments
 | ||||||
|  | func ConvertToAssignments(stmt *gorm.Statement) clause.Set { | ||||||
|  | 	selectColumns, restricted := SelectAndOmitColumns(stmt) | ||||||
|  | 	reflectModelValue := reflect.ValueOf(stmt.Model) | ||||||
|  | 
 | ||||||
|  | 	switch value := stmt.Dest.(type) { | ||||||
|  | 	case map[string]interface{}: | ||||||
|  | 		var set clause.Set = make([]clause.Assignment, 0, len(value)) | ||||||
|  | 
 | ||||||
|  | 		var keys []string | ||||||
|  | 		for k, _ := range value { | ||||||
|  | 			keys = append(keys, k) | ||||||
|  | 		} | ||||||
|  | 		sort.Strings(keys) | ||||||
|  | 
 | ||||||
|  | 		for _, k := range keys { | ||||||
|  | 			if field := stmt.Schema.LookUpField(k); field != nil { | ||||||
|  | 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
|  | 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) | ||||||
|  | 					field.Set(reflectModelValue, value[k]) | ||||||
|  | 				} | ||||||
|  | 			} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||||
|  | 				set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return set | ||||||
|  | 	default: | ||||||
|  | 		switch stmt.ReflectValue.Kind() { | ||||||
|  | 		case reflect.Struct: | ||||||
|  | 			var set clause.Set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) | ||||||
|  | 			for _, field := range stmt.Schema.FieldsByDBName { | ||||||
|  | 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
|  | 					value, _ := field.ValueOf(stmt.ReflectValue) | ||||||
|  | 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) | ||||||
|  | 					field.Set(reflectModelValue, value) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return set | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return clause.Set{} | ||||||
|  | } | ||||||
|  | |||||||
| @ -18,11 +18,11 @@ func (limit Limit) Build(builder Builder) { | |||||||
| 	if limit.Limit > 0 { | 	if limit.Limit > 0 { | ||||||
| 		builder.Write("LIMIT ") | 		builder.Write("LIMIT ") | ||||||
| 		builder.Write(strconv.Itoa(limit.Limit)) | 		builder.Write(strconv.Itoa(limit.Limit)) | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if limit.Offset > 0 { | 		if limit.Offset > 0 { | ||||||
| 		builder.Write(" OFFSET ") | 			builder.Write(" OFFSET ") | ||||||
| 		builder.Write(strconv.Itoa(limit.Offset)) | 			builder.Write(strconv.Itoa(limit.Offset)) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -22,11 +22,13 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // First find first record that match given conditions, order by primary key
 | // First find first record that match given conditions, order by primary key
 | ||||||
| func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	// TODO handle where
 |  | ||||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||||
| 	}) | 	}) | ||||||
|  | 	if len(conds) > 0 { | ||||||
|  | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
|  | 	} | ||||||
| 	tx.Statement.RaiseErrorOnNotFound = true | 	tx.Statement.RaiseErrorOnNotFound = true | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| @ -34,8 +36,11 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Take return a record that match given conditions, the order will depend on the database implementation
 | // Take return a record that match given conditions, the order will depend on the database implementation
 | ||||||
| func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { | func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance().Limit(1) | 	tx = db.getInstance().Limit(1) | ||||||
|  | 	if len(conds) > 0 { | ||||||
|  | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
|  | 	} | ||||||
| 	tx.Statement.RaiseErrorOnNotFound = true | 	tx.Statement.RaiseErrorOnNotFound = true | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| @ -43,11 +48,14 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Last find last record that match given conditions, order by primary key
 | // Last find last record that match given conditions, order by primary key
 | ||||||
| func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { | func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||||
| 		Desc:   true, | 		Desc:   true, | ||||||
| 	}) | 	}) | ||||||
|  | 	if len(conds) > 0 { | ||||||
|  | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
|  | 	} | ||||||
| 	tx.Statement.RaiseErrorOnNotFound = true | 	tx.Statement.RaiseErrorOnNotFound = true | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| @ -55,8 +63,11 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Find find records that match given conditions
 | // Find find records that match given conditions
 | ||||||
| func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { | func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	if len(conds) > 0 { | ||||||
|  | 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||||
|  | 	} | ||||||
| 	tx.Statement.Dest = out | 	tx.Statement.Dest = out | ||||||
| 	tx.callbacks.Query().Execute(tx) | 	tx.callbacks.Query().Execute(tx) | ||||||
| 	return | 	return | ||||||
| @ -75,22 +86,30 @@ func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { | |||||||
| // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||||
| func (db *DB) Update(column string, value interface{}) (tx *DB) { | func (db *DB) Update(column string, value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	tx.Statement.Dest = map[string]interface{}{column: value} | ||||||
|  | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||||
| func (db *DB) Updates(values interface{}) (tx *DB) { | func (db *DB) Updates(values interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	tx.Statement.Dest = values | ||||||
|  | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { | func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	tx.Statement.Dest = map[string]interface{}{column: value} | ||||||
|  | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) UpdateColumns(values interface{}) (tx *DB) { | func (db *DB) UpdateColumns(values interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	tx.Statement.Dest = values | ||||||
|  | 	tx.callbacks.Update().Execute(tx) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ func Now() *time.Time { | |||||||
| func RunTestsSuit(t *testing.T, db *gorm.DB) { | func RunTestsSuit(t *testing.T, db *gorm.DB) { | ||||||
| 	TestCreate(t, db) | 	TestCreate(t, db) | ||||||
| 	TestFind(t, db) | 	TestFind(t, db) | ||||||
|  | 	TestUpdate(t, db) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestCreate(t *testing.T, db *gorm.DB) { | func TestCreate(t *testing.T, db *gorm.DB) { | ||||||
| @ -133,3 +134,62 @@ func TestFind(t *testing.T, db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestUpdate(t *testing.T, db *gorm.DB) { | ||||||
|  | 	db.Migrator().DropTable(&User{}) | ||||||
|  | 	db.AutoMigrate(&User{}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Update", func(t *testing.T) { | ||||||
|  | 		var user = User{ | ||||||
|  | 			Name:     "create", | ||||||
|  | 			Age:      18, | ||||||
|  | 			Birthday: Now(), | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := db.Create(&user).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when create: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := db.Model(&user).Update("Age", 10).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when update: %v", err) | ||||||
|  | 		} else if user.Age != 10 { | ||||||
|  | 			t.Errorf("Age should equals to 10, but got %v", user.Age) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var result User | ||||||
|  | 		if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when query: %v", err) | ||||||
|  | 		} else { | ||||||
|  | 			AssertObjEqual(t, result, user, "Name", "Age", "Birthday") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		values := map[string]interface{}{"Active": true, "age": 5} | ||||||
|  | 		if err := db.Model(&user).Updates(values).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when update: %v", err) | ||||||
|  | 		} else if user.Age != 5 { | ||||||
|  | 			t.Errorf("Age should equals to 5, but got %v", user.Age) | ||||||
|  | 		} else if user.Active != true { | ||||||
|  | 			t.Errorf("Active should be true, but got %v", user.Active) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var result2 User | ||||||
|  | 		if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when query: %v", err) | ||||||
|  | 		} else { | ||||||
|  | 			AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := db.Model(&user).Updates(User{Age: 2}).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when update: %v", err) | ||||||
|  | 		} else if user.Age != 2 { | ||||||
|  | 			t.Errorf("Age should equals to 2, but got %v", user.Age) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var result3 User | ||||||
|  | 		if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when query: %v", err) | ||||||
|  | 		} else { | ||||||
|  | 			AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu