fix(scan): update Scan function to reset structs to zero values for each scan
This commit is contained in:
		
							parent
							
								
									5e599a07ec
								
							
						
					
					
						commit
						8956449e75
					
				
							
								
								
									
										3
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								scan.go
									
									
									
									
									
								
							| @ -331,6 +331,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 			} | ||||
| 		case reflect.Struct, reflect.Ptr: | ||||
| 			if initialized || rows.Next() { | ||||
| 				if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct { | ||||
| 					db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) | ||||
| 				} | ||||
| 				db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) | ||||
| 			} | ||||
| 		default: | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| @ -126,7 +127,7 @@ func TestScanRows(t *testing.T) { | ||||
| 
 | ||||
| 	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Not error should happen, got %v", err) | ||||
| 		t.Errorf("No error should happen, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	type Result struct { | ||||
| @ -148,7 +149,7 @@ func TestScanRows(t *testing.T) { | ||||
| 	}) | ||||
| 
 | ||||
| 	if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { | ||||
| 		t.Errorf("Should find expected results") | ||||
| 		t.Errorf("Should find expected results, got %+v", results) | ||||
| 	} | ||||
| 
 | ||||
| 	var ages int | ||||
| @ -158,7 +159,105 @@ func TestScanRows(t *testing.T) { | ||||
| 
 | ||||
| 	var name string | ||||
| 	if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { | ||||
| 		t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) | ||||
| 		t.Fatalf("failed to scan name, got error %v, name: %v", err, name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) { | ||||
| 	DB.Save(&User{}) | ||||
| 
 | ||||
| 	rows, err := DB.Table("users"). | ||||
| 		Select(` | ||||
| 			NULL AS bool_field, | ||||
| 			NULL AS int_field, | ||||
| 			NULL AS int8_field, | ||||
| 			NULL AS int16_field, | ||||
| 			NULL AS int32_field, | ||||
| 			NULL AS int64_field, | ||||
| 			NULL AS uint_field, | ||||
| 			NULL AS uint8_field, | ||||
| 			NULL AS uint16_field, | ||||
| 			NULL AS uint32_field, | ||||
| 			NULL AS uint64_field, | ||||
| 			NULL AS float32_field, | ||||
| 			NULL AS float64_field, | ||||
| 			NULL AS string_field, | ||||
| 			NULL AS time_field, | ||||
| 			NULL AS time_ptr_field, | ||||
| 			NULL AS embedded_int_field, | ||||
| 			NULL AS nested_embedded_int_field, | ||||
| 			NULL AS embedded_ptr_int_field | ||||
|         `).Rows() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("No error should happen, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	type NestedEmbeddedStruct struct { | ||||
| 		NestedEmbeddedIntField            int | ||||
| 		NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"` | ||||
| 	} | ||||
| 
 | ||||
| 	type EmbeddedStruct struct { | ||||
| 		EmbeddedIntField     int | ||||
| 		NestedEmbeddedStruct `gorm:"embedded"` | ||||
| 	} | ||||
| 
 | ||||
| 	type EmbeddedPtrStruct struct { | ||||
| 		EmbeddedPtrIntField   int | ||||
| 		*NestedEmbeddedStruct `gorm:"embedded"` | ||||
| 	} | ||||
| 
 | ||||
| 	type Result struct { | ||||
| 		BoolField          bool | ||||
| 		IntField           int | ||||
| 		Int8Field          int8 | ||||
| 		Int16Field         int16 | ||||
| 		Int32Field         int32 | ||||
| 		Int64Field         int64 | ||||
| 		UIntField          uint | ||||
| 		UInt8Field         uint8 | ||||
| 		UInt16Field        uint16 | ||||
| 		UInt32Field        uint32 | ||||
| 		UInt64Field        uint64 | ||||
| 		Float32Field       float32 | ||||
| 		Float64Field       float64 | ||||
| 		StringField        string | ||||
| 		TimeField          time.Time | ||||
| 		TimePtrField       *time.Time | ||||
| 		EmbeddedStruct     `gorm:"embedded"` | ||||
| 		*EmbeddedPtrStruct `gorm:"embedded"` | ||||
| 	} | ||||
| 
 | ||||
| 	currTime := time.Now() | ||||
| 	reusedVar := Result{ | ||||
| 		BoolField:         true, | ||||
| 		IntField:          1, | ||||
| 		Int8Field:         1, | ||||
| 		Int16Field:        1, | ||||
| 		Int32Field:        1, | ||||
| 		Int64Field:        1, | ||||
| 		UIntField:         1, | ||||
| 		UInt8Field:        1, | ||||
| 		UInt16Field:       1, | ||||
| 		UInt32Field:       1, | ||||
| 		UInt64Field:       1, | ||||
| 		Float32Field:      1.1, | ||||
| 		Float64Field:      1.1, | ||||
| 		StringField:       "hello", | ||||
| 		TimeField:         currTime, | ||||
| 		TimePtrField:      &currTime, | ||||
| 		EmbeddedStruct:    EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, | ||||
| 		EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, | ||||
| 	} | ||||
| 
 | ||||
| 	for rows.Next() { | ||||
| 		if err := DB.ScanRows(rows, &reusedVar); err != nil { | ||||
| 			t.Errorf("should get no error, but got %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !reflect.DeepEqual(reusedVar, Result{}) { | ||||
| 		t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 waleed.masoom
						waleed.masoom