203 lines
5.6 KiB
Go
203 lines
5.6 KiB
Go
package gorm
|
|
|
|
import (
|
|
"reflect"
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
|
|
|
|
type safeEnabledFieldTypes struct {
|
|
m map[reflect.Type]bool
|
|
l *sync.RWMutex
|
|
}
|
|
|
|
func newSafeEnabledFieldTypes() safeEnabledFieldTypes {
|
|
return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)}
|
|
}
|
|
|
|
func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) {
|
|
switch k := key.(type) {
|
|
case reflect.Type:
|
|
k = indirectType(k)
|
|
s.l.Lock()
|
|
defer s.l.Unlock()
|
|
s.m[k] = enabled
|
|
default:
|
|
s.Set(reflect.TypeOf(key), enabled)
|
|
}
|
|
}
|
|
|
|
func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) {
|
|
switch k := key.(type) {
|
|
case reflect.Type:
|
|
k = indirectType(k)
|
|
s.l.RLock()
|
|
defer s.l.RUnlock()
|
|
enabled, ok = s.m[k]
|
|
return
|
|
default:
|
|
return s.Get(reflect.TypeOf(key))
|
|
}
|
|
}
|
|
|
|
func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) {
|
|
switch k := key.(type) {
|
|
case reflect.Type:
|
|
k = indirectType(k)
|
|
s.l.RLock()
|
|
defer s.l.RUnlock()
|
|
_, ok = s.m[k]
|
|
return
|
|
default:
|
|
return s.Has(reflect.TypeOf(key))
|
|
}
|
|
}
|
|
|
|
func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) {
|
|
switch k := key.(type) {
|
|
case reflect.Type:
|
|
k = indirectType(k)
|
|
s.l.Lock()
|
|
defer s.l.Unlock()
|
|
_, ok = s.m[k]
|
|
if ok {
|
|
delete(s.m, k)
|
|
}
|
|
return
|
|
default:
|
|
return s.Del(reflect.TypeOf(key))
|
|
}
|
|
}
|
|
|
|
type StructFieldMethodCallbacksRegistrator struct {
|
|
Callbacks map[string]reflect.Value
|
|
FieldTypes safeEnabledFieldTypes
|
|
l *sync.RWMutex
|
|
}
|
|
|
|
// Register new field type and enable all available callbacks for here
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) {
|
|
for _, typ := range typs {
|
|
if !registrator.FieldTypes.Has(typ) {
|
|
registrator.FieldTypes.Set(typ, true)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Unregister field type and return if ok
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) {
|
|
return registrator.FieldTypes.Del(typ)
|
|
}
|
|
|
|
// Enable all callbacks for field type
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) {
|
|
for _, typ := range typs {
|
|
registrator.FieldTypes.Set(typ, true)
|
|
}
|
|
}
|
|
|
|
// Disable all callbacks for field type
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) {
|
|
for _, typ := range typs {
|
|
registrator.FieldTypes.Set(typ, false)
|
|
}
|
|
}
|
|
|
|
// Return if all callbacks for field type is enabled
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) EnabledFieldType(typ interface{}) bool {
|
|
if enabled, ok := registrator.FieldTypes.Get(typ); ok {
|
|
return enabled
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Return if field type is registered
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool {
|
|
return registrator.FieldTypes.Has(typ)
|
|
}
|
|
|
|
// Register new callback for fields have method methodName
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error {
|
|
value := indirect(reflect.ValueOf(caller))
|
|
|
|
if value.Kind() != reflect.Func {
|
|
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
|
|
}
|
|
|
|
if value.Type().NumIn() < 2 {
|
|
return fmt.Errorf("The caller function %v for method %q require two args. Example: func(methodInfo *gorm.Method, method interface{}).",
|
|
value.Type(), methodName)
|
|
}
|
|
|
|
if value.Type().In(0) != methodPtrType {
|
|
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
|
|
}
|
|
|
|
if value.Type().In(1).Kind() != reflect.Interface {
|
|
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
|
|
}
|
|
|
|
registrator.l.Lock()
|
|
defer registrator.l.Unlock()
|
|
registrator.Callbacks[methodName] = value
|
|
return nil
|
|
}
|
|
|
|
// Register many callbacks where key is the methodName and value is a caller function.
|
|
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error {
|
|
for i, m := range items {
|
|
for methodName, callback := range m {
|
|
err := registrator.registerCallback(methodName, callback)
|
|
if err != nil {
|
|
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
|
|
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(),
|
|
new(sync.RWMutex)}
|
|
}
|
|
|
|
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
|
|
switch method := method.(type) {
|
|
case func():
|
|
method()
|
|
case func(*Scope):
|
|
method(scope)
|
|
case func(*Scope, *Field):
|
|
method(scope, field)
|
|
case func(*DB, *Field):
|
|
newDB := scope.NewDB()
|
|
method(newDB, field)
|
|
scope.Err(newDB.Error)
|
|
case func() error:
|
|
scope.Err(method())
|
|
case func(*Scope) error:
|
|
scope.Err(method(scope))
|
|
case func(*Scope, *Field) error:
|
|
scope.Err(method(scope, field))
|
|
case func(*DB) error:
|
|
newDB := scope.NewDB()
|
|
scope.Err(method(newDB))
|
|
scope.Err(newDB.Error)
|
|
default:
|
|
scope.Err(fmt.Errorf("Invalid AfterScan method callback %v of type %v", reflect.ValueOf(method).Type(), field.Struct.Type))
|
|
}
|
|
}
|
|
|
|
// StructFieldMethodCallbacks is a default registrator for model fields callbacks where the field type is a struct
|
|
// and have a callback method.
|
|
// Default methods callbacks:
|
|
// AfterScan: Call method `AfterScanMethodCallback` after scan field from sql row.
|
|
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
|
|
|
|
func init() {
|
|
checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{
|
|
"AfterScan": AfterScanMethodCallback,
|
|
}))
|
|
} |