This commit is contained in:
Rob Rodriguez 2018-12-10 11:44:06 -08:00
parent 472c70caa4
commit 89ac376a2f
5 changed files with 484 additions and 349 deletions

3
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"go.generateTestsFlags": ["-v"]
}

View File

@ -76,7 +76,7 @@ func queryCallback(scope *Scope) {
elem = reflect.New(resultType).Elem() elem = reflect.New(resultType).Elem()
} }
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields(), elem.Addr().Interface())
if isSlice { if isSlice {
if isPtr { if isPtr {

89
interface_test.go Normal file
View File

@ -0,0 +1,89 @@
package gorm_test
import (
"reflect"
"testing"
"github.com/kr/pretty"
)
type (
UserInterface interface {
UserName() string
UserType() string
}
UserCommon struct {
Name string
Type string
}
BasicUser struct {
User
}
AdminUser struct {
BasicUser
}
GroupUser struct {
GroupID int64
User UserInterface
}
Group struct {
Users []GroupUser
}
)
func (m *BasicUser) UserName() string {
return m.Name
}
func (m *BasicUser) Type() string {
return "basic"
}
func (m *AdminUser) Type() string {
return "admin"
}
// ScanType returns the scan type for the field
func (m *GroupUser) ScanType(field string) reflect.Type {
switch field {
case "User":
// The geometry data should be encoded as a []byte first
return reflect.TypeOf(User{})
default:
return reflect.TypeOf(nil)
}
}
// ScanField handle exporting scanned fields
func (m *GroupUser) ScanField(field string, data interface{}) error {
switch field {
case "User":
m.User = data.(UserInterface)
}
return nil
}
var tt *testing.T
func TestInterface(t *testing.T) {
tt = t
DB.AutoMigrate(&UserCommon{})
user1 := UserCommon{Name: "RowUser1", type: "basic"}
DB.Save(&user1)
t.Log("loading the users")
users := make([]*UserWrapper, 0)
if DB.Table("users").Find(&users).Error != nil {
t.Errorf("No errors should happen if set table for find")
}
t.Logf(pretty.Sprint(users))
}

View File

@ -59,6 +59,7 @@ type StructField struct {
IsNormal bool IsNormal bool
IsIgnored bool IsIgnored bool
IsScanner bool IsScanner bool
IsInterface bool
HasDefaultValue bool HasDefaultValue bool
Tag reflect.StructTag Tag reflect.StructTag
TagSettings map[string]string TagSettings map[string]string
@ -100,6 +101,7 @@ func (structField *StructField) clone() *StructField {
IsNormal: structField.IsNormal, IsNormal: structField.IsNormal,
IsIgnored: structField.IsIgnored, IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner, IsScanner: structField.IsScanner,
IsInterface: structField.IsInterface,
HasDefaultValue: structField.HasDefaultValue, HasDefaultValue: structField.HasDefaultValue,
Tag: structField.Tag, Tag: structField.Tag,
TagSettings: map[string]string{}, TagSettings: map[string]string{},
@ -171,7 +173,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
// Get all fields // Get all fields
for i := 0; i < reflectType.NumField(); i++ { for i := 0; i < reflectType.NumField(); i++ {
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { fieldStruct := reflectType.Field(i)
field := &StructField{ field := &StructField{
Struct: fieldStruct, Struct: fieldStruct,
Name: fieldStruct.Name, Name: fieldStruct.Name,
@ -180,6 +183,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
TagSettings: parseTagSetting(fieldStruct.Tag), TagSettings: parseTagSetting(fieldStruct.Tag),
} }
if !ast.IsExported(fieldStruct.Name) {
if _, ok := field.TagSettingsGet("INTERFACE"); ok {
field.IsInterface = true
} else {
continue
}
}
// is ignored field // is ignored field
if _, ok := field.TagSettingsGet("-"); ok { if _, ok := field.TagSettingsGet("-"); ok {
field.IsIgnored = true field.IsIgnored = true
@ -588,11 +599,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
}(field) }(field)
case reflect.Interface:
field.IsInterface = true
default: default:
field.IsNormal = true field.IsNormal = true
} }
} }
}
// Even it is ignored, also possible to decode db value into the field // Even it is ignored, also possible to decode db value into the field
if value, ok := field.TagSettingsGet("COLUMN"); ok { if value, ok := field.TagSettingsGet("COLUMN"); ok {

View File

@ -10,6 +10,8 @@ import (
"regexp" "regexp"
"strings" "strings"
"time" "time"
"github.com/kr/pretty"
) )
// Scope contain current operation's information when you perform any operation on the database // Scope contain current operation's information when you perform any operation on the database
@ -473,18 +475,23 @@ func (scope *Scope) quoteIfPossible(str string) string {
return str return str
} }
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem ...interface{}) {
var ( var (
ignored interface{} ignored interface{}
values = make([]interface{}, len(columns)) values = make([]interface{}, len(columns))
selectFields []*Field selectFields []*Field
selectedColumnsMap = map[string]int{} selectedColumnsMap = map[string]int{}
resetFields = map[int]*Field{} resetFields = map[int]*Field{}
interfaceFields = map[string]interface{}{}
rootElem interface{}
) )
if len(elem) > 0 {
rootElem = elem[0]
}
for index, column := range columns { for index, column := range columns {
values[index] = &ignored values[index] = &ignored
selectFields = fields selectFields = fields
offset := 0 offset := 0
if idx, ok := selectedColumnsMap[column]; ok { if idx, ok := selectedColumnsMap[column]; ok {
@ -494,7 +501,17 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
for fieldIndex, field := range selectFields { for fieldIndex, field := range selectFields {
if field.DBName == column { if field.DBName == column {
if field.Field.Kind() == reflect.Ptr { if field.IsInterface {
pretty.Log(column)
if i, ok := rootElem.(interface {
ScanType(field string) reflect.Type
}); ok {
t := i.ScanType(field.DBName)
val := reflect.New(t).Interface()
values[index] = val
interfaceFields[field.DBName] = values[index]
}
} else if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface() values[index] = field.Field.Addr().Interface()
} else { } else {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
@ -514,6 +531,20 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
scope.Err(rows.Scan(values...)) scope.Err(rows.Scan(values...))
for k, v := range interfaceFields {
if i, ok := elem[0].(interface {
ScanField(field string, data interface{}) error
}); ok {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if err := i.ScanField(k, val.Interface()); err != nil {
fmt.Println(err)
}
}
}
for index, field := range resetFields { for index, field := range resetFields {
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
field.Field.Set(v) field.Field.Set(v)