Merge 7c9ed0b3118b5cae4ec75bab95217f5033ecb24c into 6ed508ec6a4ecb3531899a69cbc746ccf65a4166
This commit is contained in:
		
						commit
						2289ae7a7c
					
				
							
								
								
									
										6
									
								
								.idea/vcs.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.idea/vcs.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="VcsDirectoryMappings">
 | 
			
		||||
    <mapping directory="$PROJECT_DIR$" vcs="Git" />
 | 
			
		||||
  </component>
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										9
									
								
								field.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								field.go
									
									
									
									
									
								
							@ -14,6 +14,15 @@ type Field struct {
 | 
			
		||||
	Field   reflect.Value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (field *Field) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) {
 | 
			
		||||
	field.StructField.CallMethodCallbackArgs(name, object, append([]reflect.Value{reflect.ValueOf(field)}, in...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Call the method callback
 | 
			
		||||
func (field *Field) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) {
 | 
			
		||||
	field.CallMethodCallbackArgs(name, object, in)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Set set a value to the field
 | 
			
		||||
func (field *Field) Set(value interface{}) (err error) {
 | 
			
		||||
	if !field.Field.IsValid() {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										68
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										68
									
								
								main.go
									
									
									
									
									
								
							@ -177,15 +177,6 @@ func (s *DB) QueryExpr() *expr {
 | 
			
		||||
	return Expr(scope.SQL, scope.SQLVars...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SubQuery returns the query as sub query
 | 
			
		||||
func (s *DB) SubQuery() *expr {
 | 
			
		||||
	scope := s.NewScope(s.Value)
 | 
			
		||||
	scope.InstanceSet("skip_bindvar", true)
 | 
			
		||||
	scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
	return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
 | 
			
		||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
 | 
			
		||||
	return s.clone().search.Where(query, args...).db
 | 
			
		||||
@ -775,3 +766,62 @@ func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
 | 
			
		||||
		s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Disable after scan callback. If typs not is empty, disable for typs, other else, disable for all
 | 
			
		||||
func (s *DB) DisableAfterScanCallback(typs ...interface{}) *DB  {
 | 
			
		||||
	key := "gorm:disable_after_scan"
 | 
			
		||||
 | 
			
		||||
	s = s.clone()
 | 
			
		||||
 | 
			
		||||
	if len(typs) == 0 {
 | 
			
		||||
		s.values[key] = true
 | 
			
		||||
		return s
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, typ := range typs {
 | 
			
		||||
		rType := indirectType(reflect.TypeOf(typ))
 | 
			
		||||
		s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Enable after scan callback. If typs not is empty, enable for typs, other else, enable for all.
 | 
			
		||||
// The disabled types will not be enabled unless they are specifically informed.
 | 
			
		||||
func (s *DB) EnableAfterScanCallback(typs ...interface{}) *DB  {
 | 
			
		||||
	key := "gorm:disable_after_scan"
 | 
			
		||||
 | 
			
		||||
	s = s.clone()
 | 
			
		||||
 | 
			
		||||
	if len(typs) == 0 {
 | 
			
		||||
		s.values[key] = false
 | 
			
		||||
		return s
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, typ := range typs {
 | 
			
		||||
		rType := indirectType(reflect.TypeOf(typ))
 | 
			
		||||
		s.values[key + ":" + rType.PkgPath() + "." + rType.Name()] = false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return if after scan callbacks has be enable. If typs is empty, return default, other else, return for informed
 | 
			
		||||
// typs.
 | 
			
		||||
func (s *DB) IsEnabledAfterScanCallback(typs ...interface{}) (ok bool) {
 | 
			
		||||
	key := "gorm:disable_after_scan"
 | 
			
		||||
 | 
			
		||||
	if v, ok := s.values[key]; ok {
 | 
			
		||||
		return !v.(bool)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, typ := range typs {
 | 
			
		||||
		rType := indirectType(reflect.TypeOf(typ))
 | 
			
		||||
		v, ok := s.values[key + ":" + rType.PkgPath() + "." + rType.Name()]
 | 
			
		||||
		if ok && v.(bool) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										203
									
								
								methodcallback.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								methodcallback.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,203 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var methodPtrType = reflect.PtrTo(reflect.TypeOf(Method{}))
 | 
			
		||||
 | 
			
		||||
type safeEnabledFieldTypes struct {
 | 
			
		||||
	m map[reflect.Type]bool
 | 
			
		||||
	l *sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSafeEnabledFieldTypes() safeEnabledFieldTypes {
 | 
			
		||||
	return safeEnabledFieldTypes{make(map[reflect.Type]bool), new(sync.RWMutex)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *safeEnabledFieldTypes) Set(key interface{}, enabled bool) {
 | 
			
		||||
	switch k := key.(type) {
 | 
			
		||||
	case reflect.Type:
 | 
			
		||||
		k = indirectType(k)
 | 
			
		||||
		s.l.Lock()
 | 
			
		||||
		defer s.l.Unlock()
 | 
			
		||||
		s.m[k] = enabled
 | 
			
		||||
	default:
 | 
			
		||||
		s.Set(reflect.TypeOf(key), enabled)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *safeEnabledFieldTypes) Get(key interface{}) (enabled bool, ok bool) {
 | 
			
		||||
	switch k := key.(type) {
 | 
			
		||||
	case reflect.Type:
 | 
			
		||||
		k = indirectType(k)
 | 
			
		||||
		s.l.RLock()
 | 
			
		||||
		defer s.l.RUnlock()
 | 
			
		||||
		enabled, ok = s.m[k]
 | 
			
		||||
		return
 | 
			
		||||
	default:
 | 
			
		||||
		return s.Get(reflect.TypeOf(key))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *safeEnabledFieldTypes) Has(key interface{}) (ok bool) {
 | 
			
		||||
	switch k := key.(type) {
 | 
			
		||||
	case reflect.Type:
 | 
			
		||||
		k = indirectType(k)
 | 
			
		||||
		s.l.RLock()
 | 
			
		||||
		defer s.l.RUnlock()
 | 
			
		||||
		_, ok = s.m[k]
 | 
			
		||||
		return
 | 
			
		||||
	default:
 | 
			
		||||
		return s.Has(reflect.TypeOf(key))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *safeEnabledFieldTypes) Del(key interface{}) (ok bool) {
 | 
			
		||||
	switch k := key.(type) {
 | 
			
		||||
	case reflect.Type:
 | 
			
		||||
		k = indirectType(k)
 | 
			
		||||
		s.l.Lock()
 | 
			
		||||
		defer s.l.Unlock()
 | 
			
		||||
		_, ok = s.m[k]
 | 
			
		||||
		if ok {
 | 
			
		||||
			delete(s.m, k)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	default:
 | 
			
		||||
		return s.Del(reflect.TypeOf(key))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StructFieldMethodCallbacksRegistrator struct {
 | 
			
		||||
	Callbacks  map[string]reflect.Value
 | 
			
		||||
	FieldTypes safeEnabledFieldTypes
 | 
			
		||||
	l          *sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Register new field type and enable all available callbacks for here
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) RegisterFieldType(typs ...interface{}) {
 | 
			
		||||
	for _, typ := range typs {
 | 
			
		||||
		if !registrator.FieldTypes.Has(typ) {
 | 
			
		||||
			registrator.FieldTypes.Set(typ, true)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Unregister field type and return if ok
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) UnregisterFieldType(typ interface{}) (ok bool) {
 | 
			
		||||
	return registrator.FieldTypes.Del(typ)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Enable all callbacks for field type
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) EnableFieldType(typs ...interface{}) {
 | 
			
		||||
	for _, typ := range typs {
 | 
			
		||||
		registrator.FieldTypes.Set(typ, true)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Disable all callbacks for field type
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) DisableFieldType(typs ...interface{}) {
 | 
			
		||||
	for _, typ := range typs {
 | 
			
		||||
		registrator.FieldTypes.Set(typ, false)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return if all callbacks for field type is enabled
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) IsEnabledFieldType(typ interface{}) bool {
 | 
			
		||||
	if enabled, ok := registrator.FieldTypes.Get(typ); ok {
 | 
			
		||||
		return enabled
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return if field type is registered
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) RegisteredFieldType(typ interface{}) bool {
 | 
			
		||||
	return registrator.FieldTypes.Has(typ)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Register new callback for fields have method methodName
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallback(methodName string, caller interface{}) error {
 | 
			
		||||
	value := indirect(reflect.ValueOf(caller))
 | 
			
		||||
 | 
			
		||||
	if value.Kind() != reflect.Func {
 | 
			
		||||
		return fmt.Errorf("Caller of method %q isn't a function.", methodName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value.Type().NumIn() < 2 {
 | 
			
		||||
		return fmt.Errorf("The caller function %v for method %q require two args. Example: func(methodInfo *gorm.Method, method interface{}).",
 | 
			
		||||
			value.Type(), methodName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value.Type().In(0) != methodPtrType {
 | 
			
		||||
		return fmt.Errorf("First arg of caller %v for method %q isn't a %v type.", value.Type(), methodName, methodPtrType)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value.Type().In(1).Kind() != reflect.Interface {
 | 
			
		||||
		return fmt.Errorf("Second arg of caller %v for method %q isn't a interface{} type.", value.Type(), methodName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	registrator.l.Lock()
 | 
			
		||||
	defer registrator.l.Unlock()
 | 
			
		||||
	registrator.Callbacks[methodName] = value
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Register many callbacks where key is the methodName and value is a caller function.
 | 
			
		||||
func (registrator *StructFieldMethodCallbacksRegistrator) registerCallbackMany(items ...map[string]interface{}) error {
 | 
			
		||||
	for i, m := range items {
 | 
			
		||||
		for methodName, callback := range m {
 | 
			
		||||
			err := registrator.registerCallback(methodName, callback)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return fmt.Errorf("Register arg[%v][%q] failed: %v", i, methodName, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewStructFieldMethodCallbacksRegistrator() *StructFieldMethodCallbacksRegistrator {
 | 
			
		||||
	return &StructFieldMethodCallbacksRegistrator{make(map[string]reflect.Value), newSafeEnabledFieldTypes(),
 | 
			
		||||
		new(sync.RWMutex)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AfterScanMethodCallback(methodInfo *Method, method interface{}, field *Field, scope *Scope) {
 | 
			
		||||
	switch method := method.(type) {
 | 
			
		||||
	case func():
 | 
			
		||||
		method()
 | 
			
		||||
	case func(*Scope):
 | 
			
		||||
		method(scope)
 | 
			
		||||
	case func(*Scope, *Field):
 | 
			
		||||
		method(scope, field)
 | 
			
		||||
	case func(*DB, *Field):
 | 
			
		||||
		newDB := scope.NewDB()
 | 
			
		||||
		method(newDB, field)
 | 
			
		||||
		scope.Err(newDB.Error)
 | 
			
		||||
	case func() error:
 | 
			
		||||
		scope.Err(method())
 | 
			
		||||
	case func(*Scope) error:
 | 
			
		||||
		scope.Err(method(scope))
 | 
			
		||||
	case func(*Scope, *Field) error:
 | 
			
		||||
		scope.Err(method(scope, field))
 | 
			
		||||
	case func(*DB) error:
 | 
			
		||||
		newDB := scope.NewDB()
 | 
			
		||||
		scope.Err(method(newDB))
 | 
			
		||||
		scope.Err(newDB.Error)
 | 
			
		||||
	default:
 | 
			
		||||
		scope.Err(fmt.Errorf("Invalid AfterScan method callback %v of type %v", reflect.ValueOf(method).Type(), field.Struct.Type))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StructFieldMethodCallbacks is a default registrator for model fields callbacks where the field type is a struct
 | 
			
		||||
// and have a callback method.
 | 
			
		||||
// Default methods callbacks:
 | 
			
		||||
// AfterScan: Call method `AfterScanMethodCallback` after scan field from sql row.
 | 
			
		||||
var StructFieldMethodCallbacks = NewStructFieldMethodCallbacksRegistrator()
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	checkOrPanic(StructFieldMethodCallbacks.registerCallbackMany(map[string]interface{}{
 | 
			
		||||
		"AfterScan": AfterScanMethodCallback,
 | 
			
		||||
	}))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										127
									
								
								methodcallback_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								methodcallback_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,127 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExampleStructFieldMethodCallbacksRegistrator_DisableFieldType() {
 | 
			
		||||
	fmt.Println(`if registrator.IsEnabledFieldType(&Media{}) {
 | 
			
		||||
	registrator.DisableFieldType(&Media{})
 | 
			
		||||
}`)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExampleStructFieldMethodCallbacksRegistrator_EnabledFieldType() {
 | 
			
		||||
	fmt.Println(`if !registrator.IsEnabledFieldType(&Media{}) {
 | 
			
		||||
	println("not enabled")
 | 
			
		||||
}`)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExampleStructFieldMethodCallbacksRegistrator_EnableFieldType() {
 | 
			
		||||
	fmt.Println(`if !registrator.IsEnabledFieldType(&Media{}) {
 | 
			
		||||
	registrator.EnableFieldType(&Media{})
 | 
			
		||||
}`)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExampleStructFieldMethodCallbacksRegistrator_RegisteredFieldType() {
 | 
			
		||||
	fmt.Println(`
 | 
			
		||||
if registrator.RegisteredFieldType(&Media{}) {
 | 
			
		||||
	println("not registered")
 | 
			
		||||
}`)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExampleStructFieldMethodCallbacksRegistrator_RegisterFieldType() {
 | 
			
		||||
	fmt.Println("registrator.RegisterFieldType(&Media{})")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExampleAfterScanMethodCallback() {
 | 
			
		||||
	println(`
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	_ "github.com/jinzhu/gorm/dialects/sqlite"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Media struct {
 | 
			
		||||
	Name      string
 | 
			
		||||
	baseUrl   *string
 | 
			
		||||
	modelType reflect.Type
 | 
			
		||||
	model interface {
 | 
			
		||||
		GetID() int
 | 
			
		||||
	}
 | 
			
		||||
	fieldName *string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (image *Media) Scan(value interface{}) error {
 | 
			
		||||
	image.Name = string(value.([]byte))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (image *Media) Value() (driver.Value, error) {
 | 
			
		||||
	return image.Name, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (image *Media) AfterScan(scope *gorm.Scope, field *gorm.Field) {
 | 
			
		||||
	image.fieldName, image.model = &field.StructField.Name, scope.Value.(interface {
 | 
			
		||||
		GetID() int
 | 
			
		||||
	})
 | 
			
		||||
	baseUrl, _ := scope.DB().Get("base_url")
 | 
			
		||||
	image.baseUrl = baseUrl.(*string)
 | 
			
		||||
	image.modelType = reflect.TypeOf(scope.Value)
 | 
			
		||||
	for image.modelType.Kind() == reflect.Ptr {
 | 
			
		||||
		image.modelType = image.modelType.Elem()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (image *Media) URL() string {
 | 
			
		||||
	return strings.Join([]string{*image.baseUrl, image.modelType.Name(), strconv.Itoa(image.model.GetID()),
 | 
			
		||||
		*image.fieldName, image.Name}, "/")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type User struct {
 | 
			
		||||
	ID        int
 | 
			
		||||
	MainImage Media
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (user *User) GetID() int {
 | 
			
		||||
	return user.ID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	// register media type
 | 
			
		||||
	gorm.StructFieldMethodCallbacks.RegisterFieldType(&Media{})
 | 
			
		||||
 | 
			
		||||
	db, err := gorm.Open("sqlite3", "db.db")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.AutoMigrate(&User{})
 | 
			
		||||
 | 
			
		||||
	baseUrl := "http://example.com/media"
 | 
			
		||||
	db = db.Set("base_url", &baseUrl)
 | 
			
		||||
 | 
			
		||||
	var model User
 | 
			
		||||
	db_ := db.Where("id = ?", 1).First(&model)
 | 
			
		||||
	if db_.RecordNotFound() {
 | 
			
		||||
		db.Save(&User{MainImage: Media{Name: "picture.jpg"}})
 | 
			
		||||
		err = db.Where("id = ?", 1).First(&model).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			panic(err)
 | 
			
		||||
		}
 | 
			
		||||
	} else if db_.Error != nil {
 | 
			
		||||
		panic(db_.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	println("Media URL:", model.MainImage.URL())
 | 
			
		||||
}
 | 
			
		||||
`)
 | 
			
		||||
}
 | 
			
		||||
@ -14,6 +14,10 @@ import (
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init()  {
 | 
			
		||||
	gorm.StructFieldMethodCallbacks.RegisterFieldType(&AfterScanField{}, &AfterScanFieldPtr{}, &InvalidAfterScanField{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type User struct {
 | 
			
		||||
	Id                int64
 | 
			
		||||
	Age               int64
 | 
			
		||||
@ -253,6 +257,95 @@ func (nt NullTime) Value() (driver.Value, error) {
 | 
			
		||||
	return nt.Time, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AfterScanFieldInterface interface {
 | 
			
		||||
	CalledScopeIsNill() bool
 | 
			
		||||
	CalledFieldIsNill() bool
 | 
			
		||||
	Data() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AfterScanFieldPtr struct {
 | 
			
		||||
	data                 string
 | 
			
		||||
	calledScopeNotIsNill bool
 | 
			
		||||
	calledFieldNotIsNill bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AfterScanFieldPtr) Data() string {
 | 
			
		||||
	return s.data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AfterScanFieldPtr) CalledScopeIsNill() bool {
 | 
			
		||||
	return !s.calledScopeNotIsNill
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AfterScanFieldPtr) CalledFieldIsNill() bool {
 | 
			
		||||
	return !s.calledFieldNotIsNill
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AfterScanFieldPtr) Scan(value interface{}) error {
 | 
			
		||||
	s.data = string(value.([]byte))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AfterScanFieldPtr) Value() (driver.Value, error) {
 | 
			
		||||
	return s.data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AfterScanFieldPtr) AfterScan(scope *gorm.Scope, field *gorm.Field) {
 | 
			
		||||
	s.calledScopeNotIsNill = scope != nil
 | 
			
		||||
	s.calledFieldNotIsNill = field != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AfterScanField struct {
 | 
			
		||||
	data                 string
 | 
			
		||||
	calledScopeNotIsNill bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AfterScanField) Data() string {
 | 
			
		||||
	return s.data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AfterScanField) CalledScopeIsNill() bool {
 | 
			
		||||
	return !s.calledScopeNotIsNill
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AfterScanField) CalledFieldIsNill() bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AfterScanField) Scan(value interface{}) error {
 | 
			
		||||
	s.data = string(value.([]byte))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AfterScanField) Value() (driver.Value, error) {
 | 
			
		||||
	return s.data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AfterScanField) AfterScan(scope *gorm.Scope) {
 | 
			
		||||
	s.calledScopeNotIsNill = scope != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WithFieldAfterScanCallback struct {
 | 
			
		||||
	ID    int
 | 
			
		||||
	Name1 *AfterScanFieldPtr
 | 
			
		||||
	Name2 AfterScanFieldPtr
 | 
			
		||||
	Name3 *AfterScanField
 | 
			
		||||
	Name4 AfterScanField
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InvalidAfterScanField struct {
 | 
			
		||||
	AfterScanField
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s InvalidAfterScanField) AfterScan(invalidArg int) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WithFieldAfterScanInvalidCallback struct {
 | 
			
		||||
	ID    int
 | 
			
		||||
	Name InvalidAfterScanField
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
func getPreparedUser(name string, role string) *User {
 | 
			
		||||
	var company Company
 | 
			
		||||
	DB.Where(Company{Name: role}).FirstOrCreate(&company)
 | 
			
		||||
@ -284,7 +377,7 @@ func runMigration() {
 | 
			
		||||
		DB.Exec(fmt.Sprintf("drop table %v;", table))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
 | 
			
		||||
	values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &WithFieldAfterScanCallback{}, &WithFieldAfterScanInvalidCallback{}}
 | 
			
		||||
	for _, value := range values {
 | 
			
		||||
		DB.DropTable(value)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -66,6 +66,15 @@ func (s *ModelStruct) TableName(db *DB) string {
 | 
			
		||||
	return DefaultTableNameHandler(db, s.defaultTableName)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StructFieldMethodCallback struct {
 | 
			
		||||
	Method
 | 
			
		||||
	Caller reflect.Value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s StructFieldMethodCallback) Call(object reflect.Value, in []reflect.Value) {
 | 
			
		||||
	s.Caller.Call(append([]reflect.Value{reflect.ValueOf(&s.Method), s.ObjectMethod(object)}, in...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StructField model field's struct definition
 | 
			
		||||
type StructField struct {
 | 
			
		||||
	DBName          string
 | 
			
		||||
@ -81,6 +90,19 @@ type StructField struct {
 | 
			
		||||
	Struct          reflect.StructField
 | 
			
		||||
	IsForeignKey    bool
 | 
			
		||||
	Relationship    *Relationship
 | 
			
		||||
	MethodCallbacks map[string]StructFieldMethodCallback
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Call the method callback if exists by name.
 | 
			
		||||
func (structField *StructField) CallMethodCallbackArgs(name string, object reflect.Value, in []reflect.Value) {
 | 
			
		||||
	if callback, ok := structField.MethodCallbacks[name]; ok {
 | 
			
		||||
		callback.Call(object, in)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Call the method callback if exists by name. the
 | 
			
		||||
func (structField *StructField) CallMethodCallback(name string, object reflect.Value, in ...reflect.Value) {
 | 
			
		||||
	structField.CallMethodCallbackArgs(name, object, in)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (structField *StructField) clone() *StructField {
 | 
			
		||||
@ -97,6 +119,7 @@ func (structField *StructField) clone() *StructField {
 | 
			
		||||
		TagSettings:     map[string]string{},
 | 
			
		||||
		Struct:          structField.Struct,
 | 
			
		||||
		IsForeignKey:    structField.IsForeignKey,
 | 
			
		||||
		MethodCallbacks: structField.MethodCallbacks,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if structField.Relationship != nil {
 | 
			
		||||
@ -581,6 +604,15 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 | 
			
		||||
						field.IsNormal = true
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// register method callbacks now for improve performance
 | 
			
		||||
				field.MethodCallbacks = make(map[string]StructFieldMethodCallback)
 | 
			
		||||
 | 
			
		||||
				for callbackName, caller := range StructFieldMethodCallbacks.Callbacks {
 | 
			
		||||
					if callbackMethod := MethodByName(indirectType, callbackName); callbackMethod.valid {
 | 
			
		||||
						field.MethodCallbacks[callbackName] = StructFieldMethodCallback{callbackMethod, caller}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Even it is ignored, also possible to decode db value into the field
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										38
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								scope.go
									
									
									
									
									
								
							@ -238,7 +238,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
 | 
			
		||||
	return errors.New("could not convert column to field")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
 | 
			
		||||
// CallMethod call scope value's Method, if it is a slice, will call its element's Method one by one
 | 
			
		||||
func (scope *Scope) CallMethod(methodName string) {
 | 
			
		||||
	if scope.Value == nil {
 | 
			
		||||
		return
 | 
			
		||||
@ -473,6 +473,27 @@ func (scope *Scope) quoteIfPossible(str string) string {
 | 
			
		||||
	return str
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// call after field method callbacks
 | 
			
		||||
func (scope *Scope) afterScanCallback(scannerFields map[int]*Field, disableScanField map[int]bool) {
 | 
			
		||||
	if !scope.HasError() && scope.Value != nil {
 | 
			
		||||
		if scope.DB().IsEnabledAfterScanCallback(scope.Value) {
 | 
			
		||||
			scopeValue := reflect.ValueOf(scope)
 | 
			
		||||
			for index, field := range scannerFields {
 | 
			
		||||
				// if not is nill and if calbacks enabled for field type
 | 
			
		||||
				if StructFieldMethodCallbacks.IsEnabledFieldType(field.Field.Type()) {
 | 
			
		||||
					// not disabled on scan
 | 
			
		||||
					if _, ok := disableScanField[index]; !ok {
 | 
			
		||||
						if !isNil(field.Field) {
 | 
			
		||||
							reflectValue := field.Field.Addr()
 | 
			
		||||
							field.CallMethodCallback("AfterScan", reflectValue, scopeValue)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
 | 
			
		||||
	var (
 | 
			
		||||
		ignored            interface{}
 | 
			
		||||
@ -482,6 +503,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
 | 
			
		||||
		resetFields        = map[int]*Field{}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	scannerFields := make(map[int]*Field)
 | 
			
		||||
 | 
			
		||||
	for index, column := range columns {
 | 
			
		||||
		values[index] = &ignored
 | 
			
		||||
 | 
			
		||||
@ -492,6 +515,8 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
 | 
			
		||||
 | 
			
		||||
		for fieldIndex, field := range selectFields {
 | 
			
		||||
			if field.DBName == column {
 | 
			
		||||
				scannerFields[index] = field
 | 
			
		||||
 | 
			
		||||
				if field.Field.Kind() == reflect.Ptr {
 | 
			
		||||
					values[index] = field.Field.Addr().Interface()
 | 
			
		||||
				} else {
 | 
			
		||||
@ -512,11 +537,18 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
 | 
			
		||||
 | 
			
		||||
	scope.Err(rows.Scan(values...))
 | 
			
		||||
 | 
			
		||||
	disableScanField := make(map[int]bool)
 | 
			
		||||
 | 
			
		||||
	for index, field := range resetFields {
 | 
			
		||||
		if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
 | 
			
		||||
			field.Field.Set(v)
 | 
			
		||||
		reflectValue := reflect.ValueOf(values[index]).Elem().Elem()
 | 
			
		||||
		if reflectValue.IsValid() {
 | 
			
		||||
			field.Field.Set(reflectValue)
 | 
			
		||||
		} else {
 | 
			
		||||
			disableScanField[index] = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope.afterScanCallback(scannerFields, disableScanField)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) primaryCondition(value interface{}) string {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										134
									
								
								scope_test.go
									
									
									
									
									
								
							
							
						
						
									
										134
									
								
								scope_test.go
									
									
									
									
									
								
							@ -78,3 +78,137 @@ func TestFailedValuer(t *testing.T) {
 | 
			
		||||
		t.Errorf("The error should be returned from Valuer, but get %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAfterFieldScanCallback(t *testing.T) {
 | 
			
		||||
	model := WithFieldAfterScanCallback{}
 | 
			
		||||
	model.Name1 = &AfterScanFieldPtr{data: randName()}
 | 
			
		||||
	model.Name2 = AfterScanFieldPtr{data: randName()}
 | 
			
		||||
	model.Name3 = &AfterScanField{data: randName()}
 | 
			
		||||
	model.Name4 = AfterScanField{data: randName()}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var model2 WithFieldAfterScanCallback
 | 
			
		||||
	if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dotest := func(i int, value string, field AfterScanFieldInterface) {
 | 
			
		||||
		if field.CalledFieldIsNill() {
 | 
			
		||||
			t.Errorf("Expected Name%v.calledField, but got nil", i)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if field.CalledScopeIsNill() {
 | 
			
		||||
			t.Errorf("Expected Name%v.calledScope, but got nil", i)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if field.Data() != value {
 | 
			
		||||
			t.Errorf("Expected Name%v.data %q, but got %q", i, value, field.Data())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dotest(1, model.Name1.data, model2.Name1)
 | 
			
		||||
	dotest(2, model.Name2.data, &model2.Name2)
 | 
			
		||||
	dotest(3, model.Name3.data, model2.Name3)
 | 
			
		||||
	dotest(4, model.Name4.data, &model2.Name4)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAfterFieldScanDisableCallback(t *testing.T) {
 | 
			
		||||
	model := WithFieldAfterScanCallback{}
 | 
			
		||||
	model.Name1 = &AfterScanFieldPtr{data: randName()}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	run := func(typs ... interface{}) {
 | 
			
		||||
		DB := DB.DisableAfterScanCallback(typs...)
 | 
			
		||||
		var model2 WithFieldAfterScanCallback
 | 
			
		||||
		if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
 | 
			
		||||
			t.Errorf("%v: No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", len(typs), err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		dotest := func(i int, field AfterScanFieldInterface) {
 | 
			
		||||
			if !field.CalledFieldIsNill() {
 | 
			
		||||
				t.Errorf("%v: Expected Name%v.calledField is nil", len(typs), i)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		dotest(1, model2.Name1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	run()
 | 
			
		||||
	run(model)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAfterFieldScanCallbackTypeDisabled(t *testing.T) {
 | 
			
		||||
	model := WithFieldAfterScanCallback{}
 | 
			
		||||
	model.Name1 = &AfterScanFieldPtr{data: randName()}
 | 
			
		||||
	model.Name2 = AfterScanFieldPtr{data: randName()}
 | 
			
		||||
	model.Name3 = &AfterScanField{data: randName()}
 | 
			
		||||
	model.Name4 = AfterScanField{data: randName()}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	enabled := func(i int, field AfterScanFieldInterface) {
 | 
			
		||||
		if field.CalledScopeIsNill() {
 | 
			
		||||
			t.Errorf("Expected Name%v.calledScope, but got nil", i)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	disabled := func(i int, field AfterScanFieldInterface) {
 | 
			
		||||
		if !field.CalledScopeIsNill() {
 | 
			
		||||
			t.Errorf("Expected Name%v.calledScope is not nil", i)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gorm.StructFieldMethodCallbacks.DisableFieldType(&AfterScanFieldPtr{}, &AfterScanField{})
 | 
			
		||||
 | 
			
		||||
	if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	disabled(1, model.Name1)
 | 
			
		||||
	disabled(2, &model.Name2)
 | 
			
		||||
	disabled(3, model.Name3)
 | 
			
		||||
	disabled(4, &model.Name4)
 | 
			
		||||
 | 
			
		||||
	gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanFieldPtr{})
 | 
			
		||||
	if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	enabled(1, model.Name1)
 | 
			
		||||
	enabled(2, &model.Name2)
 | 
			
		||||
	disabled(3, model.Name3)
 | 
			
		||||
	disabled(4, &model.Name4)
 | 
			
		||||
 | 
			
		||||
	gorm.StructFieldMethodCallbacks.EnableFieldType(&AfterScanField{})
 | 
			
		||||
	if err := DB.Where("id = ?", model.ID).First(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	enabled(1, model.Name1)
 | 
			
		||||
	enabled(2, &model.Name2)
 | 
			
		||||
	enabled(3, model.Name3)
 | 
			
		||||
	enabled(4, &model.Name4)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAfterFieldScanInvalidCallback(t *testing.T) {
 | 
			
		||||
	model := WithFieldAfterScanInvalidCallback{}
 | 
			
		||||
	model.Name = InvalidAfterScanField{AfterScanField{data: randName()}}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&model).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when saving WithFieldAfterScanCallback, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var model2 WithFieldAfterScanInvalidCallback
 | 
			
		||||
	if err := DB.Where("id = ?", model.ID).First(&model2).Error; err != nil {
 | 
			
		||||
		if !strings.Contains(err.Error(), "Invalid AfterScan method callback") {
 | 
			
		||||
			t.Errorf("No error should happen when querying WithFieldAfterScanCallback with valuer, but got %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		t.Errorf("Expected error, but got nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										96
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										96
									
								
								utils.go
									
									
									
									
									
								
							@ -135,6 +135,85 @@ func indirect(reflectValue reflect.Value) reflect.Value {
 | 
			
		||||
	return reflectValue
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func indirectType(reflectType reflect.Type) reflect.Type {
 | 
			
		||||
	for reflectType.Kind() == reflect.Ptr {
 | 
			
		||||
		reflectType = reflectType.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	return reflectType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ptrToType(reflectType reflect.Type) reflect.Type {
 | 
			
		||||
	reflectType = indirectType(reflectType)
 | 
			
		||||
	return reflect.PtrTo(reflectType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Method struct {
 | 
			
		||||
	index int
 | 
			
		||||
	name  string
 | 
			
		||||
	ptr   bool
 | 
			
		||||
	valid bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Method) Index() int {
 | 
			
		||||
	return m.index
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Method) Name() string {
 | 
			
		||||
	return m.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Method) Ptr() bool {
 | 
			
		||||
	return m.ptr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Method) Valid() bool {
 | 
			
		||||
	return m.valid
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Method) TypeMethod(typ reflect.Type) reflect.Method {
 | 
			
		||||
	for typ.Kind() == reflect.Ptr {
 | 
			
		||||
		typ = typ.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	if m.ptr {
 | 
			
		||||
		typ = reflect.PtrTo(typ)
 | 
			
		||||
	}
 | 
			
		||||
	return typ.Method(m.index)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Method) ObjectMethod(object reflect.Value) reflect.Value {
 | 
			
		||||
	object = indirect(object)
 | 
			
		||||
	if m.ptr {
 | 
			
		||||
		object = object.Addr()
 | 
			
		||||
	}
 | 
			
		||||
	return object.Method(m.index)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MethodByName(typ reflect.Type, name string) (m Method) {
 | 
			
		||||
	for typ.Kind() == reflect.Ptr {
 | 
			
		||||
		typ = typ.Elem()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if typ.Kind() != reflect.Struct {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if method, ok := typ.MethodByName(name); ok {
 | 
			
		||||
		m.index = method.Index
 | 
			
		||||
		m.name = name
 | 
			
		||||
		m.valid = true
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if method, ok := reflect.PtrTo(typ).MethodByName(name); ok {
 | 
			
		||||
		m.index = method.Index
 | 
			
		||||
		m.name = name
 | 
			
		||||
		m.ptr = true
 | 
			
		||||
		m.valid = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toQueryMarks(primaryValues [][]interface{}) string {
 | 
			
		||||
	var results []string
 | 
			
		||||
 | 
			
		||||
@ -283,3 +362,20 @@ func addExtraSpaceIfExist(str string) string {
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkOrPanic(err error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// check if value is nil
 | 
			
		||||
func isNil(value reflect.Value) bool {
 | 
			
		||||
	if value.Kind() != reflect.Ptr {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if value.Pointer() == 0 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user