Refactor Scope updatedAttrsWithValues
This commit is contained in:
		
							parent
							
								
									a0aa21aec5
								
							
						
					
					
						commit
						8de97c2883
					
				@ -21,12 +21,10 @@ func init() {
 | 
				
			|||||||
// assignUpdatingAttributesCallback assign updating attributes to model
 | 
					// assignUpdatingAttributesCallback assign updating attributes to model
 | 
				
			||||||
func assignUpdatingAttributesCallback(scope *Scope) {
 | 
					func assignUpdatingAttributesCallback(scope *Scope) {
 | 
				
			||||||
	if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
 | 
						if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
 | 
				
			||||||
		if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
 | 
							if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
 | 
				
			||||||
			if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
 | 
								scope.InstanceSet("gorm:update_attrs", updateMaps)
 | 
				
			||||||
				scope.InstanceSet("gorm:update_attrs", updateMaps)
 | 
							} else {
 | 
				
			||||||
			} else {
 | 
								scope.SkipLeft()
 | 
				
			||||||
				scope.SkipLeft()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								main.go
									
									
									
									
									
								
							@ -310,7 +310,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		c.NewScope(out).inlineCondition(where...).initialize()
 | 
							c.NewScope(out).inlineCondition(where...).initialize()
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
 | 
							c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return c
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										62
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								scope.go
									
									
									
									
									
								
							@ -793,27 +793,55 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
 | 
					func convertInterfaceToMap(values interface{}) map[string]interface{} {
 | 
				
			||||||
 | 
						var attrs = map[string]interface{}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch value := values.(type) {
 | 
				
			||||||
 | 
						case map[string]interface{}:
 | 
				
			||||||
 | 
							return value
 | 
				
			||||||
 | 
						case []interface{}:
 | 
				
			||||||
 | 
							for _, v := range value {
 | 
				
			||||||
 | 
								for key, value := range convertInterfaceToMap(v) {
 | 
				
			||||||
 | 
									attrs[key] = value
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case interface{}:
 | 
				
			||||||
 | 
							reflectValue := reflect.ValueOf(values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							switch reflectValue.Kind() {
 | 
				
			||||||
 | 
							case reflect.Map:
 | 
				
			||||||
 | 
								for _, key := range reflectValue.MapKeys() {
 | 
				
			||||||
 | 
									attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								for _, field := range (&Scope{Value: values}).Fields() {
 | 
				
			||||||
 | 
									if !field.IsBlank {
 | 
				
			||||||
 | 
										attrs[field.DBName] = field.Field.Interface()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return attrs
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
 | 
				
			||||||
	if scope.IndirectValue().Kind() != reflect.Struct {
 | 
						if scope.IndirectValue().Kind() != reflect.Struct {
 | 
				
			||||||
		return values, true
 | 
							return convertInterfaceToMap(value), true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	results = map[string]interface{}{}
 | 
						results = map[string]interface{}{}
 | 
				
			||||||
	for key, value := range values {
 | 
					
 | 
				
			||||||
 | 
						for key, value := range convertInterfaceToMap(value) {
 | 
				
			||||||
		if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
 | 
							if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
 | 
				
			||||||
			if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
 | 
								if _, ok := value.(*expr); ok {
 | 
				
			||||||
				if _, ok := value.(*expr); ok {
 | 
									hasUpdate = true
 | 
				
			||||||
					hasUpdate = true
 | 
									results[field.DBName] = value
 | 
				
			||||||
					results[field.DBName] = value
 | 
					 | 
				
			||||||
				} else if !equalAsString(field.Field.Interface(), value) {
 | 
					 | 
				
			||||||
					field.Set(value)
 | 
					 | 
				
			||||||
					if field.IsNormal {
 | 
					 | 
				
			||||||
						hasUpdate = true
 | 
					 | 
				
			||||||
						results[field.DBName] = field.Field.Interface()
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				field.Set(value)
 | 
									field.Set(value)
 | 
				
			||||||
 | 
									if field.IsNormal {
 | 
				
			||||||
 | 
										hasUpdate = true
 | 
				
			||||||
 | 
										results[field.DBName] = field.Field.Interface()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -836,10 +864,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) initialize() *Scope {
 | 
					func (scope *Scope) initialize() *Scope {
 | 
				
			||||||
	for _, clause := range scope.Search.whereConditions {
 | 
						for _, clause := range scope.Search.whereConditions {
 | 
				
			||||||
		scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
 | 
							scope.updatedAttrsWithValues(clause["query"])
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
 | 
						scope.updatedAttrsWithValues(scope.Search.initAttrs)
 | 
				
			||||||
	scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
 | 
						scope.updatedAttrsWithValues(scope.Search.assignAttrs)
 | 
				
			||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,13 +20,6 @@ func TestUpdate(t *testing.T) {
 | 
				
			|||||||
	DB.First(&product1, product1.Id)
 | 
						DB.First(&product1, product1.Id)
 | 
				
			||||||
	DB.First(&product2, product2.Id)
 | 
						DB.First(&product2, product2.Id)
 | 
				
			||||||
	updatedAt1 := product1.UpdatedAt
 | 
						updatedAt1 := product1.UpdatedAt
 | 
				
			||||||
	updatedAt2 := product2.UpdatedAt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var product3 Product
 | 
					 | 
				
			||||||
	DB.First(&product3, product2.Id).Update("code", "product2newcode")
 | 
					 | 
				
			||||||
	if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
 | 
					 | 
				
			||||||
		t.Errorf("updatedAt should not be updated if nothing changed")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
 | 
						if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
 | 
				
			||||||
		t.Errorf("Product1 should not be updated")
 | 
							t.Errorf("Product1 should not be updated")
 | 
				
			||||||
@ -135,19 +128,8 @@ func TestUpdates(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	DB.First(&product1, product1.Id)
 | 
						DB.First(&product1, product1.Id)
 | 
				
			||||||
	DB.First(&product2, product2.Id)
 | 
						DB.First(&product2, product2.Id)
 | 
				
			||||||
	updatedAt1 := product1.UpdatedAt
 | 
					 | 
				
			||||||
	updatedAt2 := product2.UpdatedAt
 | 
						updatedAt2 := product2.UpdatedAt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var product3 Product
 | 
					 | 
				
			||||||
	DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
 | 
					 | 
				
			||||||
	if product3.Code != "product1newcode" || product3.Price != 100 {
 | 
					 | 
				
			||||||
		t.Errorf("Record should be updated with struct")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
 | 
					 | 
				
			||||||
		t.Errorf("updatedAt should not be updated if nothing changed")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
 | 
						if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
 | 
				
			||||||
		t.Errorf("Product2 should not be updated")
 | 
							t.Errorf("Product2 should not be updated")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										31
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								utils.go
									
									
									
									
									
								
							@ -199,37 +199,6 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
 | 
					 | 
				
			||||||
	attrs := map[string]interface{}{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	switch value := values.(type) {
 | 
					 | 
				
			||||||
	case map[string]interface{}:
 | 
					 | 
				
			||||||
		return value
 | 
					 | 
				
			||||||
	case []interface{}:
 | 
					 | 
				
			||||||
		for _, v := range value {
 | 
					 | 
				
			||||||
			for key, value := range convertInterfaceToMap(v) {
 | 
					 | 
				
			||||||
				attrs[key] = value
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case interface{}:
 | 
					 | 
				
			||||||
		reflectValue := reflect.ValueOf(values)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		switch reflectValue.Kind() {
 | 
					 | 
				
			||||||
		case reflect.Map:
 | 
					 | 
				
			||||||
			for _, key := range reflectValue.MapKeys() {
 | 
					 | 
				
			||||||
				attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		default:
 | 
					 | 
				
			||||||
			for _, field := range (&Scope{Value: values}).Fields() {
 | 
					 | 
				
			||||||
				if !field.IsBlank {
 | 
					 | 
				
			||||||
					attrs[field.DBName] = field.Field.Interface()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return attrs
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func equalAsString(a interface{}, b interface{}) bool {
 | 
					func equalAsString(a interface{}, b interface{}) bool {
 | 
				
			||||||
	return toString(a) == toString(b)
 | 
						return toString(a) == toString(b)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user