diff --git a/schema/field.go b/schema/field.go index 67e60f70..de797402 100644 --- a/schema/field.go +++ b/schema/field.go @@ -458,20 +458,12 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) { case len(field.StructField.Index) == 1 && fieldIndex >= 0: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) - if v.Type() != modelType { - fieldValue := v.FieldByName(field.Name) - return fieldValue.Interface(), fieldValue.IsZero() - } fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) - if v.Type() != modelType { - fieldValue := v.FieldByName(field.Name) - return fieldValue.Interface(), fieldValue.IsZero() - } for _, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) @@ -516,17 +508,11 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) { case len(field.StructField.Index) == 1 && fieldIndex >= 0: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { v = reflect.Indirect(v) - if v.Type() != modelType { - return v.FieldByName(field.Name) - } return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { v = reflect.Indirect(v) - if v.Type() != modelType { - return v.FieldByName(field.Name) - } for idx, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) diff --git a/statement.go b/statement.go index ba5d3f18..74feaedd 100644 --- a/statement.go +++ b/statement.go @@ -658,12 +658,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/tests/check_subset_model_change_test.go b/tests/check_subset_model_change_test.go deleted file mode 100644 index 69bb5ebc..00000000 --- a/tests/check_subset_model_change_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package tests_test - -import ( - "fmt" - "strings" - "testing" - - "gorm.io/gorm" -) - -type Man struct { - ID int - Age int - Name string - Detail string -} - -// Panic-safe BeforeUpdate hook that checks for Changed("age") -func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic in BeforeUpdate: %v", r) - } - }() - - if !tx.Statement.Changed("age") { - return nil - } - return nil -} - -func (m *Man) update(data interface{}) error { - return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error -} - -func TestBeforeUpdateStatementChanged(t *testing.T) { - DB.AutoMigrate(&Man{}) - type TestCase struct { - BaseObjects Man - change interface{} - expectError bool - } - - testCases := []TestCase{ - { - BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"}, - change: struct { - Age int - }{Age: 20}, - expectError: false, - }, - { - BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"}, - change: struct { - Name string - }{Name: "name-only"}, - expectError: true, - }, - { - BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"}, - change: struct { - Name string - Age int - }{Name: "name-only", Age: 20}, - expectError: false, - }, - } - - for _, test := range testCases { - DB.Create(&test.BaseObjects) - - // below comment is stored for future reference - // err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error - err := test.BaseObjects.update(test.change) - if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") { - if !test.expectError { - t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err) - } - } else { - if test.expectError { - t.Errorf("expected panic did not occur for input: %+v", test.change) - } - if err != nil { - t.Errorf("unexpected GORM error: %v", err) - } - } - } -} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index e04a42fb..af86f8a8 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -24,7 +24,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}, &Man{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") diff --git a/tests/submodel_test.go b/tests/submodel_test.go new file mode 100644 index 00000000..31bfda4e --- /dev/null +++ b/tests/submodel_test.go @@ -0,0 +1,45 @@ +package tests_test + +import ( + "testing" + "gorm.io/gorm" +) + +type Man struct { + ID int + Age int + Name string + Detail string +} + +// Panic-safe BeforeUpdate hook that checks for Changed("age") +func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) { + if !tx.Statement.Changed("age") { + return nil + } + return nil +} + +func TestSubModel(t *testing.T) { + man := Man{Age: 18, Name: "random-name"} + if err := DB.Create(&man).Error; err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err := DB.Model(&man).Where("id = ?", man.ID).Updates(struct { + Age int + }{Age: 20}).Error; err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result = struct{ + ID int + Age int + }{} + if err := DB.Model(&man).Where("id = ?", man.ID).Find(&result).Error; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.ID != man.ID || result.Age != 20 { + t.Fatalf("expected ID %d and Age 20, got ID %d and age", result.ID, result.Age) + } +}