diff --git a/schema/field.go b/schema/field.go index d11c555e..dfcc1b15 100644 --- a/schema/field.go +++ b/schema/field.go @@ -432,7 +432,7 @@ func (field *Field) setupValuerAndSetter() { New: func() interface{} { return &Serializer{ Field: field, - Interface: reflect.New(reflect.ValueOf(field.Serializer).Type()).Interface().(SerializerInterface), + Interface: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), } }, } @@ -489,14 +489,14 @@ func (field *Field) setupValuerAndSetter() { return value, zero } - serializer, ok := value.(SerializerInterface) + serializer, ok := value.(SerializerValuerInterface) if !ok { serializer = field.Serializer } return Serializer{ Field: field, - Interface: serializer, + Valuer: serializer, Destination: v, Context: ctx, fieldValue: value, @@ -564,6 +564,9 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return } else { err = setter(ctx, value, reflectV.Elem().Interface()) } @@ -585,7 +588,7 @@ func (field *Field) setupValuerAndSetter() { if serializer, ok := v.(*Serializer); ok { serializer.Interface.Scan(ctx, field, value, serializer.value) fallbackSetter(ctx, value, serializer.Interface, field.Set) - serializer.Interface = reflect.New(reflect.ValueOf(field.Serializer).Type()).Interface().(SerializerInterface) + serializer.Interface = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } return nil } diff --git a/schema/interfaces.go b/schema/interfaces.go index b0ad19b4..57d353a4 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -26,6 +26,7 @@ type FieldNewValuePool interface { type Serializer struct { Field *Field Interface SerializerInterface + Valuer SerializerValuerInterface Destination reflect.Value Context context.Context value interface{} @@ -40,12 +41,17 @@ func (s *Serializer) Scan(value interface{}) error { // Value implements driver.Valuer interface func (s Serializer) Value() (driver.Value, error) { - return s.Interface.Value(s.Context, s.Field, s.Destination, s.fieldValue) + return s.Valuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) } // SerializerInterface serializer interface type SerializerInterface interface { Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) } diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 29ac5c8f..4abd4ac6 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -1,29 +1,55 @@ package tests_test import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" "testing" "gorm.io/gorm" + "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) type SerializerStruct struct { gorm.Model - Name []byte `gorm:"json"` - Roles Roles `gorm:"json"` + Name []byte `gorm:"json"` + Roles Roles `gorm:"json"` + EncryptedString EncryptedString } type Roles []string +type EncryptedString string -func TestSerializerJSON(t *testing.T) { +func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) + case string: + *es = EncryptedString(strings.TrimPrefix(value, "hello")) + default: + return fmt.Errorf("unsupported data %v", dbValue) + } + return nil +} + +// Value implements serializer interface +func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return "hello" + string(es), nil +} + +func TestSerializer(t *testing.T) { DB.Migrator().DropTable(&SerializerStruct{}) if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } data := SerializerStruct{ - Name: []byte("jinzhu"), - Roles: []string{"r1", "r2"}, + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + EncryptedString: EncryptedString("pass"), } if err := DB.Create(&data).Error; err != nil { @@ -31,7 +57,9 @@ func TestSerializerJSON(t *testing.T) { } var result SerializerStruct - DB.First(&result, data.ID) + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } AssertEqual(t, result, data) }