fix(scan): update Scan function to reset structs to zero values for each scan (#7061)
Co-authored-by: waleed.masoom <waleed.masoom@wheniwork.com>
This commit is contained in:
		
							parent
							
								
									05167fd591
								
							
						
					
					
						commit
						73a988ceb2
					
				
							
								
								
									
										3
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								scan.go
									
									
									
									
									
								
							| @ -331,6 +331,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | |||||||
| 			} | 			} | ||||||
| 		case reflect.Struct, reflect.Ptr: | 		case reflect.Struct, reflect.Ptr: | ||||||
| 			if initialized || rows.Next() { | 			if initialized || rows.Next() { | ||||||
|  | 				if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct { | ||||||
|  | 					db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) | ||||||
|  | 				} | ||||||
| 				db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) | 				db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) | ||||||
| 			} | 			} | ||||||
| 		default: | 		default: | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"sort" | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| @ -126,7 +127,7 @@ func TestScanRows(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() | 	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Not error should happen, got %v", err) | 		t.Errorf("No error should happen, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	type Result struct { | 	type Result struct { | ||||||
| @ -148,7 +149,7 @@ func TestScanRows(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { | 	if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { | ||||||
| 		t.Errorf("Should find expected results") | 		t.Errorf("Should find expected results, got %+v", results) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var ages int | 	var ages int | ||||||
| @ -158,7 +159,105 @@ func TestScanRows(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	var name string | 	var name string | ||||||
| 	if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { | 	if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { | ||||||
| 		t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) | 		t.Fatalf("failed to scan name, got error %v, name: %v", err, name) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) { | ||||||
|  | 	DB.Save(&User{}) | ||||||
|  | 
 | ||||||
|  | 	rows, err := DB.Table("users"). | ||||||
|  | 		Select(` | ||||||
|  | 			NULL AS bool_field, | ||||||
|  | 			NULL AS int_field, | ||||||
|  | 			NULL AS int8_field, | ||||||
|  | 			NULL AS int16_field, | ||||||
|  | 			NULL AS int32_field, | ||||||
|  | 			NULL AS int64_field, | ||||||
|  | 			NULL AS uint_field, | ||||||
|  | 			NULL AS uint8_field, | ||||||
|  | 			NULL AS uint16_field, | ||||||
|  | 			NULL AS uint32_field, | ||||||
|  | 			NULL AS uint64_field, | ||||||
|  | 			NULL AS float32_field, | ||||||
|  | 			NULL AS float64_field, | ||||||
|  | 			NULL AS string_field, | ||||||
|  | 			NULL AS time_field, | ||||||
|  | 			NULL AS time_ptr_field, | ||||||
|  | 			NULL AS embedded_int_field, | ||||||
|  | 			NULL AS nested_embedded_int_field, | ||||||
|  | 			NULL AS embedded_ptr_int_field | ||||||
|  |         `).Rows() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("No error should happen, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type NestedEmbeddedStruct struct { | ||||||
|  | 		NestedEmbeddedIntField            int | ||||||
|  | 		NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type EmbeddedStruct struct { | ||||||
|  | 		EmbeddedIntField     int | ||||||
|  | 		NestedEmbeddedStruct `gorm:"embedded"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type EmbeddedPtrStruct struct { | ||||||
|  | 		EmbeddedPtrIntField   int | ||||||
|  | 		*NestedEmbeddedStruct `gorm:"embedded"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type Result struct { | ||||||
|  | 		BoolField          bool | ||||||
|  | 		IntField           int | ||||||
|  | 		Int8Field          int8 | ||||||
|  | 		Int16Field         int16 | ||||||
|  | 		Int32Field         int32 | ||||||
|  | 		Int64Field         int64 | ||||||
|  | 		UIntField          uint | ||||||
|  | 		UInt8Field         uint8 | ||||||
|  | 		UInt16Field        uint16 | ||||||
|  | 		UInt32Field        uint32 | ||||||
|  | 		UInt64Field        uint64 | ||||||
|  | 		Float32Field       float32 | ||||||
|  | 		Float64Field       float64 | ||||||
|  | 		StringField        string | ||||||
|  | 		TimeField          time.Time | ||||||
|  | 		TimePtrField       *time.Time | ||||||
|  | 		EmbeddedStruct     `gorm:"embedded"` | ||||||
|  | 		*EmbeddedPtrStruct `gorm:"embedded"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	currTime := time.Now() | ||||||
|  | 	reusedVar := Result{ | ||||||
|  | 		BoolField:         true, | ||||||
|  | 		IntField:          1, | ||||||
|  | 		Int8Field:         1, | ||||||
|  | 		Int16Field:        1, | ||||||
|  | 		Int32Field:        1, | ||||||
|  | 		Int64Field:        1, | ||||||
|  | 		UIntField:         1, | ||||||
|  | 		UInt8Field:        1, | ||||||
|  | 		UInt16Field:       1, | ||||||
|  | 		UInt32Field:       1, | ||||||
|  | 		UInt64Field:       1, | ||||||
|  | 		Float32Field:      1.1, | ||||||
|  | 		Float64Field:      1.1, | ||||||
|  | 		StringField:       "hello", | ||||||
|  | 		TimeField:         currTime, | ||||||
|  | 		TimePtrField:      &currTime, | ||||||
|  | 		EmbeddedStruct:    EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, | ||||||
|  | 		EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		if err := DB.ScanRows(rows, &reusedVar); err != nil { | ||||||
|  | 			t.Errorf("should get no error, but got %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !reflect.DeepEqual(reusedVar, Result{}) { | ||||||
|  | 		t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Waleed Masoom
						Waleed Masoom