From 63534145fda9a2ac9ba703650b1a44da6a03e45e Mon Sep 17 00:00:00 2001 From: aclich <71011237+aclich@users.noreply.github.com> Date: Mon, 15 May 2023 09:59:26 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=90=9B=20embedded=20struct=20test?= =?UTF-8?q?=20failed=20with=20custom=20datatypes=20(#6311)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 🐛 embedded struct test failed with custom datatypes Fix the pointer embedded struct within custom datatypes and *time.time should be nil issue. * fix: 🐛 change test case to avoid mssql driver issue change test cases from bytes to string to avoid mssql driver issue --- schema/field.go | 18 +++----- tests/embedded_struct_test.go | 80 +++++++++++++++++++++++++++++------ 2 files changed, 75 insertions(+), 23 deletions(-) diff --git a/schema/field.go b/schema/field.go index 7d1a1789..dd08e056 100644 --- a/schema/field.go +++ b/schema/field.go @@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: @@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { @@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 3747dad9..4314f88c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,9 @@ import ( "database/sql/driver" "encoding/json" "errors" + "reflect" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -104,10 +106,14 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { } type Author struct { - ID string - Name string - Email string - Age int + ID string + Name string + Email string + Age int + Content Content + ContentPtr *Content + Birthday time.Time + BirthdayPtr *time.Time } type HNPost struct { @@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { if hnPost.Author != nil { t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) } + + now := time.Now().Round(time.Second) + NewPost := HNPost{ + BasePost: &BasePost{Title: "embedded_pointer_type2"}, + Author: &Author{ + Name: "test", + Content: Content{"test"}, + ContentPtr: nil, + Birthday: now, + BirthdayPtr: nil, + }, + } + DB.Create(&NewPost) + + hnPost = HNPost{} + if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != NewPost.Title { + t.Errorf("Should find correct value for embedded pointer type") + } + + if hnPost.Author.Name != NewPost.Author.Name { + t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name) + } + + if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) { + t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content) + } + + if hnPost.Author.ContentPtr != nil { + t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr) + } + + if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() { + t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday) + } + + if hnPost.Author.BirthdayPtr != nil { + t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr) + } } type Content struct { @@ -142,18 +190,26 @@ type Content struct { } func (c Content) Value() (driver.Value, error) { - return json.Marshal(c) + // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530, + b, err := json.Marshal(c) + return string(b[:]), err } func (c *Content) Scan(src interface{}) error { - b, ok := src.([]byte) - if !ok { - return errors.New("Embedded.Scan byte assertion failed") - } - var value Content - if err := json.Unmarshal(b, &value); err != nil { - return err + str, ok := src.(string) + if !ok { + byt, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + if err := json.Unmarshal(byt, &value); err != nil { + return err + } + } else { + if err := json.Unmarshal([]byte(str), &value); err != nil { + return err + } } *c = value