From 73a988ceb22651e01c968a9ec20ae1709e73c8e6 Mon Sep 17 00:00:00 2001 From: Waleed Masoom <92062428+Waldeedle@users.noreply.github.com> Date: Wed, 12 Jun 2024 06:57:36 -0400 Subject: [PATCH] fix(scan): update Scan function to reset structs to zero values for each scan (#7061) Co-authored-by: waleed.masoom --- scan.go | 3 ++ tests/scan_test.go | 105 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index 89b46c0a..eac6ca9d 100644 --- a/scan.go +++ b/scan.go @@ -331,6 +331,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { + if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: diff --git a/tests/scan_test.go b/tests/scan_test.go index 6f2e9f54..f7def909 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -5,6 +5,7 @@ import ( "sort" "strings" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) { rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { - t.Errorf("Not error should happen, got %v", err) + t.Errorf("No error should happen, got %v", err) } type Result struct { @@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) { }) if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") + t.Errorf("Should find expected results, got %+v", results) } var ages int @@ -158,7 +159,105 @@ func TestScanRows(t *testing.T) { var name string if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { - t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + t.Fatalf("failed to scan name, got error %v, name: %v", err, name) + } +} + +func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) { + DB.Save(&User{}) + + rows, err := DB.Table("users"). + Select(` + NULL AS bool_field, + NULL AS int_field, + NULL AS int8_field, + NULL AS int16_field, + NULL AS int32_field, + NULL AS int64_field, + NULL AS uint_field, + NULL AS uint8_field, + NULL AS uint16_field, + NULL AS uint32_field, + NULL AS uint64_field, + NULL AS float32_field, + NULL AS float64_field, + NULL AS string_field, + NULL AS time_field, + NULL AS time_ptr_field, + NULL AS embedded_int_field, + NULL AS nested_embedded_int_field, + NULL AS embedded_ptr_int_field + `).Rows() + if err != nil { + t.Errorf("No error should happen, got %v", err) + } + + type NestedEmbeddedStruct struct { + NestedEmbeddedIntField int + NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"` + } + + type EmbeddedStruct struct { + EmbeddedIntField int + NestedEmbeddedStruct `gorm:"embedded"` + } + + type EmbeddedPtrStruct struct { + EmbeddedPtrIntField int + *NestedEmbeddedStruct `gorm:"embedded"` + } + + type Result struct { + BoolField bool + IntField int + Int8Field int8 + Int16Field int16 + Int32Field int32 + Int64Field int64 + UIntField uint + UInt8Field uint8 + UInt16Field uint16 + UInt32Field uint32 + UInt64Field uint64 + Float32Field float32 + Float64Field float64 + StringField string + TimeField time.Time + TimePtrField *time.Time + EmbeddedStruct `gorm:"embedded"` + *EmbeddedPtrStruct `gorm:"embedded"` + } + + currTime := time.Now() + reusedVar := Result{ + BoolField: true, + IntField: 1, + Int8Field: 1, + Int16Field: 1, + Int32Field: 1, + Int64Field: 1, + UIntField: 1, + UInt8Field: 1, + UInt16Field: 1, + UInt32Field: 1, + UInt64Field: 1, + Float32Field: 1.1, + Float64Field: 1.1, + StringField: "hello", + TimeField: currTime, + TimePtrField: &currTime, + EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, + EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, + } + + for rows.Next() { + if err := DB.ScanRows(rows, &reusedVar); err != nil { + t.Errorf("should get no error, but got %v", err) + } + } + + if !reflect.DeepEqual(reusedVar, Result{}) { + t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar) } }