adding final with tests
This commit is contained in:
parent
89ac376a2f
commit
51de8ca836
@ -74,6 +74,15 @@ func createCallback(scope *Scope) {
|
|||||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if field.UseEncoder {
|
||||||
|
if enc, ok := scope.Value.(Encoder); ok {
|
||||||
|
if val, err := enc.EncodeField(scope, field.DBName); err == nil {
|
||||||
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
|
placeholders = append(placeholders, scope.AddToVars(val))
|
||||||
|
} else {
|
||||||
|
scope.Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,6 @@ func saveAfterAssociationsCallback(scope *Scope) {
|
|||||||
default:
|
default:
|
||||||
elem := value.Addr().Interface()
|
elem := value.Addr().Interface()
|
||||||
newScope := scope.New(elem)
|
newScope := scope.New(elem)
|
||||||
|
|
||||||
if saveReference {
|
if saveReference {
|
||||||
if len(relationship.ForeignFieldNames) != 0 {
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
|
@ -75,10 +75,20 @@ func updateCallback(scope *Scope) {
|
|||||||
} else {
|
} else {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if !field.IsPrimaryKey && field.IsNormal {
|
if !field.IsPrimaryKey {
|
||||||
|
if field.IsNormal {
|
||||||
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
|
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
}
|
}
|
||||||
|
} else if field.UseEncoder {
|
||||||
|
if enc, ok := scope.Value.(Encoder); ok {
|
||||||
|
if val, err := enc.EncodeField(scope, field.DBName); err == nil {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(val)))
|
||||||
|
} else {
|
||||||
|
scope.Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
|
35
encoder.go
Normal file
35
encoder.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
// Encoder is a value encoding interface for complex field types
|
||||||
|
type Encoder interface {
|
||||||
|
EncodeField(*Scope, string) (interface{}, error)
|
||||||
|
DecodeField(scope *Scope, column string, value interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// decoder defers decoding until necessary
|
||||||
|
type decoder struct {
|
||||||
|
Encoder
|
||||||
|
scope *Scope
|
||||||
|
column string
|
||||||
|
value interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDecoder(encoder Encoder, scope *Scope, column string) *decoder {
|
||||||
|
return &decoder{
|
||||||
|
encoder,
|
||||||
|
scope,
|
||||||
|
column,
|
||||||
|
nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the sql.Scanner interface
|
||||||
|
func (d *decoder) Scan(src interface{}) error {
|
||||||
|
d.value = src
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode handles the decoding at a later time
|
||||||
|
func (d *decoder) Decode() error {
|
||||||
|
return d.DecodeField(d.scope, d.column, d.value)
|
||||||
|
}
|
115
encoder_test.go
Normal file
115
encoder_test.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Widget interface {
|
||||||
|
GetType() string
|
||||||
|
}
|
||||||
|
|
||||||
|
WidgetUser struct {
|
||||||
|
User
|
||||||
|
WidgetType string
|
||||||
|
Widget Widget `gorm:"use_encoder;column:widget;type:jsonb"`
|
||||||
|
}
|
||||||
|
|
||||||
|
SimpleWidget struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Width int64 `json:"width"`
|
||||||
|
Height int64 `json:"height"`
|
||||||
|
}
|
||||||
|
|
||||||
|
ComplexWidget struct {
|
||||||
|
SimpleWidget
|
||||||
|
Color string `json:"color"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *SimpleWidget) GetType() string {
|
||||||
|
return "simple"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ComplexWidget) GetType() string {
|
||||||
|
return "complex"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WidgetUser) EncodeField(scope *gorm.Scope, column string) (interface{}, error) {
|
||||||
|
switch column {
|
||||||
|
case "widget":
|
||||||
|
val, err := json.Marshal(m.Widget)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return string(val), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WidgetUser) DecodeField(scope *gorm.Scope, column string, value interface{}) error {
|
||||||
|
switch column {
|
||||||
|
case "widget":
|
||||||
|
b, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("Invalid type for Widget")
|
||||||
|
}
|
||||||
|
switch m.WidgetType {
|
||||||
|
case "simple":
|
||||||
|
var result SimpleWidget
|
||||||
|
if err := json.Unmarshal(b, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Widget = &result
|
||||||
|
case "complex":
|
||||||
|
var result ComplexWidget
|
||||||
|
if err := json.Unmarshal(b, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Widget = &result
|
||||||
|
default:
|
||||||
|
return errors.New("unsupported Widget type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncoder(t *testing.T) {
|
||||||
|
DB.AutoMigrate(&WidgetUser{})
|
||||||
|
|
||||||
|
user := &WidgetUser{
|
||||||
|
User: User{
|
||||||
|
Id: 1,
|
||||||
|
Name: "bob",
|
||||||
|
},
|
||||||
|
WidgetType: "simple",
|
||||||
|
Widget: &SimpleWidget{Type: "simple", Width: 12, Height: 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(user).Error; err != nil {
|
||||||
|
t.Errorf("failed to save WidgetUser %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 := WidgetUser{}
|
||||||
|
|
||||||
|
if err := DB.First(&user1, "id=?", 1).Error; err != nil {
|
||||||
|
t.Errorf("failed to retrieve WidgetUser %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user1.Widget.GetType() != "simple" {
|
||||||
|
t.Errorf("user widget invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if w, ok := user1.Widget.(*SimpleWidget); !ok {
|
||||||
|
t.Errorf("user widget is not valid")
|
||||||
|
} else {
|
||||||
|
if w.Width != 12 || w.Height != 10 {
|
||||||
|
t.Errorf("user widget is not valid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,89 +0,0 @@
|
|||||||
package gorm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/kr/pretty"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
UserInterface interface {
|
|
||||||
UserName() string
|
|
||||||
UserType() string
|
|
||||||
}
|
|
||||||
|
|
||||||
UserCommon struct {
|
|
||||||
Name string
|
|
||||||
Type string
|
|
||||||
}
|
|
||||||
|
|
||||||
BasicUser struct {
|
|
||||||
User
|
|
||||||
}
|
|
||||||
|
|
||||||
AdminUser struct {
|
|
||||||
BasicUser
|
|
||||||
}
|
|
||||||
|
|
||||||
GroupUser struct {
|
|
||||||
GroupID int64
|
|
||||||
User UserInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
Group struct {
|
|
||||||
Users []GroupUser
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *BasicUser) UserName() string {
|
|
||||||
return m.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *BasicUser) Type() string {
|
|
||||||
return "basic"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AdminUser) Type() string {
|
|
||||||
return "admin"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanType returns the scan type for the field
|
|
||||||
func (m *GroupUser) ScanType(field string) reflect.Type {
|
|
||||||
switch field {
|
|
||||||
case "User":
|
|
||||||
// The geometry data should be encoded as a []byte first
|
|
||||||
return reflect.TypeOf(User{})
|
|
||||||
default:
|
|
||||||
return reflect.TypeOf(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanField handle exporting scanned fields
|
|
||||||
func (m *GroupUser) ScanField(field string, data interface{}) error {
|
|
||||||
switch field {
|
|
||||||
case "User":
|
|
||||||
m.User = data.(UserInterface)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var tt *testing.T
|
|
||||||
|
|
||||||
func TestInterface(t *testing.T) {
|
|
||||||
tt = t
|
|
||||||
DB.AutoMigrate(&UserCommon{})
|
|
||||||
|
|
||||||
user1 := UserCommon{Name: "RowUser1", type: "basic"}
|
|
||||||
|
|
||||||
DB.Save(&user1)
|
|
||||||
|
|
||||||
t.Log("loading the users")
|
|
||||||
users := make([]*UserWrapper, 0)
|
|
||||||
|
|
||||||
if DB.Table("users").Find(&users).Error != nil {
|
|
||||||
t.Errorf("No errors should happen if set table for find")
|
|
||||||
}
|
|
||||||
t.Logf(pretty.Sprint(users))
|
|
||||||
}
|
|
@ -28,7 +28,6 @@ var (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if DB, err = OpenTestConnection(); err != nil {
|
if DB, err = OpenTestConnection(); err != nil {
|
||||||
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
|
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
|
||||||
}
|
}
|
||||||
@ -64,7 +63,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||||||
db, err = gorm.Open("mssql", dbDSN)
|
db, err = gorm.Open("mssql", dbDSN)
|
||||||
default:
|
default:
|
||||||
fmt.Println("testing sqlite3...")
|
fmt.Println("testing sqlite3...")
|
||||||
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
|
if dbDSN == "" {
|
||||||
|
dbDSN = filepath.Join(os.TempDir(), "gorm.db")
|
||||||
|
}
|
||||||
|
db, err = gorm.Open("sqlite3", dbDSN)
|
||||||
}
|
}
|
||||||
|
|
||||||
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
|
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
|
||||||
|
@ -59,7 +59,7 @@ type StructField struct {
|
|||||||
IsNormal bool
|
IsNormal bool
|
||||||
IsIgnored bool
|
IsIgnored bool
|
||||||
IsScanner bool
|
IsScanner bool
|
||||||
IsInterface bool
|
UseEncoder bool
|
||||||
HasDefaultValue bool
|
HasDefaultValue bool
|
||||||
Tag reflect.StructTag
|
Tag reflect.StructTag
|
||||||
TagSettings map[string]string
|
TagSettings map[string]string
|
||||||
@ -101,7 +101,7 @@ func (structField *StructField) clone() *StructField {
|
|||||||
IsNormal: structField.IsNormal,
|
IsNormal: structField.IsNormal,
|
||||||
IsIgnored: structField.IsIgnored,
|
IsIgnored: structField.IsIgnored,
|
||||||
IsScanner: structField.IsScanner,
|
IsScanner: structField.IsScanner,
|
||||||
IsInterface: structField.IsInterface,
|
UseEncoder: structField.UseEncoder,
|
||||||
HasDefaultValue: structField.HasDefaultValue,
|
HasDefaultValue: structField.HasDefaultValue,
|
||||||
Tag: structField.Tag,
|
Tag: structField.Tag,
|
||||||
TagSettings: map[string]string{},
|
TagSettings: map[string]string{},
|
||||||
@ -183,18 +183,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
TagSettings: parseTagSetting(fieldStruct.Tag),
|
TagSettings: parseTagSetting(fieldStruct.Tag),
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ast.IsExported(fieldStruct.Name) {
|
|
||||||
if _, ok := field.TagSettingsGet("INTERFACE"); ok {
|
|
||||||
field.IsInterface = true
|
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// is ignored field
|
// is ignored field
|
||||||
if _, ok := field.TagSettingsGet("-"); ok {
|
if _, ok := field.TagSettingsGet("-"); ok {
|
||||||
field.IsIgnored = true
|
field.IsIgnored = true
|
||||||
} else {
|
} else {
|
||||||
|
if _, ok := field.TagSettingsGet("USE_ENCODER"); ok {
|
||||||
|
field.UseEncoder = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// private interface fields can be exported explicitly
|
||||||
|
if !ast.IsExported(fieldStruct.Name) && !field.UseEncoder {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
|
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
@ -599,12 +600,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(field)
|
}(field)
|
||||||
case reflect.Interface:
|
|
||||||
field.IsInterface = true
|
|
||||||
default:
|
default:
|
||||||
|
if !field.UseEncoder {
|
||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Even it is ignored, also possible to decode db value into the field
|
// Even it is ignored, also possible to decode db value into the field
|
||||||
if value, ok := field.TagSettingsGet("COLUMN"); ok {
|
if value, ok := field.TagSettingsGet("COLUMN"); ok {
|
||||||
|
47
scope.go
47
scope.go
@ -10,8 +10,6 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/kr/pretty"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Scope contain current operation's information when you perform any operation on the database
|
// Scope contain current operation's information when you perform any operation on the database
|
||||||
@ -482,12 +480,12 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem
|
|||||||
selectFields []*Field
|
selectFields []*Field
|
||||||
selectedColumnsMap = map[string]int{}
|
selectedColumnsMap = map[string]int{}
|
||||||
resetFields = map[int]*Field{}
|
resetFields = map[int]*Field{}
|
||||||
interfaceFields = map[string]interface{}{}
|
encodedFields = make([]*decoder, 0)
|
||||||
rootElem interface{}
|
elemScope = scope
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(elem) > 0 {
|
if len(elem) > 0 {
|
||||||
rootElem = elem[0]
|
elemScope = scope.New(elem[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
for index, column := range columns {
|
for index, column := range columns {
|
||||||
@ -501,15 +499,11 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem
|
|||||||
|
|
||||||
for fieldIndex, field := range selectFields {
|
for fieldIndex, field := range selectFields {
|
||||||
if field.DBName == column {
|
if field.DBName == column {
|
||||||
if field.IsInterface {
|
if field.UseEncoder {
|
||||||
pretty.Log(column)
|
if enc, ok := elemScope.Value.(Encoder); ok {
|
||||||
if i, ok := rootElem.(interface {
|
dec := newDecoder(enc, elemScope, field.DBName)
|
||||||
ScanType(field string) reflect.Type
|
values[index] = dec
|
||||||
}); ok {
|
encodedFields = append(encodedFields, dec)
|
||||||
t := i.ScanType(field.DBName)
|
|
||||||
val := reflect.New(t).Interface()
|
|
||||||
values[index] = val
|
|
||||||
interfaceFields[field.DBName] = values[index]
|
|
||||||
}
|
}
|
||||||
} else if field.Field.Kind() == reflect.Ptr {
|
} else if field.Field.Kind() == reflect.Ptr {
|
||||||
values[index] = field.Field.Addr().Interface()
|
values[index] = field.Field.Addr().Interface()
|
||||||
@ -531,25 +525,18 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem
|
|||||||
|
|
||||||
scope.Err(rows.Scan(values...))
|
scope.Err(rows.Scan(values...))
|
||||||
|
|
||||||
for k, v := range interfaceFields {
|
|
||||||
if i, ok := elem[0].(interface {
|
|
||||||
ScanField(field string, data interface{}) error
|
|
||||||
}); ok {
|
|
||||||
val := reflect.ValueOf(v)
|
|
||||||
if val.Kind() == reflect.Ptr {
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
if err := i.ScanField(k, val.Interface()); err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for index, field := range resetFields {
|
for index, field := range resetFields {
|
||||||
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
||||||
field.Field.Set(v)
|
field.Field.Set(v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// process the decoders
|
||||||
|
for _, d := range encodedFields {
|
||||||
|
if err := d.Decode(); err != nil {
|
||||||
|
scope.Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||||
@ -1183,7 +1170,7 @@ func (scope *Scope) createTable() *Scope {
|
|||||||
var primaryKeys []string
|
var primaryKeys []string
|
||||||
var primaryKeyInColumnType = false
|
var primaryKeyInColumnType = false
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if field.IsNormal {
|
if field.IsNormal || field.UseEncoder {
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
|
|
||||||
// Check if the primary key constraint was specified as
|
// Check if the primary key constraint was specified as
|
||||||
@ -1284,7 +1271,7 @@ func (scope *Scope) autoMigrate() *Scope {
|
|||||||
} else {
|
} else {
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal || field.UseEncoder {
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user