Refactor structure
This commit is contained in:
parent
24ed796198
commit
8b567b49d0
@ -1,9 +1,10 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "log"
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
// DefaultCallback default callbacks defined by gorm
|
"github.com/jinzhu/gorm"
|
||||||
var DefaultCallback = &Callback{}
|
)
|
||||||
|
|
||||||
// Callback is a struct that contains all CRUD callbacks
|
// Callback is a struct that contains all CRUD callbacks
|
||||||
// Field `creates` contains callbacks will be call when creating object
|
// Field `creates` contains callbacks will be call when creating object
|
||||||
@ -13,23 +14,23 @@ var DefaultCallback = &Callback{}
|
|||||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
type Callback struct {
|
type Callback struct {
|
||||||
creates []*func(scope *Scope)
|
creates []*func(*gorm.DB)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(*gorm.DB)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(*gorm.DB)
|
||||||
queries []*func(scope *Scope)
|
queries []*func(*gorm.DB)
|
||||||
rowQueries []*func(scope *Scope)
|
rowQueries []*func(*gorm.DB)
|
||||||
processors []*CallbackProcessor
|
processors []*CallbackProcessor
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallbackProcessor contains callback informations
|
// CallbackProcessor contains callback informations
|
||||||
type CallbackProcessor struct {
|
type CallbackProcessor struct {
|
||||||
name string // current callback's name
|
name string // current callback's name
|
||||||
before string // register current callback before a callback
|
before string // register current callback before a callback
|
||||||
after string // register current callback after a callback
|
after string // register current callback after a callback
|
||||||
replace bool // replace callbacks with same name
|
replace bool // replace callbacks with same name
|
||||||
remove bool // delete callbacks with same name
|
remove bool // delete callbacks with same name
|
||||||
kind string // callback type: create, update, delete, query, row_query
|
kind string // callback type: create, update, delete, query, row_query
|
||||||
processor *func(scope *Scope) // callback handler
|
processor *func(*gorm.DB) // callback handler
|
||||||
parent *Callback
|
parent *Callback
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ func (c *Callback) clone() *Callback {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create could be used to register callbacks for creating object
|
// Create could be used to register callbacks for creating object
|
||||||
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*gorm.DB) {
|
||||||
// // business logic
|
// // business logic
|
||||||
// ...
|
// ...
|
||||||
//
|
//
|
||||||
@ -90,7 +91,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register a new callback, refer `Callbacks.Create`
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(*gorm.DB)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
||||||
@ -107,7 +108,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
|||||||
// Remove a registered callback
|
// Remove a registered callback
|
||||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
log.Printf("[info] removing callback `%v` from %v\n", callbackName, utils.FileWithLineNum())
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.remove = true
|
cp.remove = true
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
@ -115,12 +116,12 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Replace a registered callback with new callback
|
// Replace a registered callback with new callback
|
||||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*gorm.DB) {
|
||||||
// scope.SetColumn("Created", now)
|
// scope.SetColumn("Created", now)
|
||||||
// scope.SetColumn("Updated", now)
|
// scope.SetColumn("Updated", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(*gorm.DB)) {
|
||||||
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, utils.FileWithLineNum())
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
@ -130,7 +131,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
|||||||
|
|
||||||
// Get registered callback
|
// Get registered callback
|
||||||
// db.Callback().Create().Get("gorm:create")
|
// db.Callback().Create().Get("gorm:create")
|
||||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(*gorm.DB)) {
|
||||||
for _, p := range cp.parent.processors {
|
for _, p := range cp.parent.processors {
|
||||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||||
return *p.processor
|
return *p.processor
|
||||||
@ -150,7 +151,7 @@ func getRIndex(strs []string, str string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sortProcessors sort callback processors based on its before, after, remove, replace
|
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
func sortProcessors(cps []*CallbackProcessor) []*func(*gorm.DB) {
|
||||||
var (
|
var (
|
||||||
allNames, sortedNames []string
|
allNames, sortedNames []string
|
||||||
sortCallbackProcessor func(c *CallbackProcessor)
|
sortCallbackProcessor func(c *CallbackProcessor)
|
||||||
@ -159,7 +160,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
|||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||||
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, utils.FileWithLineNum())
|
||||||
}
|
}
|
||||||
allNames = append(allNames, cp.name)
|
allNames = append(allNames, cp.name)
|
||||||
}
|
}
|
||||||
@ -203,7 +204,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
|||||||
sortCallbackProcessor(cp)
|
sortCallbackProcessor(cp)
|
||||||
}
|
}
|
||||||
|
|
||||||
var sortedFuncs []*func(scope *Scope)
|
var sortedFuncs []*func(*gorm.DB)
|
||||||
for _, name := range sortedNames {
|
for _, name := range sortedNames {
|
||||||
if index := getRIndex(allNames, name); !cps[index].remove {
|
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||||
|
@ -1,164 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Define callbacks for creating
|
|
||||||
func init() {
|
|
||||||
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:create", createCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
|
||||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
|
||||||
func beforeCreateCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("BeforeSave")
|
|
||||||
}
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("BeforeCreate")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
|
||||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
now := NowFunc()
|
|
||||||
|
|
||||||
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
|
||||||
if createdAtField.IsBlank {
|
|
||||||
createdAtField.Set(now)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
|
|
||||||
if updatedAtField.IsBlank {
|
|
||||||
updatedAtField.Set(now)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createCallback the callback used to insert data into database
|
|
||||||
func createCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
defer scope.trace(NowFunc())
|
|
||||||
|
|
||||||
var (
|
|
||||||
columns, placeholders []string
|
|
||||||
blankColumnsWithDefaultValue []string
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
if scope.changeableField(field) {
|
|
||||||
if field.IsNormal {
|
|
||||||
if field.IsBlank && field.HasDefaultValue {
|
|
||||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
|
||||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
|
||||||
} else if !field.IsPrimaryKey || !field.IsBlank {
|
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
|
||||||
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
|
||||||
}
|
|
||||||
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
|
||||||
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
|
||||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
|
||||||
columns = append(columns, scope.Quote(foreignField.DBName))
|
|
||||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
returningColumn = "*"
|
|
||||||
quotedTableName = scope.QuotedTableName()
|
|
||||||
primaryField = scope.PrimaryField()
|
|
||||||
extraOption string
|
|
||||||
)
|
|
||||||
|
|
||||||
if str, ok := scope.Get("gorm:insert_option"); ok {
|
|
||||||
extraOption = fmt.Sprint(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
if primaryField != nil {
|
|
||||||
returningColumn = scope.Quote(primaryField.DBName)
|
|
||||||
}
|
|
||||||
|
|
||||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
|
||||||
|
|
||||||
if len(columns) == 0 {
|
|
||||||
scope.Raw(fmt.Sprintf(
|
|
||||||
"INSERT INTO %v %v%v%v",
|
|
||||||
quotedTableName,
|
|
||||||
scope.Dialect().DefaultValueStr(),
|
|
||||||
addExtraSpaceIfExist(extraOption),
|
|
||||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
scope.Raw(fmt.Sprintf(
|
|
||||||
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
|
||||||
scope.QuotedTableName(),
|
|
||||||
strings.Join(columns, ","),
|
|
||||||
strings.Join(placeholders, ","),
|
|
||||||
addExtraSpaceIfExist(extraOption),
|
|
||||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
// execute create sql
|
|
||||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
|
||||||
// set rows affected count
|
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
|
|
||||||
// set primary value to primary field
|
|
||||||
if primaryField != nil && primaryField.IsBlank {
|
|
||||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
|
||||||
scope.Err(primaryField.Set(primaryValue))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if primaryField.Field.CanAddr() {
|
|
||||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
|
||||||
primaryField.IsBlank = false
|
|
||||||
scope.db.RowsAffected = 1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scope.Err(ErrUnaddressable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
|
||||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
|
||||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
|
||||||
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
if field.IsPrimaryKey && !field.IsBlank {
|
|
||||||
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
db.Scan(scope.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
|
||||||
func afterCreateCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("AfterCreate")
|
|
||||||
}
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("AfterSave")
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,63 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Define callbacks for deleting
|
|
||||||
func init() {
|
|
||||||
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
|
||||||
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
|
||||||
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
|
||||||
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
|
||||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
|
||||||
func beforeDeleteCallback(scope *Scope) {
|
|
||||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
|
||||||
scope.Err(errors.New("Missing WHERE clause while deleting"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("BeforeDelete")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
|
||||||
func deleteCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
var extraOption string
|
|
||||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
|
||||||
extraOption = fmt.Sprint(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
|
|
||||||
|
|
||||||
if !scope.Search.Unscoped && hasDeletedAtField {
|
|
||||||
scope.Raw(fmt.Sprintf(
|
|
||||||
"UPDATE %v SET %v=%v%v%v",
|
|
||||||
scope.QuotedTableName(),
|
|
||||||
scope.Quote(deletedAtField.DBName),
|
|
||||||
scope.AddToVars(NowFunc()),
|
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
|
||||||
addExtraSpaceIfExist(extraOption),
|
|
||||||
)).Exec()
|
|
||||||
} else {
|
|
||||||
scope.Raw(fmt.Sprintf(
|
|
||||||
"DELETE FROM %v%v%v",
|
|
||||||
scope.QuotedTableName(),
|
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
|
||||||
addExtraSpaceIfExist(extraOption),
|
|
||||||
)).Exec()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
|
||||||
func afterDeleteCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("AfterDelete")
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,479 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Define callbacks for querying
|
|
||||||
func init() {
|
|
||||||
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
|
||||||
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
|
||||||
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryCallback used to query data from database
|
|
||||||
func queryCallback(scope *Scope) {
|
|
||||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer scope.trace(NowFunc())
|
|
||||||
|
|
||||||
var (
|
|
||||||
isSlice, isPtr bool
|
|
||||||
resultType reflect.Type
|
|
||||||
results = scope.IndirectValue()
|
|
||||||
)
|
|
||||||
|
|
||||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
|
||||||
if primaryField := scope.PrimaryField(); primaryField != nil {
|
|
||||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
|
||||||
results = indirect(reflect.ValueOf(value))
|
|
||||||
}
|
|
||||||
|
|
||||||
if kind := results.Kind(); kind == reflect.Slice {
|
|
||||||
isSlice = true
|
|
||||||
resultType = results.Type().Elem()
|
|
||||||
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
|
||||||
|
|
||||||
if resultType.Kind() == reflect.Ptr {
|
|
||||||
isPtr = true
|
|
||||||
resultType = resultType.Elem()
|
|
||||||
}
|
|
||||||
} else if kind != reflect.Struct {
|
|
||||||
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.prepareQuerySQL()
|
|
||||||
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.db.RowsAffected = 0
|
|
||||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
|
||||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
|
||||||
for rows.Next() {
|
|
||||||
scope.db.RowsAffected++
|
|
||||||
|
|
||||||
elem := results
|
|
||||||
if isSlice {
|
|
||||||
elem = reflect.New(resultType).Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
|
|
||||||
|
|
||||||
if isSlice {
|
|
||||||
if isPtr {
|
|
||||||
results.Set(reflect.Append(results, elem.Addr()))
|
|
||||||
} else {
|
|
||||||
results.Set(reflect.Append(results, elem))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
scope.Err(err)
|
|
||||||
} else if scope.db.RowsAffected == 0 && !isSlice {
|
|
||||||
scope.Err(ErrRecordNotFound)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// afterQueryCallback will invoke `AfterFind` method after querying
|
|
||||||
func afterQueryCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("AfterFind")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// preloadCallback used to preload associations
|
|
||||||
func preloadCallback(scope *Scope) {
|
|
||||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
|
||||||
autoPreload(scope)
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope.Search.preload == nil || scope.HasError() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
preloadedMap = map[string]bool{}
|
|
||||||
fields = scope.Fields()
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, preload := range scope.Search.preload {
|
|
||||||
var (
|
|
||||||
preloadFields = strings.Split(preload.schema, ".")
|
|
||||||
currentScope = scope
|
|
||||||
currentFields = fields
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, preloadField := range preloadFields {
|
|
||||||
var currentPreloadConditions []interface{}
|
|
||||||
|
|
||||||
if currentScope == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// if not preloaded
|
|
||||||
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
|
||||||
|
|
||||||
// assign search conditions to last preload
|
|
||||||
if idx == len(preloadFields)-1 {
|
|
||||||
currentPreloadConditions = preload.conditions
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, field := range currentFields {
|
|
||||||
if field.Name != preloadField || field.Relationship == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch field.Relationship.Kind {
|
|
||||||
case "has_one":
|
|
||||||
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
|
||||||
case "has_many":
|
|
||||||
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
|
||||||
case "belongs_to":
|
|
||||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
|
||||||
case "many_to_many":
|
|
||||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
|
||||||
default:
|
|
||||||
scope.Err(errors.New("unsupported relation"))
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadedMap[preloadKey] = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if !preloadedMap[preloadKey] {
|
|
||||||
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// preload next level
|
|
||||||
if idx < len(preloadFields)-1 {
|
|
||||||
currentScope = currentScope.getColumnAsScope(preloadField)
|
|
||||||
if currentScope != nil {
|
|
||||||
currentFields = currentScope.Fields()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func autoPreload(scope *Scope) {
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
if field.Relationship == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if val, ok := field.TagSettings["PRELOAD"]; ok {
|
|
||||||
if preload, err := strconv.ParseBool(val); err != nil {
|
|
||||||
scope.Err(errors.New("invalid preload option"))
|
|
||||||
return
|
|
||||||
} else if !preload {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Search.Preload(field.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
|
||||||
var (
|
|
||||||
preloadDB = scope.NewDB()
|
|
||||||
preloadConditions []interface{}
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, condition := range conditions {
|
|
||||||
if scopes, ok := condition.(func(*DB) *DB); ok {
|
|
||||||
preloadDB = scopes(preloadDB)
|
|
||||||
} else {
|
|
||||||
preloadConditions = append(preloadConditions, condition)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return preloadDB, preloadConditions
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleHasOnePreload used to preload has one associations
|
|
||||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
|
|
||||||
// get relations's primary keys
|
|
||||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// preload conditions
|
|
||||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
||||||
|
|
||||||
// find relations
|
|
||||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
|
||||||
values := toQueryValues(primaryKeys)
|
|
||||||
if relation.PolymorphicType != "" {
|
|
||||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
|
||||||
values = append(values, relation.PolymorphicValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
results := makeSlice(field.Struct.Type)
|
|
||||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
|
||||||
|
|
||||||
// assign find results
|
|
||||||
var (
|
|
||||||
resultsValue = indirect(reflect.ValueOf(results))
|
|
||||||
indirectScopeValue = scope.IndirectValue()
|
|
||||||
)
|
|
||||||
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
|
||||||
result := resultsValue.Index(i)
|
|
||||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
|
||||||
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
|
||||||
indirectValue.FieldByName(field.Name).Set(result)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
|
||||||
result := resultsValue.Index(i)
|
|
||||||
scope.Err(field.Set(result))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleHasManyPreload used to preload has many associations
|
|
||||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
|
|
||||||
// get relations's primary keys
|
|
||||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// preload conditions
|
|
||||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
||||||
|
|
||||||
// find relations
|
|
||||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
|
||||||
values := toQueryValues(primaryKeys)
|
|
||||||
if relation.PolymorphicType != "" {
|
|
||||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
|
||||||
values = append(values, relation.PolymorphicValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
results := makeSlice(field.Struct.Type)
|
|
||||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
|
||||||
|
|
||||||
// assign find results
|
|
||||||
var (
|
|
||||||
resultsValue = indirect(reflect.ValueOf(results))
|
|
||||||
indirectScopeValue = scope.IndirectValue()
|
|
||||||
)
|
|
||||||
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
|
||||||
preloadMap := make(map[string][]reflect.Value)
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
|
||||||
result := resultsValue.Index(i)
|
|
||||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
|
||||||
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
|
|
||||||
}
|
|
||||||
|
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
||||||
object := indirect(indirectScopeValue.Index(j))
|
|
||||||
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
|
|
||||||
f := object.FieldByName(field.Name)
|
|
||||||
if results, ok := preloadMap[toString(objectRealValue)]; ok {
|
|
||||||
f.Set(reflect.Append(f, results...))
|
|
||||||
} else {
|
|
||||||
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scope.Err(field.Set(resultsValue))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleBelongsToPreload used to preload belongs to associations
|
|
||||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
|
|
||||||
// preload conditions
|
|
||||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
||||||
|
|
||||||
// get relations's primary keys
|
|
||||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// find relations
|
|
||||||
results := makeSlice(field.Struct.Type)
|
|
||||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
|
||||||
|
|
||||||
// assign find results
|
|
||||||
var (
|
|
||||||
resultsValue = indirect(reflect.ValueOf(results))
|
|
||||||
indirectScopeValue = scope.IndirectValue()
|
|
||||||
)
|
|
||||||
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
|
||||||
result := resultsValue.Index(i)
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
|
||||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
||||||
object := indirect(indirectScopeValue.Index(j))
|
|
||||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
|
||||||
object.FieldByName(field.Name).Set(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scope.Err(field.Set(result))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleManyToManyPreload used to preload many to many associations
|
|
||||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
|
||||||
var (
|
|
||||||
relation = field.Relationship
|
|
||||||
joinTableHandler = relation.JoinTableHandler
|
|
||||||
fieldType = field.Struct.Type.Elem()
|
|
||||||
foreignKeyValue interface{}
|
|
||||||
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
|
||||||
linkHash = map[string][]reflect.Value{}
|
|
||||||
isPtr bool
|
|
||||||
)
|
|
||||||
|
|
||||||
if fieldType.Kind() == reflect.Ptr {
|
|
||||||
isPtr = true
|
|
||||||
fieldType = fieldType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
var sourceKeys = []string{}
|
|
||||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
|
||||||
sourceKeys = append(sourceKeys, key.DBName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// preload conditions
|
|
||||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
||||||
|
|
||||||
// generate query with join table
|
|
||||||
newScope := scope.New(reflect.New(fieldType).Interface())
|
|
||||||
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
|
|
||||||
|
|
||||||
if len(preloadDB.search.selects) == 0 {
|
|
||||||
preloadDB = preloadDB.Select("*")
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
|
||||||
|
|
||||||
// preload inline conditions
|
|
||||||
if len(preloadConditions) > 0 {
|
|
||||||
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := preloadDB.Rows()
|
|
||||||
|
|
||||||
if scope.Err(err) != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
|
||||||
for rows.Next() {
|
|
||||||
var (
|
|
||||||
elem = reflect.New(fieldType).Elem()
|
|
||||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
|
||||||
)
|
|
||||||
|
|
||||||
// register foreign keys in join tables
|
|
||||||
var joinTableFields []*Field
|
|
||||||
for _, sourceKey := range sourceKeys {
|
|
||||||
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
|
||||||
|
|
||||||
scope.New(elem.Addr().Interface()).
|
|
||||||
InstanceSet("gorm:skip_query_callback", true).
|
|
||||||
callCallbacks(scope.db.parent.callbacks.queries)
|
|
||||||
|
|
||||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
|
||||||
// generate hashed forkey keys in join table
|
|
||||||
for idx, joinTableField := range joinTableFields {
|
|
||||||
if !joinTableField.Field.IsNil() {
|
|
||||||
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hashedSourceKeys := toString(foreignKeys)
|
|
||||||
|
|
||||||
if isPtr {
|
|
||||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
|
||||||
} else {
|
|
||||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
scope.Err(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assign find results
|
|
||||||
var (
|
|
||||||
indirectScopeValue = scope.IndirectValue()
|
|
||||||
fieldsSourceMap = map[string][]reflect.Value{}
|
|
||||||
foreignFieldNames = []string{}
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, dbName := range relation.ForeignFieldNames {
|
|
||||||
if field, ok := scope.FieldByName(dbName); ok {
|
|
||||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
||||||
object := indirect(indirectScopeValue.Index(j))
|
|
||||||
key := toString(getValueFromFields(object, foreignFieldNames))
|
|
||||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
|
|
||||||
}
|
|
||||||
} else if indirectScopeValue.IsValid() {
|
|
||||||
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
|
||||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
|
||||||
}
|
|
||||||
for source, link := range linkHash {
|
|
||||||
for i, field := range fieldsSourceMap[source] {
|
|
||||||
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
|
||||||
if fieldsSourceMap[source][i].Len() != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,30 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import "database/sql"
|
|
||||||
|
|
||||||
// Define callbacks for row query
|
|
||||||
func init() {
|
|
||||||
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
type RowQueryResult struct {
|
|
||||||
Row *sql.Row
|
|
||||||
}
|
|
||||||
|
|
||||||
type RowsQueryResult struct {
|
|
||||||
Rows *sql.Rows
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryCallback used to query data from database
|
|
||||||
func rowQueryCallback(scope *Scope) {
|
|
||||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
|
||||||
scope.prepareQuerySQL()
|
|
||||||
|
|
||||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
|
||||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
|
||||||
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
|
|
||||||
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,170 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func beginTransactionCallback(scope *Scope) {
|
|
||||||
scope.Begin()
|
|
||||||
}
|
|
||||||
|
|
||||||
func commitOrRollbackTransactionCallback(scope *Scope) {
|
|
||||||
scope.CommitOrRollback()
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
|
|
||||||
checkTruth := func(value interface{}) bool {
|
|
||||||
if v, ok := value.(bool); ok && !v {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := value.(string); ok {
|
|
||||||
v = strings.ToLower(v)
|
|
||||||
if v == "false" || v != "skip" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
|
||||||
if r = field.Relationship; r != nil {
|
|
||||||
autoUpdate, autoCreate, saveReference = true, true, true
|
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:save_associations"); ok {
|
|
||||||
autoUpdate = checkTruth(value)
|
|
||||||
autoCreate = autoUpdate
|
|
||||||
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
|
|
||||||
autoUpdate = checkTruth(value)
|
|
||||||
autoCreate = autoUpdate
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
|
||||||
autoUpdate = checkTruth(value)
|
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
|
|
||||||
autoUpdate = checkTruth(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
|
||||||
autoCreate = checkTruth(value)
|
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
|
|
||||||
autoCreate = checkTruth(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
|
||||||
saveReference = checkTruth(value)
|
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
|
|
||||||
saveReference = checkTruth(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
|
||||||
|
|
||||||
if relationship != nil && relationship.Kind == "belongs_to" {
|
|
||||||
fieldValue := field.Field.Addr().Interface()
|
|
||||||
newScope := scope.New(fieldValue)
|
|
||||||
|
|
||||||
if newScope.PrimaryKeyZero() {
|
|
||||||
if autoCreate {
|
|
||||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
|
||||||
}
|
|
||||||
} else if autoUpdate {
|
|
||||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if saveReference {
|
|
||||||
if len(relationship.ForeignFieldNames) != 0 {
|
|
||||||
// set value's foreign key
|
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
|
||||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
|
||||||
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
|
||||||
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveAfterAssociationsCallback(scope *Scope) {
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
|
||||||
|
|
||||||
if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
|
||||||
value := field.Field
|
|
||||||
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
for i := 0; i < value.Len(); i++ {
|
|
||||||
newDB := scope.NewDB()
|
|
||||||
elem := value.Index(i).Addr().Interface()
|
|
||||||
newScope := newDB.NewScope(elem)
|
|
||||||
|
|
||||||
if saveReference {
|
|
||||||
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
|
||||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
|
||||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
|
||||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.PolymorphicType != "" {
|
|
||||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if newScope.PrimaryKeyZero() {
|
|
||||||
if autoCreate {
|
|
||||||
scope.Err(newDB.Save(elem).Error)
|
|
||||||
}
|
|
||||||
} else if autoUpdate {
|
|
||||||
scope.Err(newDB.Save(elem).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
|
|
||||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
|
||||||
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
elem := value.Addr().Interface()
|
|
||||||
newScope := scope.New(elem)
|
|
||||||
|
|
||||||
if saveReference {
|
|
||||||
if len(relationship.ForeignFieldNames) != 0 {
|
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
|
||||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
|
||||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
|
||||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.PolymorphicType != "" {
|
|
||||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if newScope.PrimaryKeyZero() {
|
|
||||||
if autoCreate {
|
|
||||||
scope.Err(scope.NewDB().Save(elem).Error)
|
|
||||||
}
|
|
||||||
} else if autoUpdate {
|
|
||||||
scope.Err(scope.NewDB().Save(elem).Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,119 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Define callbacks for updating
|
|
||||||
func init() {
|
|
||||||
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
|
||||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
|
||||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
|
||||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
|
||||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
|
||||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
|
||||||
} else {
|
|
||||||
scope.SkipLeft()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
|
||||||
func beforeUpdateCallback(scope *Scope) {
|
|
||||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
|
||||||
scope.Err(errors.New("Missing WHERE clause while updating"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("BeforeSave")
|
|
||||||
}
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("BeforeUpdate")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
|
||||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
|
||||||
scope.SetColumn("UpdatedAt", NowFunc())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateCallback the callback used to update data to database
|
|
||||||
func updateCallback(scope *Scope) {
|
|
||||||
if !scope.HasError() {
|
|
||||||
var sqls []string
|
|
||||||
|
|
||||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
|
||||||
// Sort the column names so that the generated SQL is the same every time.
|
|
||||||
updateMap := updateAttrs.(map[string]interface{})
|
|
||||||
var columns []string
|
|
||||||
for c := range updateMap {
|
|
||||||
columns = append(columns, c)
|
|
||||||
}
|
|
||||||
sort.Strings(columns)
|
|
||||||
|
|
||||||
for _, column := range columns {
|
|
||||||
value := updateMap[column]
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
if scope.changeableField(field) {
|
|
||||||
if !field.IsPrimaryKey && field.IsNormal {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
|
||||||
sqls = append(sqls,
|
|
||||||
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var extraOption string
|
|
||||||
if str, ok := scope.Get("gorm:update_option"); ok {
|
|
||||||
extraOption = fmt.Sprint(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(sqls) > 0 {
|
|
||||||
scope.Raw(fmt.Sprintf(
|
|
||||||
"UPDATE %v SET %v%v%v",
|
|
||||||
scope.QuotedTableName(),
|
|
||||||
strings.Join(sqls, ", "),
|
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
|
||||||
addExtraSpaceIfExist(extraOption),
|
|
||||||
)).Exec()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
|
||||||
func afterUpdateCallback(scope *Scope) {
|
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("AfterUpdate")
|
|
||||||
}
|
|
||||||
if !scope.HasError() {
|
|
||||||
scope.CallMethod("AfterSave")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,10 +1,8 @@
|
|||||||
package model
|
package sqlbuilder
|
||||||
|
|
||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrInvalidTable invalid table name
|
// ErrInvalidTable invalid table name
|
||||||
ErrInvalidTable = errors.New("invalid table name")
|
ErrInvalidTable = errors.New("invalid table name")
|
||||||
// ErrUnaddressable unaddressable value
|
|
||||||
ErrUnaddressable = errors.New("using unaddressable value")
|
|
||||||
)
|
)
|
@ -1,80 +1,28 @@
|
|||||||
package model
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/builder"
|
"github.com/jinzhu/gorm/model"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Field GORM model field
|
// GetAssignmentFields get assignment fields
|
||||||
type Field struct {
|
func GetAssignmentFields(tx *gorm.DB) chan [][]*model.Field {
|
||||||
*schema.Field
|
fieldChan := make(chan [][]*model.Field)
|
||||||
IsBlank bool
|
|
||||||
Value reflect.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set set a value to the field
|
|
||||||
func (field *Field) Set(value interface{}) (err error) {
|
|
||||||
if !field.Value.IsValid() {
|
|
||||||
return errors.New("field value not valid")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !field.Value.CanAddr() {
|
|
||||||
return ErrUnaddressable
|
|
||||||
}
|
|
||||||
|
|
||||||
reflectValue, ok := value.(reflect.Value)
|
|
||||||
if !ok {
|
|
||||||
reflectValue = reflect.ValueOf(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldValue := field.Value
|
|
||||||
if reflectValue.IsValid() {
|
|
||||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
|
||||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
|
||||||
} else {
|
|
||||||
if fieldValue.Kind() == reflect.Ptr {
|
|
||||||
if fieldValue.IsNil() {
|
|
||||||
fieldValue.Set(reflect.New(field.StructField.Type.Elem()))
|
|
||||||
}
|
|
||||||
fieldValue = fieldValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
|
||||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
|
||||||
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
|
||||||
err = scanner.Scan(reflectValue.Interface())
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
field.Value.Set(reflect.Zero(fieldValue.Type()))
|
|
||||||
}
|
|
||||||
|
|
||||||
field.IsBlank = isBlank(field.Value)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAssignments get assignments
|
|
||||||
func GetAssignments(tx *gorm.DB) chan [][]*Field {
|
|
||||||
fieldChan := make(chan [][]*Field)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
assignableChecker := generateAssignableChecker(selectAttrs(tx.Statement), omitAttrs(tx.Statement))
|
assignableChecker := generateAssignableChecker(selectAttrs(tx.Statement), omitAttrs(tx.Statement))
|
||||||
|
|
||||||
switch dest := tx.Statement.Dest.(type) {
|
switch dest := tx.Statement.Dest.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
fieldChan <- [][]*Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)}
|
fieldChan <- [][]*model.Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)}
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
fields := [][]*Field{}
|
fields := [][]*model.Field{}
|
||||||
tableSchema := schema.Parse(tx.Statement.Table)
|
tableSchema := schema.Parse(tx.Statement.Table)
|
||||||
|
|
||||||
for _, v := range dest {
|
for _, v := range dest {
|
||||||
@ -87,13 +35,13 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field {
|
|||||||
|
|
||||||
switch results.Kind() {
|
switch results.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
fields := [][]*Field{}
|
fields := [][]*model.Field{}
|
||||||
for i := 0; i < results.Len(); i++ {
|
for i := 0; i < results.Len(); i++ {
|
||||||
fields = append(fields, structToField(indirect(results.Index(i)), s, assignableChecker))
|
fields = append(fields, structToField(indirect(results.Index(i)), s, assignableChecker))
|
||||||
}
|
}
|
||||||
fieldChan <- fields
|
fieldChan <- fields
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
fieldChan <- [][]*Field{structToField(results, s, assignableChecker)}
|
fieldChan <- [][]*model.Field{structToField(results, s, assignableChecker)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,12 +50,12 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field {
|
|||||||
return fieldChan
|
return fieldChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) {
|
func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) {
|
||||||
// TODO assign those value to dest
|
// TODO assign those value to dest
|
||||||
for k, v := range value {
|
for k, v := range value {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
if f := s.FieldByName(k); f != nil {
|
if f := s.FieldByName(k); f != nil {
|
||||||
field := &Field{Field: f, Value: reflect.ValueOf(v)}
|
field := &model.Field{Field: f, Value: reflect.ValueOf(v)}
|
||||||
if assignableChecker(field) {
|
if assignableChecker(field) {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
}
|
}
|
||||||
@ -115,7 +63,7 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
field := &Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}
|
field := &model.Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}
|
||||||
if assignableChecker(field) {
|
if assignableChecker(field) {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
}
|
}
|
||||||
@ -127,14 +75,14 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) {
|
func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) {
|
||||||
// TODO use Offset to replace FieldByName?
|
// TODO use Offset to replace FieldByName?
|
||||||
for _, sf := range s.Fields {
|
for _, sf := range s.Fields {
|
||||||
obj := value
|
obj := value
|
||||||
for _, bn := range sf.BindNames {
|
for _, bn := range sf.BindNames {
|
||||||
obj = value.FieldByName(bn)
|
obj = value.FieldByName(bn)
|
||||||
}
|
}
|
||||||
field := &Field{Field: sf, Value: obj, IsBlank: isBlank(obj)}
|
field := &model.Field{Field: sf, Value: obj, IsBlank: model.IsBlank(obj)}
|
||||||
if assignableChecker(field) {
|
if assignableChecker(field) {
|
||||||
fields = append(fields, field)
|
fields = append(fields, field)
|
||||||
}
|
}
|
||||||
@ -143,8 +91,8 @@ func structToField(value reflect.Value, s *schema.Schema, assignableChecker func
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateAssignableChecker generate checker to check if field is assignable or not
|
// generateAssignableChecker generate checker to check if field is assignable or not
|
||||||
func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*Field) bool {
|
func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*model.Field) bool {
|
||||||
return func(field *Field) bool {
|
return func(field *model.Field) bool {
|
||||||
if len(selectAttrs) > 0 {
|
if len(selectAttrs) > 0 {
|
||||||
for _, attr := range selectAttrs {
|
for _, attr := range selectAttrs {
|
||||||
if field.Name == attr || field.DBName == attr {
|
if field.Name == attr || field.DBName == attr {
|
||||||
@ -164,7 +112,7 @@ func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*F
|
|||||||
}
|
}
|
||||||
|
|
||||||
// omitAttrs return selected attributes of stmt
|
// omitAttrs return selected attributes of stmt
|
||||||
func selectAttrs(stmt *builder.Statement) []string {
|
func selectAttrs(stmt *gorm.Statement) []string {
|
||||||
columns := stmt.Select.Columns
|
columns := stmt.Select.Columns
|
||||||
for _, arg := range stmt.Select.Args {
|
for _, arg := range stmt.Select.Args {
|
||||||
columns = append(columns, fmt.Sprint(arg))
|
columns = append(columns, fmt.Sprint(arg))
|
||||||
@ -173,6 +121,6 @@ func selectAttrs(stmt *builder.Statement) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// omitAttrs return omitted attributes of stmt
|
// omitAttrs return omitted attributes of stmt
|
||||||
func omitAttrs(stmt *builder.Statement) []string {
|
func omitAttrs(stmt *gorm.Statement) []string {
|
||||||
return stmt.Omit
|
return stmt.Omit
|
||||||
}
|
}
|
63
dialects/common/sqlbuilder/sqlbuilder.go
Normal file
63
dialects/common/sqlbuilder/sqlbuilder.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package sqlbuilder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/jinzhu/gorm/model"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
"github.com/jinzhu/inflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetTable get table name for current db operation
|
||||||
|
func GetTable(tx *gorm.DB) chan string {
|
||||||
|
tableChan := make(chan string)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var tableName string
|
||||||
|
if name, ok := tx.Statement.Table.(string); ok {
|
||||||
|
tableName = name
|
||||||
|
} else {
|
||||||
|
for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
|
||||||
|
if v != nil {
|
||||||
|
if t, ok := v.(tabler); ok {
|
||||||
|
tableName = t.TableName()
|
||||||
|
} else if t, ok := v.(dbTabler); ok {
|
||||||
|
tableName = t.TableName(tx)
|
||||||
|
} else if s := schema.Parse(v); s != nil {
|
||||||
|
if s.TableName != "" {
|
||||||
|
tableName = s.TableName
|
||||||
|
} else {
|
||||||
|
tableName = schema.ToDBName(s.ModelType.Name())
|
||||||
|
if !tx.Config.SingularTable {
|
||||||
|
tableName = inflection.Plural(tableName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableName != "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableName != "" {
|
||||||
|
if model.DefaultTableNameHandler != nil {
|
||||||
|
tableChan <- model.DefaultTableNameHandler(tx, tableName)
|
||||||
|
} else {
|
||||||
|
tableChan <- tableName
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tx.AddError(ErrInvalidTable)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return tableChan
|
||||||
|
}
|
||||||
|
|
||||||
|
type tabler interface {
|
||||||
|
TableName() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type dbTabler interface {
|
||||||
|
TableName(*gorm.DB) string
|
||||||
|
}
|
10
dialects/common/sqlbuilder/utils.go
Normal file
10
dialects/common/sqlbuilder/utils.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package sqlbuilder
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||||
|
for reflectValue.Kind() == reflect.Ptr {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
return reflectValue
|
||||||
|
}
|
@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/model"
|
"github.com/jinzhu/gorm/dialects/common/destination"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dialect Sqlite3 Dialect for GORM
|
// Dialect Sqlite3 Dialect for GORM
|
||||||
@ -23,9 +23,9 @@ func (dialect Dialect) Quote(name string) string {
|
|||||||
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||||
var (
|
var (
|
||||||
args []interface{}
|
args []interface{}
|
||||||
assignmentsChan = model.GetAssignments(tx)
|
assignmentsChan = destination.GetAssignments(tx)
|
||||||
tableNameChan = model.GetTable(tx)
|
tableNameChan = destination.GetTable(tx)
|
||||||
primaryFields []*model.Field
|
primaryFields []*destination.Field
|
||||||
)
|
)
|
||||||
|
|
||||||
s := bytes.NewBufferString("INSERT INTO ")
|
s := bytes.NewBufferString("INSERT INTO ")
|
||||||
@ -41,7 +41,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
|||||||
valueBuffer := bytes.NewBufferString("VALUES ")
|
valueBuffer := bytes.NewBufferString("VALUES ")
|
||||||
|
|
||||||
for idx, fields := range assignments {
|
for idx, fields := range assignments {
|
||||||
var primaryField *model.Field
|
var primaryField *destination.Field
|
||||||
if idx != 0 {
|
if idx != 0 {
|
||||||
valueBuffer.WriteString(",")
|
valueBuffer.WriteString(",")
|
||||||
}
|
}
|
||||||
|
21
logger/utils.go
Normal file
21
logger/utils.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
||||||
|
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
|
||||||
|
|
||||||
|
// FileWithLineNum get filename with line num for logging
|
||||||
|
func FileWithLineNum() string {
|
||||||
|
for i := 2; i < 15; i++ {
|
||||||
|
_, file, line, ok := runtime.Caller(i)
|
||||||
|
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
|
||||||
|
return fmt.Sprintf("%v:%v", file, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
@ -1,9 +1,13 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
"github.com/jinzhu/inflection"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultTableNameHandler default table name handler
|
// DefaultTableNameHandler default table name handler
|
||||||
@ -12,57 +16,52 @@ import (
|
|||||||
// }
|
// }
|
||||||
var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string
|
var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string
|
||||||
|
|
||||||
// GetTable get table name for current db operation
|
// Field GORM model field
|
||||||
func GetTable(tx *gorm.DB) chan string {
|
type Field struct {
|
||||||
tableChan := make(chan string)
|
*schema.Field
|
||||||
|
IsBlank bool
|
||||||
|
Value reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
// Set set a value to the field
|
||||||
var tableName string
|
func (field *Field) Set(value interface{}) (err error) {
|
||||||
if name, ok := tx.Statement.Table.(string); ok {
|
if !field.Value.IsValid() {
|
||||||
tableName = name
|
return errors.New("field value not valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !field.Value.CanAddr() {
|
||||||
|
return gorm.ErrUnaddressable
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue, ok := value.(reflect.Value)
|
||||||
|
if !ok {
|
||||||
|
reflectValue = reflect.ValueOf(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldValue := field.Value
|
||||||
|
if reflectValue.IsValid() {
|
||||||
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
} else {
|
} else {
|
||||||
for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
|
if fieldValue.Kind() == reflect.Ptr {
|
||||||
if v != nil {
|
if fieldValue.IsNil() {
|
||||||
if t, ok := v.(tabler); ok {
|
fieldValue.Set(reflect.New(field.StructField.Type.Elem()))
|
||||||
tableName = t.TableName()
|
|
||||||
} else if t, ok := v.(dbTabler); ok {
|
|
||||||
tableName = t.TableName(tx)
|
|
||||||
} else if s := schema.Parse(v); s != nil {
|
|
||||||
if s.TableName != "" {
|
|
||||||
tableName = s.TableName
|
|
||||||
} else {
|
|
||||||
tableName = schema.ToDBName(s.ModelType.Name())
|
|
||||||
if !tx.Config.SingularTable {
|
|
||||||
tableName = inflection.Plural(tableName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tableName != "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if tableName != "" {
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
if DefaultTableNameHandler != nil {
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
tableChan <- DefaultTableNameHandler(tx, tableName)
|
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||||
|
err = scanner.Scan(reflectValue.Interface())
|
||||||
} else {
|
} else {
|
||||||
tableChan <- tableName
|
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
tx.AddError(ErrInvalidTable)
|
|
||||||
}
|
}
|
||||||
}()
|
} else {
|
||||||
|
field.Value.Set(reflect.Zero(fieldValue.Type()))
|
||||||
|
}
|
||||||
|
|
||||||
return tableChan
|
field.IsBlank = IsBlank(field.Value)
|
||||||
}
|
return err
|
||||||
|
|
||||||
type tabler interface {
|
|
||||||
TableName() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type dbTabler interface {
|
|
||||||
TableName(*gorm.DB) string
|
|
||||||
}
|
}
|
||||||
|
@ -2,32 +2,7 @@ package model
|
|||||||
|
|
||||||
import "reflect"
|
import "reflect"
|
||||||
|
|
||||||
// ToSearchableMap convert attrs to searchable map
|
func IsBlank(value reflect.Value) bool {
|
||||||
func ToSearchableMap(attrs ...interface{}) (result interface{}) {
|
|
||||||
if len(attrs) > 1 {
|
|
||||||
if str, ok := attrs[0].(string); ok {
|
|
||||||
result = map[string]interface{}{str: attrs[1]}
|
|
||||||
}
|
|
||||||
} else if len(attrs) == 1 {
|
|
||||||
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
|
||||||
result = attr
|
|
||||||
}
|
|
||||||
|
|
||||||
if attr, ok := attrs[0].(interface{}); ok {
|
|
||||||
result = attr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func indirect(reflectValue reflect.Value) reflect.Value {
|
|
||||||
for reflectValue.Kind() == reflect.Ptr {
|
|
||||||
reflectValue = reflectValue.Elem()
|
|
||||||
}
|
|
||||||
return reflectValue
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBlank(value reflect.Value) bool {
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
return value.Len() == 0
|
return value.Len() == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user