Add Update test
This commit is contained in:
		
							parent
							
								
									0c34123796
								
							
						
					
					
						commit
						cbd55dbcd5
					
				| @ -44,13 +44,14 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) ( | ||||
| 	sort.Strings(keys) | ||||
| 
 | ||||
| 	for _, k := range keys { | ||||
| 		value := mapValue[k] | ||||
| 		if field := stmt.Schema.LookUpField(k); field != nil { | ||||
| 			k = field.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||
| 			columns = append(columns, k) | ||||
| 			values.Values[0] = append(values.Values[0], mapValue[k]) | ||||
| 			values.Values[0] = append(values.Values[0], value) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
|  | ||||
| @ -2,8 +2,10 @@ package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| ) | ||||
| 
 | ||||
| func BeforeUpdate(db *gorm.DB) { | ||||
| @ -40,6 +42,17 @@ func BeforeUpdate(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Update(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Update{}) | ||||
| 	db.Statement.AddClause(ConvertToAssignments(db.Statement)) | ||||
| 	db.Statement.Build("UPDATE", "SET", "WHERE") | ||||
| 
 | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 	} else { | ||||
| 		db.AddError(err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func AfterUpdate(db *gorm.DB) { | ||||
| @ -74,3 +87,48 @@ func AfterUpdate(db *gorm.DB) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // ConvertToAssignments convert to update assignments
 | ||||
| func ConvertToAssignments(stmt *gorm.Statement) clause.Set { | ||||
| 	selectColumns, restricted := SelectAndOmitColumns(stmt) | ||||
| 	reflectModelValue := reflect.ValueOf(stmt.Model) | ||||
| 
 | ||||
| 	switch value := stmt.Dest.(type) { | ||||
| 	case map[string]interface{}: | ||||
| 		var set clause.Set = make([]clause.Assignment, 0, len(value)) | ||||
| 
 | ||||
| 		var keys []string | ||||
| 		for k, _ := range value { | ||||
| 			keys = append(keys, k) | ||||
| 		} | ||||
| 		sort.Strings(keys) | ||||
| 
 | ||||
| 		for _, k := range keys { | ||||
| 			if field := stmt.Schema.LookUpField(k); field != nil { | ||||
| 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) | ||||
| 					field.Set(reflectModelValue, value[k]) | ||||
| 				} | ||||
| 			} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||
| 				set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		return set | ||||
| 	default: | ||||
| 		switch stmt.ReflectValue.Kind() { | ||||
| 		case reflect.Struct: | ||||
| 			var set clause.Set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) | ||||
| 			for _, field := range stmt.Schema.FieldsByDBName { | ||||
| 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 					value, _ := field.ValueOf(stmt.ReflectValue) | ||||
| 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) | ||||
| 					field.Set(reflectModelValue, value) | ||||
| 				} | ||||
| 			} | ||||
| 			return set | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return clause.Set{} | ||||
| } | ||||
|  | ||||
| @ -18,11 +18,11 @@ func (limit Limit) Build(builder Builder) { | ||||
| 	if limit.Limit > 0 { | ||||
| 		builder.Write("LIMIT ") | ||||
| 		builder.Write(strconv.Itoa(limit.Limit)) | ||||
| 	} | ||||
| 
 | ||||
| 	if limit.Offset > 0 { | ||||
| 		builder.Write(" OFFSET ") | ||||
| 		builder.Write(strconv.Itoa(limit.Offset)) | ||||
| 		if limit.Offset > 0 { | ||||
| 			builder.Write(" OFFSET ") | ||||
| 			builder.Write(strconv.Itoa(limit.Offset)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -22,11 +22,13 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| // First find first record that match given conditions, order by primary key
 | ||||
| func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	// TODO handle where
 | ||||
| func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 	}) | ||||
| 	if len(conds) > 0 { | ||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||
| 	} | ||||
| 	tx.Statement.RaiseErrorOnNotFound = true | ||||
| 	tx.Statement.Dest = out | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| @ -34,8 +36,11 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| // Take return a record that match given conditions, the order will depend on the database implementation
 | ||||
| func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { | ||||
| func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance().Limit(1) | ||||
| 	if len(conds) > 0 { | ||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||
| 	} | ||||
| 	tx.Statement.RaiseErrorOnNotFound = true | ||||
| 	tx.Statement.Dest = out | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| @ -43,11 +48,14 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| // Last find last record that match given conditions, order by primary key
 | ||||
| func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { | ||||
| func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Desc:   true, | ||||
| 	}) | ||||
| 	if len(conds) > 0 { | ||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||
| 	} | ||||
| 	tx.Statement.RaiseErrorOnNotFound = true | ||||
| 	tx.Statement.Dest = out | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| @ -55,8 +63,11 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { | ||||
| } | ||||
| 
 | ||||
| // Find find records that match given conditions
 | ||||
| func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { | ||||
| func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	if len(conds) > 0 { | ||||
| 		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) | ||||
| 	} | ||||
| 	tx.Statement.Dest = out | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| 	return | ||||
| @ -75,22 +86,30 @@ func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { | ||||
| // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||
| func (db *DB) Update(column string, value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Dest = map[string]interface{}{column: value} | ||||
| 	tx.callbacks.Update().Execute(tx) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
 | ||||
| func (db *DB) Updates(values interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Dest = values | ||||
| 	tx.callbacks.Update().Execute(tx) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Dest = map[string]interface{}{column: value} | ||||
| 	tx.callbacks.Update().Execute(tx) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UpdateColumns(values interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Dest = values | ||||
| 	tx.callbacks.Update().Execute(tx) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -18,6 +18,7 @@ func Now() *time.Time { | ||||
| func RunTestsSuit(t *testing.T, db *gorm.DB) { | ||||
| 	TestCreate(t, db) | ||||
| 	TestFind(t, db) | ||||
| 	TestUpdate(t, db) | ||||
| } | ||||
| 
 | ||||
| func TestCreate(t *testing.T, db *gorm.DB) { | ||||
| @ -133,3 +134,62 @@ func TestFind(t *testing.T, db *gorm.DB) { | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func TestUpdate(t *testing.T, db *gorm.DB) { | ||||
| 	db.Migrator().DropTable(&User{}) | ||||
| 	db.AutoMigrate(&User{}) | ||||
| 
 | ||||
| 	t.Run("Update", func(t *testing.T) { | ||||
| 		var user = User{ | ||||
| 			Name:     "create", | ||||
| 			Age:      18, | ||||
| 			Birthday: Now(), | ||||
| 		} | ||||
| 
 | ||||
| 		if err := db.Create(&user).Error; err != nil { | ||||
| 			t.Errorf("errors happened when create: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if err := db.Model(&user).Update("Age", 10).Error; err != nil { | ||||
| 			t.Errorf("errors happened when update: %v", err) | ||||
| 		} else if user.Age != 10 { | ||||
| 			t.Errorf("Age should equals to 10, but got %v", user.Age) | ||||
| 		} | ||||
| 
 | ||||
| 		var result User | ||||
| 		if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { | ||||
| 			t.Errorf("errors happened when query: %v", err) | ||||
| 		} else { | ||||
| 			AssertObjEqual(t, result, user, "Name", "Age", "Birthday") | ||||
| 		} | ||||
| 
 | ||||
| 		values := map[string]interface{}{"Active": true, "age": 5} | ||||
| 		if err := db.Model(&user).Updates(values).Error; err != nil { | ||||
| 			t.Errorf("errors happened when update: %v", err) | ||||
| 		} else if user.Age != 5 { | ||||
| 			t.Errorf("Age should equals to 5, but got %v", user.Age) | ||||
| 		} else if user.Active != true { | ||||
| 			t.Errorf("Active should be true, but got %v", user.Active) | ||||
| 		} | ||||
| 
 | ||||
| 		var result2 User | ||||
| 		if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { | ||||
| 			t.Errorf("errors happened when query: %v", err) | ||||
| 		} else { | ||||
| 			AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := db.Model(&user).Updates(User{Age: 2}).Error; err != nil { | ||||
| 			t.Errorf("errors happened when update: %v", err) | ||||
| 		} else if user.Age != 2 { | ||||
| 			t.Errorf("Age should equals to 2, but got %v", user.Age) | ||||
| 		} | ||||
| 
 | ||||
| 		var result3 User | ||||
| 		if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { | ||||
| 			t.Errorf("errors happened when query: %v", err) | ||||
| 		} else { | ||||
| 			AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu