Improve after scan method callback performance

This commit is contained in:
Moises P. Sena 2018-02-21 11:38:29 -03:00
parent c23c829d0d
commit 341532ce4b
8 changed files with 333 additions and 45 deletions

58
main.go
View File

@ -766,3 +766,61 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
} }
} }
// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all
func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB {
key := "gorm:disable_after_scan"
s = s.clone()
if len(typs) == 0 {
s.values[key] = true
return s
}
for _, typ := range typs {
rType := indirectType(reflect.TypeOf(typ))
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true
}
return s
}
// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all.
// The disabled types will not be enabled unless they are specifically informed.
func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB {
key := "gorm:disable_after_scan"
s = s.clone()
if len(typs) == 0 {
s.values[key] = false
return s
}
for _, typ := range typs {
rType := indirectType(reflect.TypeOf(typ))
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false
}
return s
}
// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed
// typs.
func (s *DB) EnabledAfterScanCallback(typs ...interface{}) (ok bool) {
key := "gorm:disable_after_scan"
if v, ok := s.values[key]; !ok || v.(bool) {
for _, typ := range typs {
rType := indirectType(reflect.TypeOf(typ))
v, ok = s.values[key + ":" + rType.PkgPath() + "." + rType.Name()]
if ok && !v.(bool) {
return false
}
}
return true
}
return false
}

View File

@ -3,17 +3,124 @@ package gorm
import ( import (
"reflect" "reflect"
"fmt" "fmt"
"sync"
) )
var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0)
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{})) var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
type safeEnabledFieldTypes struct {
m map[reflect.Type]bool
l *sync.RWMutex
}
func newSafeEnabledFieldTypes() safeEnabledFieldTypes {
return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)}
}
func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.Lock()
defer s.l.Unlock()
s.m[k] = enabled
default:
s.Set(reflect.TypeOf(key), enabled)
}
}
func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.RLock()
defer s.l.RUnlock()
enabled, ok = s.m[k]
return
default:
return s.Get(reflect.TypeOf(key))
}
}
func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.RLock()
defer s.l.RUnlock()
_, ok = s.m[k]
return
default:
return s.Has(reflect.TypeOf(key))
}
}
func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) {
switch k := key.(type) {
case reflect.Type:
k = indirectType(k)
s.l.Lock()
defer s.l.Unlock()
_, ok = s.m[k]
if ok {
delete(s.m, k)
}
return
default:
return s.Del(reflect.TypeOf(key))
}
}
type StructFieldMethodCallbacksRegistrator struct { type StructFieldMethodCallbacksRegistrator struct {
Callbacks map[string]reflect.Value Callbacks map[string]reflect.Value
FieldTypes safeEnabledFieldTypes
l *sync.RWMutex
} }
func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error { // Register new field type and enable all available callbacks for here
value := reflect.ValueOf(caller) func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) {
for _, typ := range typs {
if !registrator.FieldTypes.Has(typ) {
registrator.FieldTypes.Set(typ, true)
}
}
}
// Unregister field type and return if ok
func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) {
return registrator.FieldTypes.Del(typ)
}
// Enable all callbacks for field type
func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) {
for _, typ := range typs {
registrator.FieldTypes.Set(typ, true)
}
}
// Disable all callbacks for field type
func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) {
for _, typ := range typs {
registrator.FieldTypes.Set(typ, false)
}
}
// Return if all callbacks for field type is enabled
func (registrator *StructFieldMethodCallbacksRegistrator) EnabledFieldType(typ interface{}) bool {
if enabled, ok := registrator.FieldTypes.Get(typ); ok {
return enabled
}
return false
}
// Return if field type is registered
func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool {
return registrator.FieldTypes.Has(typ)
}
// Register new callback for fields have method methodName
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error {
value := indirect(reflect.ValueOf(caller))
if value.Kind() != reflect.Func { if value.Kind() != reflect.Func {
return fmt.Errorf("Caller of method %q isn't a function.", methodName) return fmt.Errorf("Caller of method %q isn't a function.", methodName)
@ -28,18 +135,21 @@ func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName st
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType) return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
} }
if value.Type().In(1) != interfaceType { if value.Type().In(1).Kind() != reflect.Interface {
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName) return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
} }
registrator.l.Lock()
defer registrator.l.Unlock()
registrator.Callbacks[methodName] = value registrator.Callbacks[methodName] = value
return nil return nil
} }
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error { // Register many callbacks where key is the methodName and value is a caller function.
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error {
for i, m := range items { for i, m := range items {
for methodName, callback := range m { for methodName, callback := range m {
err := registrator.Register(methodName, callback) err := registrator.registerCallback(methodName, callback)
if err != nil { if err != nil {
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err) return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
} }
@ -49,7 +159,8 @@ func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...
} }
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator { func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)} return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(),
new(sync.RWMutex)}
} }
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) { func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
@ -86,7 +197,7 @@ func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Fiel
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator() var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
func init() { func init() {
checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{ checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{
"AfterScan": AfterScanMethodCallback, "AfterScan": AfterScanMethodCallback,
})) }))
} }

View File

@ -1,16 +1,59 @@
package gorm_test package gorm_test
import "fmt"
func ExampleAfterScanMethodCallback() {
fmt.Println(`package main
import ( import (
"fmt" "fmt"
)
func init() {
}
func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() {
fmt.Println(`
if registrator.EnabledFieldType(&Media{}) {
registrator.DisableFieldType(&Media{})
}
`)
}
func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() {
fmt.Println(`
if !registrator.EnabledFieldType(&Media{}) {
println("not enabled")
}
`)
}
func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() {
fmt.Println(`
if !registrator.EnabledFieldType(&Media{}) {
registrator.EnableFieldType(&Media{})
}
`)
}
func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() {
fmt.Println(`
if registrator.RegisteredFieldType(&Media{}) {
println("not registered")
}`)
}
func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() {
fmt.Println("registrator.RegisterFieldType(&Media{})")
}
func ExampleAfterScanMethodCallback() {
println(`
package main
import (
"reflect" "reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"database/sql/driver" "database/sql/driver"
_ "github.com/jinzhu/gorm/dialects/sqlite" _ "github.com/jinzhu/gorm/dialects/sqlite"
"strconv"
"strings"
) )
type Media struct { type Media struct {
@ -45,7 +88,8 @@ func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) {
} }
func (image *Media) URL() string { func (image *Media) URL() string {
return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name) return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()),
*image.fieldName, image.Name}, "/")
} }
type User struct { type User struct {
@ -58,10 +102,14 @@ func (user *User) GetID() int {
} }
func main() { func main() {
// register media type
gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{})
db, err := gorm.Open("sqlite3", "db.db") db, err := gorm.Open("sqlite3", "db.db")
if err != nil { if err != nil {
panic(err) panic(err)
} }
db.AutoMigrate(&User{}) db.AutoMigrate(&User{})
baseUrl := "http://example.com/media" baseUrl := "http://example.com/media"
@ -79,6 +127,7 @@ func main() {
panic(db_.Error) panic(db_.Error)
} }
fmt.Println(model.MainImage.URL()) println("Media URL:", model.MainImage.URL())
}`) }
`)
} }

View File

@ -14,6 +14,10 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
func init() {
gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{})
}
type User struct { type User struct {
Id int64 Id int64
Age int64 Age int64

View File

@ -605,6 +605,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
// register method callbacks now for improve performance
field.MethodCallbacks = make(map[string]StructFieldMethodCallback) field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
for callbackName, caller := range StructFieldMethodCallbacks.Callbacks { for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {

View File

@ -473,6 +473,25 @@ func (scope *Scope) quoteIfPossible(str string) string {
return str return str
} }
// call after field method callbacks
func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) {
if !scope.HasError() {
if scope.DB().EnabledAfterScanCallback(scope.Value) {
scopeValue := reflect.ValueOf(scope)
for index, field := range scannerFields {
// if calbacks enabled for field type
if StructFieldMethodCallbacks.EnabledFieldType(field.Field.Type()) {
// not disabled on scan
if _, ok := disableScanField[index]; !ok {
reflectValue := field.Field.Addr()
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
}
}
}
}
}
}
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
var ( var (
ignored interface{} ignored interface{}
@ -527,23 +546,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
} }
} }
if !scope.HasError() { scope.afterScanCallback(scannerFields, disableScanField)
key := "gorm:disable_after_scan"
if v, ok := scope.Get(key); !ok || !v.(bool) {
valueType := indirect(reflect.ValueOf(scope.Value)).Type()
if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) {
scopeValue := reflect.ValueOf(scope)
for index, field := range scannerFields {
if _, ok := disableScanField[index]; !ok {
if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) {
reflectValue := field.Field.Addr()
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
}
}
}
}
}
}
} }
func (scope *Scope) primaryCondition(value interface{}) string { func (scope *Scope) primaryCondition(value interface{}) string {

View File

@ -7,7 +7,6 @@ import (
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"reflect"
) )
func NameIn1And2(d *gorm.DB) *gorm.DB { func NameIn1And2(d *gorm.DB) *gorm.DB {
@ -124,29 +123,80 @@ func TestAfterFieldScanDisableCallback(t *testing.T) {
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
} }
run := func(key string) { run := func(typs ... interface{}) {
DB := DB.Set(key, true) DB := DB.DisableAfterScanCallback(typs...)
var model2 WithFieldAfterScanCallback var model2 WithFieldAfterScanCallback
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err) t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err)
} }
dotest := func(i int, value string, field AfterScanFieldInterface) { dotest := func(i int, value string, field AfterScanFieldInterface) {
if !field.CalledFieldIsNill() { if !field.CalledFieldIsNill() {
t.Errorf("%q: Expected Name%v.calledField is not nil", key, i) t.Errorf("%q: Expected Name%v.calledField is not nil", len(typs), i)
} }
if !field.CalledScopeIsNill() { if !field.CalledScopeIsNill() {
t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i) t.Errorf("%q: Expected Name%v.calledScope is not nil", len(typs), i)
} }
} }
dotest(1, model.Name1.data, model2.Name1) dotest(1, model.Name1.data, model2.Name1)
} }
run("gorm:disable_after_scan") run()
typ := reflect.ValueOf(model).Type() run(model)
run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name()) }
func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) {
model := WithFieldAfterScanCallback{}
model.Name1 = &AfterScanFieldPtr{data: randName()}
model.Name2 = AfterScanFieldPtr{data: randName()}
model.Name3 = &AfterScanField{data: randName()}
model.Name4 = AfterScanField{data: randName()}
if err := DB.Save(&model).Error; err != nil {
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
}
enabled := func(i int, field AfterScanFieldInterface) {
if field.CalledScopeIsNill() {
t.Errorf("Expected Name%v.calledScope, but got nil", i)
}
}
disabled := func(i int, field AfterScanFieldInterface) {
if !field.CalledScopeIsNill() {
t.Errorf("Expected Name%v.calledScope is not nil", i)
}
}
gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{})
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
}
disabled(1, model.Name1)
disabled(2, &model.Name2)
disabled(3, model.Name3)
disabled(4, &model.Name4)
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{})
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
}
enabled(1, model.Name1)
enabled(2, &model.Name2)
disabled(3, model.Name3)
disabled(4, &model.Name4)
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{})
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
}
enabled(1, model.Name1)
enabled(2, &model.Name2)
enabled(3, model.Name3)
enabled(4, &model.Name4)
} }
func TestAfterFieldScanInvalidCallback(t *testing.T) { func TestAfterFieldScanInvalidCallback(t *testing.T) {

View File

@ -135,6 +135,18 @@ func indirect(reflectValue reflect.Value) reflect.Value {
return reflectValue return reflectValue
} }
func indirectType(reflectType reflect.Type) reflect.Type {
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
return reflectType
}
func ptrToType(reflectType reflect.Type) reflect.Type {
reflectType = indirectType(reflectType)
return reflect.PtrTo(reflectType)
}
type Method struct { type Method struct {
index int index int
name string name string