From 8039ba0eb103b21cfd285fc8f4af855159c8507b Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Mon, 13 Jun 2022 13:32:08 +0800 Subject: [PATCH] fix: can not set field in-place in join --- scan.go | 17 +++++++++++------ tests/query_test.go | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/scan.go b/scan.go index 1bb51560..6250fb57 100644 --- a/scan.go +++ b/scan.go @@ -66,18 +66,23 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) + joinedSchemaMap := make(map[*schema.Field]interface{}, 0) for idx, field := range fields { if field != nil { if len(joinFields) == 0 || joinFields[idx][0] == nil { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { - relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } } db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } diff --git a/tests/query_test.go b/tests/query_test.go index e0ee1c95..253d8409 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1260,6 +1260,11 @@ func TestQueryScannerWithSingleColumn(t *testing.T) { } func TestQueryResetNullValue(t *testing.T) { + type QueryResetItem struct { + ID string `gorm:"type:varchar(5)"` + Name string + } + type QueryResetNullValue struct { ID int Name string `gorm:"default:NULL"` @@ -1268,10 +1273,14 @@ func TestQueryResetNullValue(t *testing.T) { Number2 uint64 `gorm:"default:NULL"` Number3 float64 `gorm:"default:NULL"` Now *time.Time `gorm:"defalut:NULL"` + Item1Id string + Item1 *QueryResetItem `gorm:"references:ID"` + Item2Id string + Item2 *QueryResetItem `gorm:"references:ID"` } - DB.Migrator().DropTable(&QueryResetNullValue{}) - DB.AutoMigrate(&QueryResetNullValue{}) + DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{}) + DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{}) now := time.Now() q1 := QueryResetNullValue{ @@ -1281,9 +1290,26 @@ func TestQueryResetNullValue(t *testing.T) { Number2: 200, Number3: 300.1, Now: &now, + Item1: &QueryResetItem{ + ID: "u_1_1", + Name: "item_1_1", + }, + Item2: &QueryResetItem{ + ID: "u_1_2", + Name: "item_1_2", + }, } - q2 := QueryResetNullValue{} + q2 := QueryResetNullValue{ + Item1: &QueryResetItem{ + ID: "u_2_1", + Name: "item_2_1", + }, + Item2: &QueryResetItem{ + ID: "u_2_2", + Name: "item_2_2", + }, + } var err error err = DB.Create(&q1).Error @@ -1297,7 +1323,7 @@ func TestQueryResetNullValue(t *testing.T) { } var qs []QueryResetNullValue - err = DB.Find(&qs).Error + err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error if err != nil { t.Errorf("failed to find:%v", err) }