Improve after scan method callback performance
This commit is contained in:
parent
c23c829d0d
commit
341532ce4b
58
main.go
58
main.go
@ -766,3 +766,61 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
|||||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
|
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all
|
||||||
|
func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB {
|
||||||
|
key := "gorm:disable_after_scan"
|
||||||
|
|
||||||
|
s = s.clone()
|
||||||
|
|
||||||
|
if len(typs) == 0 {
|
||||||
|
s.values[key] = true
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, typ := range typs {
|
||||||
|
rType := indirectType(reflect.TypeOf(typ))
|
||||||
|
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all.
|
||||||
|
// The disabled types will not be enabled unless they are specifically informed.
|
||||||
|
func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB {
|
||||||
|
key := "gorm:disable_after_scan"
|
||||||
|
|
||||||
|
s = s.clone()
|
||||||
|
|
||||||
|
if len(typs) == 0 {
|
||||||
|
s.values[key] = false
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, typ := range typs {
|
||||||
|
rType := indirectType(reflect.TypeOf(typ))
|
||||||
|
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed
|
||||||
|
// typs.
|
||||||
|
func (s *DB) EnabledAfterScanCallback(typs ...interface{}) (ok bool) {
|
||||||
|
key := "gorm:disable_after_scan"
|
||||||
|
|
||||||
|
if v, ok := s.values[key]; !ok || v.(bool) {
|
||||||
|
for _, typ := range typs {
|
||||||
|
rType := indirectType(reflect.TypeOf(typ))
|
||||||
|
v, ok = s.values[key + ":" + rType.PkgPath() + "." + rType.Name()]
|
||||||
|
if ok && !v.(bool) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -3,17 +3,124 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0)
|
|
||||||
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
|
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
|
||||||
|
|
||||||
|
type safeEnabledFieldTypes struct {
|
||||||
|
m map[reflect.Type]bool
|
||||||
|
l *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSafeEnabledFieldTypes() safeEnabledFieldTypes {
|
||||||
|
return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) {
|
||||||
|
switch k := key.(type) {
|
||||||
|
case reflect.Type:
|
||||||
|
k = indirectType(k)
|
||||||
|
s.l.Lock()
|
||||||
|
defer s.l.Unlock()
|
||||||
|
s.m[k] = enabled
|
||||||
|
default:
|
||||||
|
s.Set(reflect.TypeOf(key), enabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) {
|
||||||
|
switch k := key.(type) {
|
||||||
|
case reflect.Type:
|
||||||
|
k = indirectType(k)
|
||||||
|
s.l.RLock()
|
||||||
|
defer s.l.RUnlock()
|
||||||
|
enabled, ok = s.m[k]
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
return s.Get(reflect.TypeOf(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) {
|
||||||
|
switch k := key.(type) {
|
||||||
|
case reflect.Type:
|
||||||
|
k = indirectType(k)
|
||||||
|
s.l.RLock()
|
||||||
|
defer s.l.RUnlock()
|
||||||
|
_, ok = s.m[k]
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
return s.Has(reflect.TypeOf(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) {
|
||||||
|
switch k := key.(type) {
|
||||||
|
case reflect.Type:
|
||||||
|
k = indirectType(k)
|
||||||
|
s.l.Lock()
|
||||||
|
defer s.l.Unlock()
|
||||||
|
_, ok = s.m[k]
|
||||||
|
if ok {
|
||||||
|
delete(s.m, k)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
return s.Del(reflect.TypeOf(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type StructFieldMethodCallbacksRegistrator struct {
|
type StructFieldMethodCallbacksRegistrator struct {
|
||||||
Callbacks map[string]reflect.Value
|
Callbacks map[string]reflect.Value
|
||||||
|
FieldTypes safeEnabledFieldTypes
|
||||||
|
l *sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error {
|
// Register new field type and enable all available callbacks for here
|
||||||
value := reflect.ValueOf(caller)
|
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) {
|
||||||
|
for _, typ := range typs {
|
||||||
|
if !registrator.FieldTypes.Has(typ) {
|
||||||
|
registrator.FieldTypes.Set(typ, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister field type and return if ok
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) {
|
||||||
|
return registrator.FieldTypes.Del(typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable all callbacks for field type
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) {
|
||||||
|
for _, typ := range typs {
|
||||||
|
registrator.FieldTypes.Set(typ, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable all callbacks for field type
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) {
|
||||||
|
for _, typ := range typs {
|
||||||
|
registrator.FieldTypes.Set(typ, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if all callbacks for field type is enabled
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) EnabledFieldType(typ interface{}) bool {
|
||||||
|
if enabled, ok := registrator.FieldTypes.Get(typ); ok {
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if field type is registered
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool {
|
||||||
|
return registrator.FieldTypes.Has(typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register new callback for fields have method methodName
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error {
|
||||||
|
value := indirect(reflect.ValueOf(caller))
|
||||||
|
|
||||||
if value.Kind() != reflect.Func {
|
if value.Kind() != reflect.Func {
|
||||||
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
|
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
|
||||||
@ -28,18 +135,21 @@ func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName st
|
|||||||
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
|
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value.Type().In(1) != interfaceType {
|
if value.Type().In(1).Kind() != reflect.Interface {
|
||||||
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
|
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
registrator.l.Lock()
|
||||||
|
defer registrator.l.Unlock()
|
||||||
registrator.Callbacks[methodName] = value
|
registrator.Callbacks[methodName] = value
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error {
|
// Register many callbacks where key is the methodName and value is a caller function.
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error {
|
||||||
for i, m := range items {
|
for i, m := range items {
|
||||||
for methodName, callback := range m {
|
for methodName, callback := range m {
|
||||||
err := registrator.Register(methodName, callback)
|
err := registrator.registerCallback(methodName, callback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
|
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
|
||||||
}
|
}
|
||||||
@ -49,7 +159,8 @@ func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
|
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
|
||||||
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)}
|
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(),
|
||||||
|
new(sync.RWMutex)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
|
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
|
||||||
@ -86,7 +197,7 @@ func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Fiel
|
|||||||
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
|
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{
|
checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{
|
||||||
"AfterScan": AfterScanMethodCallback,
|
"AfterScan": AfterScanMethodCallback,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
@ -1,16 +1,59 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
func ExampleAfterScanMethodCallback() {
|
|
||||||
fmt.Println(`package main
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() {
|
||||||
|
fmt.Println(`
|
||||||
|
if registrator.EnabledFieldType(&Media{}) {
|
||||||
|
registrator.DisableFieldType(&Media{})
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() {
|
||||||
|
fmt.Println(`
|
||||||
|
if !registrator.EnabledFieldType(&Media{}) {
|
||||||
|
println("not enabled")
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() {
|
||||||
|
fmt.Println(`
|
||||||
|
if !registrator.EnabledFieldType(&Media{}) {
|
||||||
|
registrator.EnableFieldType(&Media{})
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() {
|
||||||
|
fmt.Println(`
|
||||||
|
if registrator.RegisteredFieldType(&Media{}) {
|
||||||
|
println("not registered")
|
||||||
|
}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() {
|
||||||
|
fmt.Println("registrator.RegisterFieldType(&Media{})")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleAfterScanMethodCallback() {
|
||||||
|
println(`
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Media struct {
|
type Media struct {
|
||||||
@ -45,7 +88,8 @@ func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (image *Media) URL() string {
|
func (image *Media) URL() string {
|
||||||
return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name)
|
return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()),
|
||||||
|
*image.fieldName, image.Name}, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
@ -58,10 +102,14 @@ func (user *User) GetID() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// register media type
|
||||||
|
gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{})
|
||||||
|
|
||||||
db, err := gorm.Open("sqlite3", "db.db")
|
db, err := gorm.Open("sqlite3", "db.db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AutoMigrate(&User{})
|
db.AutoMigrate(&User{})
|
||||||
|
|
||||||
baseUrl := "http://example.com/media"
|
baseUrl := "http://example.com/media"
|
||||||
@ -79,6 +127,7 @@ func main() {
|
|||||||
panic(db_.Error)
|
panic(db_.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(model.MainImage.URL())
|
println("Media URL:", model.MainImage.URL())
|
||||||
}`)
|
}
|
||||||
|
`)
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,10 @@ import (
|
|||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{})
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64
|
Id int64
|
||||||
Age int64
|
Age int64
|
||||||
|
@ -605,6 +605,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// register method callbacks now for improve performance
|
||||||
field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
|
field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
|
||||||
|
|
||||||
for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {
|
for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {
|
||||||
|
37
scope.go
37
scope.go
@ -473,6 +473,25 @@ func (scope *Scope) quoteIfPossible(str string) string {
|
|||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// call after field method callbacks
|
||||||
|
func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
if scope.DB().EnabledAfterScanCallback(scope.Value) {
|
||||||
|
scopeValue := reflect.ValueOf(scope)
|
||||||
|
for index, field := range scannerFields {
|
||||||
|
// if calbacks enabled for field type
|
||||||
|
if StructFieldMethodCallbacks.EnabledFieldType(field.Field.Type()) {
|
||||||
|
// not disabled on scan
|
||||||
|
if _, ok := disableScanField[index]; !ok {
|
||||||
|
reflectValue := field.Field.Addr()
|
||||||
|
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
||||||
var (
|
var (
|
||||||
ignored interface{}
|
ignored interface{}
|
||||||
@ -527,23 +546,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.HasError() {
|
scope.afterScanCallback(scannerFields, disableScanField)
|
||||||
key := "gorm:disable_after_scan"
|
|
||||||
if v, ok := scope.Get(key); !ok || !v.(bool) {
|
|
||||||
valueType := indirect(reflect.ValueOf(scope.Value)).Type()
|
|
||||||
if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) {
|
|
||||||
scopeValue := reflect.ValueOf(scope)
|
|
||||||
for index, field := range scannerFields {
|
|
||||||
if _, ok := disableScanField[index]; !ok {
|
|
||||||
if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) {
|
|
||||||
reflectValue := field.Field.Addr()
|
|
||||||
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"reflect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
||||||
@ -124,29 +123,80 @@ func TestAfterFieldScanDisableCallback(t *testing.T) {
|
|||||||
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
run := func(key string) {
|
run := func(typs ... interface{}) {
|
||||||
DB := DB.Set(key, true)
|
DB := DB.DisableAfterScanCallback(typs...)
|
||||||
var model2 WithFieldAfterScanCallback
|
var model2 WithFieldAfterScanCallback
|
||||||
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
|
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
|
||||||
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err)
|
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dotest := func(i int, value string, field AfterScanFieldInterface) {
|
dotest := func(i int, value string, field AfterScanFieldInterface) {
|
||||||
if !field.CalledFieldIsNill() {
|
if !field.CalledFieldIsNill() {
|
||||||
t.Errorf("%q: Expected Name%v.calledField is not nil", key, i)
|
t.Errorf("%q: Expected Name%v.calledField is not nil", len(typs), i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !field.CalledScopeIsNill() {
|
if !field.CalledScopeIsNill() {
|
||||||
t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i)
|
t.Errorf("%q: Expected Name%v.calledScope is not nil", len(typs), i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dotest(1, model.Name1.data, model2.Name1)
|
dotest(1, model.Name1.data, model2.Name1)
|
||||||
}
|
}
|
||||||
|
|
||||||
run("gorm:disable_after_scan")
|
run()
|
||||||
typ := reflect.ValueOf(model).Type()
|
run(model)
|
||||||
run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name())
|
}
|
||||||
|
|
||||||
|
func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) {
|
||||||
|
model := WithFieldAfterScanCallback{}
|
||||||
|
model.Name1 = &AfterScanFieldPtr{data: randName()}
|
||||||
|
model.Name2 = AfterScanFieldPtr{data: randName()}
|
||||||
|
model.Name3 = &AfterScanField{data: randName()}
|
||||||
|
model.Name4 = AfterScanField{data: randName()}
|
||||||
|
|
||||||
|
if err := DB.Save(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := func(i int, field AfterScanFieldInterface) {
|
||||||
|
if field.CalledScopeIsNill() {
|
||||||
|
t.Errorf("Expected Name%v.calledScope, but got nil", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled := func(i int, field AfterScanFieldInterface) {
|
||||||
|
if !field.CalledScopeIsNill() {
|
||||||
|
t.Errorf("Expected Name%v.calledScope is not nil", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{})
|
||||||
|
|
||||||
|
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||||
|
}
|
||||||
|
disabled(1, model.Name1)
|
||||||
|
disabled(2, &model.Name2)
|
||||||
|
disabled(3, model.Name3)
|
||||||
|
disabled(4, &model.Name4)
|
||||||
|
|
||||||
|
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{})
|
||||||
|
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||||
|
}
|
||||||
|
enabled(1, model.Name1)
|
||||||
|
enabled(2, &model.Name2)
|
||||||
|
disabled(3, model.Name3)
|
||||||
|
disabled(4, &model.Name4)
|
||||||
|
|
||||||
|
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{})
|
||||||
|
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||||
|
}
|
||||||
|
enabled(1, model.Name1)
|
||||||
|
enabled(2, &model.Name2)
|
||||||
|
enabled(3, model.Name3)
|
||||||
|
enabled(4, &model.Name4)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAfterFieldScanInvalidCallback(t *testing.T) {
|
func TestAfterFieldScanInvalidCallback(t *testing.T) {
|
||||||
|
12
utils.go
12
utils.go
@ -135,6 +135,18 @@ func indirect(reflectValue reflect.Value) reflect.Value {
|
|||||||
return reflectValue
|
return reflectValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func indirectType(reflectType reflect.Type) reflect.Type {
|
||||||
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
return reflectType
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrToType(reflectType reflect.Type) reflect.Type {
|
||||||
|
reflectType = indirectType(reflectType)
|
||||||
|
return reflect.PtrTo(reflectType)
|
||||||
|
}
|
||||||
|
|
||||||
type Method struct {
|
type Method struct {
|
||||||
index int
|
index int
|
||||||
name string
|
name string
|
||||||
|
Loading…
x
Reference in New Issue
Block a user