Fix null setting on reused objects
Starting from v1.25.1 gorm seems to not reset null fields, causing incosistent behavior with previous behavior.
This commit is contained in:
parent
8fb9a31775
commit
2c60bfe3bc
@ -830,6 +830,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **time.Time:
|
case **time.Time:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.Set(ctx, value, *data)
|
field.Set(ctx, value, *data)
|
||||||
|
} else {
|
||||||
|
field.Set(ctx, value, time.Time{})
|
||||||
}
|
}
|
||||||
case time.Time:
|
case time.Time:
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
||||||
@ -856,6 +858,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **time.Time:
|
case **time.Time:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||||
|
} else {
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf((*time.Time)(nil)))
|
||||||
}
|
}
|
||||||
case time.Time:
|
case time.Time:
|
||||||
fieldValue := field.ReflectValueOf(ctx, value)
|
fieldValue := field.ReflectValueOf(ctx, value)
|
||||||
@ -890,7 +894,11 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if !reflectV.IsValid() {
|
if !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { // regression: https://github.com/go-gorm/gorm/pull/6311
|
||||||
|
if field.FieldType.Kind() == reflect.Pointer {
|
||||||
|
// If we have a pointer on the destination side, let's make sure we set it to null
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflect.Zero(field.FieldType))
|
||||||
|
}
|
||||||
return
|
return
|
||||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
@ -917,6 +925,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if !reflectV.IsValid() {
|
if !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
return
|
return
|
||||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
|
@ -332,3 +332,71 @@ func TestTypeAliasField(t *testing.T) {
|
|||||||
checkSchemaField(t, alias, f, func(f *schema.Field) {})
|
checkSchemaField(t, alias, f, func(f *schema.Field) {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScannerNullSupport(t *testing.T) {
|
||||||
|
schema, err := schema.Parse(&tests.NullValue{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given that we have an object with non-default values
|
||||||
|
dest := &tests.NullValue{
|
||||||
|
ScannerValue: tests.NewDummyString("test"),
|
||||||
|
NullScannerValue: &tests.DummyString{},
|
||||||
|
IntValue: 1,
|
||||||
|
NullIntValue: tests.ToPointer(int(1)),
|
||||||
|
TimeValue: time.Now(),
|
||||||
|
NullTimeValue: tests.ToPointer(time.Now()),
|
||||||
|
}
|
||||||
|
reflectValue := reflect.ValueOf(dest)
|
||||||
|
|
||||||
|
// let's assert that we are returning a pointer to a pointer
|
||||||
|
entry := schema.FieldsByName["NullScannerValue"].NewValuePool.Get()
|
||||||
|
|
||||||
|
_, ok := entry.(**tests.DummyString)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("scanner pointers are not returned as double pointers, thus null scanning will fail. Scanner: %v", entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
validateScannerNullSetting[tests.DummyString](
|
||||||
|
t, schema, reflectValue,
|
||||||
|
"ScannerValue", &dest.ScannerValue, tests.DummyString{},
|
||||||
|
"NullScannerValue", &dest.NullScannerValue,
|
||||||
|
)
|
||||||
|
// This was not faulty
|
||||||
|
validateScannerNullSetting[int](
|
||||||
|
t, schema, reflectValue,
|
||||||
|
"IntValue", &dest.IntValue, 0,
|
||||||
|
"NullIntValue", &dest.NullIntValue,
|
||||||
|
)
|
||||||
|
validateScannerNullSetting[time.Time](
|
||||||
|
t, schema, reflectValue,
|
||||||
|
"TimeValue", &dest.TimeValue, time.Time{},
|
||||||
|
"NullTimeValue", &dest.NullTimeValue,
|
||||||
|
)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateScannerNullSetting[T comparable](t *testing.T, schema *schema.Schema, reflectValue reflect.Value,
|
||||||
|
directFieldName string, directFieldPtr *T, directZeroValue T,
|
||||||
|
indirectFieldName string, indirectFieldPtr **T) {
|
||||||
|
var tPtrPtr **T // used to scan into an *T field
|
||||||
|
var tPtr *T // used to scan into an T field
|
||||||
|
|
||||||
|
err := schema.FieldsByName[directFieldName].Set(context.TODO(), reflectValue, tPtr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error setting field: %s", directFieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if *directFieldPtr != directZeroValue {
|
||||||
|
t.Fatalf("value didn't got reset to its default value: %s", directFieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = schema.FieldsByName[indirectFieldName].Set(context.TODO(), reflectValue, tPtrPtr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error setting field: %s", indirectFieldName)
|
||||||
|
}
|
||||||
|
if *indirectFieldPtr != nil {
|
||||||
|
t.Fatalf("ptr value didn't got reset to its default value: %s", indirectFieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
@ -240,3 +242,62 @@ func TestScanToEmbedded(t *testing.T) {
|
|||||||
err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error
|
err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error
|
||||||
AssertEqual(t, err, nil)
|
AssertEqual(t, err, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanNilHandling(t *testing.T) {
|
||||||
|
err := DB.Exec(`DELETE FROM null_values`).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error emptying null values: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
if i%2 == 0 {
|
||||||
|
err := DB.Exec(
|
||||||
|
`INSERT INTO null_values (id) VALUES (?)`, i,
|
||||||
|
).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot insert record %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := DB.Save(&NullValue{
|
||||||
|
Model: gorm.Model{
|
||||||
|
ID: uint(i),
|
||||||
|
},
|
||||||
|
ScannerValue: NewDummyString(fmt.Sprintf("%d", i)),
|
||||||
|
NullScannerValue: ToPointer(NewDummyString(fmt.Sprintf("%d", i))),
|
||||||
|
IntValue: i,
|
||||||
|
NullIntValue: ToPointer(i),
|
||||||
|
TimeValue: time.Now(),
|
||||||
|
NullTimeValue: ToPointer(time.Now()),
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot insert record %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := DB.Model(&NullValue{}).Order("id asc").Rows()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot query nullvalues: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
test := NullValue{}
|
||||||
|
for rows.Next() {
|
||||||
|
if err := DB.ScanRows(rows, &test); err != nil {
|
||||||
|
t.Fatalf("cannot scan nullvalues: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.ID%2 == 0 {
|
||||||
|
AssertEqual(t, test.ScannerValue, NewDummyString(""))
|
||||||
|
AssertEqual(t, test.NullScannerValue, nil)
|
||||||
|
AssertEqual(t, test.IntValue, int(0))
|
||||||
|
AssertEqual(t, test.NullIntValue, nil)
|
||||||
|
AssertEqual(t, test.TimeValue, time.Time{})
|
||||||
|
AssertEqual(t, test.NullTimeValue, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
|||||||
|
|
||||||
func RunMigrations() {
|
func RunMigrations() {
|
||||||
var err error
|
var err error
|
||||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}}
|
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}, &NullValue{}}
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||||
|
|
||||||
|
35
utils/tests/dummy_scanner.go
Normal file
35
utils/tests/dummy_scanner.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DummyString struct {
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDummyString(s string) DummyString {
|
||||||
|
return DummyString{
|
||||||
|
value: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DummyString) Scan(value interface{}) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
d.value = v
|
||||||
|
default:
|
||||||
|
d.value = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DummyString) Value() (driver.Value, error) {
|
||||||
|
return d.value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DummyString) String() string {
|
||||||
|
return d.value
|
||||||
|
}
|
@ -30,6 +30,7 @@ type User struct {
|
|||||||
Languages []Language `gorm:"many2many:UserSpeak;"`
|
Languages []Language `gorm:"many2many:UserSpeak;"`
|
||||||
Friends []*User `gorm:"many2many:user_friends;"`
|
Friends []*User `gorm:"many2many:user_friends;"`
|
||||||
Active bool
|
Active bool
|
||||||
|
Motto *DummyString
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account struct {
|
type Account struct {
|
||||||
@ -102,3 +103,13 @@ type Child struct {
|
|||||||
ParentID *uint
|
ParentID *uint
|
||||||
Parent *Parent
|
Parent *Parent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NullValue struct {
|
||||||
|
gorm.Model
|
||||||
|
ScannerValue DummyString
|
||||||
|
NullScannerValue *DummyString
|
||||||
|
IntValue int
|
||||||
|
NullIntValue *int
|
||||||
|
TimeValue time.Time
|
||||||
|
NullTimeValue *time.Time
|
||||||
|
}
|
||||||
|
@ -126,3 +126,7 @@ func Now() *time.Time {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
return &now
|
return &now
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToPointer[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user