Added Scope After Scan Method Callback
This commit is contained in:
parent
55945afb34
commit
c97b650e0e
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
9
field.go
9
field.go
@ -14,6 +14,15 @@ type Field struct {
|
|||||||
Field reflect.Value
|
Field reflect.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (field *Field) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) {
|
||||||
|
field.StructField.CallMethodCallbackArgs(name, object, append([]reflect.Value{reflect.ValueOf(field)}, in...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the method callback
|
||||||
|
func (field *Field) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) {
|
||||||
|
field.CallMethodCallbackArgs(name, object, in)
|
||||||
|
}
|
||||||
|
|
||||||
// Set set a value to the field
|
// Set set a value to the field
|
||||||
func (field *Field) Set(value interface{}) (err error) {
|
func (field *Field) Set(value interface{}) (err error) {
|
||||||
if !field.Field.IsValid() {
|
if !field.Field.IsValid() {
|
||||||
|
92
methodcallback.go
Normal file
92
methodcallback.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var interfaceType = reflect.TypeOf(func(a interface{}) {}).In(0)
|
||||||
|
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
|
||||||
|
|
||||||
|
type StructFieldMethodCallbacksRegistrator struct {
|
||||||
|
Callbacks map[string]reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) Register(methodName string, caller interface{}) error {
|
||||||
|
value := reflect.ValueOf(caller)
|
||||||
|
|
||||||
|
if value.Kind() != reflect.Func {
|
||||||
|
return fmt.Errorf("Caller of method %q isn't a function.", methodName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Type().NumIn() < 2 {
|
||||||
|
return fmt.Errorf("The caller function %v for method %q require two args. Example: func(methodInfo *gorm.Method, method interface{}).",
|
||||||
|
value.Type(), methodName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Type().In(0) != methodPtrType {
|
||||||
|
return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Type().In(1) != interfaceType {
|
||||||
|
return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
|
||||||
|
}
|
||||||
|
|
||||||
|
registrator.Callbacks[methodName] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterMany(items ...map[string]interface{}) error {
|
||||||
|
for i, m := range items {
|
||||||
|
for methodName, callback := range m {
|
||||||
|
err := registrator.Register(methodName, callback)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
|
||||||
|
return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
|
||||||
|
switch method := method.(type) {
|
||||||
|
case func():
|
||||||
|
method()
|
||||||
|
case func(*Scope):
|
||||||
|
method(scope)
|
||||||
|
case func(*Scope, *Field):
|
||||||
|
method(scope, field)
|
||||||
|
case func(*DB, *Field):
|
||||||
|
newDB := scope.NewDB()
|
||||||
|
method(newDB, field)
|
||||||
|
scope.Err(newDB.Error)
|
||||||
|
case func() error:
|
||||||
|
scope.Err(method())
|
||||||
|
case func(*Scope) error:
|
||||||
|
scope.Err(method(scope))
|
||||||
|
case func(*Scope, *Field) error:
|
||||||
|
scope.Err(method(scope, field))
|
||||||
|
case func(*DB) error:
|
||||||
|
newDB := scope.NewDB()
|
||||||
|
scope.Err(method(newDB))
|
||||||
|
scope.Err(newDB.Error)
|
||||||
|
default:
|
||||||
|
scope.Err(fmt.Errorf("Invalid AfterScan method callback %v of type %v", reflect.ValueOf(method).Type(), field.Struct.Type))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructFieldMethodCallbacks is a default registrator for model fields callbacks where the field type is a struct
|
||||||
|
// and have a callback method.
|
||||||
|
// Default methods callbacks:
|
||||||
|
// AfterScan: Call method `AfterScanMethodCallback` after scan field from sql row.
|
||||||
|
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
checkOrPanic(StructFieldMethodCallbacks.RegisterMany(map[string]interface{}{
|
||||||
|
"AfterScan": AfterScanMethodCallback,
|
||||||
|
}))
|
||||||
|
}
|
84
methodcallback_test.go
Normal file
84
methodcallback_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func ExampleAfterScanMethodCallback() {
|
||||||
|
fmt.Println(`package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"database/sql/driver"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Media struct {
|
||||||
|
Name string
|
||||||
|
baseUrl *string
|
||||||
|
modelType reflect.Type
|
||||||
|
model interface {
|
||||||
|
GetID() int
|
||||||
|
}
|
||||||
|
fieldName *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (image *Media) Scan(value interface{}) error {
|
||||||
|
image.Name = string(value.([]byte))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (image *Media) Value() (driver.Value, error) {
|
||||||
|
return image.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) {
|
||||||
|
image.fieldName, image.model = &field.StructField.Name, scope.Value.(interface {
|
||||||
|
GetID() int
|
||||||
|
})
|
||||||
|
baseUrl, _ := scope.DB().Get("base_url")
|
||||||
|
image.baseUrl = baseUrl.(*string)
|
||||||
|
image.modelType = reflect.TypeOf(scope.Value)
|
||||||
|
for image.modelType.Kind() == reflect.Ptr {
|
||||||
|
image.modelType = image.modelType.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (image *Media) URL() string {
|
||||||
|
return fmt.Sprintf("%v/%v/%v/%v/%v", *image.baseUrl, image.modelType.Name(), image.model.GetID(), *image.fieldName, image.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
MainImage Media
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) GetID() int {
|
||||||
|
return user.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db, err := gorm.Open("sqlite3", "db.db")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
db.AutoMigrate(&User{})
|
||||||
|
|
||||||
|
baseUrl := "http://example.com/media"
|
||||||
|
db = db.Set("base_url", &baseUrl)
|
||||||
|
|
||||||
|
var model User
|
||||||
|
db_ := db.Where("id = ?", 1).First(&model)
|
||||||
|
if db_.RecordNotFound() {
|
||||||
|
db.Save(&User{MainImage: Media{Name: "picture.jpg"}})
|
||||||
|
err = db.Where("id = ?", 1).First(&model).Error
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
} else if db_.Error != nil {
|
||||||
|
panic(db_.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(model.MainImage.URL())
|
||||||
|
}`)
|
||||||
|
}
|
@ -253,6 +253,95 @@ func (nt NullTime) Value() (driver.Value, error) {
|
|||||||
return nt.Time, nil
|
return nt.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AfterScanFieldInterface interface {
|
||||||
|
CalledScopeIsNill() bool
|
||||||
|
CalledFieldIsNill() bool
|
||||||
|
Data() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterScanFieldPtr struct {
|
||||||
|
data string
|
||||||
|
calledScopeNotIsNill bool
|
||||||
|
calledFieldNotIsNill bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AfterScanFieldPtr) Data() string {
|
||||||
|
return s.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AfterScanFieldPtr) CalledScopeIsNill() bool {
|
||||||
|
return !s.calledScopeNotIsNill
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AfterScanFieldPtr) CalledFieldIsNill() bool {
|
||||||
|
return !s.calledFieldNotIsNill
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AfterScanFieldPtr) Scan(value interface{}) error {
|
||||||
|
s.data = string(value.([]byte))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AfterScanFieldPtr) Value() (driver.Value, error) {
|
||||||
|
return s.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AfterScanFieldPtr) AfterScan(scope *gorm.Scope, field *gorm.Field) {
|
||||||
|
s.calledScopeNotIsNill = scope != nil
|
||||||
|
s.calledFieldNotIsNill = field != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterScanField struct {
|
||||||
|
data string
|
||||||
|
calledScopeNotIsNill bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AfterScanField) Data() string {
|
||||||
|
return s.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AfterScanField) CalledScopeIsNill() bool {
|
||||||
|
return !s.calledScopeNotIsNill
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AfterScanField) CalledFieldIsNill() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AfterScanField) Scan(value interface{}) error {
|
||||||
|
s.data = string(value.([]byte))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AfterScanField) Value() (driver.Value, error) {
|
||||||
|
return s.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AfterScanField) AfterScan(scope *gorm.Scope) {
|
||||||
|
s.calledScopeNotIsNill = scope != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type WithFieldAfterScanCallback struct {
|
||||||
|
ID int
|
||||||
|
Name1 *AfterScanFieldPtr
|
||||||
|
Name2 AfterScanFieldPtr
|
||||||
|
Name3 *AfterScanField
|
||||||
|
Name4 AfterScanField
|
||||||
|
}
|
||||||
|
|
||||||
|
type InvalidAfterScanField struct {
|
||||||
|
AfterScanField
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s InvalidAfterScanField) AfterScan(invalidArg int) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type WithFieldAfterScanInvalidCallback struct {
|
||||||
|
ID int
|
||||||
|
Name InvalidAfterScanField
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func getPreparedUser(name string, role string) *User {
|
func getPreparedUser(name string, role string) *User {
|
||||||
var company Company
|
var company Company
|
||||||
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||||
@ -284,7 +373,7 @@ func runMigration() {
|
|||||||
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
|
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &WithFieldAfterScanCallback{}, &WithFieldAfterScanInvalidCallback{}}
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
DB.DropTable(value)
|
DB.DropTable(value)
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,15 @@ func (s *ModelStruct) TableName(db *DB) string {
|
|||||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StructFieldMethodCallback struct {
|
||||||
|
Method
|
||||||
|
Caller reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s StructFieldMethodCallback) Call(object reflect.Value, in []reflect.Value) {
|
||||||
|
s.Caller.Call(append([]reflect.Value{reflect.ValueOf(&s.Method), s.ObjectMethod(object)}, in...))
|
||||||
|
}
|
||||||
|
|
||||||
// StructField model field's struct definition
|
// StructField model field's struct definition
|
||||||
type StructField struct {
|
type StructField struct {
|
||||||
DBName string
|
DBName string
|
||||||
@ -81,6 +90,19 @@ type StructField struct {
|
|||||||
Struct reflect.StructField
|
Struct reflect.StructField
|
||||||
IsForeignKey bool
|
IsForeignKey bool
|
||||||
Relationship *Relationship
|
Relationship *Relationship
|
||||||
|
MethodCallbacks map[string]StructFieldMethodCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the method callback if exists by name.
|
||||||
|
func (structField *StructField) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) {
|
||||||
|
if callback, ok := structField.MethodCallbacks[name]; ok {
|
||||||
|
callback.Call(object, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the method callback if exists by name. the
|
||||||
|
func (structField *StructField) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) {
|
||||||
|
structField.CallMethodCallbackArgs(name, object, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (structField *StructField) clone() *StructField {
|
func (structField *StructField) clone() *StructField {
|
||||||
@ -97,6 +119,7 @@ func (structField *StructField) clone() *StructField {
|
|||||||
TagSettings: map[string]string{},
|
TagSettings: map[string]string{},
|
||||||
Struct: structField.Struct,
|
Struct: structField.Struct,
|
||||||
IsForeignKey: structField.IsForeignKey,
|
IsForeignKey: structField.IsForeignKey,
|
||||||
|
MethodCallbacks: structField.MethodCallbacks,
|
||||||
}
|
}
|
||||||
|
|
||||||
if structField.Relationship != nil {
|
if structField.Relationship != nil {
|
||||||
@ -581,6 +604,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
|
||||||
|
|
||||||
|
for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {
|
||||||
|
if callbackMethod := MethodByName(indirectType, callbackName); callbackMethod.valid {
|
||||||
|
field.MethodCallbacks[callbackName] = StructFieldMethodCallback{callbackMethod, caller}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Even it is ignored, also possible to decode db value into the field
|
// Even it is ignored, also possible to decode db value into the field
|
||||||
|
33
scope.go
33
scope.go
@ -238,7 +238,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
|||||||
return errors.New("could not convert column to field")
|
return errors.New("could not convert column to field")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
|
// CallMethod call scope value's Method, if it is a slice, will call its element's Method one by one
|
||||||
func (scope *Scope) CallMethod(methodName string) {
|
func (scope *Scope) CallMethod(methodName string) {
|
||||||
if scope.Value == nil {
|
if scope.Value == nil {
|
||||||
return
|
return
|
||||||
@ -482,6 +482,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
resetFields = map[int]*Field{}
|
resetFields = map[int]*Field{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scannerFields := make(map[int]*Field)
|
||||||
|
|
||||||
for index, column := range columns {
|
for index, column := range columns {
|
||||||
values[index] = &ignored
|
values[index] = &ignored
|
||||||
|
|
||||||
@ -492,6 +494,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
|
|
||||||
for fieldIndex, field := range selectFields {
|
for fieldIndex, field := range selectFields {
|
||||||
if field.DBName == column {
|
if field.DBName == column {
|
||||||
|
scannerFields[index] = field
|
||||||
|
|
||||||
if field.Field.Kind() == reflect.Ptr {
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
values[index] = field.Field.Addr().Interface()
|
values[index] = field.Field.Addr().Interface()
|
||||||
} else {
|
} else {
|
||||||
@ -512,9 +516,32 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
|
|
||||||
scope.Err(rows.Scan(values...))
|
scope.Err(rows.Scan(values...))
|
||||||
|
|
||||||
|
disableScanField := make(map[int]bool)
|
||||||
|
|
||||||
for index, field := range resetFields {
|
for index, field := range resetFields {
|
||||||
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
reflectValue := reflect.ValueOf(values[index]).Elem().Elem()
|
||||||
field.Field.Set(v)
|
if reflectValue.IsValid() {
|
||||||
|
field.Field.Set(reflectValue)
|
||||||
|
} else {
|
||||||
|
disableScanField[index] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.HasError() {
|
||||||
|
key := "gorm:disable_after_scan"
|
||||||
|
if v, ok := scope.Get(key); !ok || !v.(bool) {
|
||||||
|
valueType := indirect(reflect.ValueOf(scope.Value)).Type()
|
||||||
|
if v, ok := scope.Get(key + ":" + valueType.PkgPath() + "." + valueType.Name()); !ok || !v.(bool) {
|
||||||
|
scopeValue := reflect.ValueOf(scope)
|
||||||
|
for index, field := range scannerFields {
|
||||||
|
if _, ok := disableScanField[index]; !ok {
|
||||||
|
if field.Field.Kind() == reflect.Struct || (field.Field.Kind() == reflect.Ptr && field.Field.Elem().Kind() == reflect.Struct) {
|
||||||
|
reflectValue := field.Field.Addr()
|
||||||
|
field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
||||||
@ -78,3 +79,90 @@ func TestFailedValuer(t *testing.T) {
|
|||||||
t.Errorf("The error should be returned from Valuer, but get %v", err)
|
t.Errorf("The error should be returned from Valuer, but get %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAfterFieldScanCallback(t *testing.T) {
|
||||||
|
model := WithFieldAfterScanCallback{}
|
||||||
|
model.Name1 = &AfterScanFieldPtr{data: randName()}
|
||||||
|
model.Name2 = AfterScanFieldPtr{data: randName()}
|
||||||
|
model.Name3 = &AfterScanField{data: randName()}
|
||||||
|
model.Name4 = AfterScanField{data: randName()}
|
||||||
|
|
||||||
|
if err := DB.Save(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var model2 WithFieldAfterScanCallback
|
||||||
|
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dotest := func(i int, value string, field AfterScanFieldInterface) {
|
||||||
|
if field.CalledFieldIsNill() {
|
||||||
|
t.Errorf("Expected Name%v.calledField, but got nil", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.CalledScopeIsNill() {
|
||||||
|
t.Errorf("Expected Name%v.calledScope, but got nil", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.Data() != value {
|
||||||
|
t.Errorf("Expected Name%v.data %q, but got %q", i, value, field.Data())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dotest(1, model.Name1.data, model2.Name1)
|
||||||
|
dotest(2, model.Name2.data, &model2.Name2)
|
||||||
|
dotest(3, model.Name3.data, model2.Name3)
|
||||||
|
dotest(4, model.Name4.data, &model2.Name4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAfterFieldScanDisableCallback(t *testing.T) {
|
||||||
|
model := WithFieldAfterScanCallback{}
|
||||||
|
model.Name1 = &AfterScanFieldPtr{data: randName()}
|
||||||
|
|
||||||
|
if err := DB.Save(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
run := func(key string) {
|
||||||
|
DB := DB.Set(key, true)
|
||||||
|
var model2 WithFieldAfterScanCallback
|
||||||
|
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
|
||||||
|
t.Errorf("%q: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dotest := func(i int, value string, field AfterScanFieldInterface) {
|
||||||
|
if !field.CalledFieldIsNill() {
|
||||||
|
t.Errorf("%q: Expected Name%v.calledField is not nil", key, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !field.CalledScopeIsNill() {
|
||||||
|
t.Errorf("%q: Expected Name%v.calledScope is not nil", key, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dotest(1, model.Name1.data, model2.Name1)
|
||||||
|
}
|
||||||
|
|
||||||
|
run("gorm:disable_after_scan")
|
||||||
|
typ := reflect.ValueOf(model).Type()
|
||||||
|
run("gorm:disable_after_scan:" + typ.PkgPath() + "." + typ.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAfterFieldScanInvalidCallback(t *testing.T) {
|
||||||
|
model := WithFieldAfterScanInvalidCallback{}
|
||||||
|
model.Name = InvalidAfterScanField{AfterScanField{data: randName()}}
|
||||||
|
|
||||||
|
if err := DB.Save(&model).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var model2 WithFieldAfterScanInvalidCallback
|
||||||
|
if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "Invalid AfterScan method callback") {
|
||||||
|
t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("Expected error, but got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
73
utils.go
73
utils.go
@ -135,6 +135,73 @@ func indirect(reflectValue reflect.Value) reflect.Value {
|
|||||||
return reflectValue
|
return reflectValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Method struct {
|
||||||
|
index int
|
||||||
|
name string
|
||||||
|
ptr bool
|
||||||
|
valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) Index() int {
|
||||||
|
return m.index
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) Ptr() bool {
|
||||||
|
return m.ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) Valid() bool {
|
||||||
|
return m.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) TypeMethod(typ reflect.Type) reflect.Method {
|
||||||
|
for typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
|
}
|
||||||
|
if m.ptr {
|
||||||
|
typ = reflect.PtrTo(typ)
|
||||||
|
}
|
||||||
|
return typ.Method(m.index)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) ObjectMethod(object reflect.Value) reflect.Value {
|
||||||
|
object = indirect(object)
|
||||||
|
if m.ptr {
|
||||||
|
object = object.Addr()
|
||||||
|
}
|
||||||
|
return object.Method(m.index)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MethodByName(typ reflect.Type, name string) (m Method) {
|
||||||
|
for typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if typ.Kind() != reflect.Struct {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if method, ok := typ.MethodByName(name); ok {
|
||||||
|
m.index = method.Index
|
||||||
|
m.name = name
|
||||||
|
m.valid = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if method, ok := reflect.PtrTo(typ).MethodByName(name); ok {
|
||||||
|
m.index = method.Index
|
||||||
|
m.name = name
|
||||||
|
m.ptr = true
|
||||||
|
m.valid = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||||
var results []string
|
var results []string
|
||||||
|
|
||||||
@ -283,3 +350,9 @@ func addExtraSpaceIfExist(str string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkOrPanic(err error) {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user