adding final with tests

This commit is contained in:
Rob Rodriguez 2018-12-10 16:06:56 -08:00
parent 89ac376a2f
commit 51de8ca836
9 changed files with 207 additions and 138 deletions

View File

@ -74,6 +74,15 @@ func createCallback(scope *Scope) {
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
} }
} }
} else if field.UseEncoder {
if enc, ok := scope.Value.(Encoder); ok {
if val, err := enc.EncodeField(scope, field.DBName); err == nil {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(val))
} else {
scope.Err(err)
}
}
} }
} }
} }

View File

@ -141,7 +141,6 @@ func saveAfterAssociationsCallback(scope *Scope) {
default: default:
elem := value.Addr().Interface() elem := value.Addr().Interface()
newScope := scope.New(elem) newScope := scope.New(elem)
if saveReference { if saveReference {
if len(relationship.ForeignFieldNames) != 0 { if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames { for idx, fieldName := range relationship.ForeignFieldNames {

View File

@ -75,10 +75,20 @@ func updateCallback(scope *Scope) {
} else { } else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if scope.changeableField(field) { if scope.changeableField(field) {
if !field.IsPrimaryKey && field.IsNormal { if !field.IsPrimaryKey {
if field.IsNormal {
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }
} else if field.UseEncoder {
if enc, ok := scope.Value.(Encoder); ok {
if val, err := enc.EncodeField(scope, field.DBName); err == nil {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(val)))
} else {
scope.Err(err)
}
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, foreignKey := range relationship.ForeignDBNames { for _, foreignKey := range relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {

35
encoder.go Normal file
View File

@ -0,0 +1,35 @@
package gorm
// Encoder is a value encoding interface for complex field types
type Encoder interface {
EncodeField(*Scope, string) (interface{}, error)
DecodeField(scope *Scope, column string, value interface{}) error
}
// decoder defers decoding until necessary
type decoder struct {
Encoder
scope *Scope
column string
value interface{}
}
func newDecoder(encoder Encoder, scope *Scope, column string) *decoder {
return &decoder{
encoder,
scope,
column,
nil,
}
}
// Scan implements the sql.Scanner interface
func (d *decoder) Scan(src interface{}) error {
d.value = src
return nil
}
// Decode handles the decoding at a later time
func (d *decoder) Decode() error {
return d.DecodeField(d.scope, d.column, d.value)
}

115
encoder_test.go Normal file
View File

@ -0,0 +1,115 @@
package gorm_test
import (
"encoding/json"
"errors"
"testing"
"github.com/jinzhu/gorm"
)
type (
Widget interface {
GetType() string
}
WidgetUser struct {
User
WidgetType string
Widget Widget `gorm:"use_encoder;column:widget;type:jsonb"`
}
SimpleWidget struct {
Type string `json:"type"`
Width int64 `json:"width"`
Height int64 `json:"height"`
}
ComplexWidget struct {
SimpleWidget
Color string `json:"color"`
}
)
func (m *SimpleWidget) GetType() string {
return "simple"
}
func (m *ComplexWidget) GetType() string {
return "complex"
}
func (m *WidgetUser) EncodeField(scope *gorm.Scope, column string) (interface{}, error) {
switch column {
case "widget":
val, err := json.Marshal(m.Widget)
if err != nil {
return nil, err
}
return string(val), nil
}
return nil, nil
}
func (m *WidgetUser) DecodeField(scope *gorm.Scope, column string, value interface{}) error {
switch column {
case "widget":
b, ok := value.([]byte)
if !ok {
return errors.New("Invalid type for Widget")
}
switch m.WidgetType {
case "simple":
var result SimpleWidget
if err := json.Unmarshal(b, &result); err != nil {
return err
}
m.Widget = &result
case "complex":
var result ComplexWidget
if err := json.Unmarshal(b, &result); err != nil {
return err
}
m.Widget = &result
default:
return errors.New("unsupported Widget type")
}
}
return nil
}
func TestEncoder(t *testing.T) {
DB.AutoMigrate(&WidgetUser{})
user := &WidgetUser{
User: User{
Id: 1,
Name: "bob",
},
WidgetType: "simple",
Widget: &SimpleWidget{Type: "simple", Width: 12, Height: 10},
}
if err := DB.Save(user).Error; err != nil {
t.Errorf("failed to save WidgetUser %v", err)
}
user1 := WidgetUser{}
if err := DB.First(&user1, "id=?", 1).Error; err != nil {
t.Errorf("failed to retrieve WidgetUser %v", err)
}
if user1.Widget.GetType() != "simple" {
t.Errorf("user widget invalid")
}
if w, ok := user1.Widget.(*SimpleWidget); !ok {
t.Errorf("user widget is not valid")
} else {
if w.Width != 12 || w.Height != 10 {
t.Errorf("user widget is not valid")
}
}
}

View File

@ -1,89 +0,0 @@
package gorm_test
import (
"reflect"
"testing"
"github.com/kr/pretty"
)
type (
UserInterface interface {
UserName() string
UserType() string
}
UserCommon struct {
Name string
Type string
}
BasicUser struct {
User
}
AdminUser struct {
BasicUser
}
GroupUser struct {
GroupID int64
User UserInterface
}
Group struct {
Users []GroupUser
}
)
func (m *BasicUser) UserName() string {
return m.Name
}
func (m *BasicUser) Type() string {
return "basic"
}
func (m *AdminUser) Type() string {
return "admin"
}
// ScanType returns the scan type for the field
func (m *GroupUser) ScanType(field string) reflect.Type {
switch field {
case "User":
// The geometry data should be encoded as a []byte first
return reflect.TypeOf(User{})
default:
return reflect.TypeOf(nil)
}
}
// ScanField handle exporting scanned fields
func (m *GroupUser) ScanField(field string, data interface{}) error {
switch field {
case "User":
m.User = data.(UserInterface)
}
return nil
}
var tt *testing.T
func TestInterface(t *testing.T) {
tt = t
DB.AutoMigrate(&UserCommon{})
user1 := UserCommon{Name: "RowUser1", type: "basic"}
DB.Save(&user1)
t.Log("loading the users")
users := make([]*UserWrapper, 0)
if DB.Table("users").Find(&users).Error != nil {
t.Errorf("No errors should happen if set table for find")
}
t.Logf(pretty.Sprint(users))
}

View File

@ -28,7 +28,6 @@ var (
func init() { func init() {
var err error var err error
if DB, err = OpenTestConnection(); err != nil { if DB, err = OpenTestConnection(); err != nil {
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
} }
@ -64,7 +63,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
db, err = gorm.Open("mssql", dbDSN) db, err = gorm.Open("mssql", dbDSN)
default: default:
fmt.Println("testing sqlite3...") fmt.Println("testing sqlite3...")
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) if dbDSN == "" {
dbDSN = filepath.Join(os.TempDir(), "gorm.db")
}
db, err = gorm.Open("sqlite3", dbDSN)
} }
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})

View File

@ -59,7 +59,7 @@ type StructField struct {
IsNormal bool IsNormal bool
IsIgnored bool IsIgnored bool
IsScanner bool IsScanner bool
IsInterface bool UseEncoder bool
HasDefaultValue bool HasDefaultValue bool
Tag reflect.StructTag Tag reflect.StructTag
TagSettings map[string]string TagSettings map[string]string
@ -101,7 +101,7 @@ func (structField *StructField) clone() *StructField {
IsNormal: structField.IsNormal, IsNormal: structField.IsNormal,
IsIgnored: structField.IsIgnored, IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner, IsScanner: structField.IsScanner,
IsInterface: structField.IsInterface, UseEncoder: structField.UseEncoder,
HasDefaultValue: structField.HasDefaultValue, HasDefaultValue: structField.HasDefaultValue,
Tag: structField.Tag, Tag: structField.Tag,
TagSettings: map[string]string{}, TagSettings: map[string]string{},
@ -183,18 +183,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
TagSettings: parseTagSetting(fieldStruct.Tag), TagSettings: parseTagSetting(fieldStruct.Tag),
} }
if !ast.IsExported(fieldStruct.Name) {
if _, ok := field.TagSettingsGet("INTERFACE"); ok {
field.IsInterface = true
} else {
continue
}
}
// is ignored field // is ignored field
if _, ok := field.TagSettingsGet("-"); ok { if _, ok := field.TagSettingsGet("-"); ok {
field.IsIgnored = true field.IsIgnored = true
} else { } else {
if _, ok := field.TagSettingsGet("USE_ENCODER"); ok {
field.UseEncoder = true
}
// private interface fields can be exported explicitly
if !ast.IsExported(fieldStruct.Name) && !field.UseEncoder {
continue
}
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
field.IsPrimaryKey = true field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
@ -599,12 +600,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
}(field) }(field)
case reflect.Interface:
field.IsInterface = true
default: default:
if !field.UseEncoder {
field.IsNormal = true field.IsNormal = true
} }
} }
}
// Even it is ignored, also possible to decode db value into the field // Even it is ignored, also possible to decode db value into the field
if value, ok := field.TagSettingsGet("COLUMN"); ok { if value, ok := field.TagSettingsGet("COLUMN"); ok {

View File

@ -10,8 +10,6 @@ import (
"regexp" "regexp"
"strings" "strings"
"time" "time"
"github.com/kr/pretty"
) )
// Scope contain current operation's information when you perform any operation on the database // Scope contain current operation's information when you perform any operation on the database
@ -482,12 +480,12 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem
selectFields []*Field selectFields []*Field
selectedColumnsMap = map[string]int{} selectedColumnsMap = map[string]int{}
resetFields = map[int]*Field{} resetFields = map[int]*Field{}
interfaceFields = map[string]interface{}{} encodedFields = make([]*decoder, 0)
rootElem interface{} elemScope = scope
) )
if len(elem) > 0 { if len(elem) > 0 {
rootElem = elem[0] elemScope = scope.New(elem[0])
} }
for index, column := range columns { for index, column := range columns {
@ -501,15 +499,11 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem
for fieldIndex, field := range selectFields { for fieldIndex, field := range selectFields {
if field.DBName == column { if field.DBName == column {
if field.IsInterface { if field.UseEncoder {
pretty.Log(column) if enc, ok := elemScope.Value.(Encoder); ok {
if i, ok := rootElem.(interface { dec := newDecoder(enc, elemScope, field.DBName)
ScanType(field string) reflect.Type values[index] = dec
}); ok { encodedFields = append(encodedFields, dec)
t := i.ScanType(field.DBName)
val := reflect.New(t).Interface()
values[index] = val
interfaceFields[field.DBName] = values[index]
} }
} else if field.Field.Kind() == reflect.Ptr { } else if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface() values[index] = field.Field.Addr().Interface()
@ -531,25 +525,18 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem
scope.Err(rows.Scan(values...)) scope.Err(rows.Scan(values...))
for k, v := range interfaceFields {
if i, ok := elem[0].(interface {
ScanField(field string, data interface{}) error
}); ok {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if err := i.ScanField(k, val.Interface()); err != nil {
fmt.Println(err)
}
}
}
for index, field := range resetFields { for index, field := range resetFields {
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
field.Field.Set(v) field.Field.Set(v)
} }
} }
// process the decoders
for _, d := range encodedFields {
if err := d.Decode(); err != nil {
scope.Err(err)
}
}
} }
func (scope *Scope) primaryCondition(value interface{}) string { func (scope *Scope) primaryCondition(value interface{}) string {
@ -1183,7 +1170,7 @@ func (scope *Scope) createTable() *Scope {
var primaryKeys []string var primaryKeys []string
var primaryKeyInColumnType = false var primaryKeyInColumnType = false
for _, field := range scope.GetModelStruct().StructFields { for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal { if field.IsNormal || field.UseEncoder {
sqlTag := scope.Dialect().DataTypeOf(field) sqlTag := scope.Dialect().DataTypeOf(field)
// Check if the primary key constraint was specified as // Check if the primary key constraint was specified as
@ -1284,7 +1271,7 @@ func (scope *Scope) autoMigrate() *Scope {
} else { } else {
for _, field := range scope.GetModelStruct().StructFields { for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(tableName, field.DBName) { if !scope.Dialect().HasColumn(tableName, field.DBName) {
if field.IsNormal { if field.IsNormal || field.UseEncoder {
sqlTag := scope.Dialect().DataTypeOf(field) sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
} }