Refactor update record (#4679)
This commit is contained in:
		
							parent
							
								
									6c94b07e98
								
							
						
					
					
						commit
						ba16b2368f
					
				| @ -23,38 +23,11 @@ func SetupUpdateReflectValue(db *gorm.DB) { | |||||||
| 						rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) | 						rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 			} else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct { |  | ||||||
| 				db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem()) |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func findType(target interface{}) reflect.Type { |  | ||||||
| 	t := reflect.TypeOf(target) |  | ||||||
| 	if t.Kind() == reflect.Ptr { |  | ||||||
| 		return t.Elem() |  | ||||||
| 	} |  | ||||||
| 	return t |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func transToModel(from, to reflect.Value) interface{} { |  | ||||||
| 	if from.String() == to.String() { |  | ||||||
| 		return from.Interface() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	fromType := from.Type() |  | ||||||
| 	for i := 0; i < fromType.NumField(); i++ { |  | ||||||
| 		fieldName := fromType.Field(i).Name |  | ||||||
| 		fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName) |  | ||||||
| 		if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		toField.Set(fromField) |  | ||||||
| 	} |  | ||||||
| 	return to.Interface() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func BeforeUpdate(db *gorm.DB) { | func BeforeUpdate(db *gorm.DB) { | ||||||
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | ||||||
| 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | 		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { | ||||||
| @ -249,35 +222,45 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
|  | 		var updatingSchema = stmt.Schema | ||||||
|  | 		if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | ||||||
|  | 			// different schema
 | ||||||
|  | 			updatingStmt := &gorm.Statement{DB: stmt.DB} | ||||||
|  | 			if err := updatingStmt.Parse(stmt.Dest); err == nil { | ||||||
|  | 				updatingSchema = updatingStmt.Schema | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		switch updatingValue.Kind() { | 		switch updatingValue.Kind() { | ||||||
| 		case reflect.Struct: | 		case reflect.Struct: | ||||||
| 			set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) | 			set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) | ||||||
| 			for _, dbName := range stmt.Schema.DBNames { | 			for _, dbName := range stmt.Schema.DBNames { | ||||||
| 				field := stmt.Schema.LookUpField(dbName) | 				if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable { | ||||||
| 				if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | 					if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | ||||||
| 					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { | 						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { | ||||||
| 						value, isZero := field.ValueOf(updatingValue) | 							value, isZero := field.ValueOf(updatingValue) | ||||||
| 						if !stmt.SkipHooks && field.AutoUpdateTime > 0 { | 							if !stmt.SkipHooks && field.AutoUpdateTime > 0 { | ||||||
| 							if field.AutoUpdateTime == schema.UnixNanosecond { | 								if field.AutoUpdateTime == schema.UnixNanosecond { | ||||||
| 								value = stmt.DB.NowFunc().UnixNano() | 									value = stmt.DB.NowFunc().UnixNano() | ||||||
| 							} else if field.AutoUpdateTime == schema.UnixMillisecond { | 								} else if field.AutoUpdateTime == schema.UnixMillisecond { | ||||||
| 								value = stmt.DB.NowFunc().UnixNano() / 1e6 | 									value = stmt.DB.NowFunc().UnixNano() / 1e6 | ||||||
| 							} else if field.GORMDataType == schema.Time { | 								} else if field.GORMDataType == schema.Time { | ||||||
| 								value = stmt.DB.NowFunc() | 									value = stmt.DB.NowFunc() | ||||||
| 							} else { | 								} else { | ||||||
| 								value = stmt.DB.NowFunc().Unix() | 									value = stmt.DB.NowFunc().Unix() | ||||||
|  | 								} | ||||||
|  | 								isZero = false | ||||||
| 							} | 							} | ||||||
| 							isZero = false |  | ||||||
| 						} |  | ||||||
| 
 | 
 | ||||||
| 						if ok || !isZero { | 							if ok || !isZero { | ||||||
| 							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) | 								set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) | ||||||
| 							assignValue(field, value) | 								assignValue(field, value) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} else { | ||||||
|  | 						if value, isZero := field.ValueOf(updatingValue); !isZero { | ||||||
|  | 							stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) | ||||||
| 						} | 						} | ||||||
| 					} |  | ||||||
| 				} else { |  | ||||||
| 					if value, isZero := field.ValueOf(updatingValue); !isZero { |  | ||||||
| 						stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) |  | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | |||||||
| @ -651,14 +651,16 @@ func TestSave(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	user3.Name = "save3_" | 	user3.Name = "save3_" | ||||||
| 	DB.Model(User{Model: user3.Model}).Save(&user3) | 	if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to save user, got %v", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	var result2 User | 	var result2 User | ||||||
| 	if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { | 	if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { | ||||||
| 		t.Fatalf("failed to find updated user") | 		t.Fatalf("failed to find updated user, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	DB.Debug().Model(User{Model: user3.Model}).Save(&struct { | 	if err := DB.Model(User{Model: user3.Model}).Save(&struct { | ||||||
| 		gorm.Model | 		gorm.Model | ||||||
| 		Placeholder string | 		Placeholder string | ||||||
| 		Name        string | 		Name        string | ||||||
| @ -666,7 +668,9 @@ func TestSave(t *testing.T) { | |||||||
| 		Model:       user3.Model, | 		Model:       user3.Model, | ||||||
| 		Placeholder: "placeholder", | 		Placeholder: "placeholder", | ||||||
| 		Name:        "save3__", | 		Name:        "save3__", | ||||||
| 	}) | 	}).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to update user, got %v", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	var result3 User | 	var result3 User | ||||||
| 	if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { | 	if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu