Add SetColumn, Changed method
This commit is contained in:
		
							parent
							
								
									e308b103c0
								
							
						
					
					
						commit
						66dcd7e3ca
					
				| @ -11,7 +11,7 @@ import ( | ||||
| 
 | ||||
| func SaveBeforeAssociations(db *gorm.DB) { | ||||
| 	if db.Error == nil && db.Statement.Schema != nil { | ||||
| 		selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) | ||||
| 		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) | ||||
| 
 | ||||
| 		// Save Belongs To associations
 | ||||
| 		for _, rel := range db.Statement.Schema.Relationships.BelongsTo { | ||||
| @ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | ||||
| 
 | ||||
| func SaveAfterAssociations(db *gorm.DB) { | ||||
| 	if db.Error == nil && db.Statement.Schema != nil { | ||||
| 		selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) | ||||
| 		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) | ||||
| 
 | ||||
| 		// Save Has One associations
 | ||||
| 		for _, rel := range db.Statement.Schema.Relationships.HasOne { | ||||
|  | ||||
| @ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||
| 		values = ConvertSliceOfMapToValuesForCreate(stmt, value) | ||||
| 	default: | ||||
| 		var ( | ||||
| 			selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) | ||||
| 			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | ||||
| 			curTime                   = stmt.DB.NowFunc() | ||||
| 			isZero                    bool | ||||
| 		) | ||||
|  | ||||
| @ -7,64 +7,10 @@ import ( | ||||
| 	"gorm.io/gorm/clause" | ||||
| ) | ||||
| 
 | ||||
| // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | ||||
| func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { | ||||
| 	results := map[string]bool{} | ||||
| 	notRestricted := false | ||||
| 
 | ||||
| 	// select columns
 | ||||
| 	for _, column := range stmt.Selects { | ||||
| 		if column == "*" { | ||||
| 			notRestricted = true | ||||
| 			for _, dbName := range stmt.Schema.DBNames { | ||||
| 				results[dbName] = true | ||||
| 			} | ||||
| 		} else if column == clause.Associations { | ||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||
| 				results[rel.Name] = true | ||||
| 			} | ||||
| 		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { | ||||
| 			results[field.DBName] = true | ||||
| 		} else { | ||||
| 			results[column] = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// omit columns
 | ||||
| 	for _, omit := range stmt.Omits { | ||||
| 		if omit == clause.Associations { | ||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||
| 				results[rel.Name] = false | ||||
| 			} | ||||
| 		} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { | ||||
| 			results[field.DBName] = false | ||||
| 		} else { | ||||
| 			results[omit] = false | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt.Schema != nil { | ||||
| 		for _, field := range stmt.Schema.Fields { | ||||
| 			name := field.DBName | ||||
| 			if name == "" { | ||||
| 				name = field.Name | ||||
| 			} | ||||
| 
 | ||||
| 			if requireCreate && !field.Creatable { | ||||
| 				results[name] = false | ||||
| 			} else if requireUpdate && !field.Updatable { | ||||
| 				results[name] = false | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return results, !notRestricted && len(stmt.Selects) > 0 | ||||
| } | ||||
| 
 | ||||
| // ConvertMapToValuesForCreate convert map to values
 | ||||
| func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { | ||||
| 	columns := make([]string, 0, len(mapValue)) | ||||
| 	selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) | ||||
| 	selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) | ||||
| 
 | ||||
| 	var keys []string | ||||
| 	for k := range mapValue { | ||||
| @ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st | ||||
| 	var ( | ||||
| 		columns                   = []string{} | ||||
| 		result                    = map[string][]interface{}{} | ||||
| 		selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) | ||||
| 		selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | ||||
| 	) | ||||
| 
 | ||||
| 	for idx, mapValue := range mapValues { | ||||
|  | ||||
| @ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) { | ||||
| // ConvertToAssignments convert to update assignments
 | ||||
| func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 	var ( | ||||
| 		selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) | ||||
| 		selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) | ||||
| 		assignValue               func(field *schema.Field, value interface{}) | ||||
| 	) | ||||
| 
 | ||||
|  | ||||
| @ -29,4 +29,6 @@ var ( | ||||
| 	ErrUnsupportedDriver = errors.New("unsupported driver") | ||||
| 	// ErrRegistered registered
 | ||||
| 	ErrRegistered = errors.New("registered") | ||||
| 	// ErrInvalidField invalid field
 | ||||
| 	ErrInvalidField = errors.New("invalid field") | ||||
| ) | ||||
|  | ||||
							
								
								
									
										117
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								statement.go
									
									
									
									
									
								
							| @ -12,6 +12,7 @@ import ( | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| // Statement statement
 | ||||
| @ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement { | ||||
| 
 | ||||
| 	return newStmt | ||||
| } | ||||
| 
 | ||||
| // Helpers
 | ||||
| // SetColumn set column's value
 | ||||
| func (stmt *Statement) SetColumn(name string, value interface{}) { | ||||
| 	if v, ok := stmt.Dest.(map[string]interface{}); ok { | ||||
| 		v[name] = value | ||||
| 	} else if stmt.Schema != nil { | ||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 			field.Set(stmt.ReflectValue, value) | ||||
| 		} else { | ||||
| 			stmt.AddError(ErrInvalidField) | ||||
| 		} | ||||
| 	} else { | ||||
| 		stmt.AddError(ErrInvalidField) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Changed check model changed or not when updating
 | ||||
| func (stmt *Statement) Changed(fields ...string) bool { | ||||
| 	modelValue := reflect.ValueOf(stmt.Model) | ||||
| 	for modelValue.Kind() == reflect.Ptr { | ||||
| 		modelValue = modelValue.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) | ||||
| 	changed := func(field *schema.Field) bool { | ||||
| 		fieldValue, isZero := field.ValueOf(modelValue) | ||||
| 		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 			if v, ok := stmt.Dest.(map[string]interface{}); ok { | ||||
| 				if fv, ok := v[field.Name]; ok { | ||||
| 					return !utils.AssertEqual(fv, fieldValue) | ||||
| 				} else if fv, ok := v[field.DBName]; ok { | ||||
| 					return !utils.AssertEqual(fv, fieldValue) | ||||
| 				} else if isZero { | ||||
| 					return true | ||||
| 				} | ||||
| 			} else { | ||||
| 				changedValue, _ := field.ValueOf(stmt.ReflectValue) | ||||
| 				return !utils.AssertEqual(changedValue, fieldValue) | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if len(fields) == 0 { | ||||
| 		for _, field := range stmt.Schema.FieldsByDBName { | ||||
| 			if changed(field) { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		for _, name := range fields { | ||||
| 			if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 				if changed(field) { | ||||
| 					return true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | ||||
| func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { | ||||
| 	results := map[string]bool{} | ||||
| 	notRestricted := false | ||||
| 
 | ||||
| 	// select columns
 | ||||
| 	for _, column := range stmt.Selects { | ||||
| 		if column == "*" { | ||||
| 			notRestricted = true | ||||
| 			for _, dbName := range stmt.Schema.DBNames { | ||||
| 				results[dbName] = true | ||||
| 			} | ||||
| 		} else if column == clause.Associations { | ||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||
| 				results[rel.Name] = true | ||||
| 			} | ||||
| 		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { | ||||
| 			results[field.DBName] = true | ||||
| 		} else { | ||||
| 			results[column] = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// omit columns
 | ||||
| 	for _, omit := range stmt.Omits { | ||||
| 		if omit == clause.Associations { | ||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||
| 				results[rel.Name] = false | ||||
| 			} | ||||
| 		} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { | ||||
| 			results[field.DBName] = false | ||||
| 		} else { | ||||
| 			results[omit] = false | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt.Schema != nil { | ||||
| 		for _, field := range stmt.Schema.Fields { | ||||
| 			name := field.DBName | ||||
| 			if name == "" { | ||||
| 				name = field.Name | ||||
| 			} | ||||
| 
 | ||||
| 			if requireCreate && !field.Creatable { | ||||
| 				results[name] = false | ||||
| 			} else if requireUpdate && !field.Updatable { | ||||
| 				results[name] = false | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return results, !notRestricted && len(stmt.Selects) > 0 | ||||
| } | ||||
|  | ||||
| @ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) { | ||||
| 		t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Product3 struct { | ||||
| 	gorm.Model | ||||
| 	Name  string | ||||
| 	Code  string | ||||
| 	Price int64 | ||||
| 	Owner string | ||||
| } | ||||
| 
 | ||||
| func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { | ||||
| 	tx.Statement.SetColumn("Price", s.Price+100) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { | ||||
| 	if tx.Statement.Changed() { | ||||
| 		tx.Statement.SetColumn("Price", s.Price+10) | ||||
| 	} | ||||
| 
 | ||||
| 	if tx.Statement.Changed("Code") { | ||||
| 		s.Price += 20 | ||||
| 		tx.Statement.SetColumn("Price", s.Price+30) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestSetColumn(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&Product3{}) | ||||
| 	DB.AutoMigrate(&Product3{}) | ||||
| 
 | ||||
| 	product := Product3{Name: "Product", Price: 0} | ||||
| 	DB.Create(&product) | ||||
| 
 | ||||
| 	if product.Price != 100 { | ||||
| 		t.Errorf("invalid price after create, got %+v", product) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) | ||||
| 
 | ||||
| 	if product.Price != 150 || product.Code != "L1212" { | ||||
| 		t.Errorf("invalid data after update, got %+v", product) | ||||
| 	} | ||||
| 
 | ||||
| 	// Code not changed, price should not change
 | ||||
| 	DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) | ||||
| 
 | ||||
| 	if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { | ||||
| 		t.Errorf("invalid data after update, got %+v", product) | ||||
| 	} | ||||
| 
 | ||||
| 	// Code changed, but not selected, price should not change
 | ||||
| 	DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) | ||||
| 
 | ||||
| 	if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { | ||||
| 		t.Errorf("invalid data after update, got %+v", product) | ||||
| 	} | ||||
| 
 | ||||
| 	// Code changed, price should changed
 | ||||
| 	DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) | ||||
| 
 | ||||
| 	if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { | ||||
| 		t.Errorf("invalid data after update, got %+v", product) | ||||
| 	} | ||||
| 
 | ||||
| 	var result Product3 | ||||
| 	DB.First(&result, product.ID) | ||||
| 
 | ||||
| 	AssertEqual(t, result, product) | ||||
| 
 | ||||
| 	// Code changed, price not selected, price should not change
 | ||||
| 	DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) | ||||
| 
 | ||||
| 	if product.Price != 220 || product.Code != "L1213" { | ||||
| 		t.Errorf("invalid data after update, got %+v", product) | ||||
| 	} | ||||
| 
 | ||||
| 	var result2 Product3 | ||||
| 	DB.First(&result2, product.ID) | ||||
| 
 | ||||
| 	AssertEqual(t, result2, product) | ||||
| } | ||||
|  | ||||
| @ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string { | ||||
| 
 | ||||
| 	return strings.Join(results, "_") | ||||
| } | ||||
| 
 | ||||
| func AssertEqual(src, dst interface{}) bool { | ||||
| 	if !reflect.DeepEqual(src, dst) { | ||||
| 		if valuer, ok := src.(driver.Valuer); ok { | ||||
| 			src, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := dst.(driver.Valuer); ok { | ||||
| 			dst, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		return reflect.DeepEqual(src, dst) | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu