Refactor structure
This commit is contained in:
parent
24ed796198
commit
8b567b49d0
@ -1,9 +1,10 @@
|
||||
package gorm
|
||||
|
||||
import "log"
|
||||
import (
|
||||
"log"
|
||||
|
||||
// DefaultCallback default callbacks defined by gorm
|
||||
var DefaultCallback = &Callback{}
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Callback is a struct that contains all CRUD callbacks
|
||||
// Field `creates` contains callbacks will be call when creating object
|
||||
@ -13,23 +14,23 @@ var DefaultCallback = &Callback{}
|
||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||
type Callback struct {
|
||||
creates []*func(scope *Scope)
|
||||
updates []*func(scope *Scope)
|
||||
deletes []*func(scope *Scope)
|
||||
queries []*func(scope *Scope)
|
||||
rowQueries []*func(scope *Scope)
|
||||
creates []*func(*gorm.DB)
|
||||
updates []*func(*gorm.DB)
|
||||
deletes []*func(*gorm.DB)
|
||||
queries []*func(*gorm.DB)
|
||||
rowQueries []*func(*gorm.DB)
|
||||
processors []*CallbackProcessor
|
||||
}
|
||||
|
||||
// CallbackProcessor contains callback informations
|
||||
type CallbackProcessor struct {
|
||||
name string // current callback's name
|
||||
before string // register current callback before a callback
|
||||
after string // register current callback after a callback
|
||||
replace bool // replace callbacks with same name
|
||||
remove bool // delete callbacks with same name
|
||||
kind string // callback type: create, update, delete, query, row_query
|
||||
processor *func(scope *Scope) // callback handler
|
||||
name string // current callback's name
|
||||
before string // register current callback before a callback
|
||||
after string // register current callback after a callback
|
||||
replace bool // replace callbacks with same name
|
||||
remove bool // delete callbacks with same name
|
||||
kind string // callback type: create, update, delete, query, row_query
|
||||
processor *func(*gorm.DB) // callback handler
|
||||
parent *Callback
|
||||
}
|
||||
|
||||
@ -45,7 +46,7 @@ func (c *Callback) clone() *Callback {
|
||||
}
|
||||
|
||||
// Create could be used to register callbacks for creating object
|
||||
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*gorm.DB) {
|
||||
// // business logic
|
||||
// ...
|
||||
//
|
||||
@ -90,7 +91,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||
}
|
||||
|
||||
// Register a new callback, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(*gorm.DB)) {
|
||||
if cp.kind == "row_query" {
|
||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
||||
@ -107,7 +108,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
||||
// Remove a registered callback
|
||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||
log.Printf("[info] removing callback `%v` from %v\n", callbackName, utils.FileWithLineNum())
|
||||
cp.name = callbackName
|
||||
cp.remove = true
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
@ -115,12 +116,12 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
}
|
||||
|
||||
// Replace a registered callback with new callback
|
||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*gorm.DB) {
|
||||
// scope.SetColumn("Created", now)
|
||||
// scope.SetColumn("Updated", now)
|
||||
// })
|
||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(*gorm.DB)) {
|
||||
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, utils.FileWithLineNum())
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.replace = true
|
||||
@ -130,7 +131,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
||||
|
||||
// Get registered callback
|
||||
// db.Callback().Create().Get("gorm:create")
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(*gorm.DB)) {
|
||||
for _, p := range cp.parent.processors {
|
||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||
return *p.processor
|
||||
@ -150,7 +151,7 @@ func getRIndex(strs []string, str string) int {
|
||||
}
|
||||
|
||||
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
func sortProcessors(cps []*CallbackProcessor) []*func(*gorm.DB) {
|
||||
var (
|
||||
allNames, sortedNames []string
|
||||
sortCallbackProcessor func(c *CallbackProcessor)
|
||||
@ -159,7 +160,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
for _, cp := range cps {
|
||||
// show warning message the callback name already exists
|
||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, utils.FileWithLineNum())
|
||||
}
|
||||
allNames = append(allNames, cp.name)
|
||||
}
|
||||
@ -203,7 +204,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
|
||||
var sortedFuncs []*func(scope *Scope)
|
||||
var sortedFuncs []*func(*gorm.DB)
|
||||
for _, name := range sortedNames {
|
||||
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||
|
@ -1,164 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for creating
|
||||
func init() {
|
||||
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:create", createCallback)
|
||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||
func beforeCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeCreate")
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
now := NowFunc()
|
||||
|
||||
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
||||
if createdAtField.IsBlank {
|
||||
createdAtField.Set(now)
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
|
||||
if updatedAtField.IsBlank {
|
||||
updatedAtField.Set(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createCallback the callback used to insert data into database
|
||||
func createCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
columns, placeholders []string
|
||||
blankColumnsWithDefaultValue []string
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal {
|
||||
if field.IsBlank && field.HasDefaultValue {
|
||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||
} else if !field.IsPrimaryKey || !field.IsBlank {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
returningColumn = "*"
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryField = scope.PrimaryField()
|
||||
extraOption string
|
||||
)
|
||||
|
||||
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if primaryField != nil {
|
||||
returningColumn = scope.Quote(primaryField.DBName)
|
||||
}
|
||||
|
||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||
|
||||
if len(columns) == 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v %v%v%v",
|
||||
quotedTableName,
|
||||
scope.Dialect().DefaultValueStr(),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(placeholders, ","),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql
|
||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if primaryField.Field.CanAddr() {
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
primaryField.IsBlank = false
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
} else {
|
||||
scope.Err(ErrUnaddressable)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsPrimaryKey && !field.IsBlank {
|
||||
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
db.Scan(scope.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||
func afterCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterCreate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Define callbacks for deleting
|
||||
func init() {
|
||||
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||
func beforeDeleteCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("Missing WHERE clause while deleting"))
|
||||
return
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeDelete")
|
||||
}
|
||||
}
|
||||
|
||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||
func deleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
|
||||
|
||||
if !scope.Search.Unscoped && hasDeletedAtField {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v=%v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
scope.Quote(deletedAtField.DBName),
|
||||
scope.AddToVars(NowFunc()),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"DELETE FROM %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||
func afterDeleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterDelete")
|
||||
}
|
||||
}
|
@ -1,479 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for querying
|
||||
func init() {
|
||||
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
||||
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
||||
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func queryCallback(scope *Scope) {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
isSlice, isPtr bool
|
||||
resultType reflect.Type
|
||||
results = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||
if primaryField := scope.PrimaryField(); primaryField != nil {
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||
results = indirect(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
if kind := results.Kind(); kind == reflect.Slice {
|
||||
isSlice = true
|
||||
resultType = results.Type().Elem()
|
||||
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
||||
|
||||
if resultType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
resultType = resultType.Elem()
|
||||
}
|
||||
} else if kind != reflect.Struct {
|
||||
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
||||
return
|
||||
}
|
||||
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.db.RowsAffected = 0
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
scope.db.RowsAffected++
|
||||
|
||||
elem := results
|
||||
if isSlice {
|
||||
elem = reflect.New(resultType).Elem()
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
|
||||
|
||||
if isSlice {
|
||||
if isPtr {
|
||||
results.Set(reflect.Append(results, elem.Addr()))
|
||||
} else {
|
||||
results.Set(reflect.Append(results, elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
scope.Err(err)
|
||||
} else if scope.db.RowsAffected == 0 && !isSlice {
|
||||
scope.Err(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||
func afterQueryCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterFind")
|
||||
}
|
||||
}
|
||||
|
||||
// preloadCallback used to preload associations
|
||||
func preloadCallback(scope *Scope) {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
||||
autoPreload(scope)
|
||||
}
|
||||
|
||||
if scope.Search.preload == nil || scope.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
preloadedMap = map[string]bool{}
|
||||
fields = scope.Fields()
|
||||
)
|
||||
|
||||
for _, preload := range scope.Search.preload {
|
||||
var (
|
||||
preloadFields = strings.Split(preload.schema, ".")
|
||||
currentScope = scope
|
||||
currentFields = fields
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
var currentPreloadConditions []interface{}
|
||||
|
||||
if currentScope == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if not preloaded
|
||||
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||
|
||||
// assign search conditions to last preload
|
||||
if idx == len(preloadFields)-1 {
|
||||
currentPreloadConditions = preload.conditions
|
||||
}
|
||||
|
||||
for _, field := range currentFields {
|
||||
if field.Name != preloadField || field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||
case "has_many":
|
||||
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||
case "many_to_many":
|
||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||
default:
|
||||
scope.Err(errors.New("unsupported relation"))
|
||||
}
|
||||
|
||||
preloadedMap[preloadKey] = true
|
||||
break
|
||||
}
|
||||
|
||||
if !preloadedMap[preloadKey] {
|
||||
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// preload next level
|
||||
if idx < len(preloadFields)-1 {
|
||||
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||
if currentScope != nil {
|
||||
currentFields = currentScope.Fields()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func autoPreload(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["PRELOAD"]; ok {
|
||||
if preload, err := strconv.ParseBool(val); err != nil {
|
||||
scope.Err(errors.New("invalid preload option"))
|
||||
return
|
||||
} else if !preload {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
scope.Search.Preload(field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
||||
var (
|
||||
preloadDB = scope.NewDB()
|
||||
preloadConditions []interface{}
|
||||
)
|
||||
|
||||
for _, condition := range conditions {
|
||||
if scopes, ok := condition.(func(*DB) *DB); ok {
|
||||
preloadDB = scopes(preloadDB)
|
||||
} else {
|
||||
preloadConditions = append(preloadConditions, condition)
|
||||
}
|
||||
}
|
||||
|
||||
return preloadDB, preloadConditions
|
||||
}
|
||||
|
||||
// handleHasOnePreload used to preload has one associations
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||
values := toQueryValues(primaryKeys)
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
||||
indirectValue.FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleHasManyPreload used to preload has many associations
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||
values := toQueryValues(primaryKeys)
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
preloadMap := make(map[string][]reflect.Value)
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
|
||||
}
|
||||
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
|
||||
f := object.FieldByName(field.Name)
|
||||
if results, ok := preloadMap[toString(objectRealValue)]; ok {
|
||||
f.Set(reflect.Append(f, results...))
|
||||
} else {
|
||||
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(resultsValue))
|
||||
}
|
||||
}
|
||||
|
||||
// handleBelongsToPreload used to preload belongs to associations
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleManyToManyPreload used to preload many to many associations
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||
var (
|
||||
relation = field.Relationship
|
||||
joinTableHandler = relation.JoinTableHandler
|
||||
fieldType = field.Struct.Type.Elem()
|
||||
foreignKeyValue interface{}
|
||||
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
||||
linkHash = map[string][]reflect.Value{}
|
||||
isPtr bool
|
||||
)
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
var sourceKeys = []string{}
|
||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||
sourceKeys = append(sourceKeys, key.DBName)
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// generate query with join table
|
||||
newScope := scope.New(reflect.New(fieldType).Interface())
|
||||
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
|
||||
|
||||
if len(preloadDB.search.selects) == 0 {
|
||||
preloadDB = preloadDB.Select("*")
|
||||
}
|
||||
|
||||
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||
|
||||
// preload inline conditions
|
||||
if len(preloadConditions) > 0 {
|
||||
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
||||
}
|
||||
|
||||
rows, err := preloadDB.Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
var (
|
||||
elem = reflect.New(fieldType).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||
)
|
||||
|
||||
// register foreign keys in join tables
|
||||
var joinTableFields []*Field
|
||||
for _, sourceKey := range sourceKeys {
|
||||
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||
|
||||
scope.New(elem.Addr().Interface()).
|
||||
InstanceSet("gorm:skip_query_callback", true).
|
||||
callCallbacks(scope.db.parent.callbacks.queries)
|
||||
|
||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||
// generate hashed forkey keys in join table
|
||||
for idx, joinTableField := range joinTableFields {
|
||||
if !joinTableField.Field.IsNil() {
|
||||
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
|
||||
}
|
||||
}
|
||||
hashedSourceKeys := toString(foreignKeys)
|
||||
|
||||
if isPtr {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
||||
} else {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
scope.Err(err)
|
||||
}
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
fieldsSourceMap = map[string][]reflect.Value{}
|
||||
foreignFieldNames = []string{}
|
||||
)
|
||||
|
||||
for _, dbName := range relation.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
key := toString(getValueFromFields(object, foreignFieldNames))
|
||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
|
||||
}
|
||||
} else if indirectScopeValue.IsValid() {
|
||||
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
||||
}
|
||||
for source, link := range linkHash {
|
||||
for i, field := range fieldsSourceMap[source] {
|
||||
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
||||
if fieldsSourceMap[source][i].Len() != 0 {
|
||||
continue
|
||||
}
|
||||
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// Define callbacks for row query
|
||||
func init() {
|
||||
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
|
||||
}
|
||||
|
||||
type RowQueryResult struct {
|
||||
Row *sql.Row
|
||||
}
|
||||
|
||||
type RowsQueryResult struct {
|
||||
Rows *sql.Rows
|
||||
Error error
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func rowQueryCallback(scope *Scope) {
|
||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
|
||||
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,170 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func beginTransactionCallback(scope *Scope) {
|
||||
scope.Begin()
|
||||
}
|
||||
|
||||
func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||
scope.CommitOrRollback()
|
||||
}
|
||||
|
||||
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
|
||||
checkTruth := func(value interface{}) bool {
|
||||
if v, ok := value.(bool); ok && !v {
|
||||
return false
|
||||
}
|
||||
|
||||
if v, ok := value.(string); ok {
|
||||
v = strings.ToLower(v)
|
||||
if v == "false" || v != "skip" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||
if r = field.Relationship; r != nil {
|
||||
autoUpdate, autoCreate, saveReference = true, true, true
|
||||
|
||||
if value, ok := scope.Get("gorm:save_associations"); ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
autoCreate = autoUpdate
|
||||
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
autoCreate = autoUpdate
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
|
||||
autoUpdate = checkTruth(value)
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
||||
autoCreate = checkTruth(value)
|
||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
|
||||
autoCreate = checkTruth(value)
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
||||
saveReference = checkTruth(value)
|
||||
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
|
||||
saveReference = checkTruth(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||
|
||||
if relationship != nil && relationship.Kind == "belongs_to" {
|
||||
fieldValue := field.Field.Addr().Interface()
|
||||
newScope := scope.New(fieldValue)
|
||||
|
||||
if newScope.PrimaryKeyZero() {
|
||||
if autoCreate {
|
||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||
}
|
||||
} else if autoUpdate {
|
||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||
}
|
||||
|
||||
if saveReference {
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
// set value's foreign key
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
||||
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saveAfterAssociationsCallback(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||
|
||||
if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||
value := field.Field
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
newDB := scope.NewDB()
|
||||
elem := value.Index(i).Addr().Interface()
|
||||
newScope := newDB.NewScope(elem)
|
||||
|
||||
if saveReference {
|
||||
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||
}
|
||||
}
|
||||
|
||||
if newScope.PrimaryKeyZero() {
|
||||
if autoCreate {
|
||||
scope.Err(newDB.Save(elem).Error)
|
||||
}
|
||||
} else if autoUpdate {
|
||||
scope.Err(newDB.Save(elem).Error)
|
||||
}
|
||||
|
||||
if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
|
||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
elem := value.Addr().Interface()
|
||||
newScope := scope.New(elem)
|
||||
|
||||
if saveReference {
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||
}
|
||||
}
|
||||
|
||||
if newScope.PrimaryKeyZero() {
|
||||
if autoCreate {
|
||||
scope.Err(scope.NewDB().Save(elem).Error)
|
||||
}
|
||||
} else if autoUpdate {
|
||||
scope.Err(scope.NewDB().Save(elem).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,119 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for updating
|
||||
func init() {
|
||||
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||
} else {
|
||||
scope.SkipLeft()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||
func beforeUpdateCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("Missing WHERE clause while updating"))
|
||||
return
|
||||
}
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeUpdate")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.SetColumn("UpdatedAt", NowFunc())
|
||||
}
|
||||
}
|
||||
|
||||
// updateCallback the callback used to update data to database
|
||||
func updateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var sqls []string
|
||||
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
// Sort the column names so that the generated SQL is the same every time.
|
||||
updateMap := updateAttrs.(map[string]interface{})
|
||||
var columns []string
|
||||
for c := range updateMap {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
for _, column := range columns {
|
||||
value := updateMap[column]
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if !field.IsPrimaryKey && field.IsNormal {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
sqls = append(sqls,
|
||||
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:update_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if len(sqls) > 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(sqls, ", "),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||
func afterUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterUpdate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
}
|
@ -1,10 +1,8 @@
|
||||
package model
|
||||
package sqlbuilder
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidTable invalid table name
|
||||
ErrInvalidTable = errors.New("invalid table name")
|
||||
// ErrUnaddressable unaddressable value
|
||||
ErrUnaddressable = errors.New("using unaddressable value")
|
||||
)
|
@ -1,80 +1,28 @@
|
||||
package model
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/builder"
|
||||
"github.com/jinzhu/gorm/model"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
// Field GORM model field
|
||||
type Field struct {
|
||||
*schema.Field
|
||||
IsBlank bool
|
||||
Value reflect.Value
|
||||
}
|
||||
|
||||
// Set set a value to the field
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
if !field.Value.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
}
|
||||
|
||||
if !field.Value.CanAddr() {
|
||||
return ErrUnaddressable
|
||||
}
|
||||
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
fieldValue := field.Value
|
||||
if reflectValue.IsValid() {
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else {
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.StructField.Type.Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
err = scanner.Scan(reflectValue.Interface())
|
||||
} else {
|
||||
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Value.Set(reflect.Zero(fieldValue.Type()))
|
||||
}
|
||||
|
||||
field.IsBlank = isBlank(field.Value)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAssignments get assignments
|
||||
func GetAssignments(tx *gorm.DB) chan [][]*Field {
|
||||
fieldChan := make(chan [][]*Field)
|
||||
// GetAssignmentFields get assignment fields
|
||||
func GetAssignmentFields(tx *gorm.DB) chan [][]*model.Field {
|
||||
fieldChan := make(chan [][]*model.Field)
|
||||
|
||||
go func() {
|
||||
assignableChecker := generateAssignableChecker(selectAttrs(tx.Statement), omitAttrs(tx.Statement))
|
||||
|
||||
switch dest := tx.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
fieldChan <- [][]*Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)}
|
||||
fieldChan <- [][]*model.Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)}
|
||||
case []map[string]interface{}:
|
||||
fields := [][]*Field{}
|
||||
fields := [][]*model.Field{}
|
||||
tableSchema := schema.Parse(tx.Statement.Table)
|
||||
|
||||
for _, v := range dest {
|
||||
@ -87,13 +35,13 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field {
|
||||
|
||||
switch results.Kind() {
|
||||
case reflect.Slice:
|
||||
fields := [][]*Field{}
|
||||
fields := [][]*model.Field{}
|
||||
for i := 0; i < results.Len(); i++ {
|
||||
fields = append(fields, structToField(indirect(results.Index(i)), s, assignableChecker))
|
||||
}
|
||||
fieldChan <- fields
|
||||
case reflect.Struct:
|
||||
fieldChan <- [][]*Field{structToField(results, s, assignableChecker)}
|
||||
fieldChan <- [][]*model.Field{structToField(results, s, assignableChecker)}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -102,12 +50,12 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field {
|
||||
return fieldChan
|
||||
}
|
||||
|
||||
func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) {
|
||||
func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) {
|
||||
// TODO assign those value to dest
|
||||
for k, v := range value {
|
||||
if s != nil {
|
||||
if f := s.FieldByName(k); f != nil {
|
||||
field := &Field{Field: f, Value: reflect.ValueOf(v)}
|
||||
field := &model.Field{Field: f, Value: reflect.ValueOf(v)}
|
||||
if assignableChecker(field) {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
@ -115,7 +63,7 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck
|
||||
}
|
||||
}
|
||||
|
||||
field := &Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}
|
||||
field := &model.Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}
|
||||
if assignableChecker(field) {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
@ -127,14 +75,14 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck
|
||||
return
|
||||
}
|
||||
|
||||
func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) {
|
||||
func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) {
|
||||
// TODO use Offset to replace FieldByName?
|
||||
for _, sf := range s.Fields {
|
||||
obj := value
|
||||
for _, bn := range sf.BindNames {
|
||||
obj = value.FieldByName(bn)
|
||||
}
|
||||
field := &Field{Field: sf, Value: obj, IsBlank: isBlank(obj)}
|
||||
field := &model.Field{Field: sf, Value: obj, IsBlank: model.IsBlank(obj)}
|
||||
if assignableChecker(field) {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
@ -143,8 +91,8 @@ func structToField(value reflect.Value, s *schema.Schema, assignableChecker func
|
||||
}
|
||||
|
||||
// generateAssignableChecker generate checker to check if field is assignable or not
|
||||
func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*Field) bool {
|
||||
return func(field *Field) bool {
|
||||
func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*model.Field) bool {
|
||||
return func(field *model.Field) bool {
|
||||
if len(selectAttrs) > 0 {
|
||||
for _, attr := range selectAttrs {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
@ -164,7 +112,7 @@ func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*F
|
||||
}
|
||||
|
||||
// omitAttrs return selected attributes of stmt
|
||||
func selectAttrs(stmt *builder.Statement) []string {
|
||||
func selectAttrs(stmt *gorm.Statement) []string {
|
||||
columns := stmt.Select.Columns
|
||||
for _, arg := range stmt.Select.Args {
|
||||
columns = append(columns, fmt.Sprint(arg))
|
||||
@ -173,6 +121,6 @@ func selectAttrs(stmt *builder.Statement) []string {
|
||||
}
|
||||
|
||||
// omitAttrs return omitted attributes of stmt
|
||||
func omitAttrs(stmt *builder.Statement) []string {
|
||||
func omitAttrs(stmt *gorm.Statement) []string {
|
||||
return stmt.Omit
|
||||
}
|
63
dialects/common/sqlbuilder/sqlbuilder.go
Normal file
63
dialects/common/sqlbuilder/sqlbuilder.go
Normal file
@ -0,0 +1,63 @@
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/model"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// GetTable get table name for current db operation
|
||||
func GetTable(tx *gorm.DB) chan string {
|
||||
tableChan := make(chan string)
|
||||
|
||||
go func() {
|
||||
var tableName string
|
||||
if name, ok := tx.Statement.Table.(string); ok {
|
||||
tableName = name
|
||||
} else {
|
||||
for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
|
||||
if v != nil {
|
||||
if t, ok := v.(tabler); ok {
|
||||
tableName = t.TableName()
|
||||
} else if t, ok := v.(dbTabler); ok {
|
||||
tableName = t.TableName(tx)
|
||||
} else if s := schema.Parse(v); s != nil {
|
||||
if s.TableName != "" {
|
||||
tableName = s.TableName
|
||||
} else {
|
||||
tableName = schema.ToDBName(s.ModelType.Name())
|
||||
if !tx.Config.SingularTable {
|
||||
tableName = inflection.Plural(tableName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tableName != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tableName != "" {
|
||||
if model.DefaultTableNameHandler != nil {
|
||||
tableChan <- model.DefaultTableNameHandler(tx, tableName)
|
||||
} else {
|
||||
tableChan <- tableName
|
||||
}
|
||||
} else {
|
||||
tx.AddError(ErrInvalidTable)
|
||||
}
|
||||
}()
|
||||
|
||||
return tableChan
|
||||
}
|
||||
|
||||
type tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type dbTabler interface {
|
||||
TableName(*gorm.DB) string
|
||||
}
|
10
dialects/common/sqlbuilder/utils.go
Normal file
10
dialects/common/sqlbuilder/utils.go
Normal file
@ -0,0 +1,10 @@
|
||||
package sqlbuilder
|
||||
|
||||
import "reflect"
|
||||
|
||||
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||
for reflectValue.Kind() == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
return reflectValue
|
||||
}
|
@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/model"
|
||||
"github.com/jinzhu/gorm/dialects/common/destination"
|
||||
)
|
||||
|
||||
// Dialect Sqlite3 Dialect for GORM
|
||||
@ -23,9 +23,9 @@ func (dialect Dialect) Quote(name string) string {
|
||||
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||
var (
|
||||
args []interface{}
|
||||
assignmentsChan = model.GetAssignments(tx)
|
||||
tableNameChan = model.GetTable(tx)
|
||||
primaryFields []*model.Field
|
||||
assignmentsChan = destination.GetAssignments(tx)
|
||||
tableNameChan = destination.GetTable(tx)
|
||||
primaryFields []*destination.Field
|
||||
)
|
||||
|
||||
s := bytes.NewBufferString("INSERT INTO ")
|
||||
@ -41,7 +41,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||
valueBuffer := bytes.NewBufferString("VALUES ")
|
||||
|
||||
for idx, fields := range assignments {
|
||||
var primaryField *model.Field
|
||||
var primaryField *destination.Field
|
||||
if idx != 0 {
|
||||
valueBuffer.WriteString(",")
|
||||
}
|
||||
|
21
logger/utils.go
Normal file
21
logger/utils.go
Normal file
@ -0,0 +1,21 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
|
||||
|
||||
// FileWithLineNum get filename with line num for logging
|
||||
func FileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", file, line)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
@ -1,9 +1,13 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// DefaultTableNameHandler default table name handler
|
||||
@ -12,57 +16,52 @@ import (
|
||||
// }
|
||||
var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string
|
||||
|
||||
// GetTable get table name for current db operation
|
||||
func GetTable(tx *gorm.DB) chan string {
|
||||
tableChan := make(chan string)
|
||||
// Field GORM model field
|
||||
type Field struct {
|
||||
*schema.Field
|
||||
IsBlank bool
|
||||
Value reflect.Value
|
||||
}
|
||||
|
||||
go func() {
|
||||
var tableName string
|
||||
if name, ok := tx.Statement.Table.(string); ok {
|
||||
tableName = name
|
||||
// Set set a value to the field
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
if !field.Value.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
}
|
||||
|
||||
if !field.Value.CanAddr() {
|
||||
return gorm.ErrUnaddressable
|
||||
}
|
||||
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
fieldValue := field.Value
|
||||
if reflectValue.IsValid() {
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else {
|
||||
for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
|
||||
if v != nil {
|
||||
if t, ok := v.(tabler); ok {
|
||||
tableName = t.TableName()
|
||||
} else if t, ok := v.(dbTabler); ok {
|
||||
tableName = t.TableName(tx)
|
||||
} else if s := schema.Parse(v); s != nil {
|
||||
if s.TableName != "" {
|
||||
tableName = s.TableName
|
||||
} else {
|
||||
tableName = schema.ToDBName(s.ModelType.Name())
|
||||
if !tx.Config.SingularTable {
|
||||
tableName = inflection.Plural(tableName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tableName != "" {
|
||||
break
|
||||
}
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.StructField.Type.Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
if tableName != "" {
|
||||
if DefaultTableNameHandler != nil {
|
||||
tableChan <- DefaultTableNameHandler(tx, tableName)
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
err = scanner.Scan(reflectValue.Interface())
|
||||
} else {
|
||||
tableChan <- tableName
|
||||
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||
}
|
||||
} else {
|
||||
tx.AddError(ErrInvalidTable)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
field.Value.Set(reflect.Zero(fieldValue.Type()))
|
||||
}
|
||||
|
||||
return tableChan
|
||||
}
|
||||
|
||||
type tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type dbTabler interface {
|
||||
TableName(*gorm.DB) string
|
||||
field.IsBlank = IsBlank(field.Value)
|
||||
return err
|
||||
}
|
||||
|
@ -2,32 +2,7 @@ package model
|
||||
|
||||
import "reflect"
|
||||
|
||||
// ToSearchableMap convert attrs to searchable map
|
||||
func ToSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||
if len(attrs) > 1 {
|
||||
if str, ok := attrs[0].(string); ok {
|
||||
result = map[string]interface{}{str: attrs[1]}
|
||||
}
|
||||
} else if len(attrs) == 1 {
|
||||
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
|
||||
if attr, ok := attrs[0].(interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||
for reflectValue.Kind() == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
return reflectValue
|
||||
}
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
func IsBlank(value reflect.Value) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.String:
|
||||
return value.Len() == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user