gorm/methodcallback.go
2018-02-20 17:50:49 -03:00

93 lines
2.8 KiB
Go

package gorm
import (
"reflect"
"fmt"
)
var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0)
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
type StructFieldMethodCallbacksRegistrator struct {
Callbacks map[string]reflect.Value
}
func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error {
value := reflect.ValueOf(caller)
if value.Kind() != reflect.Func {
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
}
if value.Type().NumIn() < 2 {
return fmt.Errorf("The caller function %v for method %q require two args. Example: func(methodInfo *gorm.Method, method interface{}).",
value.Type(), methodName)
}
if value.Type().In(0) != methodPtrType {
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
}
if value.Type().In(1) != interfaceType {
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
}
registrator.Callbacks[methodName] = value
return nil
}
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error {
for i, m := range items {
for methodName, callback := range m {
err := registrator.Register(methodName, callback)
if err != nil {
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
}
}
}
return nil
}
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)}
}
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
switch method := method.(type) {
case func():
method()
case func(*Scope):
method(scope)
case func(*Scope, *Field):
method(scope, field)
case func(*DB, *Field):
newDB := scope.NewDB()
method(newDB, field)
scope.Err(newDB.Error)
case func() error:
scope.Err(method())
case func(*Scope) error:
scope.Err(method(scope))
case func(*Scope, *Field) error:
scope.Err(method(scope, field))
case func(*DB) error:
newDB := scope.NewDB()
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("Invalid AfterScan method callback %v of type %v", reflect.ValueOf(method).Type(), field.Struct.Type))
}
}
// StructFieldMethodCallbacks is a default registrator for model fields callbacks where the field type is a struct
// and have a callback method.
// Default methods callbacks:
// AfterScan: Call method `AfterScanMethodCallback` after scan field from sql row.
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
func init() {
checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{
"AfterScan": AfterScanMethodCallback,
}))
}