From 51de8ca836a318dc806826e2bf98eb019da9f099 Mon Sep 17 00:00:00 2001 From: Rob Rodriguez Date: Mon, 10 Dec 2018 16:06:56 -0800 Subject: [PATCH] adding final with tests --- callback_create.go | 9 ++++ callback_save.go | 1 - callback_update.go | 16 +++++-- encoder.go | 35 ++++++++++++++ encoder_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++ interface_test.go | 89 ----------------------------------- main_test.go | 6 ++- model_struct.go | 27 ++++++----- scope.go | 47 +++++++----------- 9 files changed, 207 insertions(+), 138 deletions(-) create mode 100644 encoder.go create mode 100644 encoder_test.go delete mode 100644 interface_test.go diff --git a/callback_create.go b/callback_create.go index 2ab05d3b..56126268 100644 --- a/callback_create.go +++ b/callback_create.go @@ -74,6 +74,15 @@ func createCallback(scope *Scope) { placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) } } + } else if field.UseEncoder { + if enc, ok := scope.Value.(Encoder); ok { + if val, err := enc.EncodeField(scope, field.DBName); err == nil { + columns = append(columns, scope.Quote(field.DBName)) + placeholders = append(placeholders, scope.AddToVars(val)) + } else { + scope.Err(err) + } + } } } } diff --git a/callback_save.go b/callback_save.go index 3b4e0589..4d271f3c 100644 --- a/callback_save.go +++ b/callback_save.go @@ -141,7 +141,6 @@ func saveAfterAssociationsCallback(scope *Scope) { default: elem := value.Addr().Interface() newScope := scope.New(elem) - if saveReference { if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { diff --git a/callback_update.go b/callback_update.go index f6ba0ffd..2d4c42a7 100644 --- a/callback_update.go +++ b/callback_update.go @@ -75,9 +75,19 @@ func updateCallback(scope *Scope) { } else { for _, field := range scope.Fields() { if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsPrimaryKey { + if field.IsNormal { + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } + } else if field.UseEncoder { + if enc, ok := scope.Value.(Encoder); ok { + if val, err := enc.EncodeField(scope, field.DBName); err == nil { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(val))) + } else { + scope.Err(err) + } + } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { diff --git a/encoder.go b/encoder.go new file mode 100644 index 00000000..b2c77e49 --- /dev/null +++ b/encoder.go @@ -0,0 +1,35 @@ +package gorm + +// Encoder is a value encoding interface for complex field types +type Encoder interface { + EncodeField(*Scope, string) (interface{}, error) + DecodeField(scope *Scope, column string, value interface{}) error +} + +// decoder defers decoding until necessary +type decoder struct { + Encoder + scope *Scope + column string + value interface{} +} + +func newDecoder(encoder Encoder, scope *Scope, column string) *decoder { + return &decoder{ + encoder, + scope, + column, + nil, + } +} + +// Scan implements the sql.Scanner interface +func (d *decoder) Scan(src interface{}) error { + d.value = src + return nil +} + +// Decode handles the decoding at a later time +func (d *decoder) Decode() error { + return d.DecodeField(d.scope, d.column, d.value) +} diff --git a/encoder_test.go b/encoder_test.go new file mode 100644 index 00000000..c7cdce23 --- /dev/null +++ b/encoder_test.go @@ -0,0 +1,115 @@ +package gorm_test + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/jinzhu/gorm" +) + +type ( + Widget interface { + GetType() string + } + + WidgetUser struct { + User + WidgetType string + Widget Widget `gorm:"use_encoder;column:widget;type:jsonb"` + } + + SimpleWidget struct { + Type string `json:"type"` + Width int64 `json:"width"` + Height int64 `json:"height"` + } + + ComplexWidget struct { + SimpleWidget + Color string `json:"color"` + } +) + +func (m *SimpleWidget) GetType() string { + return "simple" +} + +func (m *ComplexWidget) GetType() string { + return "complex" +} + +func (m *WidgetUser) EncodeField(scope *gorm.Scope, column string) (interface{}, error) { + switch column { + case "widget": + val, err := json.Marshal(m.Widget) + if err != nil { + return nil, err + } + return string(val), nil + } + + return nil, nil +} + +func (m *WidgetUser) DecodeField(scope *gorm.Scope, column string, value interface{}) error { + switch column { + case "widget": + b, ok := value.([]byte) + if !ok { + return errors.New("Invalid type for Widget") + } + switch m.WidgetType { + case "simple": + var result SimpleWidget + if err := json.Unmarshal(b, &result); err != nil { + return err + } + m.Widget = &result + case "complex": + var result ComplexWidget + if err := json.Unmarshal(b, &result); err != nil { + return err + } + m.Widget = &result + default: + return errors.New("unsupported Widget type") + } + } + return nil +} + +func TestEncoder(t *testing.T) { + DB.AutoMigrate(&WidgetUser{}) + + user := &WidgetUser{ + User: User{ + Id: 1, + Name: "bob", + }, + WidgetType: "simple", + Widget: &SimpleWidget{Type: "simple", Width: 12, Height: 10}, + } + + if err := DB.Save(user).Error; err != nil { + t.Errorf("failed to save WidgetUser %v", err) + } + + user1 := WidgetUser{} + + if err := DB.First(&user1, "id=?", 1).Error; err != nil { + t.Errorf("failed to retrieve WidgetUser %v", err) + } + + if user1.Widget.GetType() != "simple" { + t.Errorf("user widget invalid") + } + + if w, ok := user1.Widget.(*SimpleWidget); !ok { + t.Errorf("user widget is not valid") + } else { + if w.Width != 12 || w.Height != 10 { + t.Errorf("user widget is not valid") + } + } +} diff --git a/interface_test.go b/interface_test.go deleted file mode 100644 index f325e1bd..00000000 --- a/interface_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package gorm_test - -import ( - "reflect" - "testing" - - "github.com/kr/pretty" -) - -type ( - UserInterface interface { - UserName() string - UserType() string - } - - UserCommon struct { - Name string - Type string - } - - BasicUser struct { - User - } - - AdminUser struct { - BasicUser - } - - GroupUser struct { - GroupID int64 - User UserInterface - } - - Group struct { - Users []GroupUser - } -) - -func (m *BasicUser) UserName() string { - return m.Name -} - -func (m *BasicUser) Type() string { - return "basic" -} - -func (m *AdminUser) Type() string { - return "admin" -} - -// ScanType returns the scan type for the field -func (m *GroupUser) ScanType(field string) reflect.Type { - switch field { - case "User": - // The geometry data should be encoded as a []byte first - return reflect.TypeOf(User{}) - default: - return reflect.TypeOf(nil) - } -} - -// ScanField handle exporting scanned fields -func (m *GroupUser) ScanField(field string, data interface{}) error { - switch field { - case "User": - m.User = data.(UserInterface) - } - - return nil -} - -var tt *testing.T - -func TestInterface(t *testing.T) { - tt = t - DB.AutoMigrate(&UserCommon{}) - - user1 := UserCommon{Name: "RowUser1", type: "basic"} - - DB.Save(&user1) - - t.Log("loading the users") - users := make([]*UserWrapper, 0) - - if DB.Table("users").Find(&users).Error != nil { - t.Errorf("No errors should happen if set table for find") - } - t.Logf(pretty.Sprint(users)) -} diff --git a/main_test.go b/main_test.go index 94d2fa39..ffda85b3 100644 --- a/main_test.go +++ b/main_test.go @@ -28,7 +28,6 @@ var ( func init() { var err error - if DB, err = OpenTestConnection(); err != nil { panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) } @@ -64,7 +63,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open("mssql", dbDSN) default: fmt.Println("testing sqlite3...") - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) + if dbDSN == "" { + dbDSN = filepath.Join(os.TempDir(), "gorm.db") + } + db, err = gorm.Open("sqlite3", dbDSN) } // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) diff --git a/model_struct.go b/model_struct.go index 2ae21006..2c2a0b4a 100644 --- a/model_struct.go +++ b/model_struct.go @@ -59,7 +59,7 @@ type StructField struct { IsNormal bool IsIgnored bool IsScanner bool - IsInterface bool + UseEncoder bool HasDefaultValue bool Tag reflect.StructTag TagSettings map[string]string @@ -101,7 +101,7 @@ func (structField *StructField) clone() *StructField { IsNormal: structField.IsNormal, IsIgnored: structField.IsIgnored, IsScanner: structField.IsScanner, - IsInterface: structField.IsInterface, + UseEncoder: structField.UseEncoder, HasDefaultValue: structField.HasDefaultValue, Tag: structField.Tag, TagSettings: map[string]string{}, @@ -183,18 +183,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { TagSettings: parseTagSetting(fieldStruct.Tag), } - if !ast.IsExported(fieldStruct.Name) { - if _, ok := field.TagSettingsGet("INTERFACE"); ok { - field.IsInterface = true - } else { - continue - } - } - // is ignored field if _, ok := field.TagSettingsGet("-"); ok { field.IsIgnored = true } else { + if _, ok := field.TagSettingsGet("USE_ENCODER"); ok { + field.UseEncoder = true + } + + // private interface fields can be exported explicitly + if !ast.IsExported(fieldStruct.Name) && !field.UseEncoder { + continue + } + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) @@ -599,10 +600,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } }(field) - case reflect.Interface: - field.IsInterface = true default: - field.IsNormal = true + if !field.UseEncoder { + field.IsNormal = true + } } } diff --git a/scope.go b/scope.go index 8e8758ab..31874ad7 100644 --- a/scope.go +++ b/scope.go @@ -10,8 +10,6 @@ import ( "regexp" "strings" "time" - - "github.com/kr/pretty" ) // Scope contain current operation's information when you perform any operation on the database @@ -482,12 +480,12 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem selectFields []*Field selectedColumnsMap = map[string]int{} resetFields = map[int]*Field{} - interfaceFields = map[string]interface{}{} - rootElem interface{} + encodedFields = make([]*decoder, 0) + elemScope = scope ) if len(elem) > 0 { - rootElem = elem[0] + elemScope = scope.New(elem[0]) } for index, column := range columns { @@ -501,15 +499,11 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem for fieldIndex, field := range selectFields { if field.DBName == column { - if field.IsInterface { - pretty.Log(column) - if i, ok := rootElem.(interface { - ScanType(field string) reflect.Type - }); ok { - t := i.ScanType(field.DBName) - val := reflect.New(t).Interface() - values[index] = val - interfaceFields[field.DBName] = values[index] + if field.UseEncoder { + if enc, ok := elemScope.Value.(Encoder); ok { + dec := newDecoder(enc, elemScope, field.DBName) + values[index] = dec + encodedFields = append(encodedFields, dec) } } else if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() @@ -531,25 +525,18 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem scope.Err(rows.Scan(values...)) - for k, v := range interfaceFields { - if i, ok := elem[0].(interface { - ScanField(field string, data interface{}) error - }); ok { - val := reflect.ValueOf(v) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if err := i.ScanField(k, val.Interface()); err != nil { - fmt.Println(err) - } - } - } - for index, field := range resetFields { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { field.Field.Set(v) } } + + // process the decoders + for _, d := range encodedFields { + if err := d.Decode(); err != nil { + scope.Err(err) + } + } } func (scope *Scope) primaryCondition(value interface{}) string { @@ -1183,7 +1170,7 @@ func (scope *Scope) createTable() *Scope { var primaryKeys []string var primaryKeyInColumnType = false for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { + if field.IsNormal || field.UseEncoder { sqlTag := scope.Dialect().DataTypeOf(field) // Check if the primary key constraint was specified as @@ -1284,7 +1271,7 @@ func (scope *Scope) autoMigrate() *Scope { } else { for _, field := range scope.GetModelStruct().StructFields { if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { + if field.IsNormal || field.UseEncoder { sqlTag := scope.Dialect().DataTypeOf(field) scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() }