From 2c60bfe3bc3be924755bd9822c2276a40f6795fc Mon Sep 17 00:00:00 2001 From: Peter Turi Date: Tue, 13 Feb 2024 19:22:33 +0100 Subject: [PATCH] Fix null setting on reused objects Starting from v1.25.1 gorm seems to not reset null fields, causing incosistent behavior with previous behavior. --- schema/field.go | 11 +++++- schema/field_test.go | 68 ++++++++++++++++++++++++++++++++++++ tests/scan_test.go | 61 ++++++++++++++++++++++++++++++++ tests/tests_test.go | 2 +- utils/tests/dummy_scanner.go | 35 +++++++++++++++++++ utils/tests/models.go | 11 ++++++ utils/tests/utils.go | 4 +++ 7 files changed, 190 insertions(+), 2 deletions(-) create mode 100644 utils/tests/dummy_scanner.go diff --git a/schema/field.go b/schema/field.go index 91e4c0ab..13b842b5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -830,6 +830,8 @@ func (field *Field) setupValuerAndSetter() { case **time.Time: if data != nil && *data != nil { field.Set(ctx, value, *data) + } else { + field.Set(ctx, value, time.Time{}) } case time.Time: field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) @@ -856,6 +858,8 @@ func (field *Field) setupValuerAndSetter() { case **time.Time: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } else { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf((*time.Time)(nil))) } case time.Time: fieldValue := field.ReflectValueOf(ctx, value) @@ -890,7 +894,11 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { // regression: https://github.com/go-gorm/gorm/pull/6311 + if field.FieldType.Kind() == reflect.Pointer { + // If we have a pointer on the destination side, let's make sure we set it to null + field.ReflectValueOf(ctx, value).Set(reflect.Zero(field.FieldType)) + } return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) @@ -917,6 +925,7 @@ func (field *Field) setupValuerAndSetter() { if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) diff --git a/schema/field_test.go b/schema/field_test.go index 300e375b..5a704323 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -332,3 +332,71 @@ func TestTypeAliasField(t *testing.T) { checkSchemaField(t, alias, f, func(f *schema.Field) {}) } } + +func TestScannerNullSupport(t *testing.T) { + schema, err := schema.Parse(&tests.NullValue{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err) + } + + // Given that we have an object with non-default values + dest := &tests.NullValue{ + ScannerValue: tests.NewDummyString("test"), + NullScannerValue: &tests.DummyString{}, + IntValue: 1, + NullIntValue: tests.ToPointer(int(1)), + TimeValue: time.Now(), + NullTimeValue: tests.ToPointer(time.Now()), + } + reflectValue := reflect.ValueOf(dest) + + // let's assert that we are returning a pointer to a pointer + entry := schema.FieldsByName["NullScannerValue"].NewValuePool.Get() + + _, ok := entry.(**tests.DummyString) + if !ok { + t.Fatalf("scanner pointers are not returned as double pointers, thus null scanning will fail. Scanner: %v", entry) + } + + validateScannerNullSetting[tests.DummyString]( + t, schema, reflectValue, + "ScannerValue", &dest.ScannerValue, tests.DummyString{}, + "NullScannerValue", &dest.NullScannerValue, + ) + // This was not faulty + validateScannerNullSetting[int]( + t, schema, reflectValue, + "IntValue", &dest.IntValue, 0, + "NullIntValue", &dest.NullIntValue, + ) + validateScannerNullSetting[time.Time]( + t, schema, reflectValue, + "TimeValue", &dest.TimeValue, time.Time{}, + "NullTimeValue", &dest.NullTimeValue, + ) + +} + +func validateScannerNullSetting[T comparable](t *testing.T, schema *schema.Schema, reflectValue reflect.Value, + directFieldName string, directFieldPtr *T, directZeroValue T, + indirectFieldName string, indirectFieldPtr **T) { + var tPtrPtr **T // used to scan into an *T field + var tPtr *T // used to scan into an T field + + err := schema.FieldsByName[directFieldName].Set(context.TODO(), reflectValue, tPtr) + if err != nil { + t.Fatalf("error setting field: %s", directFieldName) + } + + if *directFieldPtr != directZeroValue { + t.Fatalf("value didn't got reset to its default value: %s", directFieldName) + } + + err = schema.FieldsByName[indirectFieldName].Set(context.TODO(), reflectValue, tPtrPtr) + if err != nil { + t.Fatalf("error setting field: %s", indirectFieldName) + } + if *indirectFieldPtr != nil { + t.Fatalf("ptr value didn't got reset to its default value: %s", indirectFieldName) + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 6f2e9f54..7f03caa7 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -1,10 +1,12 @@ package tests_test import ( + "fmt" "reflect" "sort" "strings" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -240,3 +242,62 @@ func TestScanToEmbedded(t *testing.T) { err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error AssertEqual(t, err, nil) } + +func TestScanNilHandling(t *testing.T) { + err := DB.Exec(`DELETE FROM null_values`).Error + if err != nil { + t.Fatalf("error emptying null values: %v", err) + } + + for i := 0; i < 100; i++ { + if i%2 == 0 { + err := DB.Exec( + `INSERT INTO null_values (id) VALUES (?)`, i, + ).Error + if err != nil { + t.Fatalf("cannot insert record %v", err) + } + } else { + err := DB.Save(&NullValue{ + Model: gorm.Model{ + ID: uint(i), + }, + ScannerValue: NewDummyString(fmt.Sprintf("%d", i)), + NullScannerValue: ToPointer(NewDummyString(fmt.Sprintf("%d", i))), + IntValue: i, + NullIntValue: ToPointer(i), + TimeValue: time.Now(), + NullTimeValue: ToPointer(time.Now()), + }).Error + if err != nil { + t.Fatalf("cannot insert record %v", err) + } + } + + } + + rows, err := DB.Model(&NullValue{}).Order("id asc").Rows() + if err != nil { + t.Fatalf("cannot query nullvalues: %v", err) + } + + defer rows.Close() + + test := NullValue{} + for rows.Next() { + if err := DB.ScanRows(rows, &test); err != nil { + t.Fatalf("cannot scan nullvalues: %v", err) + } + + if test.ID%2 == 0 { + AssertEqual(t, test.ScannerValue, NewDummyString("")) + AssertEqual(t, test.NullScannerValue, nil) + AssertEqual(t, test.IntValue, int(0)) + AssertEqual(t, test.NullIntValue, nil) + AssertEqual(t, test.TimeValue, time.Time{}) + AssertEqual(t, test.NullTimeValue, nil) + } + + } + +} diff --git a/tests/tests_test.go b/tests/tests_test.go index a127734e..2724eb1a 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}, &NullValue{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/dummy_scanner.go b/utils/tests/dummy_scanner.go new file mode 100644 index 00000000..fc820c13 --- /dev/null +++ b/utils/tests/dummy_scanner.go @@ -0,0 +1,35 @@ +package tests + +import ( + "database/sql/driver" + "fmt" +) + +type DummyString struct { + value string +} + +func NewDummyString(s string) DummyString { + return DummyString{ + value: s, + } +} + +func (d *DummyString) Scan(value interface{}) error { + switch v := value.(type) { + case string: + d.value = v + default: + d.value = fmt.Sprintf("%v", value) + } + + return nil +} + +func (d DummyString) Value() (driver.Value, error) { + return d.value, nil +} + +func (d DummyString) String() string { + return d.value +} diff --git a/utils/tests/models.go b/utils/tests/models.go index f9f4f50e..5bca9128 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -30,6 +30,7 @@ type User struct { Languages []Language `gorm:"many2many:UserSpeak;"` Friends []*User `gorm:"many2many:user_friends;"` Active bool + Motto *DummyString } type Account struct { @@ -102,3 +103,13 @@ type Child struct { ParentID *uint Parent *Parent } + +type NullValue struct { + gorm.Model + ScannerValue DummyString + NullScannerValue *DummyString + IntValue int + NullIntValue *int + TimeValue time.Time + NullTimeValue *time.Time +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 49d01f2e..7f161a83 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -126,3 +126,7 @@ func Now() *time.Time { now := time.Now() return &now } + +func ToPointer[T any](v T) *T { + return &v +}