diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..94a25f7f --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/field.go b/field.go index 11c410b0..161a2dc3 100644 --- a/field.go +++ b/field.go @@ -14,6 +14,15 @@ type Field struct { Field reflect.Value } +func (field *Field) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) { + field.StructField.CallMethodCallbackArgs(name, object, append([]reflect.Value{reflect.ValueOf(field)}, in...)) +} + +// Call the method callback +func (field *Field) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) { + field.CallMethodCallbackArgs(name, object, in) +} + // Set set a value to the field func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { diff --git a/main.go b/main.go index c26e05c8..22171e02 100644 --- a/main.go +++ b/main.go @@ -177,15 +177,6 @@ func (s *DB) QueryExpr() *expr { return Expr(scope.SQL, scope.SQLVars...) } -// SubQuery returns the query as sub query -func (s *DB) SubQuery() *expr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) -} - // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db @@ -775,3 +766,62 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) } } + +// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all +func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB { + key := "gorm:disable_after_scan" + + s = s.clone() + + if len(typs) == 0 { + s.values[key] = true + return s + } + + for _, typ := range typs { + rType := indirectType(reflect.TypeOf(typ)) + s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true + } + + return s +} + +// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all. +// The disabled types will not be enabled unless they are specifically informed. +func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB { + key := "gorm:disable_after_scan" + + s = s.clone() + + if len(typs) == 0 { + s.values[key] = false + return s + } + + for _, typ := range typs { + rType := indirectType(reflect.TypeOf(typ)) + s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false + } + + return s +} + +// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed +// typs. +func (s *DB) IsEnabledAfterScanCallback(typs ...interface{}) (ok bool) { + key := "gorm:disable_after_scan" + + if v, ok := s.values[key]; ok { + return !v.(bool) + } + + for _, typ := range typs { + rType := indirectType(reflect.TypeOf(typ)) + v, ok := s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] + if ok && v.(bool) { + return false + } + } + + return true +} diff --git a/methodcallback.go b/methodcallback.go new file mode 100644 index 00000000..6046f50b --- /dev/null +++ b/methodcallback.go @@ -0,0 +1,203 @@ +package gorm + +import ( + "reflect" + "fmt" + "sync" +) + +var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{})) + +type safeEnabledFieldTypes struct { + m map[reflect.Type]bool + l *sync.RWMutex +} + +func newSafeEnabledFieldTypes() safeEnabledFieldTypes { + return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)} +} + +func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.Lock() + defer s.l.Unlock() + s.m[k] = enabled + default: + s.Set(reflect.TypeOf(key), enabled) + } +} + +func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.RLock() + defer s.l.RUnlock() + enabled, ok = s.m[k] + return + default: + return s.Get(reflect.TypeOf(key)) + } +} + +func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.RLock() + defer s.l.RUnlock() + _, ok = s.m[k] + return + default: + return s.Has(reflect.TypeOf(key)) + } +} + +func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) { + switch k := key.(type) { + case reflect.Type: + k = indirectType(k) + s.l.Lock() + defer s.l.Unlock() + _, ok = s.m[k] + if ok { + delete(s.m, k) + } + return + default: + return s.Del(reflect.TypeOf(key)) + } +} + +type StructFieldMethodCallbacksRegistrator struct { + Callbacks map[string]reflect.Value + FieldTypes safeEnabledFieldTypes + l *sync.RWMutex +} + +// Register new field type and enable all available callbacks for here +func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) { + for _, typ := range typs { + if !registrator.FieldTypes.Has(typ) { + registrator.FieldTypes.Set(typ, true) + } + } +} + +// Unregister field type and return if ok +func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) { + return registrator.FieldTypes.Del(typ) +} + +// Enable all callbacks for field type +func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) { + for _, typ := range typs { + registrator.FieldTypes.Set(typ, true) + } +} + +// Disable all callbacks for field type +func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) { + for _, typ := range typs { + registrator.FieldTypes.Set(typ, false) + } +} + +// Return if all callbacks for field type is enabled +func (registrator *StructFieldMethodCallbacksRegistrator) IsEnabledFieldType(typ interface{}) bool { + if enabled, ok := registrator.FieldTypes.Get(typ); ok { + return enabled + } + return false +} + +// Return if field type is registered +func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool { + return registrator.FieldTypes.Has(typ) +} + +// Register new callback for fields have method methodName +func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error { + value := indirect(reflect.ValueOf(caller)) + + if value.Kind() != reflect.Func { + return fmt.Errorf("Caller of method %q isn't a function.", methodName) + } + + if value.Type().NumIn() < 2 { + return fmt.Errorf("The caller function %v for method %q require two args. Example: func(methodInfo *gorm.Method, method interface{}).", + value.Type(), methodName) + } + + if value.Type().In(0) != methodPtrType { + return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType) + } + + if value.Type().In(1).Kind() != reflect.Interface { + return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName) + } + + registrator.l.Lock() + defer registrator.l.Unlock() + registrator.Callbacks[methodName] = value + return nil +} + +// Register many callbacks where key is the methodName and value is a caller function. +func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error { + for i, m := range items { + for methodName, callback := range m { + err := registrator.registerCallback(methodName, callback) + if err != nil { + return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err) + } + } + } + return nil +} + +func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator { + return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(), + new(sync.RWMutex)} +} + +func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) { + switch method := method.(type) { + case func(): + method() + case func(*Scope): + method(scope) + case func(*Scope, *Field): + method(scope, field) + case func(*DB, *Field): + newDB := scope.NewDB() + method(newDB, field) + scope.Err(newDB.Error) + case func() error: + scope.Err(method()) + case func(*Scope) error: + scope.Err(method(scope)) + case func(*Scope, *Field) error: + scope.Err(method(scope, field)) + case func(*DB) error: + newDB := scope.NewDB() + scope.Err(method(newDB)) + scope.Err(newDB.Error) + default: + scope.Err(fmt.Errorf("Invalid AfterScan method callback %v of type %v", reflect.ValueOf(method).Type(), field.Struct.Type)) + } +} + +// StructFieldMethodCallbacks is a default registrator for model fields callbacks where the field type is a struct +// and have a callback method. +// Default methods callbacks: +// AfterScan: Call method `AfterScanMethodCallback` after scan field from sql row. +var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator() + +func init() { + checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{ + "AfterScan": AfterScanMethodCallback, + })) +} \ No newline at end of file diff --git a/methodcallback_test.go b/methodcallback_test.go new file mode 100644 index 00000000..be7c6525 --- /dev/null +++ b/methodcallback_test.go @@ -0,0 +1,127 @@ +package gorm_test + +import ( + "fmt" +) + +func init() { + +} + +func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() { + fmt.Println(`if registrator.IsEnabledFieldType(&Media{}) { + registrator.DisableFieldType(&Media{}) +}`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() { + fmt.Println(`if !registrator.IsEnabledFieldType(&Media{}) { + println("not enabled") +}`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() { + fmt.Println(`if !registrator.IsEnabledFieldType(&Media{}) { + registrator.EnableFieldType(&Media{}) +}`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() { + fmt.Println(` +if registrator.RegisteredFieldType(&Media{}) { + println("not registered") +}`) +} + +func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() { + fmt.Println("registrator.RegisterFieldType(&Media{})") +} + +func ExampleAfterScanMethodCallback() { + println(` +package main + +import ( + "reflect" + "github.com/jinzhu/gorm" + "database/sql/driver" + _ "github.com/jinzhu/gorm/dialects/sqlite" + "strconv" + "strings" +) + +type Media struct { + Name string + baseUrl *string + modelType reflect.Type + model interface { + GetID() int + } + fieldName *string +} + +func (image *Media) Scan(value interface{}) error { + image.Name = string(value.([]byte)) + return nil +} + +func (image *Media) Value() (driver.Value, error) { + return image.Name, nil +} + +func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) { + image.fieldName, image.model = &field.StructField.Name, scope.Value.(interface { + GetID() int + }) + baseUrl, _ := scope.DB().Get("base_url") + image.baseUrl = baseUrl.(*string) + image.modelType = reflect.TypeOf(scope.Value) + for image.modelType.Kind() == reflect.Ptr { + image.modelType = image.modelType.Elem() + } +} + +func (image *Media) URL() string { + return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()), + *image.fieldName, image.Name}, "/") +} + +type User struct { + ID int + MainImage Media +} + +func (user *User) GetID() int { + return user.ID +} + +func main() { + // register media type + gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{}) + + db, err := gorm.Open("sqlite3", "db.db") + if err != nil { + panic(err) + } + + db.AutoMigrate(&User{}) + + baseUrl := "http://example.com/media" + db = db.Set("base_url", &baseUrl) + + var model User + db_ := db.Where("id = ?", 1).First(&model) + if db_.RecordNotFound() { + db.Save(&User{MainImage: Media{Name: "picture.jpg"}}) + err = db.Where("id = ?", 1).First(&model).Error + if err != nil { + panic(err) + } + } else if db_.Error != nil { + panic(db_.Error) + } + + println("Media URL:", model.MainImage.URL()) +} +`) +} diff --git a/migration_test.go b/migration_test.go index 7c694485..e5b8ee55 100644 --- a/migration_test.go +++ b/migration_test.go @@ -14,6 +14,10 @@ import ( "github.com/jinzhu/gorm" ) +func init() { + gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{}) +} + type User struct { Id int64 Age int64 @@ -253,6 +257,95 @@ func (nt NullTime) Value() (driver.Value, error) { return nt.Time, nil } +type AfterScanFieldInterface interface { + CalledScopeIsNill() bool + CalledFieldIsNill() bool + Data() string +} + +type AfterScanFieldPtr struct { + data string + calledScopeNotIsNill bool + calledFieldNotIsNill bool +} + +func (s *AfterScanFieldPtr) Data() string { + return s.data +} + +func (s AfterScanFieldPtr) CalledScopeIsNill() bool { + return !s.calledScopeNotIsNill +} + +func (s AfterScanFieldPtr) CalledFieldIsNill() bool { + return !s.calledFieldNotIsNill +} + +func (s *AfterScanFieldPtr) Scan(value interface{}) error { + s.data = string(value.([]byte)) + return nil +} + +func (s AfterScanFieldPtr) Value() (driver.Value, error) { + return s.data, nil +} + +func (s *AfterScanFieldPtr) AfterScan(scope *gorm.Scope, field *gorm.Field) { + s.calledScopeNotIsNill = scope != nil + s.calledFieldNotIsNill = field != nil +} + +type AfterScanField struct { + data string + calledScopeNotIsNill bool +} + +func (s *AfterScanField) Data() string { + return s.data +} + +func (s AfterScanField) CalledScopeIsNill() bool { + return !s.calledScopeNotIsNill +} + +func (s AfterScanField) CalledFieldIsNill() bool { + return false +} + +func (s *AfterScanField) Scan(value interface{}) error { + s.data = string(value.([]byte)) + return nil +} + +func (s AfterScanField) Value() (driver.Value, error) { + return s.data, nil +} + +func (s *AfterScanField) AfterScan(scope *gorm.Scope) { + s.calledScopeNotIsNill = scope != nil +} + +type WithFieldAfterScanCallback struct { + ID int + Name1 *AfterScanFieldPtr + Name2 AfterScanFieldPtr + Name3 *AfterScanField + Name4 AfterScanField +} + +type InvalidAfterScanField struct { + AfterScanField +} + +func (s InvalidAfterScanField) AfterScan(invalidArg int) { +} + +type WithFieldAfterScanInvalidCallback struct { + ID int + Name InvalidAfterScanField +} + + func getPreparedUser(name string, role string) *User { var company Company DB.Where(Company{Name: role}).FirstOrCreate(&company) @@ -284,7 +377,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &WithFieldAfterScanCallback{}, &WithFieldAfterScanInvalidCallback{}} for _, value := range values { DB.DropTable(value) } diff --git a/model_struct.go b/model_struct.go index f571e2e8..c32b2e90 100644 --- a/model_struct.go +++ b/model_struct.go @@ -66,6 +66,15 @@ func (s *ModelStruct) TableName(db *DB) string { return DefaultTableNameHandler(db, s.defaultTableName) } +type StructFieldMethodCallback struct { + Method + Caller reflect.Value +} + +func (s StructFieldMethodCallback) Call(object reflect.Value, in []reflect.Value) { + s.Caller.Call(append([]reflect.Value{reflect.ValueOf(&s.Method), s.ObjectMethod(object)}, in...)) +} + // StructField model field's struct definition type StructField struct { DBName string @@ -81,6 +90,19 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + MethodCallbacks map[string]StructFieldMethodCallback +} + +// Call the method callback if exists by name. +func (structField *StructField) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) { + if callback, ok := structField.MethodCallbacks[name]; ok { + callback.Call(object, in) + } +} + +// Call the method callback if exists by name. the +func (structField *StructField) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) { + structField.CallMethodCallbackArgs(name, object, in) } func (structField *StructField) clone() *StructField { @@ -97,6 +119,7 @@ func (structField *StructField) clone() *StructField { TagSettings: map[string]string{}, Struct: structField.Struct, IsForeignKey: structField.IsForeignKey, + MethodCallbacks: structField.MethodCallbacks, } if structField.Relationship != nil { @@ -581,6 +604,15 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsNormal = true } } + + // register method callbacks now for improve performance + field.MethodCallbacks = make(map[string]StructFieldMethodCallback) + + for callbackName, caller := range StructFieldMethodCallbacks.Callbacks { + if callbackMethod := MethodByName(indirectType, callbackName); callbackMethod.valid { + field.MethodCallbacks[callbackName] = StructFieldMethodCallback{callbackMethod, caller} + } + } } // Even it is ignored, also possible to decode db value into the field diff --git a/scope.go b/scope.go index 150ac710..d9f27ace 100644 --- a/scope.go +++ b/scope.go @@ -238,7 +238,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { return errors.New("could not convert column to field") } -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one +// CallMethod call scope value's Method, if it is a slice, will call its element's Method one by one func (scope *Scope) CallMethod(methodName string) { if scope.Value == nil { return @@ -473,6 +473,27 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } +// call after field method callbacks +func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) { + if !scope.HasError() && scope.Value != nil { + if scope.DB().IsEnabledAfterScanCallback(scope.Value) { + scopeValue := reflect.ValueOf(scope) + for index, field := range scannerFields { + // if not is nill and if calbacks enabled for field type + if StructFieldMethodCallbacks.IsEnabledFieldType(field.Field.Type()) { + // not disabled on scan + if _, ok := disableScanField[index]; !ok { + if !isNil(field.Field) { + reflectValue := field.Field.Addr() + field.CallMethodCallback("AfterScan", reflectValue, scopeValue) + } + } + } + } + } + } +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} @@ -482,6 +503,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields = map[int]*Field{} ) + scannerFields := make(map[int]*Field) + for index, column := range columns { values[index] = &ignored @@ -492,6 +515,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { for fieldIndex, field := range selectFields { if field.DBName == column { + scannerFields[index] = field + if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { @@ -512,11 +537,18 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { scope.Err(rows.Scan(values...)) + disableScanField := make(map[int]bool) + for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) + reflectValue := reflect.ValueOf(values[index]).Elem().Elem() + if reflectValue.IsValid() { + field.Field.Set(reflectValue) + } else { + disableScanField[index] = true } } + + scope.afterScanCallback(scannerFields, disableScanField) } func (scope *Scope) primaryCondition(value interface{}) string { diff --git a/scope_test.go b/scope_test.go index 3018f350..1a1e3621 100644 --- a/scope_test.go +++ b/scope_test.go @@ -78,3 +78,137 @@ func TestFailedValuer(t *testing.T) { t.Errorf("The error should be returned from Valuer, but get %v", err) } } + +func TestAfterFieldScanCallback(t *testing.T) { + model := WithFieldAfterScanCallback{} + model.Name1 = &AfterScanFieldPtr{data: randName()} + model.Name2 = AfterScanFieldPtr{data: randName()} + model.Name3 = &AfterScanField{data: randName()} + model.Name4 = AfterScanField{data: randName()} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + var model2 WithFieldAfterScanCallback + if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + + dotest := func(i int, value string, field AfterScanFieldInterface) { + if field.CalledFieldIsNill() { + t.Errorf("Expected Name%v.calledField, but got nil", i) + } + + if field.CalledScopeIsNill() { + t.Errorf("Expected Name%v.calledScope, but got nil", i) + } + + if field.Data() != value { + t.Errorf("Expected Name%v.data %q, but got %q", i, value, field.Data()) + } + } + + dotest(1, model.Name1.data, model2.Name1) + dotest(2, model.Name2.data, &model2.Name2) + dotest(3, model.Name3.data, model2.Name3) + dotest(4, model.Name4.data, &model2.Name4) +} + +func TestAfterFieldScanDisableCallback(t *testing.T) { + model := WithFieldAfterScanCallback{} + model.Name1 = &AfterScanFieldPtr{data: randName()} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + run := func(typs ... interface{}) { + DB := DB.DisableAfterScanCallback(typs...) + var model2 WithFieldAfterScanCallback + if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { + t.Errorf("%v: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err) + } + + dotest := func(i int, field AfterScanFieldInterface) { + if !field.CalledFieldIsNill() { + t.Errorf("%v: Expected Name%v.calledField is nil", len(typs), i) + } + } + + dotest(1, model2.Name1) + } + + run() + run(model) +} + +func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) { + model := WithFieldAfterScanCallback{} + model.Name1 = &AfterScanFieldPtr{data: randName()} + model.Name2 = AfterScanFieldPtr{data: randName()} + model.Name3 = &AfterScanField{data: randName()} + model.Name4 = AfterScanField{data: randName()} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + enabled := func(i int, field AfterScanFieldInterface) { + if field.CalledScopeIsNill() { + t.Errorf("Expected Name%v.calledScope, but got nil", i) + } + } + + disabled := func(i int, field AfterScanFieldInterface) { + if !field.CalledScopeIsNill() { + t.Errorf("Expected Name%v.calledScope is not nil", i) + } + } + + gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{}) + + if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + disabled(1, model.Name1) + disabled(2, &model.Name2) + disabled(3, model.Name3) + disabled(4, &model.Name4) + + gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{}) + if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + enabled(1, model.Name1) + enabled(2, &model.Name2) + disabled(3, model.Name3) + disabled(4, &model.Name4) + + gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{}) + if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + enabled(1, model.Name1) + enabled(2, &model.Name2) + enabled(3, model.Name3) + enabled(4, &model.Name4) +} + +func TestAfterFieldScanInvalidCallback(t *testing.T) { + model := WithFieldAfterScanInvalidCallback{} + model.Name = InvalidAfterScanField{AfterScanField{data: randName()}} + + if err := DB.Save(&model).Error; err != nil { + t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err) + } + + var model2 WithFieldAfterScanInvalidCallback + if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil { + if !strings.Contains(err.Error(), "Invalid AfterScan method callback") { + t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err) + } + } else { + t.Errorf("Expected error, but got nil") + } +} diff --git a/utils.go b/utils.go index dfaae939..355225b2 100644 --- a/utils.go +++ b/utils.go @@ -135,6 +135,85 @@ func indirect(reflectValue reflect.Value) reflect.Value { return reflectValue } +func indirectType(reflectType reflect.Type) reflect.Type { + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + return reflectType +} + +func ptrToType(reflectType reflect.Type) reflect.Type { + reflectType = indirectType(reflectType) + return reflect.PtrTo(reflectType) +} + +type Method struct { + index int + name string + ptr bool + valid bool +} + +func (m Method) Index() int { + return m.index +} + +func (m Method) Name() string { + return m.name +} + +func (m Method) Ptr() bool { + return m.ptr +} + +func (m Method) Valid() bool { + return m.valid +} + +func (m Method) TypeMethod(typ reflect.Type) reflect.Method { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if m.ptr { + typ = reflect.PtrTo(typ) + } + return typ.Method(m.index) +} + +func (m Method) ObjectMethod(object reflect.Value) reflect.Value { + object = indirect(object) + if m.ptr { + object = object.Addr() + } + return object.Method(m.index) +} + +func MethodByName(typ reflect.Type, name string) (m Method) { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + if typ.Kind() != reflect.Struct { + return + } + + if method, ok := typ.MethodByName(name); ok { + m.index = method.Index + m.name = name + m.valid = true + return + } + + if method, ok := reflect.PtrTo(typ).MethodByName(name); ok { + m.index = method.Index + m.name = name + m.ptr = true + m.valid = true + } + + return +} + func toQueryMarks(primaryValues [][]interface{}) string { var results []string @@ -283,3 +362,20 @@ func addExtraSpaceIfExist(str string) string { } return "" } + +func checkOrPanic(err error) { + if err != nil { + panic(err) + } +} + +// check if value is nil +func isNil(value reflect.Value) bool { + if value.Kind() != reflect.Ptr { + return false + } + if value.Pointer() == 0 { + return true + } + return false +} \ No newline at end of file