Refactor structure
This commit is contained in:
		
							parent
							
								
									24ed796198
								
							
						
					
					
						commit
						8b567b49d0
					
				@ -1,9 +1,10 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "log"
 | 
			
		||||
import (
 | 
			
		||||
	"log"
 | 
			
		||||
 | 
			
		||||
// DefaultCallback default callbacks defined by gorm
 | 
			
		||||
var DefaultCallback = &Callback{}
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Callback is a struct that contains all CRUD callbacks
 | 
			
		||||
//   Field `creates` contains callbacks will be call when creating object
 | 
			
		||||
@ -13,23 +14,23 @@ var DefaultCallback = &Callback{}
 | 
			
		||||
//   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
 | 
			
		||||
//   Field `processors` contains all callback processors, will be used to generate above callbacks in order
 | 
			
		||||
type Callback struct {
 | 
			
		||||
	creates    []*func(scope *Scope)
 | 
			
		||||
	updates    []*func(scope *Scope)
 | 
			
		||||
	deletes    []*func(scope *Scope)
 | 
			
		||||
	queries    []*func(scope *Scope)
 | 
			
		||||
	rowQueries []*func(scope *Scope)
 | 
			
		||||
	creates    []*func(*gorm.DB)
 | 
			
		||||
	updates    []*func(*gorm.DB)
 | 
			
		||||
	deletes    []*func(*gorm.DB)
 | 
			
		||||
	queries    []*func(*gorm.DB)
 | 
			
		||||
	rowQueries []*func(*gorm.DB)
 | 
			
		||||
	processors []*CallbackProcessor
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CallbackProcessor contains callback informations
 | 
			
		||||
type CallbackProcessor struct {
 | 
			
		||||
	name      string              // current callback's name
 | 
			
		||||
	before    string              // register current callback before a callback
 | 
			
		||||
	after     string              // register current callback after a callback
 | 
			
		||||
	replace   bool                // replace callbacks with same name
 | 
			
		||||
	remove    bool                // delete callbacks with same name
 | 
			
		||||
	kind      string              // callback type: create, update, delete, query, row_query
 | 
			
		||||
	processor *func(scope *Scope) // callback handler
 | 
			
		||||
	name      string          // current callback's name
 | 
			
		||||
	before    string          // register current callback before a callback
 | 
			
		||||
	after     string          // register current callback after a callback
 | 
			
		||||
	replace   bool            // replace callbacks with same name
 | 
			
		||||
	remove    bool            // delete callbacks with same name
 | 
			
		||||
	kind      string          // callback type: create, update, delete, query, row_query
 | 
			
		||||
	processor *func(*gorm.DB) // callback handler
 | 
			
		||||
	parent    *Callback
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -45,7 +46,7 @@ func (c *Callback) clone() *Callback {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create could be used to register callbacks for creating object
 | 
			
		||||
//     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
 | 
			
		||||
//     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*gorm.DB) {
 | 
			
		||||
//       // business logic
 | 
			
		||||
//       ...
 | 
			
		||||
//
 | 
			
		||||
@ -90,7 +91,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Register a new callback, refer `Callbacks.Create`
 | 
			
		||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
 | 
			
		||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(*gorm.DB)) {
 | 
			
		||||
	if cp.kind == "row_query" {
 | 
			
		||||
		if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
 | 
			
		||||
			log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
 | 
			
		||||
@ -107,7 +108,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
 | 
			
		||||
// Remove a registered callback
 | 
			
		||||
//     db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
 | 
			
		||||
func (cp *CallbackProcessor) Remove(callbackName string) {
 | 
			
		||||
	log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
 | 
			
		||||
	log.Printf("[info] removing callback `%v` from %v\n", callbackName, utils.FileWithLineNum())
 | 
			
		||||
	cp.name = callbackName
 | 
			
		||||
	cp.remove = true
 | 
			
		||||
	cp.parent.processors = append(cp.parent.processors, cp)
 | 
			
		||||
@ -115,12 +116,12 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Replace a registered callback with new callback
 | 
			
		||||
//     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
 | 
			
		||||
//     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*gorm.DB) {
 | 
			
		||||
//		   scope.SetColumn("Created", now)
 | 
			
		||||
//		   scope.SetColumn("Updated", now)
 | 
			
		||||
//     })
 | 
			
		||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
 | 
			
		||||
	log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
 | 
			
		||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(*gorm.DB)) {
 | 
			
		||||
	log.Printf("[info] replacing callback `%v` from %v\n", callbackName, utils.FileWithLineNum())
 | 
			
		||||
	cp.name = callbackName
 | 
			
		||||
	cp.processor = &callback
 | 
			
		||||
	cp.replace = true
 | 
			
		||||
@ -130,7 +131,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
 | 
			
		||||
 | 
			
		||||
// Get registered callback
 | 
			
		||||
//    db.Callback().Create().Get("gorm:create")
 | 
			
		||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
 | 
			
		||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(*gorm.DB)) {
 | 
			
		||||
	for _, p := range cp.parent.processors {
 | 
			
		||||
		if p.name == callbackName && p.kind == cp.kind && !cp.remove {
 | 
			
		||||
			return *p.processor
 | 
			
		||||
@ -150,7 +151,7 @@ func getRIndex(strs []string, str string) int {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// sortProcessors sort callback processors based on its before, after, remove, replace
 | 
			
		||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
 | 
			
		||||
func sortProcessors(cps []*CallbackProcessor) []*func(*gorm.DB) {
 | 
			
		||||
	var (
 | 
			
		||||
		allNames, sortedNames []string
 | 
			
		||||
		sortCallbackProcessor func(c *CallbackProcessor)
 | 
			
		||||
@ -159,7 +160,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
 | 
			
		||||
	for _, cp := range cps {
 | 
			
		||||
		// show warning message the callback name already exists
 | 
			
		||||
		if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
 | 
			
		||||
			log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
 | 
			
		||||
			log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, utils.FileWithLineNum())
 | 
			
		||||
		}
 | 
			
		||||
		allNames = append(allNames, cp.name)
 | 
			
		||||
	}
 | 
			
		||||
@ -203,7 +204,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
 | 
			
		||||
		sortCallbackProcessor(cp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sortedFuncs []*func(scope *Scope)
 | 
			
		||||
	var sortedFuncs []*func(*gorm.DB)
 | 
			
		||||
	for _, name := range sortedNames {
 | 
			
		||||
		if index := getRIndex(allNames, name); !cps[index].remove {
 | 
			
		||||
			sortedFuncs = append(sortedFuncs, cps[index].processor)
 | 
			
		||||
 | 
			
		||||
@ -1,164 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for creating
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:create", createCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
 | 
			
		||||
	DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
 | 
			
		||||
func beforeCreateCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("BeforeSave")
 | 
			
		||||
	}
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("BeforeCreate")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
 | 
			
		||||
func updateTimeStampForCreateCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		now := NowFunc()
 | 
			
		||||
 | 
			
		||||
		if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
 | 
			
		||||
			if createdAtField.IsBlank {
 | 
			
		||||
				createdAtField.Set(now)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
 | 
			
		||||
			if updatedAtField.IsBlank {
 | 
			
		||||
				updatedAtField.Set(now)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createCallback the callback used to insert data into database
 | 
			
		||||
func createCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		defer scope.trace(NowFunc())
 | 
			
		||||
 | 
			
		||||
		var (
 | 
			
		||||
			columns, placeholders        []string
 | 
			
		||||
			blankColumnsWithDefaultValue []string
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		for _, field := range scope.Fields() {
 | 
			
		||||
			if scope.changeableField(field) {
 | 
			
		||||
				if field.IsNormal {
 | 
			
		||||
					if field.IsBlank && field.HasDefaultValue {
 | 
			
		||||
						blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
 | 
			
		||||
						scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
 | 
			
		||||
					} else if !field.IsPrimaryKey || !field.IsBlank {
 | 
			
		||||
						columns = append(columns, scope.Quote(field.DBName))
 | 
			
		||||
						placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
 | 
			
		||||
					}
 | 
			
		||||
				} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
 | 
			
		||||
					for _, foreignKey := range field.Relationship.ForeignDBNames {
 | 
			
		||||
						if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
 | 
			
		||||
							columns = append(columns, scope.Quote(foreignField.DBName))
 | 
			
		||||
							placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var (
 | 
			
		||||
			returningColumn = "*"
 | 
			
		||||
			quotedTableName = scope.QuotedTableName()
 | 
			
		||||
			primaryField    = scope.PrimaryField()
 | 
			
		||||
			extraOption     string
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		if str, ok := scope.Get("gorm:insert_option"); ok {
 | 
			
		||||
			extraOption = fmt.Sprint(str)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if primaryField != nil {
 | 
			
		||||
			returningColumn = scope.Quote(primaryField.DBName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
 | 
			
		||||
 | 
			
		||||
		if len(columns) == 0 {
 | 
			
		||||
			scope.Raw(fmt.Sprintf(
 | 
			
		||||
				"INSERT INTO %v %v%v%v",
 | 
			
		||||
				quotedTableName,
 | 
			
		||||
				scope.Dialect().DefaultValueStr(),
 | 
			
		||||
				addExtraSpaceIfExist(extraOption),
 | 
			
		||||
				addExtraSpaceIfExist(lastInsertIDReturningSuffix),
 | 
			
		||||
			))
 | 
			
		||||
		} else {
 | 
			
		||||
			scope.Raw(fmt.Sprintf(
 | 
			
		||||
				"INSERT INTO %v (%v) VALUES (%v)%v%v",
 | 
			
		||||
				scope.QuotedTableName(),
 | 
			
		||||
				strings.Join(columns, ","),
 | 
			
		||||
				strings.Join(placeholders, ","),
 | 
			
		||||
				addExtraSpaceIfExist(extraOption),
 | 
			
		||||
				addExtraSpaceIfExist(lastInsertIDReturningSuffix),
 | 
			
		||||
			))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// execute create sql
 | 
			
		||||
		if lastInsertIDReturningSuffix == "" || primaryField == nil {
 | 
			
		||||
			if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
				// set rows affected count
 | 
			
		||||
				scope.db.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
 | 
			
		||||
				// set primary value to primary field
 | 
			
		||||
				if primaryField != nil && primaryField.IsBlank {
 | 
			
		||||
					if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
 | 
			
		||||
						scope.Err(primaryField.Set(primaryValue))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			if primaryField.Field.CanAddr() {
 | 
			
		||||
				if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
 | 
			
		||||
					primaryField.IsBlank = false
 | 
			
		||||
					scope.db.RowsAffected = 1
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				scope.Err(ErrUnaddressable)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
 | 
			
		||||
func forceReloadAfterCreateCallback(scope *Scope) {
 | 
			
		||||
	if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
 | 
			
		||||
		db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
 | 
			
		||||
		for _, field := range scope.Fields() {
 | 
			
		||||
			if field.IsPrimaryKey && !field.IsBlank {
 | 
			
		||||
				db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		db.Scan(scope.Value)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
 | 
			
		||||
func afterCreateCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("AfterCreate")
 | 
			
		||||
	}
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("AfterSave")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,63 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for deleting
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
			
		||||
	DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
 | 
			
		||||
	DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
 | 
			
		||||
	DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
 | 
			
		||||
	DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
 | 
			
		||||
func beforeDeleteCallback(scope *Scope) {
 | 
			
		||||
	if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
 | 
			
		||||
		scope.Err(errors.New("Missing WHERE clause while deleting"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("BeforeDelete")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
 | 
			
		||||
func deleteCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		var extraOption string
 | 
			
		||||
		if str, ok := scope.Get("gorm:delete_option"); ok {
 | 
			
		||||
			extraOption = fmt.Sprint(str)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
 | 
			
		||||
 | 
			
		||||
		if !scope.Search.Unscoped && hasDeletedAtField {
 | 
			
		||||
			scope.Raw(fmt.Sprintf(
 | 
			
		||||
				"UPDATE %v SET %v=%v%v%v",
 | 
			
		||||
				scope.QuotedTableName(),
 | 
			
		||||
				scope.Quote(deletedAtField.DBName),
 | 
			
		||||
				scope.AddToVars(NowFunc()),
 | 
			
		||||
				addExtraSpaceIfExist(scope.CombinedConditionSql()),
 | 
			
		||||
				addExtraSpaceIfExist(extraOption),
 | 
			
		||||
			)).Exec()
 | 
			
		||||
		} else {
 | 
			
		||||
			scope.Raw(fmt.Sprintf(
 | 
			
		||||
				"DELETE FROM %v%v%v",
 | 
			
		||||
				scope.QuotedTableName(),
 | 
			
		||||
				addExtraSpaceIfExist(scope.CombinedConditionSql()),
 | 
			
		||||
				addExtraSpaceIfExist(extraOption),
 | 
			
		||||
			)).Exec()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
 | 
			
		||||
func afterDeleteCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("AfterDelete")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,479 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for querying
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:query", queryCallback)
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:preload", preloadCallback)
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// queryCallback used to query data from database
 | 
			
		||||
func queryCallback(scope *Scope) {
 | 
			
		||||
	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer scope.trace(NowFunc())
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		isSlice, isPtr bool
 | 
			
		||||
		resultType     reflect.Type
 | 
			
		||||
		results        = scope.IndirectValue()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
 | 
			
		||||
		if primaryField := scope.PrimaryField(); primaryField != nil {
 | 
			
		||||
			scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value, ok := scope.Get("gorm:query_destination"); ok {
 | 
			
		||||
		results = indirect(reflect.ValueOf(value))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if kind := results.Kind(); kind == reflect.Slice {
 | 
			
		||||
		isSlice = true
 | 
			
		||||
		resultType = results.Type().Elem()
 | 
			
		||||
		results.Set(reflect.MakeSlice(results.Type(), 0, 0))
 | 
			
		||||
 | 
			
		||||
		if resultType.Kind() == reflect.Ptr {
 | 
			
		||||
			isPtr = true
 | 
			
		||||
			resultType = resultType.Elem()
 | 
			
		||||
		}
 | 
			
		||||
	} else if kind != reflect.Struct {
 | 
			
		||||
		scope.Err(errors.New("unsupported destination, should be slice or struct"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.db.RowsAffected = 0
 | 
			
		||||
		if str, ok := scope.Get("gorm:query_option"); ok {
 | 
			
		||||
			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			defer rows.Close()
 | 
			
		||||
 | 
			
		||||
			columns, _ := rows.Columns()
 | 
			
		||||
			for rows.Next() {
 | 
			
		||||
				scope.db.RowsAffected++
 | 
			
		||||
 | 
			
		||||
				elem := results
 | 
			
		||||
				if isSlice {
 | 
			
		||||
					elem = reflect.New(resultType).Elem()
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
 | 
			
		||||
 | 
			
		||||
				if isSlice {
 | 
			
		||||
					if isPtr {
 | 
			
		||||
						results.Set(reflect.Append(results, elem.Addr()))
 | 
			
		||||
					} else {
 | 
			
		||||
						results.Set(reflect.Append(results, elem))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := rows.Err(); err != nil {
 | 
			
		||||
				scope.Err(err)
 | 
			
		||||
			} else if scope.db.RowsAffected == 0 && !isSlice {
 | 
			
		||||
				scope.Err(ErrRecordNotFound)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// afterQueryCallback will invoke `AfterFind` method after querying
 | 
			
		||||
func afterQueryCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("AfterFind")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// preloadCallback used to preload associations
 | 
			
		||||
func preloadCallback(scope *Scope) {
 | 
			
		||||
	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, ok := scope.Get("gorm:auto_preload"); ok {
 | 
			
		||||
		autoPreload(scope)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if scope.Search.preload == nil || scope.HasError() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		preloadedMap = map[string]bool{}
 | 
			
		||||
		fields       = scope.Fields()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for _, preload := range scope.Search.preload {
 | 
			
		||||
		var (
 | 
			
		||||
			preloadFields = strings.Split(preload.schema, ".")
 | 
			
		||||
			currentScope  = scope
 | 
			
		||||
			currentFields = fields
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		for idx, preloadField := range preloadFields {
 | 
			
		||||
			var currentPreloadConditions []interface{}
 | 
			
		||||
 | 
			
		||||
			if currentScope == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// if not preloaded
 | 
			
		||||
			if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
 | 
			
		||||
 | 
			
		||||
				// assign search conditions to last preload
 | 
			
		||||
				if idx == len(preloadFields)-1 {
 | 
			
		||||
					currentPreloadConditions = preload.conditions
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				for _, field := range currentFields {
 | 
			
		||||
					if field.Name != preloadField || field.Relationship == nil {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					switch field.Relationship.Kind {
 | 
			
		||||
					case "has_one":
 | 
			
		||||
						currentScope.handleHasOnePreload(field, currentPreloadConditions)
 | 
			
		||||
					case "has_many":
 | 
			
		||||
						currentScope.handleHasManyPreload(field, currentPreloadConditions)
 | 
			
		||||
					case "belongs_to":
 | 
			
		||||
						currentScope.handleBelongsToPreload(field, currentPreloadConditions)
 | 
			
		||||
					case "many_to_many":
 | 
			
		||||
						currentScope.handleManyToManyPreload(field, currentPreloadConditions)
 | 
			
		||||
					default:
 | 
			
		||||
						scope.Err(errors.New("unsupported relation"))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					preloadedMap[preloadKey] = true
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if !preloadedMap[preloadKey] {
 | 
			
		||||
					scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// preload next level
 | 
			
		||||
			if idx < len(preloadFields)-1 {
 | 
			
		||||
				currentScope = currentScope.getColumnAsScope(preloadField)
 | 
			
		||||
				if currentScope != nil {
 | 
			
		||||
					currentFields = currentScope.Fields()
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func autoPreload(scope *Scope) {
 | 
			
		||||
	for _, field := range scope.Fields() {
 | 
			
		||||
		if field.Relationship == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if val, ok := field.TagSettings["PRELOAD"]; ok {
 | 
			
		||||
			if preload, err := strconv.ParseBool(val); err != nil {
 | 
			
		||||
				scope.Err(errors.New("invalid preload option"))
 | 
			
		||||
				return
 | 
			
		||||
			} else if !preload {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		scope.Search.Preload(field.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
 | 
			
		||||
	var (
 | 
			
		||||
		preloadDB         = scope.NewDB()
 | 
			
		||||
		preloadConditions []interface{}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for _, condition := range conditions {
 | 
			
		||||
		if scopes, ok := condition.(func(*DB) *DB); ok {
 | 
			
		||||
			preloadDB = scopes(preloadDB)
 | 
			
		||||
		} else {
 | 
			
		||||
			preloadConditions = append(preloadConditions, condition)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return preloadDB, preloadConditions
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleHasOnePreload used to preload has one associations
 | 
			
		||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
 | 
			
		||||
	relation := field.Relationship
 | 
			
		||||
 | 
			
		||||
	// get relations's primary keys
 | 
			
		||||
	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
 | 
			
		||||
	if len(primaryKeys) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// preload conditions
 | 
			
		||||
	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
			
		||||
 | 
			
		||||
	// find relations
 | 
			
		||||
	query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
 | 
			
		||||
	values := toQueryValues(primaryKeys)
 | 
			
		||||
	if relation.PolymorphicType != "" {
 | 
			
		||||
		query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
 | 
			
		||||
		values = append(values, relation.PolymorphicValue)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	results := makeSlice(field.Struct.Type)
 | 
			
		||||
	scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
 | 
			
		||||
 | 
			
		||||
	// assign find results
 | 
			
		||||
	var (
 | 
			
		||||
		resultsValue       = indirect(reflect.ValueOf(results))
 | 
			
		||||
		indirectScopeValue = scope.IndirectValue()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if indirectScopeValue.Kind() == reflect.Slice {
 | 
			
		||||
		for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
			
		||||
			for i := 0; i < resultsValue.Len(); i++ {
 | 
			
		||||
				result := resultsValue.Index(i)
 | 
			
		||||
				foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
 | 
			
		||||
				if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
 | 
			
		||||
					indirectValue.FieldByName(field.Name).Set(result)
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		for i := 0; i < resultsValue.Len(); i++ {
 | 
			
		||||
			result := resultsValue.Index(i)
 | 
			
		||||
			scope.Err(field.Set(result))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleHasManyPreload used to preload has many associations
 | 
			
		||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
 | 
			
		||||
	relation := field.Relationship
 | 
			
		||||
 | 
			
		||||
	// get relations's primary keys
 | 
			
		||||
	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
 | 
			
		||||
	if len(primaryKeys) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// preload conditions
 | 
			
		||||
	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
			
		||||
 | 
			
		||||
	// find relations
 | 
			
		||||
	query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
 | 
			
		||||
	values := toQueryValues(primaryKeys)
 | 
			
		||||
	if relation.PolymorphicType != "" {
 | 
			
		||||
		query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
 | 
			
		||||
		values = append(values, relation.PolymorphicValue)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	results := makeSlice(field.Struct.Type)
 | 
			
		||||
	scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
 | 
			
		||||
 | 
			
		||||
	// assign find results
 | 
			
		||||
	var (
 | 
			
		||||
		resultsValue       = indirect(reflect.ValueOf(results))
 | 
			
		||||
		indirectScopeValue = scope.IndirectValue()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if indirectScopeValue.Kind() == reflect.Slice {
 | 
			
		||||
		preloadMap := make(map[string][]reflect.Value)
 | 
			
		||||
		for i := 0; i < resultsValue.Len(); i++ {
 | 
			
		||||
			result := resultsValue.Index(i)
 | 
			
		||||
			foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
 | 
			
		||||
			preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
			
		||||
			object := indirect(indirectScopeValue.Index(j))
 | 
			
		||||
			objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
 | 
			
		||||
			f := object.FieldByName(field.Name)
 | 
			
		||||
			if results, ok := preloadMap[toString(objectRealValue)]; ok {
 | 
			
		||||
				f.Set(reflect.Append(f, results...))
 | 
			
		||||
			} else {
 | 
			
		||||
				f.Set(reflect.MakeSlice(f.Type(), 0, 0))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		scope.Err(field.Set(resultsValue))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleBelongsToPreload used to preload belongs to associations
 | 
			
		||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
 | 
			
		||||
	relation := field.Relationship
 | 
			
		||||
 | 
			
		||||
	// preload conditions
 | 
			
		||||
	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
			
		||||
 | 
			
		||||
	// get relations's primary keys
 | 
			
		||||
	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
 | 
			
		||||
	if len(primaryKeys) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// find relations
 | 
			
		||||
	results := makeSlice(field.Struct.Type)
 | 
			
		||||
	scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 | 
			
		||||
 | 
			
		||||
	// assign find results
 | 
			
		||||
	var (
 | 
			
		||||
		resultsValue       = indirect(reflect.ValueOf(results))
 | 
			
		||||
		indirectScopeValue = scope.IndirectValue()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < resultsValue.Len(); i++ {
 | 
			
		||||
		result := resultsValue.Index(i)
 | 
			
		||||
		if indirectScopeValue.Kind() == reflect.Slice {
 | 
			
		||||
			value := getValueFromFields(result, relation.AssociationForeignFieldNames)
 | 
			
		||||
			for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
			
		||||
				object := indirect(indirectScopeValue.Index(j))
 | 
			
		||||
				if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
 | 
			
		||||
					object.FieldByName(field.Name).Set(result)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			scope.Err(field.Set(result))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleManyToManyPreload used to preload many to many associations
 | 
			
		||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
 | 
			
		||||
	var (
 | 
			
		||||
		relation         = field.Relationship
 | 
			
		||||
		joinTableHandler = relation.JoinTableHandler
 | 
			
		||||
		fieldType        = field.Struct.Type.Elem()
 | 
			
		||||
		foreignKeyValue  interface{}
 | 
			
		||||
		foreignKeyType   = reflect.ValueOf(&foreignKeyValue).Type()
 | 
			
		||||
		linkHash         = map[string][]reflect.Value{}
 | 
			
		||||
		isPtr            bool
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if fieldType.Kind() == reflect.Ptr {
 | 
			
		||||
		isPtr = true
 | 
			
		||||
		fieldType = fieldType.Elem()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sourceKeys = []string{}
 | 
			
		||||
	for _, key := range joinTableHandler.SourceForeignKeys() {
 | 
			
		||||
		sourceKeys = append(sourceKeys, key.DBName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// preload conditions
 | 
			
		||||
	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
 | 
			
		||||
 | 
			
		||||
	// generate query with join table
 | 
			
		||||
	newScope := scope.New(reflect.New(fieldType).Interface())
 | 
			
		||||
	preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
 | 
			
		||||
 | 
			
		||||
	if len(preloadDB.search.selects) == 0 {
 | 
			
		||||
		preloadDB = preloadDB.Select("*")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
 | 
			
		||||
 | 
			
		||||
	// preload inline conditions
 | 
			
		||||
	if len(preloadConditions) > 0 {
 | 
			
		||||
		preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := preloadDB.Rows()
 | 
			
		||||
 | 
			
		||||
	if scope.Err(err) != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	columns, _ := rows.Columns()
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var (
 | 
			
		||||
			elem   = reflect.New(fieldType).Elem()
 | 
			
		||||
			fields = scope.New(elem.Addr().Interface()).Fields()
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		// register foreign keys in join tables
 | 
			
		||||
		var joinTableFields []*Field
 | 
			
		||||
		for _, sourceKey := range sourceKeys {
 | 
			
		||||
			joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		scope.scan(rows, columns, append(fields, joinTableFields...))
 | 
			
		||||
 | 
			
		||||
		scope.New(elem.Addr().Interface()).
 | 
			
		||||
			InstanceSet("gorm:skip_query_callback", true).
 | 
			
		||||
			callCallbacks(scope.db.parent.callbacks.queries)
 | 
			
		||||
 | 
			
		||||
		var foreignKeys = make([]interface{}, len(sourceKeys))
 | 
			
		||||
		// generate hashed forkey keys in join table
 | 
			
		||||
		for idx, joinTableField := range joinTableFields {
 | 
			
		||||
			if !joinTableField.Field.IsNil() {
 | 
			
		||||
				foreignKeys[idx] = joinTableField.Field.Elem().Interface()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		hashedSourceKeys := toString(foreignKeys)
 | 
			
		||||
 | 
			
		||||
		if isPtr {
 | 
			
		||||
			linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
 | 
			
		||||
		} else {
 | 
			
		||||
			linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := rows.Err(); err != nil {
 | 
			
		||||
		scope.Err(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// assign find results
 | 
			
		||||
	var (
 | 
			
		||||
		indirectScopeValue = scope.IndirectValue()
 | 
			
		||||
		fieldsSourceMap    = map[string][]reflect.Value{}
 | 
			
		||||
		foreignFieldNames  = []string{}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for _, dbName := range relation.ForeignFieldNames {
 | 
			
		||||
		if field, ok := scope.FieldByName(dbName); ok {
 | 
			
		||||
			foreignFieldNames = append(foreignFieldNames, field.Name)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if indirectScopeValue.Kind() == reflect.Slice {
 | 
			
		||||
		for j := 0; j < indirectScopeValue.Len(); j++ {
 | 
			
		||||
			object := indirect(indirectScopeValue.Index(j))
 | 
			
		||||
			key := toString(getValueFromFields(object, foreignFieldNames))
 | 
			
		||||
			fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
 | 
			
		||||
		}
 | 
			
		||||
	} else if indirectScopeValue.IsValid() {
 | 
			
		||||
		key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
 | 
			
		||||
		fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
 | 
			
		||||
	}
 | 
			
		||||
	for source, link := range linkHash {
 | 
			
		||||
		for i, field := range fieldsSourceMap[source] {
 | 
			
		||||
			//If not 0 this means Value is a pointer and we already added preloaded models to it
 | 
			
		||||
			if fieldsSourceMap[source][i].Len() != 0 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,30 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// Define callbacks for row query
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RowQueryResult struct {
 | 
			
		||||
	Row *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RowsQueryResult struct {
 | 
			
		||||
	Rows  *sql.Rows
 | 
			
		||||
	Error error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// queryCallback used to query data from database
 | 
			
		||||
func rowQueryCallback(scope *Scope) {
 | 
			
		||||
	if result, ok := scope.InstanceGet("row_query_result"); ok {
 | 
			
		||||
		scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
		if rowResult, ok := result.(*RowQueryResult); ok {
 | 
			
		||||
			rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
 | 
			
		||||
		} else if rowsResult, ok := result.(*RowsQueryResult); ok {
 | 
			
		||||
			rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,170 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func beginTransactionCallback(scope *Scope) {
 | 
			
		||||
	scope.Begin()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func commitOrRollbackTransactionCallback(scope *Scope) {
 | 
			
		||||
	scope.CommitOrRollback()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
 | 
			
		||||
	checkTruth := func(value interface{}) bool {
 | 
			
		||||
		if v, ok := value.(bool); ok && !v {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if v, ok := value.(string); ok {
 | 
			
		||||
			v = strings.ToLower(v)
 | 
			
		||||
			if v == "false" || v != "skip" {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
 | 
			
		||||
		if r = field.Relationship; r != nil {
 | 
			
		||||
			autoUpdate, autoCreate, saveReference = true, true, true
 | 
			
		||||
 | 
			
		||||
			if value, ok := scope.Get("gorm:save_associations"); ok {
 | 
			
		||||
				autoUpdate = checkTruth(value)
 | 
			
		||||
				autoCreate = autoUpdate
 | 
			
		||||
			} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
 | 
			
		||||
				autoUpdate = checkTruth(value)
 | 
			
		||||
				autoCreate = autoUpdate
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if value, ok := scope.Get("gorm:association_autoupdate"); ok {
 | 
			
		||||
				autoUpdate = checkTruth(value)
 | 
			
		||||
			} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
 | 
			
		||||
				autoUpdate = checkTruth(value)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if value, ok := scope.Get("gorm:association_autocreate"); ok {
 | 
			
		||||
				autoCreate = checkTruth(value)
 | 
			
		||||
			} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
 | 
			
		||||
				autoCreate = checkTruth(value)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if value, ok := scope.Get("gorm:association_save_reference"); ok {
 | 
			
		||||
				saveReference = checkTruth(value)
 | 
			
		||||
			} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
 | 
			
		||||
				saveReference = checkTruth(value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func saveBeforeAssociationsCallback(scope *Scope) {
 | 
			
		||||
	for _, field := range scope.Fields() {
 | 
			
		||||
		autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
 | 
			
		||||
 | 
			
		||||
		if relationship != nil && relationship.Kind == "belongs_to" {
 | 
			
		||||
			fieldValue := field.Field.Addr().Interface()
 | 
			
		||||
			newScope := scope.New(fieldValue)
 | 
			
		||||
 | 
			
		||||
			if newScope.PrimaryKeyZero() {
 | 
			
		||||
				if autoCreate {
 | 
			
		||||
					scope.Err(scope.NewDB().Save(fieldValue).Error)
 | 
			
		||||
				}
 | 
			
		||||
			} else if autoUpdate {
 | 
			
		||||
				scope.Err(scope.NewDB().Save(fieldValue).Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if saveReference {
 | 
			
		||||
				if len(relationship.ForeignFieldNames) != 0 {
 | 
			
		||||
					// set value's foreign key
 | 
			
		||||
					for idx, fieldName := range relationship.ForeignFieldNames {
 | 
			
		||||
						associationForeignName := relationship.AssociationForeignDBNames[idx]
 | 
			
		||||
						if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
 | 
			
		||||
							scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func saveAfterAssociationsCallback(scope *Scope) {
 | 
			
		||||
	for _, field := range scope.Fields() {
 | 
			
		||||
		autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
 | 
			
		||||
 | 
			
		||||
		if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
 | 
			
		||||
			value := field.Field
 | 
			
		||||
 | 
			
		||||
			switch value.Kind() {
 | 
			
		||||
			case reflect.Slice:
 | 
			
		||||
				for i := 0; i < value.Len(); i++ {
 | 
			
		||||
					newDB := scope.NewDB()
 | 
			
		||||
					elem := value.Index(i).Addr().Interface()
 | 
			
		||||
					newScope := newDB.NewScope(elem)
 | 
			
		||||
 | 
			
		||||
					if saveReference {
 | 
			
		||||
						if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
 | 
			
		||||
							for idx, fieldName := range relationship.ForeignFieldNames {
 | 
			
		||||
								associationForeignName := relationship.AssociationForeignDBNames[idx]
 | 
			
		||||
								if f, ok := scope.FieldByName(associationForeignName); ok {
 | 
			
		||||
									scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if relationship.PolymorphicType != "" {
 | 
			
		||||
							scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if newScope.PrimaryKeyZero() {
 | 
			
		||||
						if autoCreate {
 | 
			
		||||
							scope.Err(newDB.Save(elem).Error)
 | 
			
		||||
						}
 | 
			
		||||
					} else if autoUpdate {
 | 
			
		||||
						scope.Err(newDB.Save(elem).Error)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
 | 
			
		||||
						if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
 | 
			
		||||
							scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			default:
 | 
			
		||||
				elem := value.Addr().Interface()
 | 
			
		||||
				newScope := scope.New(elem)
 | 
			
		||||
 | 
			
		||||
				if saveReference {
 | 
			
		||||
					if len(relationship.ForeignFieldNames) != 0 {
 | 
			
		||||
						for idx, fieldName := range relationship.ForeignFieldNames {
 | 
			
		||||
							associationForeignName := relationship.AssociationForeignDBNames[idx]
 | 
			
		||||
							if f, ok := scope.FieldByName(associationForeignName); ok {
 | 
			
		||||
								scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if relationship.PolymorphicType != "" {
 | 
			
		||||
						scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if newScope.PrimaryKeyZero() {
 | 
			
		||||
					if autoCreate {
 | 
			
		||||
						scope.Err(scope.NewDB().Save(elem).Error)
 | 
			
		||||
					}
 | 
			
		||||
				} else if autoUpdate {
 | 
			
		||||
					scope.Err(scope.NewDB().Save(elem).Error)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,119 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for updating
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:update", updateCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
 | 
			
		||||
	DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// assignUpdatingAttributesCallback assign updating attributes to model
 | 
			
		||||
func assignUpdatingAttributesCallback(scope *Scope) {
 | 
			
		||||
	if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
 | 
			
		||||
		if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
 | 
			
		||||
			scope.InstanceSet("gorm:update_attrs", updateMaps)
 | 
			
		||||
		} else {
 | 
			
		||||
			scope.SkipLeft()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
 | 
			
		||||
func beforeUpdateCallback(scope *Scope) {
 | 
			
		||||
	if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
 | 
			
		||||
		scope.Err(errors.New("Missing WHERE clause while updating"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
			
		||||
		if !scope.HasError() {
 | 
			
		||||
			scope.CallMethod("BeforeSave")
 | 
			
		||||
		}
 | 
			
		||||
		if !scope.HasError() {
 | 
			
		||||
			scope.CallMethod("BeforeUpdate")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
 | 
			
		||||
func updateTimeStampForUpdateCallback(scope *Scope) {
 | 
			
		||||
	if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
			
		||||
		scope.SetColumn("UpdatedAt", NowFunc())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// updateCallback the callback used to update data to database
 | 
			
		||||
func updateCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		var sqls []string
 | 
			
		||||
 | 
			
		||||
		if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
			
		||||
			// Sort the column names so that the generated SQL is the same every time.
 | 
			
		||||
			updateMap := updateAttrs.(map[string]interface{})
 | 
			
		||||
			var columns []string
 | 
			
		||||
			for c := range updateMap {
 | 
			
		||||
				columns = append(columns, c)
 | 
			
		||||
			}
 | 
			
		||||
			sort.Strings(columns)
 | 
			
		||||
 | 
			
		||||
			for _, column := range columns {
 | 
			
		||||
				value := updateMap[column]
 | 
			
		||||
				sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			for _, field := range scope.Fields() {
 | 
			
		||||
				if scope.changeableField(field) {
 | 
			
		||||
					if !field.IsPrimaryKey && field.IsNormal {
 | 
			
		||||
						sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 | 
			
		||||
					} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
			
		||||
						for _, foreignKey := range relationship.ForeignDBNames {
 | 
			
		||||
							if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
 | 
			
		||||
								sqls = append(sqls,
 | 
			
		||||
									fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var extraOption string
 | 
			
		||||
		if str, ok := scope.Get("gorm:update_option"); ok {
 | 
			
		||||
			extraOption = fmt.Sprint(str)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(sqls) > 0 {
 | 
			
		||||
			scope.Raw(fmt.Sprintf(
 | 
			
		||||
				"UPDATE %v SET %v%v%v",
 | 
			
		||||
				scope.QuotedTableName(),
 | 
			
		||||
				strings.Join(sqls, ", "),
 | 
			
		||||
				addExtraSpaceIfExist(scope.CombinedConditionSql()),
 | 
			
		||||
				addExtraSpaceIfExist(extraOption),
 | 
			
		||||
			)).Exec()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
 | 
			
		||||
func afterUpdateCallback(scope *Scope) {
 | 
			
		||||
	if _, ok := scope.Get("gorm:update_column"); !ok {
 | 
			
		||||
		if !scope.HasError() {
 | 
			
		||||
			scope.CallMethod("AfterUpdate")
 | 
			
		||||
		}
 | 
			
		||||
		if !scope.HasError() {
 | 
			
		||||
			scope.CallMethod("AfterSave")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,10 +1,8 @@
 | 
			
		||||
package model
 | 
			
		||||
package sqlbuilder
 | 
			
		||||
 | 
			
		||||
import "errors"
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// ErrInvalidTable invalid table name
 | 
			
		||||
	ErrInvalidTable = errors.New("invalid table name")
 | 
			
		||||
	// ErrUnaddressable unaddressable value
 | 
			
		||||
	ErrUnaddressable = errors.New("using unaddressable value")
 | 
			
		||||
)
 | 
			
		||||
@ -1,80 +1,28 @@
 | 
			
		||||
package model
 | 
			
		||||
package sqlbuilder
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/builder"
 | 
			
		||||
	"github.com/jinzhu/gorm/model"
 | 
			
		||||
	"github.com/jinzhu/gorm/schema"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Field GORM model field
 | 
			
		||||
type Field struct {
 | 
			
		||||
	*schema.Field
 | 
			
		||||
	IsBlank bool
 | 
			
		||||
	Value   reflect.Value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Set set a value to the field
 | 
			
		||||
func (field *Field) Set(value interface{}) (err error) {
 | 
			
		||||
	if !field.Value.IsValid() {
 | 
			
		||||
		return errors.New("field value not valid")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !field.Value.CanAddr() {
 | 
			
		||||
		return ErrUnaddressable
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reflectValue, ok := value.(reflect.Value)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		reflectValue = reflect.ValueOf(value)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fieldValue := field.Value
 | 
			
		||||
	if reflectValue.IsValid() {
 | 
			
		||||
		if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
 | 
			
		||||
			fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
 | 
			
		||||
		} else {
 | 
			
		||||
			if fieldValue.Kind() == reflect.Ptr {
 | 
			
		||||
				if fieldValue.IsNil() {
 | 
			
		||||
					fieldValue.Set(reflect.New(field.StructField.Type.Elem()))
 | 
			
		||||
				}
 | 
			
		||||
				fieldValue = fieldValue.Elem()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
 | 
			
		||||
				fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
 | 
			
		||||
			} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
 | 
			
		||||
				err = scanner.Scan(reflectValue.Interface())
 | 
			
		||||
			} else {
 | 
			
		||||
				err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		field.Value.Set(reflect.Zero(fieldValue.Type()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	field.IsBlank = isBlank(field.Value)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAssignments get assignments
 | 
			
		||||
func GetAssignments(tx *gorm.DB) chan [][]*Field {
 | 
			
		||||
	fieldChan := make(chan [][]*Field)
 | 
			
		||||
// GetAssignmentFields get assignment fields
 | 
			
		||||
func GetAssignmentFields(tx *gorm.DB) chan [][]*model.Field {
 | 
			
		||||
	fieldChan := make(chan [][]*model.Field)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		assignableChecker := generateAssignableChecker(selectAttrs(tx.Statement), omitAttrs(tx.Statement))
 | 
			
		||||
 | 
			
		||||
		switch dest := tx.Statement.Dest.(type) {
 | 
			
		||||
		case map[string]interface{}:
 | 
			
		||||
			fieldChan <- [][]*Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)}
 | 
			
		||||
			fieldChan <- [][]*model.Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)}
 | 
			
		||||
		case []map[string]interface{}:
 | 
			
		||||
			fields := [][]*Field{}
 | 
			
		||||
			fields := [][]*model.Field{}
 | 
			
		||||
			tableSchema := schema.Parse(tx.Statement.Table)
 | 
			
		||||
 | 
			
		||||
			for _, v := range dest {
 | 
			
		||||
@ -87,13 +35,13 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field {
 | 
			
		||||
 | 
			
		||||
				switch results.Kind() {
 | 
			
		||||
				case reflect.Slice:
 | 
			
		||||
					fields := [][]*Field{}
 | 
			
		||||
					fields := [][]*model.Field{}
 | 
			
		||||
					for i := 0; i < results.Len(); i++ {
 | 
			
		||||
						fields = append(fields, structToField(indirect(results.Index(i)), s, assignableChecker))
 | 
			
		||||
					}
 | 
			
		||||
					fieldChan <- fields
 | 
			
		||||
				case reflect.Struct:
 | 
			
		||||
					fieldChan <- [][]*Field{structToField(results, s, assignableChecker)}
 | 
			
		||||
					fieldChan <- [][]*model.Field{structToField(results, s, assignableChecker)}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@ -102,12 +50,12 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field {
 | 
			
		||||
	return fieldChan
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) {
 | 
			
		||||
func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) {
 | 
			
		||||
	// TODO assign those value to dest
 | 
			
		||||
	for k, v := range value {
 | 
			
		||||
		if s != nil {
 | 
			
		||||
			if f := s.FieldByName(k); f != nil {
 | 
			
		||||
				field := &Field{Field: f, Value: reflect.ValueOf(v)}
 | 
			
		||||
				field := &model.Field{Field: f, Value: reflect.ValueOf(v)}
 | 
			
		||||
				if assignableChecker(field) {
 | 
			
		||||
					fields = append(fields, field)
 | 
			
		||||
				}
 | 
			
		||||
@ -115,7 +63,7 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		field := &Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}
 | 
			
		||||
		field := &model.Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}
 | 
			
		||||
		if assignableChecker(field) {
 | 
			
		||||
			fields = append(fields, field)
 | 
			
		||||
		}
 | 
			
		||||
@ -127,14 +75,14 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) {
 | 
			
		||||
func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) {
 | 
			
		||||
	// TODO use Offset to replace FieldByName?
 | 
			
		||||
	for _, sf := range s.Fields {
 | 
			
		||||
		obj := value
 | 
			
		||||
		for _, bn := range sf.BindNames {
 | 
			
		||||
			obj = value.FieldByName(bn)
 | 
			
		||||
		}
 | 
			
		||||
		field := &Field{Field: sf, Value: obj, IsBlank: isBlank(obj)}
 | 
			
		||||
		field := &model.Field{Field: sf, Value: obj, IsBlank: model.IsBlank(obj)}
 | 
			
		||||
		if assignableChecker(field) {
 | 
			
		||||
			fields = append(fields, field)
 | 
			
		||||
		}
 | 
			
		||||
@ -143,8 +91,8 @@ func structToField(value reflect.Value, s *schema.Schema, assignableChecker func
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// generateAssignableChecker generate checker to check if field is assignable or not
 | 
			
		||||
func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*Field) bool {
 | 
			
		||||
	return func(field *Field) bool {
 | 
			
		||||
func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*model.Field) bool {
 | 
			
		||||
	return func(field *model.Field) bool {
 | 
			
		||||
		if len(selectAttrs) > 0 {
 | 
			
		||||
			for _, attr := range selectAttrs {
 | 
			
		||||
				if field.Name == attr || field.DBName == attr {
 | 
			
		||||
@ -164,7 +112,7 @@ func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*F
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// omitAttrs return selected attributes of stmt
 | 
			
		||||
func selectAttrs(stmt *builder.Statement) []string {
 | 
			
		||||
func selectAttrs(stmt *gorm.Statement) []string {
 | 
			
		||||
	columns := stmt.Select.Columns
 | 
			
		||||
	for _, arg := range stmt.Select.Args {
 | 
			
		||||
		columns = append(columns, fmt.Sprint(arg))
 | 
			
		||||
@ -173,6 +121,6 @@ func selectAttrs(stmt *builder.Statement) []string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// omitAttrs return omitted attributes of stmt
 | 
			
		||||
func omitAttrs(stmt *builder.Statement) []string {
 | 
			
		||||
func omitAttrs(stmt *gorm.Statement) []string {
 | 
			
		||||
	return stmt.Omit
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										63
									
								
								dialects/common/sqlbuilder/sqlbuilder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								dialects/common/sqlbuilder/sqlbuilder.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
package sqlbuilder
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/model"
 | 
			
		||||
	"github.com/jinzhu/gorm/schema"
 | 
			
		||||
	"github.com/jinzhu/inflection"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetTable get table name for current db operation
 | 
			
		||||
func GetTable(tx *gorm.DB) chan string {
 | 
			
		||||
	tableChan := make(chan string)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		var tableName string
 | 
			
		||||
		if name, ok := tx.Statement.Table.(string); ok {
 | 
			
		||||
			tableName = name
 | 
			
		||||
		} else {
 | 
			
		||||
			for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
 | 
			
		||||
				if v != nil {
 | 
			
		||||
					if t, ok := v.(tabler); ok {
 | 
			
		||||
						tableName = t.TableName()
 | 
			
		||||
					} else if t, ok := v.(dbTabler); ok {
 | 
			
		||||
						tableName = t.TableName(tx)
 | 
			
		||||
					} else if s := schema.Parse(v); s != nil {
 | 
			
		||||
						if s.TableName != "" {
 | 
			
		||||
							tableName = s.TableName
 | 
			
		||||
						} else {
 | 
			
		||||
							tableName = schema.ToDBName(s.ModelType.Name())
 | 
			
		||||
							if !tx.Config.SingularTable {
 | 
			
		||||
								tableName = inflection.Plural(tableName)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if tableName != "" {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tableName != "" {
 | 
			
		||||
			if model.DefaultTableNameHandler != nil {
 | 
			
		||||
				tableChan <- model.DefaultTableNameHandler(tx, tableName)
 | 
			
		||||
			} else {
 | 
			
		||||
				tableChan <- tableName
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			tx.AddError(ErrInvalidTable)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return tableChan
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type tabler interface {
 | 
			
		||||
	TableName() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dbTabler interface {
 | 
			
		||||
	TableName(*gorm.DB) string
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										10
									
								
								dialects/common/sqlbuilder/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								dialects/common/sqlbuilder/utils.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
package sqlbuilder
 | 
			
		||||
 | 
			
		||||
import "reflect"
 | 
			
		||||
 | 
			
		||||
func indirect(reflectValue reflect.Value) reflect.Value {
 | 
			
		||||
	for reflectValue.Kind() == reflect.Ptr {
 | 
			
		||||
		reflectValue = reflectValue.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	return reflectValue
 | 
			
		||||
}
 | 
			
		||||
@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/model"
 | 
			
		||||
	"github.com/jinzhu/gorm/dialects/common/destination"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Dialect Sqlite3 Dialect for GORM
 | 
			
		||||
@ -23,9 +23,9 @@ func (dialect Dialect) Quote(name string) string {
 | 
			
		||||
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		args            []interface{}
 | 
			
		||||
		assignmentsChan = model.GetAssignments(tx)
 | 
			
		||||
		tableNameChan   = model.GetTable(tx)
 | 
			
		||||
		primaryFields   []*model.Field
 | 
			
		||||
		assignmentsChan = destination.GetAssignments(tx)
 | 
			
		||||
		tableNameChan   = destination.GetTable(tx)
 | 
			
		||||
		primaryFields   []*destination.Field
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	s := bytes.NewBufferString("INSERT INTO ")
 | 
			
		||||
@ -41,7 +41,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
 | 
			
		||||
		valueBuffer := bytes.NewBufferString("VALUES ")
 | 
			
		||||
 | 
			
		||||
		for idx, fields := range assignments {
 | 
			
		||||
			var primaryField *model.Field
 | 
			
		||||
			var primaryField *destination.Field
 | 
			
		||||
			if idx != 0 {
 | 
			
		||||
				valueBuffer.WriteString(",")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										21
									
								
								logger/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								logger/utils.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"runtime"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
 | 
			
		||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
 | 
			
		||||
 | 
			
		||||
// FileWithLineNum get filename with line num for logging
 | 
			
		||||
func FileWithLineNum() string {
 | 
			
		||||
	for i := 2; i < 15; i++ {
 | 
			
		||||
		_, file, line, ok := runtime.Caller(i)
 | 
			
		||||
		if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
 | 
			
		||||
			return fmt.Sprintf("%v:%v", file, line)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
@ -1,9 +1,13 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/schema"
 | 
			
		||||
	"github.com/jinzhu/inflection"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DefaultTableNameHandler default table name handler
 | 
			
		||||
@ -12,57 +16,52 @@ import (
 | 
			
		||||
//    }
 | 
			
		||||
var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string
 | 
			
		||||
 | 
			
		||||
// GetTable get table name for current db operation
 | 
			
		||||
func GetTable(tx *gorm.DB) chan string {
 | 
			
		||||
	tableChan := make(chan string)
 | 
			
		||||
// Field GORM model field
 | 
			
		||||
type Field struct {
 | 
			
		||||
	*schema.Field
 | 
			
		||||
	IsBlank bool
 | 
			
		||||
	Value   reflect.Value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		var tableName string
 | 
			
		||||
		if name, ok := tx.Statement.Table.(string); ok {
 | 
			
		||||
			tableName = name
 | 
			
		||||
// Set set a value to the field
 | 
			
		||||
func (field *Field) Set(value interface{}) (err error) {
 | 
			
		||||
	if !field.Value.IsValid() {
 | 
			
		||||
		return errors.New("field value not valid")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !field.Value.CanAddr() {
 | 
			
		||||
		return gorm.ErrUnaddressable
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reflectValue, ok := value.(reflect.Value)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		reflectValue = reflect.ValueOf(value)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fieldValue := field.Value
 | 
			
		||||
	if reflectValue.IsValid() {
 | 
			
		||||
		if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
 | 
			
		||||
			fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
 | 
			
		||||
		} else {
 | 
			
		||||
			for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
 | 
			
		||||
				if v != nil {
 | 
			
		||||
					if t, ok := v.(tabler); ok {
 | 
			
		||||
						tableName = t.TableName()
 | 
			
		||||
					} else if t, ok := v.(dbTabler); ok {
 | 
			
		||||
						tableName = t.TableName(tx)
 | 
			
		||||
					} else if s := schema.Parse(v); s != nil {
 | 
			
		||||
						if s.TableName != "" {
 | 
			
		||||
							tableName = s.TableName
 | 
			
		||||
						} else {
 | 
			
		||||
							tableName = schema.ToDBName(s.ModelType.Name())
 | 
			
		||||
							if !tx.Config.SingularTable {
 | 
			
		||||
								tableName = inflection.Plural(tableName)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if tableName != "" {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
			if fieldValue.Kind() == reflect.Ptr {
 | 
			
		||||
				if fieldValue.IsNil() {
 | 
			
		||||
					fieldValue.Set(reflect.New(field.StructField.Type.Elem()))
 | 
			
		||||
				}
 | 
			
		||||
				fieldValue = fieldValue.Elem()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tableName != "" {
 | 
			
		||||
			if DefaultTableNameHandler != nil {
 | 
			
		||||
				tableChan <- DefaultTableNameHandler(tx, tableName)
 | 
			
		||||
			if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
 | 
			
		||||
				fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
 | 
			
		||||
			} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
 | 
			
		||||
				err = scanner.Scan(reflectValue.Interface())
 | 
			
		||||
			} else {
 | 
			
		||||
				tableChan <- tableName
 | 
			
		||||
				err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			tx.AddError(ErrInvalidTable)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	} else {
 | 
			
		||||
		field.Value.Set(reflect.Zero(fieldValue.Type()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tableChan
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type tabler interface {
 | 
			
		||||
	TableName() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dbTabler interface {
 | 
			
		||||
	TableName(*gorm.DB) string
 | 
			
		||||
	field.IsBlank = IsBlank(field.Value)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2,32 +2,7 @@ package model
 | 
			
		||||
 | 
			
		||||
import "reflect"
 | 
			
		||||
 | 
			
		||||
// ToSearchableMap convert attrs to searchable map
 | 
			
		||||
func ToSearchableMap(attrs ...interface{}) (result interface{}) {
 | 
			
		||||
	if len(attrs) > 1 {
 | 
			
		||||
		if str, ok := attrs[0].(string); ok {
 | 
			
		||||
			result = map[string]interface{}{str: attrs[1]}
 | 
			
		||||
		}
 | 
			
		||||
	} else if len(attrs) == 1 {
 | 
			
		||||
		if attr, ok := attrs[0].(map[string]interface{}); ok {
 | 
			
		||||
			result = attr
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if attr, ok := attrs[0].(interface{}); ok {
 | 
			
		||||
			result = attr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func indirect(reflectValue reflect.Value) reflect.Value {
 | 
			
		||||
	for reflectValue.Kind() == reflect.Ptr {
 | 
			
		||||
		reflectValue = reflectValue.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	return reflectValue
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isBlank(value reflect.Value) bool {
 | 
			
		||||
func IsBlank(value reflect.Value) bool {
 | 
			
		||||
	switch value.Kind() {
 | 
			
		||||
	case reflect.String:
 | 
			
		||||
		return value.Len() == 0
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user