Fix null setting on reused objects
Starting from v1.25.1 gorm seems to not reset null fields, causing incosistent behavior with previous behavior.
This commit is contained in:
parent
8fb9a31775
commit
2c60bfe3bc
@ -830,6 +830,8 @@ func (field *Field) setupValuerAndSetter() {
|
||||
case **time.Time:
|
||||
if data != nil && *data != nil {
|
||||
field.Set(ctx, value, *data)
|
||||
} else {
|
||||
field.Set(ctx, value, time.Time{})
|
||||
}
|
||||
case time.Time:
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
||||
@ -856,6 +858,8 @@ func (field *Field) setupValuerAndSetter() {
|
||||
case **time.Time:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf((*time.Time)(nil)))
|
||||
}
|
||||
case time.Time:
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
@ -890,7 +894,11 @@ func (field *Field) setupValuerAndSetter() {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { // regression: https://github.com/go-gorm/gorm/pull/6311
|
||||
if field.FieldType.Kind() == reflect.Pointer {
|
||||
// If we have a pointer on the destination side, let's make sure we set it to null
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.Zero(field.FieldType))
|
||||
}
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
@ -917,6 +925,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
|
@ -332,3 +332,71 @@ func TestTypeAliasField(t *testing.T) {
|
||||
checkSchemaField(t, alias, f, func(f *schema.Field) {})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScannerNullSupport(t *testing.T) {
|
||||
schema, err := schema.Parse(&tests.NullValue{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
|
||||
}
|
||||
|
||||
// Given that we have an object with non-default values
|
||||
dest := &tests.NullValue{
|
||||
ScannerValue: tests.NewDummyString("test"),
|
||||
NullScannerValue: &tests.DummyString{},
|
||||
IntValue: 1,
|
||||
NullIntValue: tests.ToPointer(int(1)),
|
||||
TimeValue: time.Now(),
|
||||
NullTimeValue: tests.ToPointer(time.Now()),
|
||||
}
|
||||
reflectValue := reflect.ValueOf(dest)
|
||||
|
||||
// let's assert that we are returning a pointer to a pointer
|
||||
entry := schema.FieldsByName["NullScannerValue"].NewValuePool.Get()
|
||||
|
||||
_, ok := entry.(**tests.DummyString)
|
||||
if !ok {
|
||||
t.Fatalf("scanner pointers are not returned as double pointers, thus null scanning will fail. Scanner: %v", entry)
|
||||
}
|
||||
|
||||
validateScannerNullSetting[tests.DummyString](
|
||||
t, schema, reflectValue,
|
||||
"ScannerValue", &dest.ScannerValue, tests.DummyString{},
|
||||
"NullScannerValue", &dest.NullScannerValue,
|
||||
)
|
||||
// This was not faulty
|
||||
validateScannerNullSetting[int](
|
||||
t, schema, reflectValue,
|
||||
"IntValue", &dest.IntValue, 0,
|
||||
"NullIntValue", &dest.NullIntValue,
|
||||
)
|
||||
validateScannerNullSetting[time.Time](
|
||||
t, schema, reflectValue,
|
||||
"TimeValue", &dest.TimeValue, time.Time{},
|
||||
"NullTimeValue", &dest.NullTimeValue,
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
func validateScannerNullSetting[T comparable](t *testing.T, schema *schema.Schema, reflectValue reflect.Value,
|
||||
directFieldName string, directFieldPtr *T, directZeroValue T,
|
||||
indirectFieldName string, indirectFieldPtr **T) {
|
||||
var tPtrPtr **T // used to scan into an *T field
|
||||
var tPtr *T // used to scan into an T field
|
||||
|
||||
err := schema.FieldsByName[directFieldName].Set(context.TODO(), reflectValue, tPtr)
|
||||
if err != nil {
|
||||
t.Fatalf("error setting field: %s", directFieldName)
|
||||
}
|
||||
|
||||
if *directFieldPtr != directZeroValue {
|
||||
t.Fatalf("value didn't got reset to its default value: %s", directFieldName)
|
||||
}
|
||||
|
||||
err = schema.FieldsByName[indirectFieldName].Set(context.TODO(), reflectValue, tPtrPtr)
|
||||
if err != nil {
|
||||
t.Fatalf("error setting field: %s", indirectFieldName)
|
||||
}
|
||||
if *indirectFieldPtr != nil {
|
||||
t.Fatalf("ptr value didn't got reset to its default value: %s", indirectFieldName)
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,12 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -240,3 +242,62 @@ func TestScanToEmbedded(t *testing.T) {
|
||||
err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
func TestScanNilHandling(t *testing.T) {
|
||||
err := DB.Exec(`DELETE FROM null_values`).Error
|
||||
if err != nil {
|
||||
t.Fatalf("error emptying null values: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%2 == 0 {
|
||||
err := DB.Exec(
|
||||
`INSERT INTO null_values (id) VALUES (?)`, i,
|
||||
).Error
|
||||
if err != nil {
|
||||
t.Fatalf("cannot insert record %v", err)
|
||||
}
|
||||
} else {
|
||||
err := DB.Save(&NullValue{
|
||||
Model: gorm.Model{
|
||||
ID: uint(i),
|
||||
},
|
||||
ScannerValue: NewDummyString(fmt.Sprintf("%d", i)),
|
||||
NullScannerValue: ToPointer(NewDummyString(fmt.Sprintf("%d", i))),
|
||||
IntValue: i,
|
||||
NullIntValue: ToPointer(i),
|
||||
TimeValue: time.Now(),
|
||||
NullTimeValue: ToPointer(time.Now()),
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("cannot insert record %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rows, err := DB.Model(&NullValue{}).Order("id asc").Rows()
|
||||
if err != nil {
|
||||
t.Fatalf("cannot query nullvalues: %v", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
test := NullValue{}
|
||||
for rows.Next() {
|
||||
if err := DB.ScanRows(rows, &test); err != nil {
|
||||
t.Fatalf("cannot scan nullvalues: %v", err)
|
||||
}
|
||||
|
||||
if test.ID%2 == 0 {
|
||||
AssertEqual(t, test.ScannerValue, NewDummyString(""))
|
||||
AssertEqual(t, test.NullScannerValue, nil)
|
||||
AssertEqual(t, test.IntValue, int(0))
|
||||
AssertEqual(t, test.NullIntValue, nil)
|
||||
AssertEqual(t, test.TimeValue, time.Time{})
|
||||
AssertEqual(t, test.NullTimeValue, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
|
||||
func RunMigrations() {
|
||||
var err error
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}}
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}, &NullValue{}}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||
|
||||
|
35
utils/tests/dummy_scanner.go
Normal file
35
utils/tests/dummy_scanner.go
Normal file
@ -0,0 +1,35 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type DummyString struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func NewDummyString(s string) DummyString {
|
||||
return DummyString{
|
||||
value: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DummyString) Scan(value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
d.value = v
|
||||
default:
|
||||
d.value = fmt.Sprintf("%v", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d DummyString) Value() (driver.Value, error) {
|
||||
return d.value, nil
|
||||
}
|
||||
|
||||
func (d DummyString) String() string {
|
||||
return d.value
|
||||
}
|
@ -30,6 +30,7 @@ type User struct {
|
||||
Languages []Language `gorm:"many2many:UserSpeak;"`
|
||||
Friends []*User `gorm:"many2many:user_friends;"`
|
||||
Active bool
|
||||
Motto *DummyString
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
@ -102,3 +103,13 @@ type Child struct {
|
||||
ParentID *uint
|
||||
Parent *Parent
|
||||
}
|
||||
|
||||
type NullValue struct {
|
||||
gorm.Model
|
||||
ScannerValue DummyString
|
||||
NullScannerValue *DummyString
|
||||
IntValue int
|
||||
NullIntValue *int
|
||||
TimeValue time.Time
|
||||
NullTimeValue *time.Time
|
||||
}
|
||||
|
@ -126,3 +126,7 @@ func Now() *time.Time {
|
||||
now := time.Now()
|
||||
return &now
|
||||
}
|
||||
|
||||
func ToPointer[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user