diff --git a/main.go b/main.go index 4bbaadab..15c0547d 100644 --- a/main.go +++ b/main.go @@ -766,3 +766,61 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) } } + +// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all +func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB { + key := "gorm:disable_after_scan" + + s = s.clone() + + if len(typs) == 0 { + s.values[key] = true + return s + } + + for _, typ := range typs { + rType := indirectType(reflect.TypeOf(typ)) + s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true + } + + return s +} + +// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all. +// The disabled types will not be enabled unless they are specifically informed. +func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB { + key := "gorm:disable_after_scan" + + s = s.clone() + + if len(typs) == 0 { + s.values[key] = false + return s + } + + for _, typ := range typs { + rType := indirectType(reflect.TypeOf(typ)) + s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false + } + + return s +} + +// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed +// typs. +func (s *DB) EnabledAfterScanCallback(typs ...interface{}) (ok bool) { + key := "gorm:disable_after_scan" + + if v, ok := s.values[key]; !ok || v.(bool) { + for _, typ := range typs { + rType := indirectType(reflect.TypeOf(typ)) + v, ok = s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] + if ok && !v.(bool) { + return false + } + } + return true + } + + return false +} diff --git a/methodcallback.go b/methodcallback.go index f87e111e..5f437b10 100644 --- a/methodcallback.go +++ b/methodcallback.go @@ -3,17 +3,124 @@ package gorm import ( "reflect" "fmt" + "sync" ) -var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0) var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{})) -type StructFieldMethodCallbacksRegistrator struct { - Callbacks map[string]reflect.Value +type safeEnabledFieldTypes struct { + m map[reflect.Type]bool + l *sync.RWMutex } -func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error { - value := reflect.ValueOf(caller) +func newSafeEnabledFieldTypes() safeEnabledFieldTypes { + return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)} +} + +func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.Lock() + defer s.l.Unlock() + s.m[k] = enabled + default: + s.Set(reflect.TypeOf(key), enabled) + } +} + +func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.RLock() + defer s.l.RUnlock() + enabled, ok = s.m[k] + return + default: + return s.Get(reflect.TypeOf(key)) + } +} + +func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.RLock() + defer s.l.RUnlock() + _, ok = s.m[k] + return + default: + return s.Has(reflect.TypeOf(key)) + } +} + +func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.Lock() + defer s.l.Unlock() + _, ok = s.m[k] + if ok { + delete(s.m, k) + } + return + default: + return s.Del(reflect.TypeOf(key)) + } +} + +type StructFieldMethodCallbacksRegistrator struct { + Callbacks map[string]reflect.Value + FieldTypes safeEnabledFieldTypes + l *sync.RWMutex +} + +// Register new field type and enable all available callbacks for here +func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) { + for _, typ := range typs { + if !registrator.FieldTypes.Has(typ) { + registrator.FieldTypes.Set(typ, true) + } + } +} + +// Unregister field type and return if ok +func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) { + return registrator.FieldTypes.Del(typ) +} + +// Enable all callbacks for field type +func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) { + for _, typ := range typs { + registrator.FieldTypes.Set(typ, true) + } +} + +// Disable all callbacks for field type +func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) { + for _, typ := range typs { + registrator.FieldTypes.Set(typ, false) + } +} + +// Return if all callbacks for field type is enabled +func (registrator *StructFieldMethodCallbacksRegistrator) EnabledFieldType(typ interface{}) bool { + if enabled, ok := registrator.FieldTypes.Get(typ); ok { + return enabled + } + return false +} + +// Return if field type is registered +func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool { + return registrator.FieldTypes.Has(typ) +} + +// Register new callback for fields have method methodName +func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error { + value := indirect(reflect.ValueOf(caller)) if value.Kind() != reflect.Func { return fmt.Errorf("Caller of method %q isn't a function.", methodName) @@ -28,18 +135,21 @@ func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName st return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType) } - if value.Type().In(1) != interfaceType { + if value.Type().In(1).Kind() != reflect.Interface { return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName) } + registrator.l.Lock() + defer registrator.l.Unlock() registrator.Callbacks[methodName] = value return nil } -func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error { +// Register many callbacks where key is the methodName and value is a caller function. +func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error { for i, m := range items { for methodName, callback := range m { - err := registrator.Register(methodName, callback) + err := registrator.registerCallback(methodName, callback) if err != nil { return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err) } @@ -49,7 +159,8 @@ func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ... } func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator { - return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)} + return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(), + new(sync.RWMutex)} } func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) { @@ -86,7 +197,7 @@ func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Fiel var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator() func init() { - checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{ + checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{ "AfterScan": AfterScanMethodCallback, })) -} +} \ No newline at end of file diff --git a/methodcallback_test.go b/methodcallback_test.go index 0836aff4..6f417c89 100644 --- a/methodcallback_test.go +++ b/methodcallback_test.go @@ -1,16 +1,59 @@ package gorm_test -import "fmt" - -func ExampleAfterScanMethodCallback() { - fmt.Println(`package main - import ( "fmt" +) + +func init() { + +} + +func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() { + fmt.Println(` +if registrator.EnabledFieldType(&Media{}) { + registrator.DisableFieldType(&Media{}) +} +`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() { + fmt.Println(` +if !registrator.EnabledFieldType(&Media{}) { + println("not enabled") +} +`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() { + fmt.Println(` +if !registrator.EnabledFieldType(&Media{}) { + registrator.EnableFieldType(&Media{}) +} +`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() { + fmt.Println(` +if registrator.RegisteredFieldType(&Media{}) { + println("not registered") +}`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() { + fmt.Println("registrator.RegisterFieldType(&Media{})") +} + +func ExampleAfterScanMethodCallback() { + println(` +package main + +import ( "reflect" "github.com/jinzhu/gorm" "database/sql/driver" _ "github.com/jinzhu/gorm/dialects/sqlite" + "strconv" + "strings" ) type Media struct { @@ -45,7 +88,8 @@ func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) { } func (image *Media) URL() string { - return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name) + return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()), + *image.fieldName, image.Name}, "/") } type User struct { @@ -58,10 +102,14 @@ func (user *User) GetID() int { } func main() { + // register media type + gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{}) + db, err := gorm.Open("sqlite3", "db.db") if err != nil { panic(err) } + db.AutoMigrate(&User{}) baseUrl := "http://example.com/media" @@ -79,6 +127,7 @@ func main() { panic(db_.Error) } - fmt.Println(model.MainImage.URL()) -}`) + println("Media URL:", model.MainImage.URL()) +} +`) } diff --git a/migration_test.go b/migration_test.go index 21c27fdd..e5b8ee55 100644 --- a/migration_test.go +++ b/migration_test.go @@ -14,6 +14,10 @@ import ( "github.com/jinzhu/gorm" ) +func init() { + gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{}) +} + type User struct { Id int64 Age int64 diff --git a/model_struct.go b/model_struct.go index d45a7143..c32b2e90 100644 --- a/model_struct.go +++ b/model_struct.go @@ -605,6 +605,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + // register method callbacks now for improve performance field.MethodCallbacks = make(map[string]StructFieldMethodCallback) for callbackName, caller := range StructFieldMethodCallbacks.Callbacks { diff --git a/scope.go b/scope.go index 8bea93d0..6935219e 100644 --- a/scope.go +++ b/scope.go @@ -473,6 +473,25 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } +// call after field method callbacks +func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) { + if !scope.HasError() { + if scope.DB().EnabledAfterScanCallback(scope.Value) { + scopeValue := reflect.ValueOf(scope) + for index, field := range scannerFields { + // if calbacks enabled for field type + if StructFieldMethodCallbacks.EnabledFieldType(field.Field.Type()) { + // not disabled on scan + if _, ok := disableScanField[index]; !ok { + reflectValue := field.Field.Addr() + field.CallMethodCallback("AfterScan", reflectValue, scopeValue) + } + } + } + } + } +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} @@ -527,23 +546,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { } } - if !scope.HasError() { - key := "gorm:disable_after_scan" - if v, ok := scope.Get(key); !ok || !v.(bool) { - valueType := indirect(reflect.ValueOf(scope.Value)).Type() - if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) { - scopeValue := reflect.ValueOf(scope) - for index, field := range scannerFields { - if _, ok := disableScanField[index]; !ok { - if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) { - reflectValue := field.Field.Addr() - field.CallMethodCallback("AfterScan", reflectValue, scopeValue) - } - } - } - } - } - } + scope.afterScanCallback(scannerFields, disableScanField) } func (scope *Scope) primaryCondition(value interface{}) string { diff --git a/scope_test.go b/scope_test.go index c4df9c6b..fa4c30fb 100644 --- a/scope_test.go +++ b/scope_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/jinzhu/gorm" - "reflect" ) func NameIn1And2(d *gorm.DB) *gorm.DB { @@ -124,29 +123,80 @@ func TestAfterFieldScanDisableCallback(t *testing.T) { t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) } - run := func(key string) { - DB := DB.Set(key, true) + run := func(typs ... interface{}) { + DB := DB.DisableAfterScanCallback(typs...) var model2 WithFieldAfterScanCallback if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { - t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err) + t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err) } dotest := func(i int, value string, field AfterScanFieldInterface) { if !field.CalledFieldIsNill() { - t.Errorf("%q: Expected Name%v.calledField is not nil", key, i) + t.Errorf("%q: Expected Name%v.calledField is not nil", len(typs), i) } if !field.CalledScopeIsNill() { - t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i) + t.Errorf("%q: Expected Name%v.calledScope is not nil", len(typs), i) } } dotest(1, model.Name1.data, model2.Name1) } - run("gorm:disable_after_scan") - typ := reflect.ValueOf(model).Type() - run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name()) + run() + run(model) +} + +func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) { + model := WithFieldAfterScanCallback{} + model.Name1 = &AfterScanFieldPtr{data: randName()} + model.Name2 = AfterScanFieldPtr{data: randName()} + model.Name3 = &AfterScanField{data: randName()} + model.Name4 = AfterScanField{data: randName()} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + enabled := func(i int, field AfterScanFieldInterface) { + if field.CalledScopeIsNill() { + t.Errorf("Expected Name%v.calledScope, but got nil", i) + } + } + + disabled := func(i int, field AfterScanFieldInterface) { + if !field.CalledScopeIsNill() { + t.Errorf("Expected Name%v.calledScope is not nil", i) + } + } + + gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{}) + + if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + disabled(1, model.Name1) + disabled(2, &model.Name2) + disabled(3, model.Name3) + disabled(4, &model.Name4) + + gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{}) + if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + enabled(1, model.Name1) + enabled(2, &model.Name2) + disabled(3, model.Name3) + disabled(4, &model.Name4) + + gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{}) + if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + enabled(1, model.Name1) + enabled(2, &model.Name2) + enabled(3, model.Name3) + enabled(4, &model.Name4) } func TestAfterFieldScanInvalidCallback(t *testing.T) { diff --git a/utils.go b/utils.go index 999e8894..870f143d 100644 --- a/utils.go +++ b/utils.go @@ -135,6 +135,18 @@ func indirect(reflectValue reflect.Value) reflect.Value { return reflectValue } +func indirectType(reflectType reflect.Type) reflect.Type { + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + return reflectType +} + +func ptrToType(reflectType reflect.Type) reflect.Type { + reflectType = indirectType(reflectType) + return reflect.PtrTo(reflectType) +} + type Method struct { index int name string