add context funcs
This commit is contained in:
parent
79a77d771d
commit
c7f9cb552b
27
callback.go
27
callback.go
@ -1,6 +1,9 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// DefaultCallback default callbacks defined by gorm
|
// DefaultCallback default callbacks defined by gorm
|
||||||
var DefaultCallback = &Callback{logger: nopLogger{}}
|
var DefaultCallback = &Callback{logger: nopLogger{}}
|
||||||
@ -14,11 +17,11 @@ var DefaultCallback = &Callback{logger: nopLogger{}}
|
|||||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
type Callback struct {
|
type Callback struct {
|
||||||
logger logger
|
logger logger
|
||||||
creates []*func(scope *Scope)
|
creates []*func(ctx context.Context, scope *Scope)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(ctx context.Context, scope *Scope)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(ctx context.Context, scope *Scope)
|
||||||
queries []*func(scope *Scope)
|
queries []*func(ctx context.Context, scope *Scope)
|
||||||
rowQueries []*func(scope *Scope)
|
rowQueries []*func(ctx context.Context, scope *Scope)
|
||||||
processors []*CallbackProcessor
|
processors []*CallbackProcessor
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,7 +34,7 @@ type CallbackProcessor struct {
|
|||||||
replace bool // replace callbacks with same name
|
replace bool // replace callbacks with same name
|
||||||
remove bool // delete callbacks with same name
|
remove bool // delete callbacks with same name
|
||||||
kind string // callback type: create, update, delete, query, row_query
|
kind string // callback type: create, update, delete, query, row_query
|
||||||
processor *func(scope *Scope) // callback handler
|
processor *func(ctx context.Context, scope *Scope) // callback handler
|
||||||
parent *Callback
|
parent *Callback
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register a new callback, refer `Callbacks.Create`
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
||||||
@ -123,7 +126,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
// scope.SetColumn("CreatedAt", now)
|
// scope.SetColumn("CreatedAt", now)
|
||||||
// scope.SetColumn("UpdatedAt", now)
|
// scope.SetColumn("UpdatedAt", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
||||||
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
@ -134,7 +137,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
|||||||
|
|
||||||
// Get registered callback
|
// Get registered callback
|
||||||
// db.Callback().Create().Get("gorm:create")
|
// db.Callback().Create().Get("gorm:create")
|
||||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
|
||||||
for _, p := range cp.parent.processors {
|
for _, p := range cp.parent.processors {
|
||||||
if p.name == callbackName && p.kind == cp.kind {
|
if p.name == callbackName && p.kind == cp.kind {
|
||||||
if p.remove {
|
if p.remove {
|
||||||
@ -158,7 +161,7 @@ func getRIndex(strs []string, str string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sortProcessors sort callback processors based on its before, after, remove, replace
|
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
func sortProcessors(cps []*CallbackProcessor) []*func(ctx context.Context, scope *Scope) {
|
||||||
var (
|
var (
|
||||||
allNames, sortedNames []string
|
allNames, sortedNames []string
|
||||||
sortCallbackProcessor func(c *CallbackProcessor)
|
sortCallbackProcessor func(c *CallbackProcessor)
|
||||||
@ -211,7 +214,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
|||||||
sortCallbackProcessor(cp)
|
sortCallbackProcessor(cp)
|
||||||
}
|
}
|
||||||
|
|
||||||
var sortedFuncs []*func(scope *Scope)
|
var sortedFuncs []*func(ctx context.Context, scope *Scope)
|
||||||
for _, name := range sortedNames {
|
for _, name := range sortedNames {
|
||||||
if index := getRIndex(allNames, name); !cps[index].remove {
|
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -19,7 +20,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||||
func beforeCreateCallback(scope *Scope) {
|
func beforeCreateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
scope.CallMethod("BeforeSave")
|
scope.CallMethod("BeforeSave")
|
||||||
}
|
}
|
||||||
@ -29,7 +30,7 @@ func beforeCreateCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
func updateTimeStampForCreateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
now := scope.db.nowFunc()
|
now := scope.db.nowFunc()
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createCallback the callback used to insert data into database
|
// createCallback the callback used to insert data into database
|
||||||
func createCallback(scope *Scope) {
|
func createCallback(ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
@ -100,10 +101,10 @@ func createCallback(scope *Scope) {
|
|||||||
returningColumn = scope.Quote(primaryField.DBName)
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
|
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(ctx, quotedTableName, returningColumn, columns)
|
||||||
var lastInsertIDReturningSuffix string
|
var lastInsertIDReturningSuffix string
|
||||||
if lastInsertIDOutputInterstitial == "" {
|
if lastInsertIDOutputInterstitial == "" {
|
||||||
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(ctx, quotedTableName, returningColumn)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
@ -130,7 +131,7 @@ func createCallback(scope *Scope) {
|
|||||||
|
|
||||||
// execute create sql: no primaryField
|
// execute create sql: no primaryField
|
||||||
if primaryField == nil {
|
if primaryField == nil {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
// set rows affected count
|
// set rows affected count
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
@ -146,7 +147,7 @@ func createCallback(scope *Scope) {
|
|||||||
|
|
||||||
// execute create sql: lastInsertID implemention for majority of dialects
|
// execute create sql: lastInsertID implemention for majority of dialects
|
||||||
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
|
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
// set rows affected count
|
// set rows affected count
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
@ -162,7 +163,7 @@ func createCallback(scope *Scope) {
|
|||||||
|
|
||||||
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
||||||
if primaryField.Field.CanAddr() {
|
if primaryField.Field.CanAddr() {
|
||||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
if err := scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
primaryField.IsBlank = false
|
primaryField.IsBlank = false
|
||||||
scope.db.RowsAffected = 1
|
scope.db.RowsAffected = 1
|
||||||
}
|
}
|
||||||
@ -174,7 +175,7 @@ func createCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
func forceReloadAfterCreateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||||
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
@ -187,7 +188,7 @@ func forceReloadAfterCreateCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||||
func afterCreateCallback(scope *Scope) {
|
func afterCreateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
scope.CallMethod("AfterCreate")
|
scope.CallMethod("AfterCreate")
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
@ -15,7 +16,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||||
func beforeDeleteCallback(scope *Scope) {
|
func beforeDeleteCallback(_ctx context.Context, scope *Scope) {
|
||||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||||
scope.Err(errors.New("missing WHERE clause while deleting"))
|
scope.Err(errors.New("missing WHERE clause while deleting"))
|
||||||
return
|
return
|
||||||
@ -26,7 +27,7 @@ func beforeDeleteCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||||
func deleteCallback(scope *Scope) {
|
func deleteCallback(ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
var extraOption string
|
var extraOption string
|
||||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||||
@ -43,20 +44,20 @@ func deleteCallback(scope *Scope) {
|
|||||||
scope.AddToVars(scope.db.nowFunc()),
|
scope.AddToVars(scope.db.nowFunc()),
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
)).Exec()
|
)).Exec(ctx)
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"DELETE FROM %v%v%v",
|
"DELETE FROM %v%v%v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
)).Exec()
|
)).Exec(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||||
func afterDeleteCallback(scope *Scope) {
|
func afterDeleteCallback(_ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
scope.CallMethod("AfterDelete")
|
scope.CallMethod("AfterDelete")
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -14,7 +15,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// queryCallback used to query data from database
|
// queryCallback used to query data from database
|
||||||
func queryCallback(scope *Scope) {
|
func queryCallback(ctx context.Context, scope *Scope) {
|
||||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -69,7 +70,7 @@ func queryCallback(scope *Scope) {
|
|||||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if rows, err := scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
@ -102,7 +103,7 @@ func queryCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// afterQueryCallback will invoke `AfterFind` method after querying
|
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||||
func afterQueryCallback(scope *Scope) {
|
func afterQueryCallback(_ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
scope.CallMethod("AfterFind")
|
scope.CallMethod("AfterFind")
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -9,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// preloadCallback used to preload associations
|
// preloadCallback used to preload associations
|
||||||
func preloadCallback(scope *Scope) {
|
func preloadCallback(ctx context.Context, scope *Scope) {
|
||||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -18,9 +19,9 @@ func preloadCallback(scope *Scope) {
|
|||||||
// If gorm:auto_preload IS NOT a bool then auto preload.
|
// If gorm:auto_preload IS NOT a bool then auto preload.
|
||||||
// Else if it IS a bool, use the value
|
// Else if it IS a bool, use the value
|
||||||
if apb, ok := ap.(bool); !ok {
|
if apb, ok := ap.(bool); !ok {
|
||||||
autoPreload(scope)
|
autoPreload(ctx, scope)
|
||||||
} else if apb {
|
} else if apb {
|
||||||
autoPreload(scope)
|
autoPreload(ctx, scope)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,13 +63,13 @@ func preloadCallback(scope *Scope) {
|
|||||||
|
|
||||||
switch field.Relationship.Kind {
|
switch field.Relationship.Kind {
|
||||||
case "has_one":
|
case "has_one":
|
||||||
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
currentScope.handleHasOnePreload(ctx, field, currentPreloadConditions)
|
||||||
case "has_many":
|
case "has_many":
|
||||||
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
currentScope.handleHasManyPreload(ctx, field, currentPreloadConditions)
|
||||||
case "belongs_to":
|
case "belongs_to":
|
||||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
currentScope.handleBelongsToPreload(ctx, field, currentPreloadConditions)
|
||||||
case "many_to_many":
|
case "many_to_many":
|
||||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
currentScope.handleManyToManyPreload(ctx, field, currentPreloadConditions)
|
||||||
default:
|
default:
|
||||||
scope.Err(errors.New("unsupported relation"))
|
scope.Err(errors.New("unsupported relation"))
|
||||||
}
|
}
|
||||||
@ -94,7 +95,7 @@ func preloadCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func autoPreload(scope *Scope) {
|
func autoPreload(_ctx context.Context, scope *Scope) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if field.Relationship == nil {
|
if field.Relationship == nil {
|
||||||
continue
|
continue
|
||||||
@ -131,7 +132,7 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleHasOnePreload used to preload has one associations
|
// handleHasOnePreload used to preload has one associations
|
||||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
func (scope *Scope) handleHasOnePreload(_ctx context.Context, field *Field, conditions []interface{}) {
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
|
|
||||||
// get relations's primary keys
|
// get relations's primary keys
|
||||||
@ -183,7 +184,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleHasManyPreload used to preload has many associations
|
// handleHasManyPreload used to preload has many associations
|
||||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
func (scope *Scope) handleHasManyPreload(_ctx context.Context, field *Field, conditions []interface{}) {
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
|
|
||||||
// get relations's primary keys
|
// get relations's primary keys
|
||||||
@ -236,7 +237,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleBelongsToPreload used to preload belongs to associations
|
// handleBelongsToPreload used to preload belongs to associations
|
||||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
func (scope *Scope) handleBelongsToPreload(_ctx context.Context, field *Field, conditions []interface{}) {
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
|
|
||||||
// preload conditions
|
// preload conditions
|
||||||
@ -283,7 +284,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleManyToManyPreload used to preload many to many associations
|
// handleManyToManyPreload used to preload many to many associations
|
||||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
func (scope *Scope) handleManyToManyPreload(ctx context.Context, field *Field, conditions []interface{}) {
|
||||||
var (
|
var (
|
||||||
relation = field.Relationship
|
relation = field.Relationship
|
||||||
joinTableHandler = relation.JoinTableHandler
|
joinTableHandler = relation.JoinTableHandler
|
||||||
@ -346,7 +347,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
|
|
||||||
scope.New(elem.Addr().Interface()).
|
scope.New(elem.Addr().Interface()).
|
||||||
InstanceSet("gorm:skip_query_callback", true).
|
InstanceSet("gorm:skip_query_callback", true).
|
||||||
callCallbacks(scope.db.parent.callbacks.queries)
|
callCallbacks(ctx, scope.db.parent.callbacks.queries)
|
||||||
|
|
||||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||||
// generate hashed forkey keys in join table
|
// generate hashed forkey keys in join table
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
@ -20,7 +21,7 @@ type RowsQueryResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// queryCallback used to query data from database
|
// queryCallback used to query data from database
|
||||||
func rowQueryCallback(scope *Scope) {
|
func rowQueryCallback(ctx context.Context, scope *Scope) {
|
||||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||||
scope.prepareQuerySQL()
|
scope.prepareQuerySQL()
|
||||||
|
|
||||||
@ -33,9 +34,9 @@ func rowQueryCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
if rowResult, ok := result.(*RowQueryResult); ok {
|
||||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
rowResult.Row = scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...)
|
||||||
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
|
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
|
||||||
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func beginTransactionCallback(scope *Scope) {
|
func beginTransactionCallback(_ctx context.Context, scope *Scope) {
|
||||||
scope.Begin()
|
scope.Begin()
|
||||||
}
|
}
|
||||||
|
|
||||||
func commitOrRollbackTransactionCallback(scope *Scope) {
|
func commitOrRollbackTransactionCallback(_ctx context.Context, scope *Scope) {
|
||||||
scope.CommitOrRollback()
|
scope.CommitOrRollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +65,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
func saveBeforeAssociationsCallback(_ctx context.Context, scope *Scope) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ func saveBeforeAssociationsCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveAfterAssociationsCallback(scope *Scope) {
|
func saveAfterAssociationsCallback(_ctx context.Context, scope *Scope) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||||
|
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
|
func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
|
||||||
var names []string
|
var names []string
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
||||||
@ -16,11 +17,11 @@ func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
|
|||||||
return reflect.DeepEqual(names, fnames)
|
return reflect.DeepEqual(names, fnames)
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(s *Scope) {}
|
func create(ctx context.Context, s *Scope) {}
|
||||||
func beforeCreate1(s *Scope) {}
|
func beforeCreate1(ctx context.Context, s *Scope) {}
|
||||||
func beforeCreate2(s *Scope) {}
|
func beforeCreate2(ctx context.Context, s *Scope) {}
|
||||||
func afterCreate1(s *Scope) {}
|
func afterCreate1(ctx context.Context, s *Scope) {}
|
||||||
func afterCreate2(s *Scope) {}
|
func afterCreate2(ctx context.Context, s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
@ -83,7 +84,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceCreate(s *Scope) {}
|
func replaceCreate(ctx context.Context, s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
@ -21,7 +22,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
func assignUpdatingAttributesCallback(_ctx context.Context, scope *Scope) {
|
||||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||||
@ -32,7 +33,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||||
func beforeUpdateCallback(scope *Scope) {
|
func beforeUpdateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||||
scope.Err(errors.New("missing WHERE clause while updating"))
|
scope.Err(errors.New("missing WHERE clause while updating"))
|
||||||
return
|
return
|
||||||
@ -48,14 +49,14 @@ func beforeUpdateCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
func updateTimeStampForUpdateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
|
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateCallback the callback used to update data to database
|
// updateCallback the callback used to update data to database
|
||||||
func updateCallback(scope *Scope) {
|
func updateCallback(ctx context.Context, scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
|
|
||||||
@ -103,13 +104,13 @@ func updateCallback(scope *Scope) {
|
|||||||
strings.Join(sqls, ", "),
|
strings.Join(sqls, ", "),
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
)).Exec()
|
)).Exec(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||||
func afterUpdateCallback(scope *Scope) {
|
func afterUpdateCallback(_ctx context.Context, scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
scope.CallMethod("AfterUpdate")
|
scope.CallMethod("AfterUpdate")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@ -182,22 +183,23 @@ func TestGetCallback(t *testing.T) {
|
|||||||
t.Errorf("`gorm:test_callback` should be nil")
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||||
callback := DB.Callback().Create().Get("gorm:test_callback")
|
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(scope)
|
ctx := context.Background()
|
||||||
|
callback(ctx, scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
DB.Callback().Create().Replace("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(scope)
|
callback(ctx, scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
@ -207,12 +209,12 @@ func TestGetCallback(t *testing.T) {
|
|||||||
t.Errorf("`gorm:test_callback` should be nil")
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(scope)
|
callback(ctx, scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
@ -220,7 +222,7 @@ func TestGetCallback(t *testing.T) {
|
|||||||
|
|
||||||
func TestUseDefaultCallback(t *testing.T) {
|
func TestUseDefaultCallback(t *testing.T) {
|
||||||
createCallbackName := "gorm:test_use_default_callback_for_create"
|
createCallbackName := "gorm:test_use_default_callback_for_create"
|
||||||
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
|
gorm.DefaultCallback.Create().Register(createCallbackName, func(context.Context, *gorm.Scope) {
|
||||||
// nop
|
// nop
|
||||||
})
|
})
|
||||||
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
||||||
@ -233,16 +235,17 @@ func TestUseDefaultCallback(t *testing.T) {
|
|||||||
|
|
||||||
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
||||||
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
||||||
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
|
gorm.DefaultCallback.Update().Register(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
|
||||||
scope.Set(scopeValueName, 1)
|
scope.Set(scopeValueName, 1)
|
||||||
})
|
})
|
||||||
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
|
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
|
||||||
scope.Set(scopeValueName, 2)
|
scope.Set(scopeValueName, 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
scope := DB.NewScope(nil)
|
scope := DB.NewScope(nil)
|
||||||
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
||||||
callback(scope)
|
ctx := context.Background()
|
||||||
|
callback(ctx, scope)
|
||||||
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
||||||
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -21,12 +22,13 @@ type CustomColumnAndIgnoredFieldClash struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomizeColumn(t *testing.T) {
|
func TestCustomizeColumn(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
col := "mapped_name"
|
col := "mapped_name"
|
||||||
DB.DropTable(&CustomizeColumn{})
|
DB.DropTable(&CustomizeColumn{})
|
||||||
DB.AutoMigrate(&CustomizeColumn{})
|
DB.AutoMigrate(&CustomizeColumn{})
|
||||||
|
|
||||||
scope := DB.NewScope(&CustomizeColumn{})
|
scope := DB.NewScope(&CustomizeColumn{})
|
||||||
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
if !scope.Dialect().HasColumn(ctx, scope.TableName(), col) {
|
||||||
t.Errorf("CustomizeColumn should have column %s", col)
|
t.Errorf("CustomizeColumn should have column %s", col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
23
dialect.go
23
dialect.go
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -24,26 +25,26 @@ type Dialect interface {
|
|||||||
DataTypeOf(field *StructField) string
|
DataTypeOf(field *StructField) string
|
||||||
|
|
||||||
// HasIndex check has index or not
|
// HasIndex check has index or not
|
||||||
HasIndex(tableName string, indexName string) bool
|
HasIndex(ctx context.Context, tableName string, indexName string) bool
|
||||||
// HasForeignKey check has foreign key or not
|
// HasForeignKey check has foreign key or not
|
||||||
HasForeignKey(tableName string, foreignKeyName string) bool
|
HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool
|
||||||
// RemoveIndex remove index
|
// RemoveIndex remove index
|
||||||
RemoveIndex(tableName string, indexName string) error
|
RemoveIndex(ctx context.Context, tableName string, indexName string) error
|
||||||
// HasTable check has table or not
|
// HasTable check has table or not
|
||||||
HasTable(tableName string) bool
|
HasTable(ctx context.Context, tableName string) bool
|
||||||
// HasColumn check has column or not
|
// HasColumn check has column or not
|
||||||
HasColumn(tableName string, columnName string) bool
|
HasColumn(ctx context.Context, tableName string, columnName string) bool
|
||||||
// ModifyColumn modify column's type
|
// ModifyColumn modify column's type
|
||||||
ModifyColumn(tableName string, columnName string, typ string) error
|
ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error
|
||||||
|
|
||||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||||
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
||||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||||
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
LastInsertIDOutputInterstitial(ctx context.Context, tableName, columnName string, columns []string) string
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
LastInsertIDReturningSuffix(ctx context.Context, tableName, columnName string) string
|
||||||
// DefaultValueStr
|
// DefaultValueStr
|
||||||
DefaultValueStr() string
|
DefaultValueStr() string
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ type Dialect interface {
|
|||||||
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase() string
|
CurrentDatabase(ctx context.Context) string
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
|
|||||||
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
|
func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) {
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
return splitStrings[0], splitStrings[1]
|
return splitStrings[0], splitStrings[1]
|
||||||
}
|
}
|
||||||
return dialect.CurrentDatabase(), tableName
|
return dialect.CurrentDatabase(ctx), tableName
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s commonDialect) HasForeignKey(_ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasTable(tableName string) bool {
|
func (s commonDialect) HasTable(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
|
func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) CurrentDatabase() (name string) {
|
func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) {
|
||||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,11 +163,11 @@ func (commonDialect) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
func (commonDialect) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (commonDialect) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
|
func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasTable(tableName string) bool {
|
func (s mysql) HasTable(ctx context.Context, tableName string) bool {
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
var name string
|
var name string
|
||||||
// allow mysql database name with '-' character
|
// allow mysql database name with '-' character
|
||||||
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
|
if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -183,9 +184,9 @@ func (s mysql) HasTable(tableName string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasIndex(tableName string, indexName string) bool {
|
func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
|
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else {
|
} else {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
@ -193,9 +194,9 @@ func (s mysql) HasIndex(tableName string, indexName string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasColumn(tableName string, columnName string) bool {
|
func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
|
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else {
|
} else {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
@ -203,8 +204,8 @@ func (s mysql) HasColumn(tableName string, columnName string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) CurrentDatabase() (name string) {
|
func (s mysql) CurrentDatabase(ctx context.Context) (name string) {
|
||||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -91,40 +92,40 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasTable(tableName string) bool {
|
func (s postgres) HasTable(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) CurrentDatabase() (name string) {
|
func (s postgres) CurrentDatabase(ctx context.Context) (name string) {
|
||||||
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
|
func (s postgres) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, key string, columns []string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
func (s postgres) LastInsertIDReturningSuffix(_ctx context.Context, tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -70,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasTable(tableName string) bool {
|
func (s sqlite3) HasTable(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) CurrentDatabase() (name string) {
|
func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) {
|
||||||
var (
|
var (
|
||||||
ifaces = make([]interface{}, 3)
|
ifaces = make([]interface{}, 3)
|
||||||
pointers = make([]*string, 3)
|
pointers = make([]*string, 3)
|
||||||
@ -97,7 +98,7 @@ func (s sqlite3) CurrentDatabase() (name string) {
|
|||||||
for i = 0; i < 3; i++ {
|
for i = 0; i < 3; i++ {
|
||||||
ifaces[i] = &pointers[i]
|
ifaces[i] = &pointers[i]
|
||||||
}
|
}
|
||||||
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
if err := s.db.QueryRowContext(ctx, "PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if pointers[1] != nil {
|
if pointers[1] != nil {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -15,7 +16,7 @@ import (
|
|||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setIdentityInsert(scope *gorm.Scope) {
|
func setIdentityInsert(ctx context.Context, scope *gorm.Scope) {
|
||||||
if scope.Dialect().GetName() == "mssql" {
|
if scope.Dialect().GetName() == "mssql" {
|
||||||
for _, field := range scope.PrimaryFields() {
|
for _, field := range scope.PrimaryFields() {
|
||||||
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
|
||||||
@ -26,7 +27,7 @@ func setIdentityInsert(scope *gorm.Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func turnOffIdentityInsert(scope *gorm.Scope) {
|
func turnOffIdentityInsert(ctx context.Context, scope *gorm.Scope) {
|
||||||
if scope.Dialect().GetName() == "mssql" {
|
if scope.Dialect().GetName() == "mssql" {
|
||||||
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
|
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
|
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
|
||||||
@ -122,21 +123,21 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
|||||||
return field.IsPrimaryKey
|
return field.IsPrimaryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow(`SELECT count(*)
|
s.db.QueryRowContext(ctx, `SELECT count(*)
|
||||||
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
|
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
|
||||||
inner join information_schema.tables as I on I.TABLE_NAME = T.name
|
inner join information_schema.tables as I on I.TABLE_NAME = T.name
|
||||||
WHERE F.name = ?
|
WHERE F.name = ?
|
||||||
@ -144,27 +145,27 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTable(tableName string) bool {
|
func (s mssql) HasTable(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
|
func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) CurrentDatabase() (name string) {
|
func (s mssql) CurrentDatabase(ctx context.Context) (name string) {
|
||||||
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,7 +199,7 @@ func (mssql) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
func (mssql) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
// No OUTPUT to query
|
// No OUTPUT to query
|
||||||
return ""
|
return ""
|
||||||
@ -206,7 +207,7 @@ func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, column
|
|||||||
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (mssql) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
|
||||||
// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
|
// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
|
||||||
return "; SELECT SCOPE_IDENTITY()"
|
return "; SELECT SCOPE_IDENTITY()"
|
||||||
}
|
}
|
||||||
@ -220,12 +221,12 @@ func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, stri
|
|||||||
return indexName, columnName
|
return indexName, columnName
|
||||||
}
|
}
|
||||||
|
|
||||||
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableName string) (string, string) {
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
return splitStrings[0], splitStrings[1]
|
return splitStrings[0], splitStrings[1]
|
||||||
}
|
}
|
||||||
return dialect.CurrentDatabase(), tableName
|
return dialect.CurrentDatabase(ctx), tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
// JSON type to support easy handling of JSON data in character table fields
|
// JSON type to support easy handling of JSON data in character table fields
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
type BasePost struct {
|
type BasePost struct {
|
||||||
Id int64
|
Id int64
|
||||||
@ -27,9 +30,10 @@ type EngadgetPost struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
||||||
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
||||||
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
|
if !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_email") {
|
||||||
t.Errorf("should has prefix for embedded columns")
|
t.Errorf("should has prefix for embedded columns")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +42,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hnScope := DB.NewScope(&HNPost{})
|
hnScope := DB.NewScope(&HNPost{})
|
||||||
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
|
if !dialect.HasColumn(ctx, hnScope.TableName(), "user_id") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_name") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_email") {
|
||||||
t.Errorf("should has prefix for embedded columns")
|
t.Errorf("should has prefix for embedded columns")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,10 +7,10 @@ import (
|
|||||||
|
|
||||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||||
type SQLCommon interface {
|
type SQLCommon interface {
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
Prepare(query string) (*sql.Stmt, error)
|
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlDb interface {
|
type sqlDb interface {
|
||||||
|
237
main.go
237
main.go
@ -327,51 +327,91 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
|||||||
|
|
||||||
// First find first record that match given conditions, order by primary key
|
// First find first record that match given conditions, order by primary key
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.FirstContext(ctx, out, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) FirstContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.NewScope(out)
|
newScope := s.NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
|
|
||||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||||
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.TakeContext(ctx, out, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) TakeContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.NewScope(out)
|
newScope := s.NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
return newScope.inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last find last record that match given conditions, order by primary key
|
// Last find last record that match given conditions, order by primary key
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.LastContext(ctx, out, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) LastContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.NewScope(out)
|
newScope := s.NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find find records that match given conditions
|
// Find find records that match given conditions
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
ctx := context.Background()
|
||||||
|
return s.FindContext(ctx, out, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) FindContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
|
||||||
|
return s.NewScope(out).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
//Preloads preloads relations, don`t touch out
|
//Preloads preloads relations, don`t touch out
|
||||||
func (s *DB) Preloads(out interface{}) *DB {
|
func (s *DB) Preloads(out interface{}) *DB {
|
||||||
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
|
ctx := context.Background()
|
||||||
|
return s.PreloadsContext(ctx, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) PreloadsContext(ctx context.Context, out interface{}) *DB {
|
||||||
|
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(ctx, s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scan value to a struct
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
ctx := context.Background()
|
||||||
|
return s.ScanContext(ctx, dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) ScanContext(ctx context.Context, dest interface{}) *DB {
|
||||||
|
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(ctx, s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Row return `*sql.Row` with given conditions
|
// Row return `*sql.Row` with given conditions
|
||||||
func (s *DB) Row() *sql.Row {
|
func (s *DB) Row() *sql.Row {
|
||||||
return s.NewScope(s.Value).row()
|
ctx := context.Background()
|
||||||
|
return s.RowContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) RowContext(ctx context.Context) *sql.Row {
|
||||||
|
return s.NewScope(s.Value).row(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rows return `*sql.Rows` with given conditions
|
// Rows return `*sql.Rows` with given conditions
|
||||||
func (s *DB) Rows() (*sql.Rows, error) {
|
func (s *DB) Rows() (*sql.Rows, error) {
|
||||||
return s.NewScope(s.Value).rows()
|
ctx := context.Background()
|
||||||
|
return s.RowsContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) RowsContext(ctx context.Context) (*sql.Rows, error) {
|
||||||
|
return s.NewScope(s.Value).rows(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanRows scan `*sql.Rows` to give struct
|
// ScanRows scan `*sql.Rows` to give struct
|
||||||
@ -393,12 +433,22 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
|||||||
// var ages []int64
|
// var ages []int64
|
||||||
// db.Find(&users).Pluck("age", &ages)
|
// db.Find(&users).Pluck("age", &ages)
|
||||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||||
return s.NewScope(s.Value).pluck(column, value).db
|
ctx := context.Background()
|
||||||
|
return s.PluckContext(ctx, column, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) PluckContext(ctx context.Context, column string, value interface{}) *DB {
|
||||||
|
return s.NewScope(s.Value).pluck(ctx, column, value).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count get how many records for a model
|
// Count get how many records for a model
|
||||||
func (s *DB) Count(value interface{}) *DB {
|
func (s *DB) Count(value interface{}) *DB {
|
||||||
return s.NewScope(s.Value).count(value).db
|
ctx := context.Background()
|
||||||
|
return s.CountContext(ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) CountContext(ctx context.Context, value interface{}) *DB {
|
||||||
|
return s.NewScope(s.Value).count(ctx, value).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Related get related associations
|
// Related get related associations
|
||||||
@ -409,8 +459,13 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
|||||||
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||||
// https://jinzhu.github.io/gorm/crud.html#firstorinit
|
// https://jinzhu.github.io/gorm/crud.html#firstorinit
|
||||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.FirstOrInitContext(ctx, out, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) FirstOrInitContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if result := c.First(out, where...); result.Error != nil {
|
if result := c.FirstContext(ctx, out, where...); result.Error != nil {
|
||||||
if !result.RecordNotFound() {
|
if !result.RecordNotFound() {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
@ -424,14 +479,19 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||||||
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
|
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||||
// https://jinzhu.github.io/gorm/crud.html#firstorcreate
|
// https://jinzhu.github.io/gorm/crud.html#firstorcreate
|
||||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.FirstOrCreateContext(ctx, out, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) FirstOrCreateContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if result := s.First(out, where...); result.Error != nil {
|
if result := s.First(out, where...); result.Error != nil {
|
||||||
if !result.RecordNotFound() {
|
if !result.RecordNotFound() {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db
|
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(ctx, c.parent.callbacks.creates).db
|
||||||
} else if len(c.search.assignAttrs) > 0 {
|
} else if len(c.search.assignAttrs) > 0 {
|
||||||
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db
|
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(ctx, c.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@ -439,54 +499,89 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
|||||||
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
// WARNING when update with struct, GORM will not update fields that with zero value
|
// WARNING when update with struct, GORM will not update fields that with zero value
|
||||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
return s.Updates(toSearchableMap(attrs...), true)
|
ctx := context.Background()
|
||||||
|
return s.UpdateContext(ctx, attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) UpdateContext(ctx context.Context, attrs ...interface{}) *DB {
|
||||||
|
return s.UpdatesContext(ctx, toSearchableMap(attrs...), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.UpdatesContext(ctx, values, ignoreProtectedAttrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) UpdatesContext(ctx context.Context, values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
return s.NewScope(s.Value).
|
return s.NewScope(s.Value).
|
||||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callbacks.updates).db
|
callCallbacks(ctx, s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||||
return s.UpdateColumns(toSearchableMap(attrs...))
|
ctx := context.Background()
|
||||||
|
return s.UpdateColumnContext(ctx, attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) UpdateColumnContext(ctx context.Context, attrs ...interface{}) *DB {
|
||||||
|
return s.UpdateColumnsContext(ctx, toSearchableMap(attrs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.UpdateColumnsContext(ctx, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) UpdateColumnsContext(ctx context.Context, values interface{}) *DB {
|
||||||
return s.NewScope(s.Value).
|
return s.NewScope(s.Value).
|
||||||
Set("gorm:update_column", true).
|
Set("gorm:update_column", true).
|
||||||
Set("gorm:save_associations", false).
|
Set("gorm:save_associations", false).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callbacks.updates).db
|
callCallbacks(ctx, s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.SaveContext(ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) SaveContext(ctx context.Context, value interface{}) *DB {
|
||||||
scope := s.NewScope(value)
|
scope := s.NewScope(value)
|
||||||
if !scope.PrimaryKeyZero() {
|
if !scope.PrimaryKeyZero() {
|
||||||
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
newDB := scope.callCallbacks(ctx, s.parent.callbacks.updates).db
|
||||||
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
||||||
return s.New().Table(scope.TableName()).FirstOrCreate(value)
|
return s.New().Table(scope.TableName()).FirstOrCreateContext(ctx, value)
|
||||||
}
|
}
|
||||||
return newDB
|
return newDB
|
||||||
}
|
}
|
||||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
return scope.callCallbacks(ctx, s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create insert the value into database
|
// Create insert the value into database
|
||||||
func (s *DB) Create(value interface{}) *DB {
|
func (s *DB) Create(value interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.CreateContext(ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) CreateContext(ctx context.Context, value interface{}) *DB {
|
||||||
scope := s.NewScope(value)
|
scope := s.NewScope(value)
|
||||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
return scope.callCallbacks(ctx, s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||||
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
|
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
|
||||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||||
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
ctx := context.Background()
|
||||||
|
return s.DeleteContext(ctx, value, where...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) DeleteContext(ctx context.Context, value interface{}, where ...interface{}) *DB {
|
||||||
|
return s.NewScope(value).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.deletes).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
||||||
@ -497,11 +592,16 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
|||||||
|
|
||||||
// Exec execute raw sql
|
// Exec execute raw sql
|
||||||
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.ExecContext(ctx, sql, values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) ExecContext(ctx context.Context, sql string, values ...interface{}) *DB {
|
||||||
scope := s.NewScope(nil)
|
scope := s.NewScope(nil)
|
||||||
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
|
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
|
||||||
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
||||||
scope.Raw(generatedSQL)
|
scope.Raw(generatedSQL)
|
||||||
return scope.Exec().db
|
return scope.Exec(ctx).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model specify the model you would like to run db operations
|
// Model specify the model you would like to run db operations
|
||||||
@ -628,32 +728,47 @@ func (s *DB) RecordNotFound() bool {
|
|||||||
|
|
||||||
// CreateTable create table for models
|
// CreateTable create table for models
|
||||||
func (s *DB) CreateTable(models ...interface{}) *DB {
|
func (s *DB) CreateTable(models ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.CreateTableContext(ctx, models...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) CreateTableContext(ctx context.Context, models ...interface{}) *DB {
|
||||||
db := s.Unscoped()
|
db := s.Unscoped()
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
db = db.NewScope(model).createTable().db
|
db = db.NewScope(model).createTable(ctx).db
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropTable drop table for models
|
// DropTable drop table for models
|
||||||
func (s *DB) DropTable(values ...interface{}) *DB {
|
func (s *DB) DropTable(values ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.DropTableContext(ctx, values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) DropTableContext(ctx context.Context, values ...interface{}) *DB {
|
||||||
db := s.clone()
|
db := s.clone()
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if tableName, ok := value.(string); ok {
|
if tableName, ok := value.(string); ok {
|
||||||
db = db.Table(tableName)
|
db = db.Table(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db = db.NewScope(value).dropTable().db
|
db = db.NewScope(value).dropTable(ctx).db
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropTableIfExists drop table if it is exist
|
// DropTableIfExists drop table if it is exist
|
||||||
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.DropTableIfExistsContext(ctx, values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) DropTableIfExistsContext(ctx context.Context, values ...interface{}) *DB {
|
||||||
db := s.clone()
|
db := s.clone()
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if s.HasTable(value) {
|
if s.HasTable(value) {
|
||||||
db.AddError(s.DropTable(value).Error)
|
db.AddError(s.DropTableContext(ctx, value).Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
@ -661,6 +776,11 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
|||||||
|
|
||||||
// HasTable check has table or not
|
// HasTable check has table or not
|
||||||
func (s *DB) HasTable(value interface{}) bool {
|
func (s *DB) HasTable(value interface{}) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasTableContext(ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) HasTableContext(ctx context.Context, value interface{}) bool {
|
||||||
var (
|
var (
|
||||||
scope = s.NewScope(value)
|
scope = s.NewScope(value)
|
||||||
tableName string
|
tableName string
|
||||||
@ -672,68 +792,108 @@ func (s *DB) HasTable(value interface{}) bool {
|
|||||||
tableName = scope.TableName()
|
tableName = scope.TableName()
|
||||||
}
|
}
|
||||||
|
|
||||||
has := scope.Dialect().HasTable(tableName)
|
has := scope.Dialect().HasTable(ctx, tableName)
|
||||||
s.AddError(scope.db.Error)
|
s.AddError(scope.db.Error)
|
||||||
return has
|
return has
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
|
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
|
||||||
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.AutoMigrateContext(ctx, values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) AutoMigrateContext(ctx context.Context, values ...interface{}) *DB {
|
||||||
db := s.Unscoped()
|
db := s.Unscoped()
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
db = db.NewScope(value).autoMigrate().db
|
db = db.NewScope(value).autoMigrate(ctx).db
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModifyColumn modify column to type
|
// ModifyColumn modify column to type
|
||||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.ModifyColumnContext(ctx, column, typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) ModifyColumnContext(ctx context.Context, column string, typ string) *DB {
|
||||||
scope := s.NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.modifyColumn(column, typ)
|
scope.modifyColumn(ctx, column, typ)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropColumn drop a column
|
// DropColumn drop a column
|
||||||
func (s *DB) DropColumn(column string) *DB {
|
func (s *DB) DropColumn(column string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.DropColumnContext(ctx, column)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) DropColumnContext(ctx context.Context, column string) *DB {
|
||||||
scope := s.NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.dropColumn(column)
|
scope.dropColumn(ctx, column)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddIndex add index for columns with given name
|
// AddIndex add index for columns with given name
|
||||||
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
|
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.AddIndexContext(ctx, indexName, columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) AddIndexContext(ctx context.Context, indexName string, columns ...string) *DB {
|
||||||
scope := s.Unscoped().NewScope(s.Value)
|
scope := s.Unscoped().NewScope(s.Value)
|
||||||
scope.addIndex(false, indexName, columns...)
|
scope.addIndex(ctx, false, indexName, columns...)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddUniqueIndex add unique index for columns with given name
|
// AddUniqueIndex add unique index for columns with given name
|
||||||
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.AddUniqueIndexContext(ctx, indexName, columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) AddUniqueIndexContext(ctx context.Context, indexName string, columns ...string) *DB {
|
||||||
scope := s.Unscoped().NewScope(s.Value)
|
scope := s.Unscoped().NewScope(s.Value)
|
||||||
scope.addIndex(true, indexName, columns...)
|
scope.addIndex(ctx, true, indexName, columns...)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveIndex remove index with name
|
// RemoveIndex remove index with name
|
||||||
func (s *DB) RemoveIndex(indexName string) *DB {
|
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.RemoveIndexContext(ctx, indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) RemoveIndexContext(ctx context.Context, indexName string) *DB {
|
||||||
scope := s.NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.removeIndex(indexName)
|
scope.removeIndex(ctx, indexName)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddForeignKey Add foreign key to the given scope, e.g:
|
// AddForeignKey Add foreign key to the given scope, e.g:
|
||||||
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.AddForeignKeyContext(ctx, field, dest, onDelete, onUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) AddForeignKeyContext(ctx context.Context, field string, dest string, onDelete string, onUpdate string) *DB {
|
||||||
scope := s.NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
scope.addForeignKey(ctx, field, dest, onDelete, onUpdate)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveForeignKey Remove foreign key from the given scope, e.g:
|
// RemoveForeignKey Remove foreign key from the given scope, e.g:
|
||||||
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
|
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
|
||||||
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
|
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.RemoveForeignKeyContext(ctx, field, dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) RemoveForeignKeyContext(ctx context.Context, field string, dest string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
scope.removeForeignKey(field, dest)
|
scope.removeForeignKey(ctx, field, dest)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -784,6 +944,11 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
|
|||||||
|
|
||||||
// SetJoinTableHandler set a model's join table handler for a relation
|
// SetJoinTableHandler set a model's join table handler for a relation
|
||||||
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s.SetJoinTableHandlerContext(ctx, source, column, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) SetJoinTableHandlerContext(ctx context.Context, source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||||
scope := s.NewScope(source)
|
scope := s.NewScope(source)
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if field.Name == column || field.DBName == column {
|
if field.Name == column || field.DBName == column {
|
||||||
@ -792,7 +957,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
handler.Setup(field.Relationship, many2many, source, destination)
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
field.Relationship.JoinTableHandler = handler
|
field.Relationship.JoinTableHandler = handler
|
||||||
if table := handler.Table(s); scope.Dialect().HasTable(table) {
|
if table := handler.Table(s); scope.Dialect().HasTable(ctx, table) {
|
||||||
s.Table(table).AutoMigrate(handler)
|
s.Table(table).AutoMigrate(handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
@ -302,12 +303,14 @@ func runMigration() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexes(t *testing.T) {
|
func TestIndexes(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scope := DB.NewScope(&Email{})
|
scope := DB.NewScope(&Email{})
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email should have index idx_email_email")
|
t.Errorf("Email should have index idx_email_email")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +318,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email's index idx_email_email should be deleted")
|
t.Errorf("Email's index idx_email_email should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,7 +326,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,7 +334,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,7 +342,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,7 +364,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,6 +384,8 @@ type EmailWithIdx struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoMigration(t *testing.T) {
|
func TestAutoMigration(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
DB.AutoMigrate(&Address{})
|
DB.AutoMigrate(&Address{})
|
||||||
DB.DropTable(&EmailWithIdx{})
|
DB.DropTable(&EmailWithIdx{})
|
||||||
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
||||||
@ -391,11 +396,11 @@ func TestAutoMigration(t *testing.T) {
|
|||||||
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
||||||
|
|
||||||
scope := DB.NewScope(&EmailWithIdx{})
|
scope := DB.NewScope(&EmailWithIdx{})
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_agent") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_email_with_idxes_registered_at") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -462,6 +467,8 @@ type MultipleIndexes struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMultipleIndexes(t *testing.T) {
|
func TestMultipleIndexes(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
||||||
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
||||||
}
|
}
|
||||||
@ -474,23 +481,23 @@ func TestMultipleIndexes(t *testing.T) {
|
|||||||
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
||||||
|
|
||||||
scope := DB.NewScope(&MultipleIndexes{})
|
scope := DB.NewScope(&MultipleIndexes{})
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_name") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_email") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multiple_indexes_email") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multipleindexes_user_other") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
|
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multiple_indexes_other") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,6 +547,8 @@ func TestModifyColumnType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexWithPrefixLength(t *testing.T) {
|
func TestIndexWithPrefixLength(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
||||||
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
||||||
}
|
}
|
||||||
@ -571,7 +580,7 @@ func TestIndexWithPrefixLength(t *testing.T) {
|
|||||||
if err := DB.CreateTable(table).Error; err != nil {
|
if err := DB.CreateTable(table).Error; err != nil {
|
||||||
t.Errorf("Failed to create %s table: %v", tableName, err)
|
t.Errorf("Failed to create %s table: %v", tableName, err)
|
||||||
}
|
}
|
||||||
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
|
if !scope.Dialect().HasIndex(ctx, tableName, "idx_index_with_prefixes_length") {
|
||||||
t.Errorf("Failed to create %s table index:", tableName)
|
t.Errorf("Failed to create %s table index:", tableName)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
@ -1684,7 +1685,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
|
|||||||
|
|
||||||
called := 0
|
called := 0
|
||||||
|
|
||||||
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
|
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(ctx context.Context, scope *gorm.Scope) {
|
||||||
called = called + 1
|
called = called + 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
83
scope.go
83
scope.go
@ -2,6 +2,7 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
@ -357,11 +358,11 @@ func (scope *Scope) Raw(sql string) *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec perform generated SQL
|
// Exec perform generated SQL
|
||||||
func (scope *Scope) Exec() *Scope {
|
func (scope *Scope) Exec(ctx context.Context) *Scope {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||||
scope.db.RowsAffected = count
|
scope.db.RowsAffected = count
|
||||||
}
|
}
|
||||||
@ -856,7 +857,7 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
|||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
func (scope *Scope) callCallbacks(ctx context.Context, funcs []*func(ctx context.Context, s *Scope)) *Scope {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
if db, ok := scope.db.db.(sqlTx); ok {
|
if db, ok := scope.db.db.(sqlTx); ok {
|
||||||
@ -866,7 +867,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
(*f)(scope)
|
(*f)(ctx, scope)
|
||||||
if scope.skipLeft {
|
if scope.skipLeft {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -933,22 +934,22 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) row() *sql.Row {
|
func (scope *Scope) row(ctx context.Context) *sql.Row {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
result := &RowQueryResult{}
|
result := &RowQueryResult{}
|
||||||
scope.InstanceSet("row_query_result", result)
|
scope.InstanceSet("row_query_result", result)
|
||||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries)
|
||||||
|
|
||||||
return result.Row
|
return result.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
func (scope *Scope) rows(ctx context.Context) (*sql.Rows, error) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
result := &RowsQueryResult{}
|
result := &RowsQueryResult{}
|
||||||
scope.InstanceSet("row_query_result", result)
|
scope.InstanceSet("row_query_result", result)
|
||||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries)
|
||||||
|
|
||||||
return result.Rows, result.Error
|
return result.Rows, result.Error
|
||||||
}
|
}
|
||||||
@ -979,7 +980,7 @@ func (scope *Scope) isQueryForColumn(query interface{}, column string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
func (scope *Scope) pluck(ctx context.Context, column string, value interface{}) *Scope {
|
||||||
dest := reflect.Indirect(reflect.ValueOf(value))
|
dest := reflect.Indirect(reflect.ValueOf(value))
|
||||||
if dest.Kind() != reflect.Slice {
|
if dest.Kind() != reflect.Slice {
|
||||||
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
||||||
@ -994,7 +995,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|||||||
scope.Search.Select(column)
|
scope.Search.Select(column)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := scope.rows()
|
rows, err := scope.rows(ctx)
|
||||||
if scope.Err(err) == nil {
|
if scope.Err(err) == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@ -1010,7 +1011,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) count(value interface{}) *Scope {
|
func (scope *Scope) count(ctx context.Context, value interface{}) *Scope {
|
||||||
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
|
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
|
||||||
if len(scope.Search.group) != 0 {
|
if len(scope.Search.group) != 0 {
|
||||||
if len(scope.Search.havingConditions) != 0 {
|
if len(scope.Search.havingConditions) != 0 {
|
||||||
@ -1027,7 +1028,7 @@ func (scope *Scope) count(value interface{}) *Scope {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
scope.Search.ignoreOrderQuery = true
|
scope.Search.ignoreOrderQuery = true
|
||||||
scope.Err(scope.row().Scan(value))
|
scope.Err(scope.row(ctx).Scan(value))
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1124,11 +1125,11 @@ func (scope *Scope) getTableOptions() string {
|
|||||||
return " " + tableOptions.(string)
|
return " " + tableOptions.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
func (scope *Scope) createJoinTable(ctx context.Context, field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
joinTable := joinTableHandler.Table(scope.db)
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
if !scope.Dialect().HasTable(joinTable) {
|
if !scope.Dialect().HasTable(ctx, joinTable) {
|
||||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||||
|
|
||||||
var sqlTypes, primaryKeys []string
|
var sqlTypes, primaryKeys []string
|
||||||
@ -1156,11 +1157,11 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||||||
|
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
||||||
}
|
}
|
||||||
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
scope.NewDB().Table(joinTable).AutoMigrateContext(ctx, joinTableHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) createTable() *Scope {
|
func (scope *Scope) createTable(ctx context.Context) *Scope {
|
||||||
var tags []string
|
var tags []string
|
||||||
var primaryKeys []string
|
var primaryKeys []string
|
||||||
var primaryKeyInColumnType = false
|
var primaryKeyInColumnType = false
|
||||||
@ -1181,7 +1182,7 @@ func (scope *Scope) createTable() *Scope {
|
|||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||||
}
|
}
|
||||||
scope.createJoinTable(field)
|
scope.createJoinTable(ctx, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
var primaryKeyStr string
|
var primaryKeyStr string
|
||||||
@ -1189,27 +1190,27 @@ func (scope *Scope) createTable() *Scope {
|
|||||||
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec(ctx)
|
||||||
|
|
||||||
scope.autoIndex()
|
scope.autoIndex()
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropTable() *Scope {
|
func (scope *Scope) dropTable(ctx context.Context) *Scope {
|
||||||
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec(ctx)
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
func (scope *Scope) modifyColumn(ctx context.Context, column string, typ string) {
|
||||||
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
|
scope.db.AddError(scope.Dialect().ModifyColumn(ctx, scope.QuotedTableName(), scope.Quote(column), typ))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropColumn(column string) {
|
func (scope *Scope) dropColumn(ctx context.Context, column string) {
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
func (scope *Scope) addIndex(ctx context.Context, unique bool, indexName string, column ...string) {
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
if scope.Dialect().HasIndex(ctx, scope.TableName(), indexName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1223,23 +1224,23 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|||||||
sqlCreate = "CREATE UNIQUE INDEX"
|
sqlCreate = "CREATE UNIQUE INDEX"
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
func (scope *Scope) addForeignKey(ctx context.Context, field string, dest string, onDelete string, onUpdate string) {
|
||||||
// Compatible with old generated key
|
// Compatible with old generated key
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
|
|
||||||
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
if scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeForeignKey(field string, dest string) {
|
func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest string) {
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
if !scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var mysql mysql
|
var mysql mysql
|
||||||
@ -1250,28 +1251,28 @@ func (scope *Scope) removeForeignKey(field string, dest string) {
|
|||||||
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeIndex(indexName string) {
|
func (scope *Scope) removeIndex(ctx context.Context, indexName string) {
|
||||||
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
scope.Dialect().RemoveIndex(ctx, scope.TableName(), indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) autoMigrate() *Scope {
|
func (scope *Scope) autoMigrate(ctx context.Context) *Scope {
|
||||||
tableName := scope.TableName()
|
tableName := scope.TableName()
|
||||||
quotedTableName := scope.QuotedTableName()
|
quotedTableName := scope.QuotedTableName()
|
||||||
|
|
||||||
if !scope.Dialect().HasTable(tableName) {
|
if !scope.Dialect().HasTable(ctx, tableName) {
|
||||||
scope.createTable()
|
scope.createTable(ctx)
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(ctx, tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scope.createJoinTable(field)
|
scope.createJoinTable(ctx, field)
|
||||||
}
|
}
|
||||||
scope.autoIndex()
|
scope.autoIndex()
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user