Fix null setting on reused objects

Starting from v1.25.1 gorm seems to not reset null fields, causing
incosistent behavior with previous behavior.
This commit is contained in:
Peter Turi 2024-02-13 19:22:33 +01:00
parent 8fb9a31775
commit 2c60bfe3bc
No known key found for this signature in database
GPG Key ID: 0C1EB70972B011CA
7 changed files with 190 additions and 2 deletions

View File

@ -830,6 +830,8 @@ func (field *Field) setupValuerAndSetter() {
case **time.Time:
if data != nil && *data != nil {
field.Set(ctx, value, *data)
} else {
field.Set(ctx, value, time.Time{})
}
case time.Time:
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
@ -856,6 +858,8 @@ func (field *Field) setupValuerAndSetter() {
case **time.Time:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
} else {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf((*time.Time)(nil)))
}
case time.Time:
fieldValue := field.ReflectValueOf(ctx, value)
@ -890,7 +894,11 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { // regression: https://github.com/go-gorm/gorm/pull/6311
if field.FieldType.Kind() == reflect.Pointer {
// If we have a pointer on the destination side, let's make sure we set it to null
field.ReflectValueOf(ctx, value).Set(reflect.Zero(field.FieldType))
}
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
@ -917,6 +925,7 @@ func (field *Field) setupValuerAndSetter() {
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)

View File

@ -332,3 +332,71 @@ func TestTypeAliasField(t *testing.T) {
checkSchemaField(t, alias, f, func(f *schema.Field) {})
}
}
func TestScannerNullSupport(t *testing.T) {
schema, err := schema.Parse(&tests.NullValue{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
}
// Given that we have an object with non-default values
dest := &tests.NullValue{
ScannerValue: tests.NewDummyString("test"),
NullScannerValue: &tests.DummyString{},
IntValue: 1,
NullIntValue: tests.ToPointer(int(1)),
TimeValue: time.Now(),
NullTimeValue: tests.ToPointer(time.Now()),
}
reflectValue := reflect.ValueOf(dest)
// let's assert that we are returning a pointer to a pointer
entry := schema.FieldsByName["NullScannerValue"].NewValuePool.Get()
_, ok := entry.(**tests.DummyString)
if !ok {
t.Fatalf("scanner pointers are not returned as double pointers, thus null scanning will fail. Scanner: %v", entry)
}
validateScannerNullSetting[tests.DummyString](
t, schema, reflectValue,
"ScannerValue", &dest.ScannerValue, tests.DummyString{},
"NullScannerValue", &dest.NullScannerValue,
)
// This was not faulty
validateScannerNullSetting[int](
t, schema, reflectValue,
"IntValue", &dest.IntValue, 0,
"NullIntValue", &dest.NullIntValue,
)
validateScannerNullSetting[time.Time](
t, schema, reflectValue,
"TimeValue", &dest.TimeValue, time.Time{},
"NullTimeValue", &dest.NullTimeValue,
)
}
func validateScannerNullSetting[T comparable](t *testing.T, schema *schema.Schema, reflectValue reflect.Value,
directFieldName string, directFieldPtr *T, directZeroValue T,
indirectFieldName string, indirectFieldPtr **T) {
var tPtrPtr **T // used to scan into an *T field
var tPtr *T // used to scan into an T field
err := schema.FieldsByName[directFieldName].Set(context.TODO(), reflectValue, tPtr)
if err != nil {
t.Fatalf("error setting field: %s", directFieldName)
}
if *directFieldPtr != directZeroValue {
t.Fatalf("value didn't got reset to its default value: %s", directFieldName)
}
err = schema.FieldsByName[indirectFieldName].Set(context.TODO(), reflectValue, tPtrPtr)
if err != nil {
t.Fatalf("error setting field: %s", indirectFieldName)
}
if *indirectFieldPtr != nil {
t.Fatalf("ptr value didn't got reset to its default value: %s", indirectFieldName)
}
}

View File

@ -1,10 +1,12 @@
package tests_test
import (
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
@ -240,3 +242,62 @@ func TestScanToEmbedded(t *testing.T) {
err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error
AssertEqual(t, err, nil)
}
func TestScanNilHandling(t *testing.T) {
err := DB.Exec(`DELETE FROM null_values`).Error
if err != nil {
t.Fatalf("error emptying null values: %v", err)
}
for i := 0; i < 100; i++ {
if i%2 == 0 {
err := DB.Exec(
`INSERT INTO null_values (id) VALUES (?)`, i,
).Error
if err != nil {
t.Fatalf("cannot insert record %v", err)
}
} else {
err := DB.Save(&NullValue{
Model: gorm.Model{
ID: uint(i),
},
ScannerValue: NewDummyString(fmt.Sprintf("%d", i)),
NullScannerValue: ToPointer(NewDummyString(fmt.Sprintf("%d", i))),
IntValue: i,
NullIntValue: ToPointer(i),
TimeValue: time.Now(),
NullTimeValue: ToPointer(time.Now()),
}).Error
if err != nil {
t.Fatalf("cannot insert record %v", err)
}
}
}
rows, err := DB.Model(&NullValue{}).Order("id asc").Rows()
if err != nil {
t.Fatalf("cannot query nullvalues: %v", err)
}
defer rows.Close()
test := NullValue{}
for rows.Next() {
if err := DB.ScanRows(rows, &test); err != nil {
t.Fatalf("cannot scan nullvalues: %v", err)
}
if test.ID%2 == 0 {
AssertEqual(t, test.ScannerValue, NewDummyString(""))
AssertEqual(t, test.NullScannerValue, nil)
AssertEqual(t, test.IntValue, int(0))
AssertEqual(t, test.NullIntValue, nil)
AssertEqual(t, test.TimeValue, time.Time{})
AssertEqual(t, test.NullTimeValue, nil)
}
}
}

View File

@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
func RunMigrations() {
var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}, &NullValue{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

View File

@ -0,0 +1,35 @@
package tests
import (
"database/sql/driver"
"fmt"
)
type DummyString struct {
value string
}
func NewDummyString(s string) DummyString {
return DummyString{
value: s,
}
}
func (d *DummyString) Scan(value interface{}) error {
switch v := value.(type) {
case string:
d.value = v
default:
d.value = fmt.Sprintf("%v", value)
}
return nil
}
func (d DummyString) Value() (driver.Value, error) {
return d.value, nil
}
func (d DummyString) String() string {
return d.value
}

View File

@ -30,6 +30,7 @@ type User struct {
Languages []Language `gorm:"many2many:UserSpeak;"`
Friends []*User `gorm:"many2many:user_friends;"`
Active bool
Motto *DummyString
}
type Account struct {
@ -102,3 +103,13 @@ type Child struct {
ParentID *uint
Parent *Parent
}
type NullValue struct {
gorm.Model
ScannerValue DummyString
NullScannerValue *DummyString
IntValue int
NullIntValue *int
TimeValue time.Time
NullTimeValue *time.Time
}

View File

@ -126,3 +126,7 @@ func Now() *time.Time {
now := time.Now()
return &now
}
func ToPointer[T any](v T) *T {
return &v
}