Refactor update record
This commit is contained in:
		
							parent
							
								
									6c94b07e98
								
							
						
					
					
						commit
						2b6789c2d7
					
				@ -23,38 +23,11 @@ func SetupUpdateReflectValue(db *gorm.DB) {
 | 
			
		||||
						rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			} else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct {
 | 
			
		||||
				db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func findType(target interface{}) reflect.Type {
 | 
			
		||||
	t := reflect.TypeOf(target)
 | 
			
		||||
	if t.Kind() == reflect.Ptr {
 | 
			
		||||
		return t.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func transToModel(from, to reflect.Value) interface{} {
 | 
			
		||||
	if from.String() == to.String() {
 | 
			
		||||
		return from.Interface()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fromType := from.Type()
 | 
			
		||||
	for i := 0; i < fromType.NumField(); i++ {
 | 
			
		||||
		fieldName := fromType.Field(i).Name
 | 
			
		||||
		fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName)
 | 
			
		||||
		if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		toField.Set(fromField)
 | 
			
		||||
	}
 | 
			
		||||
	return to.Interface()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BeforeUpdate(db *gorm.DB) {
 | 
			
		||||
	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
 | 
			
		||||
		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
 | 
			
		||||
@ -249,35 +222,45 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		var updatingSchema = stmt.Schema
 | 
			
		||||
		if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
 | 
			
		||||
			// different schema
 | 
			
		||||
			updatingStmt := &gorm.Statement{DB: stmt.DB}
 | 
			
		||||
			if err := updatingStmt.Parse(stmt.Dest); err == nil {
 | 
			
		||||
				updatingSchema = updatingStmt.Schema
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		switch updatingValue.Kind() {
 | 
			
		||||
		case reflect.Struct:
 | 
			
		||||
			set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
 | 
			
		||||
			for _, dbName := range stmt.Schema.DBNames {
 | 
			
		||||
				field := stmt.Schema.LookUpField(dbName)
 | 
			
		||||
				if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
 | 
			
		||||
					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
 | 
			
		||||
						value, isZero := field.ValueOf(updatingValue)
 | 
			
		||||
						if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
 | 
			
		||||
							if field.AutoUpdateTime == schema.UnixNanosecond {
 | 
			
		||||
								value = stmt.DB.NowFunc().UnixNano()
 | 
			
		||||
							} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
			
		||||
								value = stmt.DB.NowFunc().UnixNano() / 1e6
 | 
			
		||||
							} else if field.GORMDataType == schema.Time {
 | 
			
		||||
								value = stmt.DB.NowFunc()
 | 
			
		||||
							} else {
 | 
			
		||||
								value = stmt.DB.NowFunc().Unix()
 | 
			
		||||
				if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable {
 | 
			
		||||
					if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
 | 
			
		||||
						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
 | 
			
		||||
							value, isZero := field.ValueOf(updatingValue)
 | 
			
		||||
							if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
 | 
			
		||||
								if field.AutoUpdateTime == schema.UnixNanosecond {
 | 
			
		||||
									value = stmt.DB.NowFunc().UnixNano()
 | 
			
		||||
								} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
			
		||||
									value = stmt.DB.NowFunc().UnixNano() / 1e6
 | 
			
		||||
								} else if field.GORMDataType == schema.Time {
 | 
			
		||||
									value = stmt.DB.NowFunc()
 | 
			
		||||
								} else {
 | 
			
		||||
									value = stmt.DB.NowFunc().Unix()
 | 
			
		||||
								}
 | 
			
		||||
								isZero = false
 | 
			
		||||
							}
 | 
			
		||||
							isZero = false
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if ok || !isZero {
 | 
			
		||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
 | 
			
		||||
							assignValue(field, value)
 | 
			
		||||
							if ok || !isZero {
 | 
			
		||||
								set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
 | 
			
		||||
								assignValue(field, value)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					} else {
 | 
			
		||||
						if value, isZero := field.ValueOf(updatingValue); !isZero {
 | 
			
		||||
							stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					if value, isZero := field.ValueOf(updatingValue); !isZero {
 | 
			
		||||
						stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -651,14 +651,16 @@ func TestSave(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user3.Name = "save3_"
 | 
			
		||||
	DB.Model(User{Model: user3.Model}).Save(&user3)
 | 
			
		||||
	if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil {
 | 
			
		||||
		t.Fatalf("failed to save user, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var result2 User
 | 
			
		||||
	if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID {
 | 
			
		||||
		t.Fatalf("failed to find updated user")
 | 
			
		||||
		t.Fatalf("failed to find updated user, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Debug().Model(User{Model: user3.Model}).Save(&struct {
 | 
			
		||||
	if err := DB.Model(User{Model: user3.Model}).Save(&struct {
 | 
			
		||||
		gorm.Model
 | 
			
		||||
		Placeholder string
 | 
			
		||||
		Name        string
 | 
			
		||||
@ -666,7 +668,9 @@ func TestSave(t *testing.T) {
 | 
			
		||||
		Model:       user3.Model,
 | 
			
		||||
		Placeholder: "placeholder",
 | 
			
		||||
		Name:        "save3__",
 | 
			
		||||
	})
 | 
			
		||||
	}).Error; err != nil {
 | 
			
		||||
		t.Fatalf("failed to update user, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var result3 User
 | 
			
		||||
	if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user