Improve after scan method callback performance

This commit is contained in:
Moises P. Sena 2018-02-21 11:38:29 -03:00
parent c23c829d0d
commit 341532ce4b
8 changed files with 333 additions and 45 deletions

58
main.go
View File

@ -766,3 +766,61 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
}
}
// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all
func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB {
key := "gorm:disable_after_scan"
s = s.clone()
if len(typs) == 0 {
s.values[key] = true
return s
}
for _, typ := range typs {
rType := indirectType(reflect.TypeOf(typ))
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true
}
return s
}
// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all.
// The disabled types will not be enabled unless they are specifically informed.
func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB {
key := "gorm:disable_after_scan"
s = s.clone()
if len(typs) == 0 {
s.values[key] = false
return s
}
for _, typ := range typs {
rType := indirectType(reflect.TypeOf(typ))
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false
}
return s
}
// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed
// typs.
func (s *DB) EnabledAfterScanCallback(typs ...interface{}) (ok bool) {
key := "gorm:disable_after_scan"
if v, ok := s.values[key]; !ok || v.(bool) {
for _, typ := range typs {
rType := indirectType(reflect.TypeOf(typ))
v, ok = s.values[key + ":" + rType.PkgPath() + "." + rType.Name()]
if ok && !v.(bool) {
return false
}
}
return true
}
return false
}

View File

@ -3,17 +3,124 @@ package gorm
import (
"reflect"
"fmt"
"sync"
)
var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0)
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
type StructFieldMethodCallbacksRegistrator struct {
Callbacks map[string]reflect.Value
type safeEnabledFieldTypes struct {
m map[reflect.Type]bool
l *sync.RWMutex
}
func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error {
value := reflect.ValueOf(caller)
func newSafeEnabledFieldTypes() safeEnabledFieldTypes {
return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)}
}
func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.Lock()
defer s.l.Unlock()
s.m[k] = enabled
default:
s.Set(reflect.TypeOf(key), enabled)
}
}
func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.RLock()
defer s.l.RUnlock()
enabled, ok = s.m[k]
return
default:
return s.Get(reflect.TypeOf(key))
}
}
func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.RLock()
defer s.l.RUnlock()
_, ok = s.m[k]
return
default:
return s.Has(reflect.TypeOf(key))
}
}
func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.Lock()
defer s.l.Unlock()
_, ok = s.m[k]
if ok {
delete(s.m, k)
}
return
default:
return s.Del(reflect.TypeOf(key))
}
}
type StructFieldMethodCallbacksRegistrator struct {
Callbacks map[string]reflect.Value
FieldTypes safeEnabledFieldTypes
l *sync.RWMutex
}
// Register new field type and enable all available callbacks for here
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) {
for _, typ := range typs {
if !registrator.FieldTypes.Has(typ) {
registrator.FieldTypes.Set(typ, true)
}
}
}
// Unregister field type and return if ok
func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) {
return registrator.FieldTypes.Del(typ)
}
// Enable all callbacks for field type
func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) {
for _, typ := range typs {
registrator.FieldTypes.Set(typ, true)
}
}
// Disable all callbacks for field type
func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) {
for _, typ := range typs {
registrator.FieldTypes.Set(typ, false)
}
}
// Return if all callbacks for field type is enabled
func (registrator *StructFieldMethodCallbacksRegistrator) EnabledFieldType(typ interface{}) bool {
if enabled, ok := registrator.FieldTypes.Get(typ); ok {
return enabled
}
return false
}
// Return if field type is registered
func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool {
return registrator.FieldTypes.Has(typ)
}
// Register new callback for fields have method methodName
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error {
value := indirect(reflect.ValueOf(caller))
if value.Kind() != reflect.Func {
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
@ -28,18 +135,21 @@ func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName st
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
}
if value.Type().In(1) != interfaceType {
if value.Type().In(1).Kind() != reflect.Interface {
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
}
registrator.l.Lock()
defer registrator.l.Unlock()
registrator.Callbacks[methodName] = value
return nil
}
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error {
// Register many callbacks where key is the methodName and value is a caller function.
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error {
for i, m := range items {
for methodName, callback := range m {
err := registrator.Register(methodName, callback)
err := registrator.registerCallback(methodName, callback)
if err != nil {
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
}
@ -49,7 +159,8 @@ func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...
}
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)}
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(),
new(sync.RWMutex)}
}
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
@ -86,7 +197,7 @@ func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Fiel
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
func init() {
checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{
checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{
"AfterScan": AfterScanMethodCallback,
}))
}
}

View File

@ -1,16 +1,59 @@
package gorm_test
import "fmt"
func ExampleAfterScanMethodCallback() {
fmt.Println(`package main
import (
"fmt"
)
func init() {
}
func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() {
fmt.Println(`
if registrator.EnabledFieldType(&Media{}) {
registrator.DisableFieldType(&Media{})
}
`)
}
func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() {
fmt.Println(`
if !registrator.EnabledFieldType(&Media{}) {
println("not enabled")
}
`)
}
func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() {
fmt.Println(`
if !registrator.EnabledFieldType(&Media{}) {
registrator.EnableFieldType(&Media{})
}
`)
}
func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() {
fmt.Println(`
if registrator.RegisteredFieldType(&Media{}) {
println("not registered")
}`)
}
func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() {
fmt.Println("registrator.RegisterFieldType(&Media{})")
}
func ExampleAfterScanMethodCallback() {
println(`
package main
import (
"reflect"
"github.com/jinzhu/gorm"
"database/sql/driver"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"strconv"
"strings"
)
type Media struct {
@ -45,7 +88,8 @@ func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) {
}
func (image *Media) URL() string {
return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name)
return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()),
*image.fieldName, image.Name}, "/")
}
type User struct {
@ -58,10 +102,14 @@ func (user *User) GetID() int {
}
func main() {
// register media type
gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{})
db, err := gorm.Open("sqlite3", "db.db")
if err != nil {
panic(err)
}
db.AutoMigrate(&User{})
baseUrl := "http://example.com/media"
@ -79,6 +127,7 @@ func main() {
panic(db_.Error)
}
fmt.Println(model.MainImage.URL())
}`)
println("Media URL:", model.MainImage.URL())
}
`)
}

View File

@ -14,6 +14,10 @@ import (
"github.com/jinzhu/gorm"
)
func init() {
gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{})
}
type User struct {
Id int64
Age int64

View File

@ -605,6 +605,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}
// register method callbacks now for improve performance
field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {

View File

@ -473,6 +473,25 @@ func (scope *Scope) quoteIfPossible(str string) string {
return str
}
// call after field method callbacks
func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) {
if !scope.HasError() {
if scope.DB().EnabledAfterScanCallback(scope.Value) {
scopeValue := reflect.ValueOf(scope)
for index, field := range scannerFields {
// if calbacks enabled for field type
if StructFieldMethodCallbacks.EnabledFieldType(field.Field.Type()) {
// not disabled on scan
if _, ok := disableScanField[index]; !ok {
reflectValue := field.Field.Addr()
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
}
}
}
}
}
}
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
var (
ignored interface{}
@ -527,23 +546,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
}
}
if !scope.HasError() {
key := "gorm:disable_after_scan"
if v, ok := scope.Get(key); !ok || !v.(bool) {
valueType := indirect(reflect.ValueOf(scope.Value)).Type()
if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) {
scopeValue := reflect.ValueOf(scope)
for index, field := range scannerFields {
if _, ok := disableScanField[index]; !ok {
if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) {
reflectValue := field.Field.Addr()
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
}
}
}
}
}
}
scope.afterScanCallback(scannerFields, disableScanField)
}
func (scope *Scope) primaryCondition(value interface{}) string {

View File

@ -7,7 +7,6 @@ import (
"testing"
"github.com/jinzhu/gorm"
"reflect"
)
func NameIn1And2(d *gorm.DB) *gorm.DB {
@ -124,29 +123,80 @@ func TestAfterFieldScanDisableCallback(t *testing.T) {
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
}
run := func(key string) {
DB := DB.Set(key, true)
run := func(typs ... interface{}) {
DB := DB.DisableAfterScanCallback(typs...)
var model2 WithFieldAfterScanCallback
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err)
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err)
}
dotest := func(i int, value string, field AfterScanFieldInterface) {
if !field.CalledFieldIsNill() {
t.Errorf("%q: Expected Name%v.calledField is not nil", key, i)
t.Errorf("%q: Expected Name%v.calledField is not nil", len(typs), i)
}
if !field.CalledScopeIsNill() {
t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i)
t.Errorf("%q: Expected Name%v.calledScope is not nil", len(typs), i)
}
}
dotest(1, model.Name1.data, model2.Name1)
}
run("gorm:disable_after_scan")
typ := reflect.ValueOf(model).Type()
run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name())
run()
run(model)
}
func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) {
model := WithFieldAfterScanCallback{}
model.Name1 = &AfterScanFieldPtr{data: randName()}
model.Name2 = AfterScanFieldPtr{data: randName()}
model.Name3 = &AfterScanField{data: randName()}
model.Name4 = AfterScanField{data: randName()}
if err := DB.Save(&model).Error; err != nil {
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
}
enabled := func(i int, field AfterScanFieldInterface) {
if field.CalledScopeIsNill() {
t.Errorf("Expected Name%v.calledScope, but got nil", i)
}
}
disabled := func(i int, field AfterScanFieldInterface) {
if !field.CalledScopeIsNill() {
t.Errorf("Expected Name%v.calledScope is not nil", i)
}
}
gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{})
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
}
disabled(1, model.Name1)
disabled(2, &model.Name2)
disabled(3, model.Name3)
disabled(4, &model.Name4)
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{})
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
}
enabled(1, model.Name1)
enabled(2, &model.Name2)
disabled(3, model.Name3)
disabled(4, &model.Name4)
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{})
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
}
enabled(1, model.Name1)
enabled(2, &model.Name2)
enabled(3, model.Name3)
enabled(4, &model.Name4)
}
func TestAfterFieldScanInvalidCallback(t *testing.T) {

View File

@ -135,6 +135,18 @@ func indirect(reflectValue reflect.Value) reflect.Value {
return reflectValue
}
func indirectType(reflectType reflect.Type) reflect.Type {
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
return reflectType
}
func ptrToType(reflectType reflect.Type) reflect.Type {
reflectType = indirectType(reflectType)
return reflect.PtrTo(reflectType)
}
type Method struct {
index int
name string