Improve after scan method callback performance
This commit is contained in:
parent
c23c829d0d
commit
341532ce4b
58
main.go
58
main.go
@ -766,3 +766,61 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
|
||||
}
|
||||
}
|
||||
|
||||
// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all
|
||||
func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB {
|
||||
key := "gorm:disable_after_scan"
|
||||
|
||||
s = s.clone()
|
||||
|
||||
if len(typs) == 0 {
|
||||
s.values[key] = true
|
||||
return s
|
||||
}
|
||||
|
||||
for _, typ := range typs {
|
||||
rType := indirectType(reflect.TypeOf(typ))
|
||||
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all.
|
||||
// The disabled types will not be enabled unless they are specifically informed.
|
||||
func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB {
|
||||
key := "gorm:disable_after_scan"
|
||||
|
||||
s = s.clone()
|
||||
|
||||
if len(typs) == 0 {
|
||||
s.values[key] = false
|
||||
return s
|
||||
}
|
||||
|
||||
for _, typ := range typs {
|
||||
rType := indirectType(reflect.TypeOf(typ))
|
||||
s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed
|
||||
// typs.
|
||||
func (s *DB) EnabledAfterScanCallback(typs ...interface{}) (ok bool) {
|
||||
key := "gorm:disable_after_scan"
|
||||
|
||||
if v, ok := s.values[key]; !ok || v.(bool) {
|
||||
for _, typ := range typs {
|
||||
rType := indirectType(reflect.TypeOf(typ))
|
||||
v, ok = s.values[key + ":" + rType.PkgPath() + "." + rType.Name()]
|
||||
if ok && !v.(bool) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -3,17 +3,124 @@ package gorm
|
||||
import (
|
||||
"reflect"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0)
|
||||
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
|
||||
|
||||
type StructFieldMethodCallbacksRegistrator struct {
|
||||
Callbacks map[string]reflect.Value
|
||||
type safeEnabledFieldTypes struct {
|
||||
m map[reflect.Type]bool
|
||||
l *sync.RWMutex
|
||||
}
|
||||
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error {
|
||||
value := reflect.ValueOf(caller)
|
||||
func newSafeEnabledFieldTypes() safeEnabledFieldTypes {
|
||||
return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)}
|
||||
}
|
||||
|
||||
func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) {
|
||||
switch k := key.(type) {
|
||||
case reflect.Type:
|
||||
k = indirectType(k)
|
||||
s.l.Lock()
|
||||
defer s.l.Unlock()
|
||||
s.m[k] = enabled
|
||||
default:
|
||||
s.Set(reflect.TypeOf(key), enabled)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) {
|
||||
switch k := key.(type) {
|
||||
case reflect.Type:
|
||||
k = indirectType(k)
|
||||
s.l.RLock()
|
||||
defer s.l.RUnlock()
|
||||
enabled, ok = s.m[k]
|
||||
return
|
||||
default:
|
||||
return s.Get(reflect.TypeOf(key))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) {
|
||||
switch k := key.(type) {
|
||||
case reflect.Type:
|
||||
k = indirectType(k)
|
||||
s.l.RLock()
|
||||
defer s.l.RUnlock()
|
||||
_, ok = s.m[k]
|
||||
return
|
||||
default:
|
||||
return s.Has(reflect.TypeOf(key))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) {
|
||||
switch k := key.(type) {
|
||||
case reflect.Type:
|
||||
k = indirectType(k)
|
||||
s.l.Lock()
|
||||
defer s.l.Unlock()
|
||||
_, ok = s.m[k]
|
||||
if ok {
|
||||
delete(s.m, k)
|
||||
}
|
||||
return
|
||||
default:
|
||||
return s.Del(reflect.TypeOf(key))
|
||||
}
|
||||
}
|
||||
|
||||
type StructFieldMethodCallbacksRegistrator struct {
|
||||
Callbacks map[string]reflect.Value
|
||||
FieldTypes safeEnabledFieldTypes
|
||||
l *sync.RWMutex
|
||||
}
|
||||
|
||||
// Register new field type and enable all available callbacks for here
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) {
|
||||
for _, typ := range typs {
|
||||
if !registrator.FieldTypes.Has(typ) {
|
||||
registrator.FieldTypes.Set(typ, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unregister field type and return if ok
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) {
|
||||
return registrator.FieldTypes.Del(typ)
|
||||
}
|
||||
|
||||
// Enable all callbacks for field type
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) {
|
||||
for _, typ := range typs {
|
||||
registrator.FieldTypes.Set(typ, true)
|
||||
}
|
||||
}
|
||||
|
||||
// Disable all callbacks for field type
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) {
|
||||
for _, typ := range typs {
|
||||
registrator.FieldTypes.Set(typ, false)
|
||||
}
|
||||
}
|
||||
|
||||
// Return if all callbacks for field type is enabled
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) EnabledFieldType(typ interface{}) bool {
|
||||
if enabled, ok := registrator.FieldTypes.Get(typ); ok {
|
||||
return enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Return if field type is registered
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool {
|
||||
return registrator.FieldTypes.Has(typ)
|
||||
}
|
||||
|
||||
// Register new callback for fields have method methodName
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error {
|
||||
value := indirect(reflect.ValueOf(caller))
|
||||
|
||||
if value.Kind() != reflect.Func {
|
||||
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
|
||||
@ -28,18 +135,21 @@ func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName st
|
||||
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
|
||||
}
|
||||
|
||||
if value.Type().In(1) != interfaceType {
|
||||
if value.Type().In(1).Kind() != reflect.Interface {
|
||||
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
|
||||
}
|
||||
|
||||
registrator.l.Lock()
|
||||
defer registrator.l.Unlock()
|
||||
registrator.Callbacks[methodName] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error {
|
||||
// Register many callbacks where key is the methodName and value is a caller function.
|
||||
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error {
|
||||
for i, m := range items {
|
||||
for methodName, callback := range m {
|
||||
err := registrator.Register(methodName, callback)
|
||||
err := registrator.registerCallback(methodName, callback)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
|
||||
}
|
||||
@ -49,7 +159,8 @@ func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...
|
||||
}
|
||||
|
||||
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
|
||||
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)}
|
||||
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(),
|
||||
new(sync.RWMutex)}
|
||||
}
|
||||
|
||||
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
|
||||
@ -86,7 +197,7 @@ func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Fiel
|
||||
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
|
||||
|
||||
func init() {
|
||||
checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{
|
||||
checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{
|
||||
"AfterScan": AfterScanMethodCallback,
|
||||
}))
|
||||
}
|
||||
}
|
@ -1,16 +1,59 @@
|
||||
package gorm_test
|
||||
|
||||
import "fmt"
|
||||
|
||||
func ExampleAfterScanMethodCallback() {
|
||||
fmt.Println(`package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
}
|
||||
|
||||
func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() {
|
||||
fmt.Println(`
|
||||
if registrator.EnabledFieldType(&Media{}) {
|
||||
registrator.DisableFieldType(&Media{})
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() {
|
||||
fmt.Println(`
|
||||
if !registrator.EnabledFieldType(&Media{}) {
|
||||
println("not enabled")
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() {
|
||||
fmt.Println(`
|
||||
if !registrator.EnabledFieldType(&Media{}) {
|
||||
registrator.EnableFieldType(&Media{})
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() {
|
||||
fmt.Println(`
|
||||
if registrator.RegisteredFieldType(&Media{}) {
|
||||
println("not registered")
|
||||
}`)
|
||||
}
|
||||
|
||||
func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() {
|
||||
fmt.Println("registrator.RegisterFieldType(&Media{})")
|
||||
}
|
||||
|
||||
func ExampleAfterScanMethodCallback() {
|
||||
println(`
|
||||
package main
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"github.com/jinzhu/gorm"
|
||||
"database/sql/driver"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Media struct {
|
||||
@ -45,7 +88,8 @@ func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) {
|
||||
}
|
||||
|
||||
func (image *Media) URL() string {
|
||||
return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name)
|
||||
return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()),
|
||||
*image.fieldName, image.Name}, "/")
|
||||
}
|
||||
|
||||
type User struct {
|
||||
@ -58,10 +102,14 @@ func (user *User) GetID() int {
|
||||
}
|
||||
|
||||
func main() {
|
||||
// register media type
|
||||
gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{})
|
||||
|
||||
db, err := gorm.Open("sqlite3", "db.db")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.AutoMigrate(&User{})
|
||||
|
||||
baseUrl := "http://example.com/media"
|
||||
@ -79,6 +127,7 @@ func main() {
|
||||
panic(db_.Error)
|
||||
}
|
||||
|
||||
fmt.Println(model.MainImage.URL())
|
||||
}`)
|
||||
println("Media URL:", model.MainImage.URL())
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
@ -14,6 +14,10 @@ import (
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{})
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Id int64
|
||||
Age int64
|
||||
|
@ -605,6 +605,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
}
|
||||
|
||||
// register method callbacks now for improve performance
|
||||
field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
|
||||
|
||||
for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {
|
||||
|
37
scope.go
37
scope.go
@ -473,6 +473,25 @@ func (scope *Scope) quoteIfPossible(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
// call after field method callbacks
|
||||
func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) {
|
||||
if !scope.HasError() {
|
||||
if scope.DB().EnabledAfterScanCallback(scope.Value) {
|
||||
scopeValue := reflect.ValueOf(scope)
|
||||
for index, field := range scannerFields {
|
||||
// if calbacks enabled for field type
|
||||
if StructFieldMethodCallbacks.EnabledFieldType(field.Field.Type()) {
|
||||
// not disabled on scan
|
||||
if _, ok := disableScanField[index]; !ok {
|
||||
reflectValue := field.Field.Addr()
|
||||
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
||||
var (
|
||||
ignored interface{}
|
||||
@ -527,23 +546,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
||||
}
|
||||
}
|
||||
|
||||
if !scope.HasError() {
|
||||
key := "gorm:disable_after_scan"
|
||||
if v, ok := scope.Get(key); !ok || !v.(bool) {
|
||||
valueType := indirect(reflect.ValueOf(scope.Value)).Type()
|
||||
if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) {
|
||||
scopeValue := reflect.ValueOf(scope)
|
||||
for index, field := range scannerFields {
|
||||
if _, ok := disableScanField[index]; !ok {
|
||||
if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) {
|
||||
reflectValue := field.Field.Addr()
|
||||
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
scope.afterScanCallback(scannerFields, disableScanField)
|
||||
}
|
||||
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
||||
@ -124,29 +123,80 @@ func TestAfterFieldScanDisableCallback(t *testing.T) {
|
||||
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||
}
|
||||
|
||||
run := func(key string) {
|
||||
DB := DB.Set(key, true)
|
||||
run := func(typs ... interface{}) {
|
||||
DB := DB.DisableAfterScanCallback(typs...)
|
||||
var model2 WithFieldAfterScanCallback
|
||||
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
|
||||
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err)
|
||||
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err)
|
||||
}
|
||||
|
||||
dotest := func(i int, value string, field AfterScanFieldInterface) {
|
||||
if !field.CalledFieldIsNill() {
|
||||
t.Errorf("%q: Expected Name%v.calledField is not nil", key, i)
|
||||
t.Errorf("%q: Expected Name%v.calledField is not nil", len(typs), i)
|
||||
}
|
||||
|
||||
if !field.CalledScopeIsNill() {
|
||||
t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i)
|
||||
t.Errorf("%q: Expected Name%v.calledScope is not nil", len(typs), i)
|
||||
}
|
||||
}
|
||||
|
||||
dotest(1, model.Name1.data, model2.Name1)
|
||||
}
|
||||
|
||||
run("gorm:disable_after_scan")
|
||||
typ := reflect.ValueOf(model).Type()
|
||||
run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name())
|
||||
run()
|
||||
run(model)
|
||||
}
|
||||
|
||||
func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) {
|
||||
model := WithFieldAfterScanCallback{}
|
||||
model.Name1 = &AfterScanFieldPtr{data: randName()}
|
||||
model.Name2 = AfterScanFieldPtr{data: randName()}
|
||||
model.Name3 = &AfterScanField{data: randName()}
|
||||
model.Name4 = AfterScanField{data: randName()}
|
||||
|
||||
if err := DB.Save(&model).Error; err != nil {
|
||||
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||
}
|
||||
|
||||
enabled := func(i int, field AfterScanFieldInterface) {
|
||||
if field.CalledScopeIsNill() {
|
||||
t.Errorf("Expected Name%v.calledScope, but got nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
disabled := func(i int, field AfterScanFieldInterface) {
|
||||
if !field.CalledScopeIsNill() {
|
||||
t.Errorf("Expected Name%v.calledScope is not nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{})
|
||||
|
||||
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
|
||||
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||
}
|
||||
disabled(1, model.Name1)
|
||||
disabled(2, &model.Name2)
|
||||
disabled(3, model.Name3)
|
||||
disabled(4, &model.Name4)
|
||||
|
||||
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{})
|
||||
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
|
||||
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||
}
|
||||
enabled(1, model.Name1)
|
||||
enabled(2, &model.Name2)
|
||||
disabled(3, model.Name3)
|
||||
disabled(4, &model.Name4)
|
||||
|
||||
gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{})
|
||||
if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
|
||||
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||
}
|
||||
enabled(1, model.Name1)
|
||||
enabled(2, &model.Name2)
|
||||
enabled(3, model.Name3)
|
||||
enabled(4, &model.Name4)
|
||||
}
|
||||
|
||||
func TestAfterFieldScanInvalidCallback(t *testing.T) {
|
||||
|
12
utils.go
12
utils.go
@ -135,6 +135,18 @@ func indirect(reflectValue reflect.Value) reflect.Value {
|
||||
return reflectValue
|
||||
}
|
||||
|
||||
func indirectType(reflectType reflect.Type) reflect.Type {
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
return reflectType
|
||||
}
|
||||
|
||||
func ptrToType(reflectType reflect.Type) reflect.Type {
|
||||
reflectType = indirectType(reflectType)
|
||||
return reflect.PtrTo(reflectType)
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
index int
|
||||
name string
|
||||
|
Loading…
x
Reference in New Issue
Block a user