Add tests for sub model
This commit is contained in:
parent
67de7a8af8
commit
4e34a6d21b
@ -458,20 +458,12 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
|||||||
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
||||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
v = reflect.Indirect(v)
|
v = reflect.Indirect(v)
|
||||||
if v.Type() != modelType {
|
|
||||||
fieldValue := v.FieldByName(field.Name)
|
|
||||||
return fieldValue.Interface(), fieldValue.IsZero()
|
|
||||||
}
|
|
||||||
fieldValue := v.Field(fieldIndex)
|
fieldValue := v.Field(fieldIndex)
|
||||||
return fieldValue.Interface(), fieldValue.IsZero()
|
return fieldValue.Interface(), fieldValue.IsZero()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
v = reflect.Indirect(v)
|
v = reflect.Indirect(v)
|
||||||
if v.Type() != modelType {
|
|
||||||
fieldValue := v.FieldByName(field.Name)
|
|
||||||
return fieldValue.Interface(), fieldValue.IsZero()
|
|
||||||
}
|
|
||||||
for _, fieldIdx := range field.StructField.Index {
|
for _, fieldIdx := range field.StructField.Index {
|
||||||
if fieldIdx >= 0 {
|
if fieldIdx >= 0 {
|
||||||
v = v.Field(fieldIdx)
|
v = v.Field(fieldIdx)
|
||||||
@ -516,17 +508,11 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
|||||||
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
||||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||||
v = reflect.Indirect(v)
|
v = reflect.Indirect(v)
|
||||||
if v.Type() != modelType {
|
|
||||||
return v.FieldByName(field.Name)
|
|
||||||
}
|
|
||||||
return v.Field(fieldIndex)
|
return v.Field(fieldIndex)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||||
v = reflect.Indirect(v)
|
v = reflect.Indirect(v)
|
||||||
if v.Type() != modelType {
|
|
||||||
return v.FieldByName(field.Name)
|
|
||||||
}
|
|
||||||
for idx, fieldIdx := range field.StructField.Index {
|
for idx, fieldIdx := range field.StructField.Index {
|
||||||
if fieldIdx >= 0 {
|
if fieldIdx >= 0 {
|
||||||
v = v.Field(fieldIdx)
|
v = v.Field(fieldIdx)
|
||||||
|
13
statement.go
13
statement.go
@ -658,12 +658,15 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
|||||||
for destValue.Kind() == reflect.Ptr {
|
for destValue.Kind() == reflect.Ptr {
|
||||||
destValue = destValue.Elem()
|
destValue = destValue.Elem()
|
||||||
}
|
}
|
||||||
|
if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||||
changedValue, zero := field.ValueOf(stmt.Context, destValue)
|
if destField := descSchema.LookUpField(field.DBName); destField != nil {
|
||||||
if v {
|
changedValue, zero := destField.ValueOf(stmt.Context, destValue)
|
||||||
return !utils.AssertEqual(changedValue, fieldValue)
|
if v {
|
||||||
|
return !utils.AssertEqual(changedValue, fieldValue)
|
||||||
|
}
|
||||||
|
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -1,88 +0,0 @@
|
|||||||
package tests_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Man struct {
|
|
||||||
ID int
|
|
||||||
Age int
|
|
||||||
Name string
|
|
||||||
Detail string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Panic-safe BeforeUpdate hook that checks for Changed("age")
|
|
||||||
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !tx.Statement.Changed("age") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Man) update(data interface{}) error {
|
|
||||||
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBeforeUpdateStatementChanged(t *testing.T) {
|
|
||||||
DB.AutoMigrate(&Man{})
|
|
||||||
type TestCase struct {
|
|
||||||
BaseObjects Man
|
|
||||||
change interface{}
|
|
||||||
expectError bool
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []TestCase{
|
|
||||||
{
|
|
||||||
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
|
|
||||||
change: struct {
|
|
||||||
Age int
|
|
||||||
}{Age: 20},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
|
|
||||||
change: struct {
|
|
||||||
Name string
|
|
||||||
}{Name: "name-only"},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
|
|
||||||
change: struct {
|
|
||||||
Name string
|
|
||||||
Age int
|
|
||||||
}{Name: "name-only", Age: 20},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
DB.Create(&test.BaseObjects)
|
|
||||||
|
|
||||||
// below comment is stored for future reference
|
|
||||||
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
|
|
||||||
err := test.BaseObjects.update(test.change)
|
|
||||||
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
|
|
||||||
if !test.expectError {
|
|
||||||
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if test.expectError {
|
|
||||||
t.Errorf("expected panic did not occur for input: %+v", test.change)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unexpected GORM error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -24,7 +24,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMigrate(t *testing.T) {
|
func TestMigrate(t *testing.T) {
|
||||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}}
|
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}, &Man{}}
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||||
DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")
|
DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")
|
||||||
|
45
tests/submodel_test.go
Normal file
45
tests/submodel_test.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Man struct {
|
||||||
|
ID int
|
||||||
|
Age int
|
||||||
|
Name string
|
||||||
|
Detail string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Panic-safe BeforeUpdate hook that checks for Changed("age")
|
||||||
|
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
|
||||||
|
if !tx.Statement.Changed("age") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubModel(t *testing.T) {
|
||||||
|
man := Man{Age: 18, Name: "random-name"}
|
||||||
|
if err := DB.Create(&man).Error; err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&man).Where("id = ?", man.ID).Updates(struct {
|
||||||
|
Age int
|
||||||
|
}{Age: 20}).Error; err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result = struct{
|
||||||
|
ID int
|
||||||
|
Age int
|
||||||
|
}{}
|
||||||
|
if err := DB.Model(&man).Where("id = ?", man.ID).Find(&result).Error; err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.ID != man.ID || result.Age != 20 {
|
||||||
|
t.Fatalf("expected ID %d and Age 20, got ID %d and age", result.ID, result.Age)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user