From c97b650e0ec7a5734f9b851266ad2833ca41e617 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Tue, 20 Feb 2018 17:49:38 -0300 Subject: [PATCH] Added Scope After Scan Method Callback --- .idea/vcs.xml | 6 +++ field.go | 9 +++++ methodcallback.go | 92 ++++++++++++++++++++++++++++++++++++++++++ methodcallback_test.go | 84 ++++++++++++++++++++++++++++++++++++++ migration_test.go | 91 ++++++++++++++++++++++++++++++++++++++++- model_struct.go | 31 ++++++++++++++ scope.go | 33 +++++++++++++-- scope_test.go | 88 ++++++++++++++++++++++++++++++++++++++++ utils.go | 73 +++++++++++++++++++++++++++++++++ 9 files changed, 503 insertions(+), 4 deletions(-) create mode 100644 .idea/vcs.xml create mode 100644 methodcallback.go create mode 100644 methodcallback_test.go diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..94a25f7f --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/field.go b/field.go index 11c410b0..161a2dc3 100644 --- a/field.go +++ b/field.go @@ -14,6 +14,15 @@ type Field struct { Field reflect.Value } +func (field *Field) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) { + field.StructField.CallMethodCallbackArgs(name, object, append([]reflect.Value{reflect.ValueOf(field)}, in...)) +} + +// Call the method callback +func (field *Field) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) { + field.CallMethodCallbackArgs(name, object, in) +} + // Set set a value to the field func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { diff --git a/methodcallback.go b/methodcallback.go new file mode 100644 index 00000000..f87e111e --- /dev/null +++ b/methodcallback.go @@ -0,0 +1,92 @@ +package gorm + +import ( + "reflect" + "fmt" +) + +var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0) +var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{})) + +type StructFieldMethodCallbacksRegistrator struct { + Callbacks map[string]reflect.Value +} + +func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error { + value := reflect.ValueOf(caller) + + if value.Kind() != reflect.Func { + return fmt.Errorf("Caller of method %q isn't a function.", methodName) + } + + if value.Type().NumIn() < 2 { + return fmt.Errorf("The caller function %v for method %q require two args. Example: func(methodInfo *gorm.Method, method interface{}).", + value.Type(), methodName) + } + + if value.Type().In(0) != methodPtrType { + return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType) + } + + if value.Type().In(1) != interfaceType { + return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName) + } + + registrator.Callbacks[methodName] = value + return nil +} + +func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error { + for i, m := range items { + for methodName, callback := range m { + err := registrator.Register(methodName, callback) + if err != nil { + return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err) + } + } + } + return nil +} + +func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator { + return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)} +} + +func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) { + switch method := method.(type) { + case func(): + method() + case func(*Scope): + method(scope) + case func(*Scope, *Field): + method(scope, field) + case func(*DB, *Field): + newDB := scope.NewDB() + method(newDB, field) + scope.Err(newDB.Error) + case func() error: + scope.Err(method()) + case func(*Scope) error: + scope.Err(method(scope)) + case func(*Scope, *Field) error: + scope.Err(method(scope, field)) + case func(*DB) error: + newDB := scope.NewDB() + scope.Err(method(newDB)) + scope.Err(newDB.Error) + default: + scope.Err(fmt.Errorf("Invalid AfterScan method callback %v of type %v", reflect.ValueOf(method).Type(), field.Struct.Type)) + } +} + +// StructFieldMethodCallbacks is a default registrator for model fields callbacks where the field type is a struct +// and have a callback method. +// Default methods callbacks: +// AfterScan: Call method `AfterScanMethodCallback` after scan field from sql row. +var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator() + +func init() { + checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{ + "AfterScan": AfterScanMethodCallback, + })) +} diff --git a/methodcallback_test.go b/methodcallback_test.go new file mode 100644 index 00000000..0836aff4 --- /dev/null +++ b/methodcallback_test.go @@ -0,0 +1,84 @@ +package gorm_test + +import "fmt" + +func ExampleAfterScanMethodCallback() { + fmt.Println(`package main + +import ( + "fmt" + "reflect" + "github.com/jinzhu/gorm" + "database/sql/driver" + _ "github.com/jinzhu/gorm/dialects/sqlite" +) + +type Media struct { + Name string + baseUrl *string + modelType reflect.Type + model interface { + GetID() int + } + fieldName *string +} + +func (image *Media) Scan(value interface{}) error { + image.Name = string(value.([]byte)) + return nil +} + +func (image *Media) Value() (driver.Value, error) { + return image.Name, nil +} + +func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) { + image.fieldName, image.model = &field.StructField.Name, scope.Value.(interface { + GetID() int + }) + baseUrl, _ := scope.DB().Get("base_url") + image.baseUrl = baseUrl.(*string) + image.modelType = reflect.TypeOf(scope.Value) + for image.modelType.Kind() == reflect.Ptr { + image.modelType = image.modelType.Elem() + } +} + +func (image *Media) URL() string { + return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name) +} + +type User struct { + ID int + MainImage Media +} + +func (user *User) GetID() int { + return user.ID +} + +func main() { + db, err := gorm.Open("sqlite3", "db.db") + if err != nil { + panic(err) + } + db.AutoMigrate(&User{}) + + baseUrl := "http://example.com/media" + db = db.Set("base_url", &baseUrl) + + var model User + db_ := db.Where("id = ?", 1).First(&model) + if db_.RecordNotFound() { + db.Save(&User{MainImage: Media{Name: "picture.jpg"}}) + err = db.Where("id = ?", 1).First(&model).Error + if err != nil { + panic(err) + } + } else if db_.Error != nil { + panic(db_.Error) + } + + fmt.Println(model.MainImage.URL()) +}`) +} diff --git a/migration_test.go b/migration_test.go index 7c694485..21c27fdd 100644 --- a/migration_test.go +++ b/migration_test.go @@ -253,6 +253,95 @@ func (nt NullTime) Value() (driver.Value, error) { return nt.Time, nil } +type AfterScanFieldInterface interface { + CalledScopeIsNill() bool + CalledFieldIsNill() bool + Data() string +} + +type AfterScanFieldPtr struct { + data string + calledScopeNotIsNill bool + calledFieldNotIsNill bool +} + +func (s *AfterScanFieldPtr) Data() string { + return s.data +} + +func (s AfterScanFieldPtr) CalledScopeIsNill() bool { + return !s.calledScopeNotIsNill +} + +func (s AfterScanFieldPtr) CalledFieldIsNill() bool { + return !s.calledFieldNotIsNill +} + +func (s *AfterScanFieldPtr) Scan(value interface{}) error { + s.data = string(value.([]byte)) + return nil +} + +func (s AfterScanFieldPtr) Value() (driver.Value, error) { + return s.data, nil +} + +func (s *AfterScanFieldPtr) AfterScan(scope *gorm.Scope, field *gorm.Field) { + s.calledScopeNotIsNill = scope != nil + s.calledFieldNotIsNill = field != nil +} + +type AfterScanField struct { + data string + calledScopeNotIsNill bool +} + +func (s *AfterScanField) Data() string { + return s.data +} + +func (s AfterScanField) CalledScopeIsNill() bool { + return !s.calledScopeNotIsNill +} + +func (s AfterScanField) CalledFieldIsNill() bool { + return false +} + +func (s *AfterScanField) Scan(value interface{}) error { + s.data = string(value.([]byte)) + return nil +} + +func (s AfterScanField) Value() (driver.Value, error) { + return s.data, nil +} + +func (s *AfterScanField) AfterScan(scope *gorm.Scope) { + s.calledScopeNotIsNill = scope != nil +} + +type WithFieldAfterScanCallback struct { + ID int + Name1 *AfterScanFieldPtr + Name2 AfterScanFieldPtr + Name3 *AfterScanField + Name4 AfterScanField +} + +type InvalidAfterScanField struct { + AfterScanField +} + +func (s InvalidAfterScanField) AfterScan(invalidArg int) { +} + +type WithFieldAfterScanInvalidCallback struct { + ID int + Name InvalidAfterScanField +} + + func getPreparedUser(name string, role string) *User { var company Company DB.Where(Company{Name: role}).FirstOrCreate(&company) @@ -284,7 +373,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &WithFieldAfterScanCallback{}, &WithFieldAfterScanInvalidCallback{}} for _, value := range values { DB.DropTable(value) } diff --git a/model_struct.go b/model_struct.go index f571e2e8..d45a7143 100644 --- a/model_struct.go +++ b/model_struct.go @@ -66,6 +66,15 @@ func (s *ModelStruct) TableName(db *DB) string { return DefaultTableNameHandler(db, s.defaultTableName) } +type StructFieldMethodCallback struct { + Method + Caller reflect.Value +} + +func (s StructFieldMethodCallback) Call(object reflect.Value, in []reflect.Value) { + s.Caller.Call(append([]reflect.Value{reflect.ValueOf(&s.Method), s.ObjectMethod(object)}, in...)) +} + // StructField model field's struct definition type StructField struct { DBName string @@ -81,6 +90,19 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + MethodCallbacks map[string]StructFieldMethodCallback +} + +// Call the method callback if exists by name. +func (structField *StructField) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) { + if callback, ok := structField.MethodCallbacks[name]; ok { + callback.Call(object, in) + } +} + +// Call the method callback if exists by name. the +func (structField *StructField) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) { + structField.CallMethodCallbackArgs(name, object, in) } func (structField *StructField) clone() *StructField { @@ -97,6 +119,7 @@ func (structField *StructField) clone() *StructField { TagSettings: map[string]string{}, Struct: structField.Struct, IsForeignKey: structField.IsForeignKey, + MethodCallbacks: structField.MethodCallbacks, } if structField.Relationship != nil { @@ -581,6 +604,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsNormal = true } } + + field.MethodCallbacks = make(map[string]StructFieldMethodCallback) + + for callbackName, caller := range StructFieldMethodCallbacks.Callbacks { + if callbackMethod := MethodByName(indirectType, callbackName); callbackMethod.valid { + field.MethodCallbacks[callbackName] = StructFieldMethodCallback{callbackMethod, caller} + } + } } // Even it is ignored, also possible to decode db value into the field diff --git a/scope.go b/scope.go index 25077efc..8bea93d0 100644 --- a/scope.go +++ b/scope.go @@ -238,7 +238,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { return errors.New("could not convert column to field") } -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one +// CallMethod call scope value's Method, if it is a slice, will call its element's Method one by one func (scope *Scope) CallMethod(methodName string) { if scope.Value == nil { return @@ -482,6 +482,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields = map[int]*Field{} ) + scannerFields := make(map[int]*Field) + for index, column := range columns { values[index] = &ignored @@ -492,6 +494,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { for fieldIndex, field := range selectFields { if field.DBName == column { + scannerFields[index] = field + if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { @@ -512,9 +516,32 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { scope.Err(rows.Scan(values...)) + disableScanField := make(map[int]bool) + for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) + reflectValue := reflect.ValueOf(values[index]).Elem().Elem() + if reflectValue.IsValid() { + field.Field.Set(reflectValue) + } else { + disableScanField[index] = true + } + } + + if !scope.HasError() { + key := "gorm:disable_after_scan" + if v, ok := scope.Get(key); !ok || !v.(bool) { + valueType := indirect(reflect.ValueOf(scope.Value)).Type() + if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) { + scopeValue := reflect.ValueOf(scope) + for index, field := range scannerFields { + if _, ok := disableScanField[index]; !ok { + if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) { + reflectValue := field.Field.Addr() + field.CallMethodCallback("AfterScan", reflectValue, scopeValue) + } + } + } + } } } } diff --git a/scope_test.go b/scope_test.go index 3018f350..c4df9c6b 100644 --- a/scope_test.go +++ b/scope_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/jinzhu/gorm" + "reflect" ) func NameIn1And2(d *gorm.DB) *gorm.DB { @@ -78,3 +79,90 @@ func TestFailedValuer(t *testing.T) { t.Errorf("The error should be returned from Valuer, but get %v", err) } } + +func TestAfterFieldScanCallback(t *testing.T) { + model := WithFieldAfterScanCallback{} + model.Name1 = &AfterScanFieldPtr{data: randName()} + model.Name2 = AfterScanFieldPtr{data: randName()} + model.Name3 = &AfterScanField{data: randName()} + model.Name4 = AfterScanField{data: randName()} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + var model2 WithFieldAfterScanCallback + if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + + dotest := func(i int, value string, field AfterScanFieldInterface) { + if field.CalledFieldIsNill() { + t.Errorf("Expected Name%v.calledField, but got nil", i) + } + + if field.CalledScopeIsNill() { + t.Errorf("Expected Name%v.calledScope, but got nil", i) + } + + if field.Data() != value { + t.Errorf("Expected Name%v.data %q, but got %q", i, value, field.Data()) + } + } + + dotest(1, model.Name1.data, model2.Name1) + dotest(2, model.Name2.data, &model2.Name2) + dotest(3, model.Name3.data, model2.Name3) + dotest(4, model.Name4.data, &model2.Name4) +} + +func TestAfterFieldScanDisableCallback(t *testing.T) { + model := WithFieldAfterScanCallback{} + model.Name1 = &AfterScanFieldPtr{data: randName()} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + run := func(key string) { + DB := DB.Set(key, true) + var model2 WithFieldAfterScanCallback + if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { + t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err) + } + + dotest := func(i int, value string, field AfterScanFieldInterface) { + if !field.CalledFieldIsNill() { + t.Errorf("%q: Expected Name%v.calledField is not nil", key, i) + } + + if !field.CalledScopeIsNill() { + t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i) + } + } + + dotest(1, model.Name1.data, model2.Name1) + } + + run("gorm:disable_after_scan") + typ := reflect.ValueOf(model).Type() + run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name()) +} + +func TestAfterFieldScanInvalidCallback(t *testing.T) { + model := WithFieldAfterScanInvalidCallback{} + model.Name = InvalidAfterScanField{AfterScanField{data: randName()}} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + var model2 WithFieldAfterScanInvalidCallback + if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { + if !strings.Contains(err.Error(), "Invalid AfterScan method callback") { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + } else { + t.Errorf("Expected error, but got nil") + } +} diff --git a/utils.go b/utils.go index dfaae939..999e8894 100644 --- a/utils.go +++ b/utils.go @@ -135,6 +135,73 @@ func indirect(reflectValue reflect.Value) reflect.Value { return reflectValue } +type Method struct { + index int + name string + ptr bool + valid bool +} + +func (m Method) Index() int { + return m.index +} + +func (m Method) Name() string { + return m.name +} + +func (m Method) Ptr() bool { + return m.ptr +} + +func (m Method) Valid() bool { + return m.valid +} + +func (m Method) TypeMethod(typ reflect.Type) reflect.Method { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if m.ptr { + typ = reflect.PtrTo(typ) + } + return typ.Method(m.index) +} + +func (m Method) ObjectMethod(object reflect.Value) reflect.Value { + object = indirect(object) + if m.ptr { + object = object.Addr() + } + return object.Method(m.index) +} + +func MethodByName(typ reflect.Type, name string) (m Method) { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + if typ.Kind() != reflect.Struct { + return + } + + if method, ok := typ.MethodByName(name); ok { + m.index = method.Index + m.name = name + m.valid = true + return + } + + if method, ok := reflect.PtrTo(typ).MethodByName(name); ok { + m.index = method.Index + m.name = name + m.ptr = true + m.valid = true + } + + return +} + func toQueryMarks(primaryValues [][]interface{}) string { var results []string @@ -283,3 +350,9 @@ func addExtraSpaceIfExist(str string) string { } return "" } + +func checkOrPanic(err error) { + if err != nil { + panic(err) + } +} \ No newline at end of file