Add tests for sub model

This commit is contained in:
Jinzhu 2025-08-20 12:51:17 +08:00
parent 67de7a8af8
commit 4e34a6d21b
5 changed files with 54 additions and 108 deletions

View File

@ -458,20 +458,12 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
fieldValue := v.Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
@ -516,17 +508,11 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
return v.Field(fieldIndex)
}
default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)

View File

@ -658,12 +658,15 @@ func (stmt *Statement) Changed(fields ...string) bool {
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
if destField := descSchema.LookUpField(field.DBName); destField != nil {
changedValue, zero := destField.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
return false

View File

@ -1,88 +0,0 @@
package tests_test
import (
"fmt"
"strings"
"testing"
"gorm.io/gorm"
)
type Man struct {
ID int
Age int
Name string
Detail string
}
// Panic-safe BeforeUpdate hook that checks for Changed("age")
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
}
}()
if !tx.Statement.Changed("age") {
return nil
}
return nil
}
func (m *Man) update(data interface{}) error {
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
}
func TestBeforeUpdateStatementChanged(t *testing.T) {
DB.AutoMigrate(&Man{})
type TestCase struct {
BaseObjects Man
change interface{}
expectError bool
}
testCases := []TestCase{
{
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
change: struct {
Age int
}{Age: 20},
expectError: false,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
}{Name: "name-only"},
expectError: true,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
Age int
}{Name: "name-only", Age: 20},
expectError: false,
},
}
for _, test := range testCases {
DB.Create(&test.BaseObjects)
// below comment is stored for future reference
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
err := test.BaseObjects.update(test.change)
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
if !test.expectError {
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
}
} else {
if test.expectError {
t.Errorf("expected panic did not occur for input: %+v", test.change)
}
if err != nil {
t.Errorf("unexpected GORM error: %v", err)
}
}
}
}

View File

@ -24,7 +24,7 @@ import (
)
func TestMigrate(t *testing.T) {
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}, &Man{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")

45
tests/submodel_test.go Normal file
View File

@ -0,0 +1,45 @@
package tests_test
import (
"testing"
"gorm.io/gorm"
)
type Man struct {
ID int
Age int
Name string
Detail string
}
// Panic-safe BeforeUpdate hook that checks for Changed("age")
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
if !tx.Statement.Changed("age") {
return nil
}
return nil
}
func TestSubModel(t *testing.T) {
man := Man{Age: 18, Name: "random-name"}
if err := DB.Create(&man).Error; err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := DB.Model(&man).Where("id = ?", man.ID).Updates(struct {
Age int
}{Age: 20}).Error; err != nil {
t.Fatalf("unexpected error: %v", err)
}
var result = struct{
ID int
Age int
}{}
if err := DB.Model(&man).Where("id = ?", man.ID).Find(&result).Error; err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.ID != man.ID || result.Age != 20 {
t.Fatalf("expected ID %d and Age 20, got ID %d and age", result.ID, result.Age)
}
}