Fix Scan with interface
This commit is contained in:
		
							parent
							
								
									61b018cb94
								
							
						
					
					
						commit
						12bbde89e6
					
				| @ -506,7 +506,12 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { | ||||
| 	tx.Statement.Dest = dest | ||||
| 	tx.Statement.ReflectValue = reflect.ValueOf(dest) | ||||
| 	for tx.Statement.ReflectValue.Kind() == reflect.Ptr { | ||||
| 		tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() | ||||
| 		elem := tx.Statement.ReflectValue.Elem() | ||||
| 		if !elem.IsValid() { | ||||
| 			elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) | ||||
| 			tx.Statement.ReflectValue.Set(elem) | ||||
| 		} | ||||
| 		tx.Statement.ReflectValue = elem | ||||
| 	} | ||||
| 	Scan(rows, tx, true) | ||||
| 	return tx.Error | ||||
|  | ||||
							
								
								
									
										20
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								scan.go
									
									
									
									
									
								
							| @ -97,11 +97,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 		} | ||||
| 	default: | ||||
| 		Schema := db.Statement.Schema | ||||
| 		reflectValue := db.Statement.ReflectValue | ||||
| 		if reflectValue.Kind() == reflect.Interface { | ||||
| 			reflectValue = reflectValue.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		switch db.Statement.ReflectValue.Kind() { | ||||
| 		switch reflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var ( | ||||
| 				reflectValueType = db.Statement.ReflectValue.Type().Elem() | ||||
| 				reflectValueType = reflectValue.Type().Elem() | ||||
| 				isPtr            = reflectValueType.Kind() == reflect.Ptr | ||||
| 				fields           = make([]*schema.Field, len(columns)) | ||||
| 				joinFields       [][2]*schema.Field | ||||
| @ -111,7 +115,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 				reflectValueType = reflectValueType.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) | ||||
| 			db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) | ||||
| 
 | ||||
| 			if Schema != nil { | ||||
| 				if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { | ||||
| @ -186,13 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 				} | ||||
| 
 | ||||
| 				if isPtr { | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem)) | ||||
| 				} else { | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem())) | ||||
| 				} | ||||
| 			} | ||||
| 		case reflect.Struct, reflect.Ptr: | ||||
| 			if db.Statement.ReflectValue.Type() != Schema.ModelType { | ||||
| 			if reflectValue.Type() != Schema.ModelType { | ||||
| 				Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) | ||||
| 			} | ||||
| 
 | ||||
| @ -220,11 +224,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 
 | ||||
| 				for idx, column := range columns { | ||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 						field.Set(db.Statement.ReflectValue, values[idx]) | ||||
| 						field.Set(reflectValue, values[idx]) | ||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { | ||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 								relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) | ||||
| 								relValue := rel.Field.ReflectValueOf(reflectValue) | ||||
| 								value := reflect.ValueOf(values[idx]).Elem() | ||||
| 
 | ||||
| 								if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||
|  | ||||
| @ -77,7 +77,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | ||||
| 		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||
| 	} | ||||
| 
 | ||||
| 	modelType := reflect.ValueOf(dest).Type() | ||||
| 	modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() | ||||
| 	if modelType.Kind() == reflect.Interface { | ||||
| 		modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() | ||||
| 	} | ||||
| 
 | ||||
| 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { | ||||
| 		modelType = modelType.Elem() | ||||
| 	} | ||||
|  | ||||
| @ -29,8 +29,9 @@ func TestScan(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	var resPointer *result | ||||
| 	DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer) | ||||
| 	if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { | ||||
| 	if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil { | ||||
| 		t.Fatalf("Failed to query with pointer of value, got error %v", err) | ||||
| 	} else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) { | ||||
| 		t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) | ||||
| 	} | ||||
| 
 | ||||
| @ -70,6 +71,38 @@ func TestScan(t *testing.T) { | ||||
| 	if uint(id) != user2.ID { | ||||
| 		t.Errorf("Failed to scan to customized data type") | ||||
| 	} | ||||
| 
 | ||||
| 	var resInt interface{} | ||||
| 	resInt = &User{} | ||||
| 	if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil { | ||||
| 		t.Fatalf("Failed to query with pointer of value, got error %v", err) | ||||
| 	} else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age { | ||||
| 		t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3) | ||||
| 	} | ||||
| 
 | ||||
| 	var resInt2 interface{} | ||||
| 	resInt2 = &User{} | ||||
| 	if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil { | ||||
| 		t.Fatalf("Failed to query with pointer of value, got error %v", err) | ||||
| 	} else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age { | ||||
| 		t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3) | ||||
| 	} | ||||
| 
 | ||||
| 	var resInt3 interface{} | ||||
| 	resInt3 = []User{} | ||||
| 	if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil { | ||||
| 		t.Fatalf("Failed to query with pointer of value, got error %v", err) | ||||
| 	} else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { | ||||
| 		t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3) | ||||
| 	} | ||||
| 
 | ||||
| 	var resInt4 interface{} | ||||
| 	resInt4 = []User{} | ||||
| 	if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil { | ||||
| 		t.Fatalf("Failed to query with pointer of value, got error %v", err) | ||||
| 	} else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { | ||||
| 		t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestScanRows(t *testing.T) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu