update
This commit is contained in:
parent
472c70caa4
commit
89ac376a2f
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"go.generateTestsFlags": ["-v"]
|
||||
}
|
@ -76,7 +76,7 @@ func queryCallback(scope *Scope) {
|
||||
elem = reflect.New(resultType).Elem()
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
|
||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields(), elem.Addr().Interface())
|
||||
|
||||
if isSlice {
|
||||
if isPtr {
|
||||
|
89
interface_test.go
Normal file
89
interface_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/kr/pretty"
|
||||
)
|
||||
|
||||
type (
|
||||
UserInterface interface {
|
||||
UserName() string
|
||||
UserType() string
|
||||
}
|
||||
|
||||
UserCommon struct {
|
||||
Name string
|
||||
Type string
|
||||
}
|
||||
|
||||
BasicUser struct {
|
||||
User
|
||||
}
|
||||
|
||||
AdminUser struct {
|
||||
BasicUser
|
||||
}
|
||||
|
||||
GroupUser struct {
|
||||
GroupID int64
|
||||
User UserInterface
|
||||
}
|
||||
|
||||
Group struct {
|
||||
Users []GroupUser
|
||||
}
|
||||
)
|
||||
|
||||
func (m *BasicUser) UserName() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
func (m *BasicUser) Type() string {
|
||||
return "basic"
|
||||
}
|
||||
|
||||
func (m *AdminUser) Type() string {
|
||||
return "admin"
|
||||
}
|
||||
|
||||
// ScanType returns the scan type for the field
|
||||
func (m *GroupUser) ScanType(field string) reflect.Type {
|
||||
switch field {
|
||||
case "User":
|
||||
// The geometry data should be encoded as a []byte first
|
||||
return reflect.TypeOf(User{})
|
||||
default:
|
||||
return reflect.TypeOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// ScanField handle exporting scanned fields
|
||||
func (m *GroupUser) ScanField(field string, data interface{}) error {
|
||||
switch field {
|
||||
case "User":
|
||||
m.User = data.(UserInterface)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var tt *testing.T
|
||||
|
||||
func TestInterface(t *testing.T) {
|
||||
tt = t
|
||||
DB.AutoMigrate(&UserCommon{})
|
||||
|
||||
user1 := UserCommon{Name: "RowUser1", type: "basic"}
|
||||
|
||||
DB.Save(&user1)
|
||||
|
||||
t.Log("loading the users")
|
||||
users := make([]*UserWrapper, 0)
|
||||
|
||||
if DB.Table("users").Find(&users).Error != nil {
|
||||
t.Errorf("No errors should happen if set table for find")
|
||||
}
|
||||
t.Logf(pretty.Sprint(users))
|
||||
}
|
@ -59,6 +59,7 @@ type StructField struct {
|
||||
IsNormal bool
|
||||
IsIgnored bool
|
||||
IsScanner bool
|
||||
IsInterface bool
|
||||
HasDefaultValue bool
|
||||
Tag reflect.StructTag
|
||||
TagSettings map[string]string
|
||||
@ -100,6 +101,7 @@ func (structField *StructField) clone() *StructField {
|
||||
IsNormal: structField.IsNormal,
|
||||
IsIgnored: structField.IsIgnored,
|
||||
IsScanner: structField.IsScanner,
|
||||
IsInterface: structField.IsInterface,
|
||||
HasDefaultValue: structField.HasDefaultValue,
|
||||
Tag: structField.Tag,
|
||||
TagSettings: map[string]string{},
|
||||
@ -171,7 +173,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
|
||||
// Get all fields
|
||||
for i := 0; i < reflectType.NumField(); i++ {
|
||||
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
fieldStruct := reflectType.Field(i)
|
||||
|
||||
field := &StructField{
|
||||
Struct: fieldStruct,
|
||||
Name: fieldStruct.Name,
|
||||
@ -180,6 +183,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
TagSettings: parseTagSetting(fieldStruct.Tag),
|
||||
}
|
||||
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
if _, ok := field.TagSettingsGet("INTERFACE"); ok {
|
||||
field.IsInterface = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// is ignored field
|
||||
if _, ok := field.TagSettingsGet("-"); ok {
|
||||
field.IsIgnored = true
|
||||
@ -588,11 +599,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
}
|
||||
}(field)
|
||||
case reflect.Interface:
|
||||
field.IsInterface = true
|
||||
default:
|
||||
field.IsNormal = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Even it is ignored, also possible to decode db value into the field
|
||||
if value, ok := field.TagSettingsGet("COLUMN"); ok {
|
||||
|
37
scope.go
37
scope.go
@ -10,6 +10,8 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kr/pretty"
|
||||
)
|
||||
|
||||
// Scope contain current operation's information when you perform any operation on the database
|
||||
@ -473,18 +475,23 @@ func (scope *Scope) quoteIfPossible(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem ...interface{}) {
|
||||
var (
|
||||
ignored interface{}
|
||||
values = make([]interface{}, len(columns))
|
||||
selectFields []*Field
|
||||
selectedColumnsMap = map[string]int{}
|
||||
resetFields = map[int]*Field{}
|
||||
interfaceFields = map[string]interface{}{}
|
||||
rootElem interface{}
|
||||
)
|
||||
|
||||
if len(elem) > 0 {
|
||||
rootElem = elem[0]
|
||||
}
|
||||
|
||||
for index, column := range columns {
|
||||
values[index] = &ignored
|
||||
|
||||
selectFields = fields
|
||||
offset := 0
|
||||
if idx, ok := selectedColumnsMap[column]; ok {
|
||||
@ -494,7 +501,17 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
||||
|
||||
for fieldIndex, field := range selectFields {
|
||||
if field.DBName == column {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
if field.IsInterface {
|
||||
pretty.Log(column)
|
||||
if i, ok := rootElem.(interface {
|
||||
ScanType(field string) reflect.Type
|
||||
}); ok {
|
||||
t := i.ScanType(field.DBName)
|
||||
val := reflect.New(t).Interface()
|
||||
values[index] = val
|
||||
interfaceFields[field.DBName] = values[index]
|
||||
}
|
||||
} else if field.Field.Kind() == reflect.Ptr {
|
||||
values[index] = field.Field.Addr().Interface()
|
||||
} else {
|
||||
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
|
||||
@ -514,6 +531,20 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
||||
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
for k, v := range interfaceFields {
|
||||
if i, ok := elem[0].(interface {
|
||||
ScanField(field string, data interface{}) error
|
||||
}); ok {
|
||||
val := reflect.ValueOf(v)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
if err := i.ScanField(k, val.Interface()); err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for index, field := range resetFields {
|
||||
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
||||
field.Field.Set(v)
|
||||
|
Loading…
x
Reference in New Issue
Block a user