add context funcs

This commit is contained in:
Daniel Gatis 2020-01-03 22:41:09 -03:00
parent 79a77d771d
commit c7f9cb552b
23 changed files with 474 additions and 272 deletions

View File

@ -1,6 +1,9 @@
package gorm package gorm
import "fmt" import (
"context"
"fmt"
)
// DefaultCallback default callbacks defined by gorm // DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{logger: nopLogger{}} var DefaultCallback = &Callback{logger: nopLogger{}}
@ -14,24 +17,24 @@ var DefaultCallback = &Callback{logger: nopLogger{}}
// Field `processors` contains all callback processors, will be used to generate above callbacks in order // Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct { type Callback struct {
logger logger logger logger
creates []*func(scope *Scope) creates []*func(ctx context.Context, scope *Scope)
updates []*func(scope *Scope) updates []*func(ctx context.Context, scope *Scope)
deletes []*func(scope *Scope) deletes []*func(ctx context.Context, scope *Scope)
queries []*func(scope *Scope) queries []*func(ctx context.Context, scope *Scope)
rowQueries []*func(scope *Scope) rowQueries []*func(ctx context.Context, scope *Scope)
processors []*CallbackProcessor processors []*CallbackProcessor
} }
// CallbackProcessor contains callback informations // CallbackProcessor contains callback informations
type CallbackProcessor struct { type CallbackProcessor struct {
logger logger logger logger
name string // current callback's name name string // current callback's name
before string // register current callback before a callback before string // register current callback before a callback
after string // register current callback after a callback after string // register current callback after a callback
replace bool // replace callbacks with same name replace bool // replace callbacks with same name
remove bool // delete callbacks with same name remove bool // delete callbacks with same name
kind string // callback type: create, update, delete, query, row_query kind string // callback type: create, update, delete, query, row_query
processor *func(scope *Scope) // callback handler processor *func(ctx context.Context, scope *Scope) // callback handler
parent *Callback parent *Callback
} }
@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
} }
// Register a new callback, refer `Callbacks.Create` // Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { func (cp *CallbackProcessor) Register(callbackName string, callback func(ctx context.Context, scope *Scope)) {
if cp.kind == "row_query" { if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
@ -123,7 +126,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// scope.SetColumn("CreatedAt", now) // scope.SetColumn("CreatedAt", now)
// scope.SetColumn("UpdatedAt", now) // scope.SetColumn("UpdatedAt", now)
// }) // })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { func (cp *CallbackProcessor) Replace(callbackName string, callback func(ctx context.Context, scope *Scope)) {
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName cp.name = callbackName
cp.processor = &callback cp.processor = &callback
@ -134,7 +137,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
// Get registered callback // Get registered callback
// db.Callback().Create().Get("gorm:create") // db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { func (cp *CallbackProcessor) Get(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
for _, p := range cp.parent.processors { for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind { if p.name == callbackName && p.kind == cp.kind {
if p.remove { if p.remove {
@ -158,7 +161,7 @@ func getRIndex(strs []string, str string) int {
} }
// sortProcessors sort callback processors based on its before, after, remove, replace // sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { func sortProcessors(cps []*CallbackProcessor) []*func(ctx context.Context, scope *Scope) {
var ( var (
allNames, sortedNames []string allNames, sortedNames []string
sortCallbackProcessor func(c *CallbackProcessor) sortCallbackProcessor func(c *CallbackProcessor)
@ -211,7 +214,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
sortCallbackProcessor(cp) sortCallbackProcessor(cp)
} }
var sortedFuncs []*func(scope *Scope) var sortedFuncs []*func(ctx context.Context, scope *Scope)
for _, name := range sortedNames { for _, name := range sortedNames {
if index := getRIndex(allNames, name); !cps[index].remove { if index := getRIndex(allNames, name); !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor) sortedFuncs = append(sortedFuncs, cps[index].processor)

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
) )
@ -19,7 +20,7 @@ func init() {
} }
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) { func beforeCreateCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("BeforeSave") scope.CallMethod("BeforeSave")
} }
@ -29,7 +30,7 @@ func beforeCreateCallback(scope *Scope) {
} }
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) { func updateTimeStampForCreateCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
now := scope.db.nowFunc() now := scope.db.nowFunc()
@ -48,7 +49,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
} }
// createCallback the callback used to insert data into database // createCallback the callback used to insert data into database
func createCallback(scope *Scope) { func createCallback(ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
@ -100,10 +101,10 @@ func createCallback(scope *Scope) {
returningColumn = scope.Quote(primaryField.DBName) returningColumn = scope.Quote(primaryField.DBName)
} }
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(ctx, quotedTableName, returningColumn, columns)
var lastInsertIDReturningSuffix string var lastInsertIDReturningSuffix string
if lastInsertIDOutputInterstitial == "" { if lastInsertIDOutputInterstitial == "" {
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(ctx, quotedTableName, returningColumn)
} }
if len(columns) == 0 { if len(columns) == 0 {
@ -130,7 +131,7 @@ func createCallback(scope *Scope) {
// execute create sql: no primaryField // execute create sql: no primaryField
if primaryField == nil { if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count // set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
@ -146,7 +147,7 @@ func createCallback(scope *Scope) {
// execute create sql: lastInsertID implemention for majority of dialects // execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count // set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
@ -162,7 +163,7 @@ func createCallback(scope *Scope) {
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() { if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { if err := scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false primaryField.IsBlank = false
scope.db.RowsAffected = 1 scope.db.RowsAffected = 1
} }
@ -174,7 +175,7 @@ func createCallback(scope *Scope) {
} }
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) { func forceReloadAfterCreateCallback(_ctx context.Context, scope *Scope) {
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
@ -187,7 +188,7 @@ func forceReloadAfterCreateCallback(scope *Scope) {
} }
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
func afterCreateCallback(scope *Scope) { func afterCreateCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("AfterCreate") scope.CallMethod("AfterCreate")
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
) )
@ -15,7 +16,7 @@ func init() {
} }
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) { func beforeDeleteCallback(_ctx context.Context, scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("missing WHERE clause while deleting")) scope.Err(errors.New("missing WHERE clause while deleting"))
return return
@ -26,7 +27,7 @@ func beforeDeleteCallback(scope *Scope) {
} }
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) { func deleteCallback(ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
var extraOption string var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok { if str, ok := scope.Get("gorm:delete_option"); ok {
@ -43,20 +44,20 @@ func deleteCallback(scope *Scope) {
scope.AddToVars(scope.db.nowFunc()), scope.AddToVars(scope.db.nowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(extraOption),
)).Exec() )).Exec(ctx)
} else { } else {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"DELETE FROM %v%v%v", "DELETE FROM %v%v%v",
scope.QuotedTableName(), scope.QuotedTableName(),
addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(extraOption),
)).Exec() )).Exec(ctx)
} }
} }
} }
// afterDeleteCallback will invoke `AfterDelete` method after deleting // afterDeleteCallback will invoke `AfterDelete` method after deleting
func afterDeleteCallback(scope *Scope) { func afterDeleteCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("AfterDelete") scope.CallMethod("AfterDelete")
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -14,7 +15,7 @@ func init() {
} }
// queryCallback used to query data from database // queryCallback used to query data from database
func queryCallback(scope *Scope) { func queryCallback(ctx context.Context, scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return return
} }
@ -69,7 +70,7 @@ func queryCallback(scope *Scope) {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
} }
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if rows, err := scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close() defer rows.Close()
columns, _ := rows.Columns() columns, _ := rows.Columns()
@ -102,7 +103,7 @@ func queryCallback(scope *Scope) {
} }
// afterQueryCallback will invoke `AfterFind` method after querying // afterQueryCallback will invoke `AfterFind` method after querying
func afterQueryCallback(scope *Scope) { func afterQueryCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("AfterFind") scope.CallMethod("AfterFind")
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -9,7 +10,7 @@ import (
) )
// preloadCallback used to preload associations // preloadCallback used to preload associations
func preloadCallback(scope *Scope) { func preloadCallback(ctx context.Context, scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return return
} }
@ -18,9 +19,9 @@ func preloadCallback(scope *Scope) {
// If gorm:auto_preload IS NOT a bool then auto preload. // If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value // Else if it IS a bool, use the value
if apb, ok := ap.(bool); !ok { if apb, ok := ap.(bool); !ok {
autoPreload(scope) autoPreload(ctx, scope)
} else if apb { } else if apb {
autoPreload(scope) autoPreload(ctx, scope)
} }
} }
@ -62,13 +63,13 @@ func preloadCallback(scope *Scope) {
switch field.Relationship.Kind { switch field.Relationship.Kind {
case "has_one": case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions) currentScope.handleHasOnePreload(ctx, field, currentPreloadConditions)
case "has_many": case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions) currentScope.handleHasManyPreload(ctx, field, currentPreloadConditions)
case "belongs_to": case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions) currentScope.handleBelongsToPreload(ctx, field, currentPreloadConditions)
case "many_to_many": case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions) currentScope.handleManyToManyPreload(ctx, field, currentPreloadConditions)
default: default:
scope.Err(errors.New("unsupported relation")) scope.Err(errors.New("unsupported relation"))
} }
@ -94,7 +95,7 @@ func preloadCallback(scope *Scope) {
} }
} }
func autoPreload(scope *Scope) { func autoPreload(_ctx context.Context, scope *Scope) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if field.Relationship == nil { if field.Relationship == nil {
continue continue
@ -131,7 +132,7 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*
} }
// handleHasOnePreload used to preload has one associations // handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasOnePreload(_ctx context.Context, field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
@ -183,7 +184,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
} }
// handleHasManyPreload used to preload has many associations // handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasManyPreload(_ctx context.Context, field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
@ -236,7 +237,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
} }
// handleBelongsToPreload used to preload belongs to associations // handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleBelongsToPreload(_ctx context.Context, field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
// preload conditions // preload conditions
@ -283,7 +284,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
} }
// handleManyToManyPreload used to preload many to many associations // handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleManyToManyPreload(ctx context.Context, field *Field, conditions []interface{}) {
var ( var (
relation = field.Relationship relation = field.Relationship
joinTableHandler = relation.JoinTableHandler joinTableHandler = relation.JoinTableHandler
@ -346,7 +347,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
scope.New(elem.Addr().Interface()). scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true). InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries) callCallbacks(ctx, scope.db.parent.callbacks.queries)
var foreignKeys = make([]interface{}, len(sourceKeys)) var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table // generate hashed forkey keys in join table

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
) )
@ -20,7 +21,7 @@ type RowsQueryResult struct {
} }
// queryCallback used to query data from database // queryCallback used to query data from database
func rowQueryCallback(scope *Scope) { func rowQueryCallback(ctx context.Context, scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok { if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL() scope.prepareQuerySQL()
@ -33,9 +34,9 @@ func rowQueryCallback(scope *Scope) {
} }
if rowResult, ok := result.(*RowQueryResult); ok { if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) rowResult.Row = scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok { } else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...)
} }
} }
} }

View File

@ -1,15 +1,16 @@
package gorm package gorm
import ( import (
"context"
"reflect" "reflect"
"strings" "strings"
) )
func beginTransactionCallback(scope *Scope) { func beginTransactionCallback(_ctx context.Context, scope *Scope) {
scope.Begin() scope.Begin()
} }
func commitOrRollbackTransactionCallback(scope *Scope) { func commitOrRollbackTransactionCallback(_ctx context.Context, scope *Scope) {
scope.CommitOrRollback() scope.CommitOrRollback()
} }
@ -64,7 +65,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
return return
} }
func saveBeforeAssociationsCallback(scope *Scope) { func saveBeforeAssociationsCallback(_ctx context.Context, scope *Scope) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
@ -95,7 +96,7 @@ func saveBeforeAssociationsCallback(scope *Scope) {
} }
} }
func saveAfterAssociationsCallback(scope *Scope) { func saveAfterAssociationsCallback(_ctx context.Context, scope *Scope) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)

View File

@ -1,13 +1,14 @@
package gorm package gorm
import ( import (
"context"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
) )
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
var names []string var names []string
for _, f := range funcs { for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
@ -16,11 +17,11 @@ func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
return reflect.DeepEqual(names, fnames) return reflect.DeepEqual(names, fnames)
} }
func create(s *Scope) {} func create(ctx context.Context, s *Scope) {}
func beforeCreate1(s *Scope) {} func beforeCreate1(ctx context.Context, s *Scope) {}
func beforeCreate2(s *Scope) {} func beforeCreate2(ctx context.Context, s *Scope) {}
func afterCreate1(s *Scope) {} func afterCreate1(ctx context.Context, s *Scope) {}
func afterCreate2(s *Scope) {} func afterCreate2(ctx context.Context, s *Scope) {}
func TestRegisterCallback(t *testing.T) { func TestRegisterCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger} var callback = &Callback{logger: defaultLogger}
@ -83,7 +84,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
} }
} }
func replaceCreate(s *Scope) {} func replaceCreate(ctx context.Context, s *Scope) {}
func TestReplaceCallback(t *testing.T) { func TestReplaceCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger} var callback = &Callback{logger: defaultLogger}

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
@ -21,7 +22,7 @@ func init() {
} }
// assignUpdatingAttributesCallback assign updating attributes to model // assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) { func assignUpdatingAttributesCallback(_ctx context.Context, scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps) scope.InstanceSet("gorm:update_attrs", updateMaps)
@ -32,7 +33,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
} }
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) { func beforeUpdateCallback(_ctx context.Context, scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("missing WHERE clause while updating")) scope.Err(errors.New("missing WHERE clause while updating"))
return return
@ -48,14 +49,14 @@ func beforeUpdateCallback(scope *Scope) {
} }
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) { func updateTimeStampForUpdateCallback(_ctx context.Context, scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", scope.db.nowFunc()) scope.SetColumn("UpdatedAt", scope.db.nowFunc())
} }
} }
// updateCallback the callback used to update data to database // updateCallback the callback used to update data to database
func updateCallback(scope *Scope) { func updateCallback(ctx context.Context, scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
var sqls []string var sqls []string
@ -103,13 +104,13 @@ func updateCallback(scope *Scope) {
strings.Join(sqls, ", "), strings.Join(sqls, ", "),
addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(extraOption),
)).Exec() )).Exec(ctx)
} }
} }
} }
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
func afterUpdateCallback(scope *Scope) { func afterUpdateCallback(_ctx context.Context, scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("AfterUpdate") scope.CallMethod("AfterUpdate")

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"context"
"errors" "errors"
"reflect" "reflect"
"testing" "testing"
@ -182,22 +183,23 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback` should be nil") t.Errorf("`gorm:test_callback` should be nil")
} }
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback") callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil { if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil") t.Errorf("`gorm:test_callback` should be non-nil")
} }
callback(scope) ctx := context.Background()
callback(ctx, scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
} }
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) DB.Callback().Create().Replace("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback") callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil { if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil") t.Errorf("`gorm:test_callback` should be non-nil")
} }
callback(scope) callback(ctx, scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
} }
@ -207,12 +209,12 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback` should be nil") t.Errorf("`gorm:test_callback` should be nil")
} }
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback") callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil { if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil") t.Errorf("`gorm:test_callback` should be non-nil")
} }
callback(scope) callback(ctx, scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
} }
@ -220,7 +222,7 @@ func TestGetCallback(t *testing.T) {
func TestUseDefaultCallback(t *testing.T) { func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create" createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { gorm.DefaultCallback.Create().Register(createCallbackName, func(context.Context, *gorm.Scope) {
// nop // nop
}) })
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
@ -233,16 +235,17 @@ func TestUseDefaultCallback(t *testing.T) {
updateCallbackName := "gorm:test_use_default_callback_for_update" updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value" scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { gorm.DefaultCallback.Update().Register(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
scope.Set(scopeValueName, 1) scope.Set(scopeValueName, 1)
}) })
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { gorm.DefaultCallback.Update().Replace(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
scope.Set(scopeValueName, 2) scope.Set(scopeValueName, 2)
}) })
scope := DB.NewScope(nil) scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName) callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope) ctx := context.Background()
callback(ctx, scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 { if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
} }

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -21,12 +22,13 @@ type CustomColumnAndIgnoredFieldClash struct {
} }
func TestCustomizeColumn(t *testing.T) { func TestCustomizeColumn(t *testing.T) {
ctx := context.Background()
col := "mapped_name" col := "mapped_name"
DB.DropTable(&CustomizeColumn{}) DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{}) DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) { if !scope.Dialect().HasColumn(ctx, scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col) t.Errorf("CustomizeColumn should have column %s", col)
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
@ -24,26 +25,26 @@ type Dialect interface {
DataTypeOf(field *StructField) string DataTypeOf(field *StructField) string
// HasIndex check has index or not // HasIndex check has index or not
HasIndex(tableName string, indexName string) bool HasIndex(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not // HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index // RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error RemoveIndex(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not // HasTable check has table or not
HasTable(tableName string) bool HasTable(ctx context.Context, tableName string) bool
// HasColumn check has column or not // HasColumn check has column or not
HasColumn(tableName string, columnName string) bool HasColumn(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type // ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error) LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string LastInsertIDOutputInterstitial(ctx context.Context, tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string LastInsertIDReturningSuffix(ctx context.Context, tableName, columnName string) string
// DefaultValueStr // DefaultValueStr
DefaultValueStr() string DefaultValueStr() string
@ -54,7 +55,7 @@ type Dialect interface {
NormalizeIndexAndColumn(indexName, columnName string) (string, string) NormalizeIndexAndColumn(indexName, columnName string) (string, string)
// CurrentDatabase return current database name // CurrentDatabase return current database name
CurrentDatabase() string CurrentDatabase(ctx context.Context) string
} }
var dialectsMap = map[string]Dialect{} var dialectsMap = map[string]Dialect{}
@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType) return fieldValue, dataType, size, strings.TrimSpace(additionalType)
} }
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") { if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2) splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1] return splitStrings[0], splitStrings[1]
} }
return dialect.CurrentDatabase(), tableName return dialect.CurrentDatabase(ctx), tableName
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s commonDialect) HasIndex(tableName string, indexName string) bool { func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0 return count > 0
} }
func (s commonDialect) RemoveIndex(tableName string, indexName string) error { func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
return err return err
} }
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { func (s commonDialect) HasForeignKey(_ctx context.Context, tableName string, foreignKeyName string) bool {
return false return false
} }
func (s commonDialect) HasTable(tableName string) bool { func (s commonDialect) HasTable(ctx context.Context, tableName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s commonDialect) HasColumn(tableName string, columnName string) bool { func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0 return count > 0
} }
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err return err
} }
func (s commonDialect) CurrentDatabase() (name string) { func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name) s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return return
} }
@ -162,11 +163,11 @@ func (commonDialect) SelectFromDummyTable() string {
return "" return ""
} }
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { func (commonDialect) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
return "" return ""
} }
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { func (commonDialect) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
return "" return ""
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"crypto/sha1" "crypto/sha1"
"database/sql" "database/sql"
"fmt" "fmt"
@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s mysql) RemoveIndex(tableName string, indexName string) error { func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err return err
} }
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err return err
} }
@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err
return return
} }
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0 return count > 0
} }
func (s mysql) HasTable(tableName string) bool { func (s mysql) HasTable(ctx context.Context, tableName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
var name string var name string
// allow mysql database name with '-' character // allow mysql database name with '-' character
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false return false
} }
@ -183,9 +184,9 @@ func (s mysql) HasTable(tableName string) bool {
} }
} }
func (s mysql) HasIndex(tableName string, indexName string) bool { func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
panic(err) panic(err)
} else { } else {
defer rows.Close() defer rows.Close()
@ -193,9 +194,9 @@ func (s mysql) HasIndex(tableName string, indexName string) bool {
} }
} }
func (s mysql) HasColumn(tableName string, columnName string) bool { func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
panic(err) panic(err)
} else { } else {
defer rows.Close() defer rows.Close()
@ -203,8 +204,8 @@ func (s mysql) HasColumn(tableName string, columnName string) bool {
} }
} }
func (s mysql) CurrentDatabase() (name string) { func (s mysql) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name) s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return return
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@ -91,40 +92,40 @@ func (s *postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s postgres) HasIndex(tableName string, indexName string) bool { func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0 return count > 0
} }
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0 return count > 0
} }
func (s postgres) HasTable(tableName string) bool { func (s postgres) HasTable(ctx context.Context, tableName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s postgres) HasColumn(tableName string, columnName string) bool { func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0 return count > 0
} }
func (s postgres) CurrentDatabase() (name string) { func (s postgres) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
return return
} }
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { func (s postgres) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, key string, columns []string) string {
return "" return ""
} }
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { func (s postgres) LastInsertIDReturningSuffix(_ctx context.Context, tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key) return fmt.Sprintf("RETURNING %v.%v", tableName, key)
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -70,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s sqlite3) HasIndex(tableName string, indexName string) bool { func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s sqlite3) HasTable(tableName string) bool { func (s sqlite3) HasTable(ctx context.Context, tableName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s sqlite3) HasColumn(tableName string, columnName string) bool { func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s sqlite3) CurrentDatabase() (name string) { func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) {
var ( var (
ifaces = make([]interface{}, 3) ifaces = make([]interface{}, 3)
pointers = make([]*string, 3) pointers = make([]*string, 3)
@ -97,7 +98,7 @@ func (s sqlite3) CurrentDatabase() (name string) {
for i = 0; i < 3; i++ { for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i] ifaces[i] = &pointers[i]
} }
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { if err := s.db.QueryRowContext(ctx, "PRAGMA database_list").Scan(ifaces...); err != nil {
return return
} }
if pointers[1] != nil { if pointers[1] != nil {

View File

@ -1,6 +1,7 @@
package mssql package mssql
import ( import (
"context"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
@ -15,7 +16,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
func setIdentityInsert(scope *gorm.Scope) { func setIdentityInsert(ctx context.Context, scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" { if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() { for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
@ -26,7 +27,7 @@ func setIdentityInsert(scope *gorm.Scope) {
} }
} }
func turnOffIdentityInsert(scope *gorm.Scope) { func turnOffIdentityInsert(ctx context.Context, scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" { if scope.Dialect().GetName() == "mssql" {
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
@ -122,49 +123,49 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
return field.IsPrimaryKey return field.IsPrimaryKey
} }
func (s mssql) HasIndex(tableName string, indexName string) bool { func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s mssql) RemoveIndex(tableName string, indexName string) error { func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err return err
} }
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow(`SELECT count(*) s.db.QueryRowContext(ctx, `SELECT count(*)
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
inner join information_schema.tables as I on I.TABLE_NAME = T.name inner join information_schema.tables as I on I.TABLE_NAME = T.name
WHERE F.name = ? WHERE F.name = ?
AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
return count > 0 return count > 0
} }
func (s mssql) HasTable(tableName string) bool { func (s mssql) HasTable(ctx context.Context, tableName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0 return count > 0
} }
func (s mssql) HasColumn(tableName string, columnName string) bool { func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0 return count > 0
} }
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err return err
} }
func (s mssql) CurrentDatabase() (name string) { func (s mssql) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name)
return return
} }
@ -198,7 +199,7 @@ func (mssql) SelectFromDummyTable() string {
return "" return ""
} }
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { func (mssql) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
if len(columns) == 0 { if len(columns) == 0 {
// No OUTPUT to query // No OUTPUT to query
return "" return ""
@ -206,7 +207,7 @@ func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, column
return fmt.Sprintf("OUTPUT Inserted.%v", columnName) return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
} }
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { func (mssql) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
return "; SELECT SCOPE_IDENTITY()" return "; SELECT SCOPE_IDENTITY()"
} }
@ -220,12 +221,12 @@ func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, stri
return indexName, columnName return indexName, columnName
} }
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") { if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2) splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1] return splitStrings[0], splitStrings[1]
} }
return dialect.CurrentDatabase(), tableName return dialect.CurrentDatabase(ctx), tableName
} }
// JSON type to support easy handling of JSON data in character table fields // JSON type to support easy handling of JSON data in character table fields

View File

@ -1,6 +1,9 @@
package gorm_test package gorm_test
import "testing" import (
"context"
"testing"
)
type BasePost struct { type BasePost struct {
Id int64 Id int64
@ -27,9 +30,10 @@ type EngadgetPost struct {
} }
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
ctx := context.Background()
dialect := DB.NewScope(&EngadgetPost{}).Dialect() dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{}) engadgetPostScope := DB.NewScope(&EngadgetPost{})
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { if !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns") t.Errorf("should has prefix for embedded columns")
} }
@ -38,7 +42,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
} }
hnScope := DB.NewScope(&HNPost{}) hnScope := DB.NewScope(&HNPost{})
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { if !dialect.HasColumn(ctx, hnScope.TableName(), "user_id") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_name") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns") t.Errorf("should has prefix for embedded columns")
} }
} }

View File

@ -7,10 +7,10 @@ import (
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface { type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
type sqlDb interface { type sqlDb interface {

237
main.go
View File

@ -327,51 +327,91 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.FirstContext(ctx, out, where...)
}
func (s *DB) FirstContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out) newScope := s.NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC"). return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take return a record that match given conditions, the order will depend on the database implementation
func (s *DB) Take(out interface{}, where ...interface{}) *DB { func (s *DB) Take(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.TakeContext(ctx, out, where...)
}
func (s *DB) TakeContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out) newScope := s.NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db return newScope.inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
} }
// Last find last record that match given conditions, order by primary key // Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB { func (s *DB) Last(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.LastContext(ctx, out, where...)
}
func (s *DB) LastContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out) newScope := s.NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC"). return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
} }
// Find find records that match given conditions // Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB { func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db ctx := context.Background()
return s.FindContext(ctx, out, where...)
}
func (s *DB) FindContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
return s.NewScope(out).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
} }
//Preloads preloads relations, don`t touch out //Preloads preloads relations, don`t touch out
func (s *DB) Preloads(out interface{}) *DB { func (s *DB) Preloads(out interface{}) *DB {
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db ctx := context.Background()
return s.PreloadsContext(ctx, out)
}
func (s *DB) PreloadsContext(ctx context.Context, out interface{}) *DB {
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(ctx, s.parent.callbacks.queries).db
} }
// Scan scan value to a struct // Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB { func (s *DB) Scan(dest interface{}) *DB {
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db ctx := context.Background()
return s.ScanContext(ctx, dest)
}
func (s *DB) ScanContext(ctx context.Context, dest interface{}) *DB {
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(ctx, s.parent.callbacks.queries).db
} }
// Row return `*sql.Row` with given conditions // Row return `*sql.Row` with given conditions
func (s *DB) Row() *sql.Row { func (s *DB) Row() *sql.Row {
return s.NewScope(s.Value).row() ctx := context.Background()
return s.RowContext(ctx)
}
func (s *DB) RowContext(ctx context.Context) *sql.Row {
return s.NewScope(s.Value).row(ctx)
} }
// Rows return `*sql.Rows` with given conditions // Rows return `*sql.Rows` with given conditions
func (s *DB) Rows() (*sql.Rows, error) { func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows() ctx := context.Background()
return s.RowsContext(ctx)
}
func (s *DB) RowsContext(ctx context.Context) (*sql.Rows, error) {
return s.NewScope(s.Value).rows(ctx)
} }
// ScanRows scan `*sql.Rows` to give struct // ScanRows scan `*sql.Rows` to give struct
@ -393,12 +433,22 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
// var ages []int64 // var ages []int64
// db.Find(&users).Pluck("age", &ages) // db.Find(&users).Pluck("age", &ages)
func (s *DB) Pluck(column string, value interface{}) *DB { func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db ctx := context.Background()
return s.PluckContext(ctx, column, value)
}
func (s *DB) PluckContext(ctx context.Context, column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(ctx, column, value).db
} }
// Count get how many records for a model // Count get how many records for a model
func (s *DB) Count(value interface{}) *DB { func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db ctx := context.Background()
return s.CountContext(ctx, value)
}
func (s *DB) CountContext(ctx context.Context, value interface{}) *DB {
return s.NewScope(s.Value).count(ctx, value).db
} }
// Related get related associations // Related get related associations
@ -409,8 +459,13 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/crud.html#firstorinit // https://jinzhu.github.io/gorm/crud.html#firstorinit
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.FirstOrInitContext(ctx, out, where...)
}
func (s *DB) FirstOrInitContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
c := s.clone() c := s.clone()
if result := c.First(out, where...); result.Error != nil { if result := c.FirstContext(ctx, out, where...); result.Error != nil {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
@ -424,14 +479,19 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) // FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/crud.html#firstorcreate // https://jinzhu.github.io/gorm/crud.html#firstorcreate
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.FirstOrCreateContext(ctx, out, where...)
}
func (s *DB) FirstOrCreateContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
c := s.clone() c := s.clone()
if result := s.First(out, where...); result.Error != nil { if result := s.First(out, where...); result.Error != nil {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(ctx, c.parent.callbacks.creates).db
} else if len(c.search.assignAttrs) > 0 { } else if len(c.search.assignAttrs) > 0 {
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(ctx, c.parent.callbacks.updates).db
} }
return c return c
} }
@ -439,54 +499,89 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
// WARNING when update with struct, GORM will not update fields that with zero value // WARNING when update with struct, GORM will not update fields that with zero value
func (s *DB) Update(attrs ...interface{}) *DB { func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true) ctx := context.Background()
return s.UpdateContext(ctx, attrs...)
}
func (s *DB) UpdateContext(ctx context.Context, attrs ...interface{}) *DB {
return s.UpdatesContext(ctx, toSearchableMap(attrs...), true)
} }
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
ctx := context.Background()
return s.UpdatesContext(ctx, values, ignoreProtectedAttrs...)
}
func (s *DB) UpdatesContext(ctx context.Context, values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.NewScope(s.Value). return s.NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values). InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db callCallbacks(ctx, s.parent.callbacks.updates).db
} }
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumn(attrs ...interface{}) *DB { func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
return s.UpdateColumns(toSearchableMap(attrs...)) ctx := context.Background()
return s.UpdateColumnContext(ctx, attrs...)
}
func (s *DB) UpdateColumnContext(ctx context.Context, attrs ...interface{}) *DB {
return s.UpdateColumnsContext(ctx, toSearchableMap(attrs...))
} }
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumns(values interface{}) *DB { func (s *DB) UpdateColumns(values interface{}) *DB {
ctx := context.Background()
return s.UpdateColumnsContext(ctx, values)
}
func (s *DB) UpdateColumnsContext(ctx context.Context, values interface{}) *DB {
return s.NewScope(s.Value). return s.NewScope(s.Value).
Set("gorm:update_column", true). Set("gorm:update_column", true).
Set("gorm:save_associations", false). Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values). InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db callCallbacks(ctx, s.parent.callbacks.updates).db
} }
// Save update value in database, if the value doesn't have primary key, will insert it // Save update value in database, if the value doesn't have primary key, will insert it
func (s *DB) Save(value interface{}) *DB { func (s *DB) Save(value interface{}) *DB {
ctx := context.Background()
return s.SaveContext(ctx, value)
}
func (s *DB) SaveContext(ctx context.Context, value interface{}) *DB {
scope := s.NewScope(value) scope := s.NewScope(value)
if !scope.PrimaryKeyZero() { if !scope.PrimaryKeyZero() {
newDB := scope.callCallbacks(s.parent.callbacks.updates).db newDB := scope.callCallbacks(ctx, s.parent.callbacks.updates).db
if newDB.Error == nil && newDB.RowsAffected == 0 { if newDB.Error == nil && newDB.RowsAffected == 0 {
return s.New().Table(scope.TableName()).FirstOrCreate(value) return s.New().Table(scope.TableName()).FirstOrCreateContext(ctx, value)
} }
return newDB return newDB
} }
return scope.callCallbacks(s.parent.callbacks.creates).db return scope.callCallbacks(ctx, s.parent.callbacks.creates).db
} }
// Create insert the value into database // Create insert the value into database
func (s *DB) Create(value interface{}) *DB { func (s *DB) Create(value interface{}) *DB {
ctx := context.Background()
return s.CreateContext(ctx, value)
}
func (s *DB) CreateContext(ctx context.Context, value interface{}) *DB {
scope := s.NewScope(value) scope := s.NewScope(value)
return scope.callCallbacks(s.parent.callbacks.creates).db return scope.callCallbacks(ctx, s.parent.callbacks.creates).db
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time // WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
func (s *DB) Delete(value interface{}, where ...interface{}) *DB { func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db ctx := context.Background()
return s.DeleteContext(ctx, value, where...)
}
func (s *DB) DeleteContext(ctx context.Context, value interface{}, where ...interface{}) *DB {
return s.NewScope(value).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.deletes).db
} }
// Raw use raw sql as conditions, won't run it unless invoked by other methods // Raw use raw sql as conditions, won't run it unless invoked by other methods
@ -497,11 +592,16 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
// Exec execute raw sql // Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB { func (s *DB) Exec(sql string, values ...interface{}) *DB {
ctx := context.Background()
return s.ExecContext(ctx, sql, values...)
}
func (s *DB) ExecContext(ctx context.Context, sql string, values ...interface{}) *DB {
scope := s.NewScope(nil) scope := s.NewScope(nil)
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSQL) scope.Raw(generatedSQL)
return scope.Exec().db return scope.Exec(ctx).db
} }
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
@ -628,32 +728,47 @@ func (s *DB) RecordNotFound() bool {
// CreateTable create table for models // CreateTable create table for models
func (s *DB) CreateTable(models ...interface{}) *DB { func (s *DB) CreateTable(models ...interface{}) *DB {
ctx := context.Background()
return s.CreateTableContext(ctx, models...)
}
func (s *DB) CreateTableContext(ctx context.Context, models ...interface{}) *DB {
db := s.Unscoped() db := s.Unscoped()
for _, model := range models { for _, model := range models {
db = db.NewScope(model).createTable().db db = db.NewScope(model).createTable(ctx).db
} }
return db return db
} }
// DropTable drop table for models // DropTable drop table for models
func (s *DB) DropTable(values ...interface{}) *DB { func (s *DB) DropTable(values ...interface{}) *DB {
ctx := context.Background()
return s.DropTableContext(ctx, values...)
}
func (s *DB) DropTableContext(ctx context.Context, values ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, value := range values {
if tableName, ok := value.(string); ok { if tableName, ok := value.(string); ok {
db = db.Table(tableName) db = db.Table(tableName)
} }
db = db.NewScope(value).dropTable().db db = db.NewScope(value).dropTable(ctx).db
} }
return db return db
} }
// DropTableIfExists drop table if it is exist // DropTableIfExists drop table if it is exist
func (s *DB) DropTableIfExists(values ...interface{}) *DB { func (s *DB) DropTableIfExists(values ...interface{}) *DB {
ctx := context.Background()
return s.DropTableIfExistsContext(ctx, values...)
}
func (s *DB) DropTableIfExistsContext(ctx context.Context, values ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, value := range values {
if s.HasTable(value) { if s.HasTable(value) {
db.AddError(s.DropTable(value).Error) db.AddError(s.DropTableContext(ctx, value).Error)
} }
} }
return db return db
@ -661,6 +776,11 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB {
// HasTable check has table or not // HasTable check has table or not
func (s *DB) HasTable(value interface{}) bool { func (s *DB) HasTable(value interface{}) bool {
ctx := context.Background()
return s.HasTableContext(ctx, value)
}
func (s *DB) HasTableContext(ctx context.Context, value interface{}) bool {
var ( var (
scope = s.NewScope(value) scope = s.NewScope(value)
tableName string tableName string
@ -672,68 +792,108 @@ func (s *DB) HasTable(value interface{}) bool {
tableName = scope.TableName() tableName = scope.TableName()
} }
has := scope.Dialect().HasTable(tableName) has := scope.Dialect().HasTable(ctx, tableName)
s.AddError(scope.db.Error) s.AddError(scope.db.Error)
return has return has
} }
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data // AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB { func (s *DB) AutoMigrate(values ...interface{}) *DB {
ctx := context.Background()
return s.AutoMigrateContext(ctx, values...)
}
func (s *DB) AutoMigrateContext(ctx context.Context, values ...interface{}) *DB {
db := s.Unscoped() db := s.Unscoped()
for _, value := range values { for _, value := range values {
db = db.NewScope(value).autoMigrate().db db = db.NewScope(value).autoMigrate(ctx).db
} }
return db return db
} }
// ModifyColumn modify column to type // ModifyColumn modify column to type
func (s *DB) ModifyColumn(column string, typ string) *DB { func (s *DB) ModifyColumn(column string, typ string) *DB {
ctx := context.Background()
return s.ModifyColumnContext(ctx, column, typ)
}
func (s *DB) ModifyColumnContext(ctx context.Context, column string, typ string) *DB {
scope := s.NewScope(s.Value) scope := s.NewScope(s.Value)
scope.modifyColumn(column, typ) scope.modifyColumn(ctx, column, typ)
return scope.db return scope.db
} }
// DropColumn drop a column // DropColumn drop a column
func (s *DB) DropColumn(column string) *DB { func (s *DB) DropColumn(column string) *DB {
ctx := context.Background()
return s.DropColumnContext(ctx, column)
}
func (s *DB) DropColumnContext(ctx context.Context, column string) *DB {
scope := s.NewScope(s.Value) scope := s.NewScope(s.Value)
scope.dropColumn(column) scope.dropColumn(ctx, column)
return scope.db return scope.db
} }
// AddIndex add index for columns with given name // AddIndex add index for columns with given name
func (s *DB) AddIndex(indexName string, columns ...string) *DB { func (s *DB) AddIndex(indexName string, columns ...string) *DB {
ctx := context.Background()
return s.AddIndexContext(ctx, indexName, columns...)
}
func (s *DB) AddIndexContext(ctx context.Context, indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value) scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(false, indexName, columns...) scope.addIndex(ctx, false, indexName, columns...)
return scope.db return scope.db
} }
// AddUniqueIndex add unique index for columns with given name // AddUniqueIndex add unique index for columns with given name
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
ctx := context.Background()
return s.AddUniqueIndexContext(ctx, indexName, columns...)
}
func (s *DB) AddUniqueIndexContext(ctx context.Context, indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value) scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(true, indexName, columns...) scope.addIndex(ctx, true, indexName, columns...)
return scope.db return scope.db
} }
// RemoveIndex remove index with name // RemoveIndex remove index with name
func (s *DB) RemoveIndex(indexName string) *DB { func (s *DB) RemoveIndex(indexName string) *DB {
ctx := context.Background()
return s.RemoveIndexContext(ctx, indexName)
}
func (s *DB) RemoveIndexContext(ctx context.Context, indexName string) *DB {
scope := s.NewScope(s.Value) scope := s.NewScope(s.Value)
scope.removeIndex(indexName) scope.removeIndex(ctx, indexName)
return scope.db return scope.db
} }
// AddForeignKey Add foreign key to the given scope, e.g: // AddForeignKey Add foreign key to the given scope, e.g:
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
ctx := context.Background()
return s.AddForeignKeyContext(ctx, field, dest, onDelete, onUpdate)
}
func (s *DB) AddForeignKeyContext(ctx context.Context, field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.NewScope(s.Value) scope := s.NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate) scope.addForeignKey(ctx, field, dest, onDelete, onUpdate)
return scope.db return scope.db
} }
// RemoveForeignKey Remove foreign key from the given scope, e.g: // RemoveForeignKey Remove foreign key from the given scope, e.g:
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") // db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
func (s *DB) RemoveForeignKey(field string, dest string) *DB { func (s *DB) RemoveForeignKey(field string, dest string) *DB {
ctx := context.Background()
return s.RemoveForeignKeyContext(ctx, field, dest)
}
func (s *DB) RemoveForeignKeyContext(ctx context.Context, field string, dest string) *DB {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.removeForeignKey(field, dest) scope.removeForeignKey(ctx, field, dest)
return scope.db return scope.db
} }
@ -784,6 +944,11 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
// SetJoinTableHandler set a model's join table handler for a relation // SetJoinTableHandler set a model's join table handler for a relation
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
ctx := context.Background()
s.SetJoinTableHandlerContext(ctx, source, column, handler)
}
func (s *DB) SetJoinTableHandlerContext(ctx context.Context, source interface{}, column string, handler JoinTableHandlerInterface) {
scope := s.NewScope(source) scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields { for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column { if field.Name == column || field.DBName == column {
@ -792,7 +957,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination) handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(table) { if table := handler.Table(s); scope.Dialect().HasTable(ctx, table) {
s.Table(table).AutoMigrate(handler) s.Table(table).AutoMigrate(handler)
} }
} }

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
@ -302,12 +303,14 @@ func runMigration() {
} }
func TestIndexes(t *testing.T) { func TestIndexes(t *testing.T) {
ctx := context.Background()
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
scope := DB.NewScope(&Email{}) scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email") t.Errorf("Email should have index idx_email_email")
} }
@ -315,7 +318,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted") t.Errorf("Email's index idx_email_email should be deleted")
} }
@ -323,7 +326,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
@ -331,7 +334,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
@ -339,7 +342,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
@ -361,7 +364,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
@ -381,6 +384,8 @@ type EmailWithIdx struct {
} }
func TestAutoMigration(t *testing.T) { func TestAutoMigration(t *testing.T) {
ctx := context.Background()
DB.AutoMigrate(&Address{}) DB.AutoMigrate(&Address{})
DB.DropTable(&EmailWithIdx{}) DB.DropTable(&EmailWithIdx{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
@ -391,11 +396,11 @@ func TestAutoMigration(t *testing.T) {
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
scope := DB.NewScope(&EmailWithIdx{}) scope := DB.NewScope(&EmailWithIdx{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
@ -462,6 +467,8 @@ type MultipleIndexes struct {
} }
func TestMultipleIndexes(t *testing.T) { func TestMultipleIndexes(t *testing.T) {
ctx := context.Background()
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
} }
@ -474,23 +481,23 @@ func TestMultipleIndexes(t *testing.T) {
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
scope := DB.NewScope(&MultipleIndexes{}) scope := DB.NewScope(&MultipleIndexes{})
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_name") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_email") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multiple_indexes_email") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multipleindexes_user_other") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multiple_indexes_other") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
@ -540,6 +547,8 @@ func TestModifyColumnType(t *testing.T) {
} }
func TestIndexWithPrefixLength(t *testing.T) { func TestIndexWithPrefixLength(t *testing.T) {
ctx := context.Background()
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
t.Skip("Skipping this because only mysql support setting an index prefix length") t.Skip("Skipping this because only mysql support setting an index prefix length")
} }
@ -571,7 +580,7 @@ func TestIndexWithPrefixLength(t *testing.T) {
if err := DB.CreateTable(table).Error; err != nil { if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err) t.Errorf("Failed to create %s table: %v", tableName, err)
} }
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { if !scope.Dialect().HasIndex(ctx, tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName) t.Errorf("Failed to create %s table index:", tableName)
} }
}) })

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"os" "os"
@ -1684,7 +1685,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
called := 0 called := 0
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(ctx context.Context, scope *gorm.Scope) {
called = called + 1 called = called + 1
}) })

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
@ -357,11 +358,11 @@ func (scope *Scope) Raw(sql string) *Scope {
} }
// Exec perform generated SQL // Exec perform generated SQL
func (scope *Scope) Exec() *Scope { func (scope *Scope) Exec(ctx context.Context) *Scope {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count scope.db.RowsAffected = count
} }
@ -856,7 +857,7 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
return scope return scope
} }
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { func (scope *Scope) callCallbacks(ctx context.Context, funcs []*func(ctx context.Context, s *Scope)) *Scope {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
if db, ok := scope.db.db.(sqlTx); ok { if db, ok := scope.db.db.(sqlTx); ok {
@ -866,7 +867,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
} }
}() }()
for _, f := range funcs { for _, f := range funcs {
(*f)(scope) (*f)(ctx, scope)
if scope.skipLeft { if scope.skipLeft {
break break
} }
@ -933,22 +934,22 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
return return
} }
func (scope *Scope) row() *sql.Row { func (scope *Scope) row(ctx context.Context) *sql.Row {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
result := &RowQueryResult{} result := &RowQueryResult{}
scope.InstanceSet("row_query_result", result) scope.InstanceSet("row_query_result", result)
scope.callCallbacks(scope.db.parent.callbacks.rowQueries) scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries)
return result.Row return result.Row
} }
func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) rows(ctx context.Context) (*sql.Rows, error) {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
result := &RowsQueryResult{} result := &RowsQueryResult{}
scope.InstanceSet("row_query_result", result) scope.InstanceSet("row_query_result", result)
scope.callCallbacks(scope.db.parent.callbacks.rowQueries) scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries)
return result.Rows, result.Error return result.Rows, result.Error
} }
@ -979,7 +980,7 @@ func (scope *Scope) isQueryForColumn(query interface{}, column string) bool {
return false return false
} }
func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) pluck(ctx context.Context, column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value)) dest := reflect.Indirect(reflect.ValueOf(value))
if dest.Kind() != reflect.Slice { if dest.Kind() != reflect.Slice {
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
@ -994,7 +995,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
scope.Search.Select(column) scope.Search.Select(column)
} }
rows, err := scope.rows() rows, err := scope.rows(ctx)
if scope.Err(err) == nil { if scope.Err(err) == nil {
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
@ -1010,7 +1011,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
return scope return scope
} }
func (scope *Scope) count(value interface{}) *Scope { func (scope *Scope) count(ctx context.Context, value interface{}) *Scope {
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
if len(scope.Search.group) != 0 { if len(scope.Search.group) != 0 {
if len(scope.Search.havingConditions) != 0 { if len(scope.Search.havingConditions) != 0 {
@ -1027,7 +1028,7 @@ func (scope *Scope) count(value interface{}) *Scope {
} }
} }
scope.Search.ignoreOrderQuery = true scope.Search.ignoreOrderQuery = true
scope.Err(scope.row().Scan(value)) scope.Err(scope.row(ctx).Scan(value))
return scope return scope
} }
@ -1124,11 +1125,11 @@ func (scope *Scope) getTableOptions() string {
return " " + tableOptions.(string) return " " + tableOptions.(string)
} }
func (scope *Scope) createJoinTable(field *StructField) { func (scope *Scope) createJoinTable(ctx context.Context, field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db) joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(joinTable) { if !scope.Dialect().HasTable(ctx, joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string var sqlTypes, primaryKeys []string
@ -1156,11 +1157,11 @@ func (scope *Scope) createJoinTable(field *StructField) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
} }
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) scope.NewDB().Table(joinTable).AutoMigrateContext(ctx, joinTableHandler)
} }
} }
func (scope *Scope) createTable() *Scope { func (scope *Scope) createTable(ctx context.Context) *Scope {
var tags []string var tags []string
var primaryKeys []string var primaryKeys []string
var primaryKeyInColumnType = false var primaryKeyInColumnType = false
@ -1181,7 +1182,7 @@ func (scope *Scope) createTable() *Scope {
if field.IsPrimaryKey { if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
} }
scope.createJoinTable(field) scope.createJoinTable(ctx, field)
} }
var primaryKeyStr string var primaryKeyStr string
@ -1189,27 +1190,27 @@ func (scope *Scope) createTable() *Scope {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
} }
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec(ctx)
scope.autoIndex() scope.autoIndex()
return scope return scope
} }
func (scope *Scope) dropTable() *Scope { func (scope *Scope) dropTable(ctx context.Context) *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec(ctx)
return scope return scope
} }
func (scope *Scope) modifyColumn(column string, typ string) { func (scope *Scope) modifyColumn(ctx context.Context, column string, typ string) {
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) scope.db.AddError(scope.Dialect().ModifyColumn(ctx, scope.QuotedTableName(), scope.Quote(column), typ))
} }
func (scope *Scope) dropColumn(column string) { func (scope *Scope) dropColumn(ctx context.Context, column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec(ctx)
} }
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addIndex(ctx context.Context, unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope.TableName(), indexName) { if scope.Dialect().HasIndex(ctx, scope.TableName(), indexName) {
return return
} }
@ -1223,23 +1224,23 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
sqlCreate = "CREATE UNIQUE INDEX" sqlCreate = "CREATE UNIQUE INDEX"
} }
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec(ctx)
} }
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(ctx context.Context, field string, dest string, onDelete string, onUpdate string) {
// Compatible with old generated key // Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { if scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
return return
} }
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec(ctx)
} }
func (scope *Scope) removeForeignKey(field string, dest string) { func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { if !scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
return return
} }
var mysql mysql var mysql mysql
@ -1250,28 +1251,28 @@ func (scope *Scope) removeForeignKey(field string, dest string) {
query = `ALTER TABLE %s DROP CONSTRAINT %s;` query = `ALTER TABLE %s DROP CONSTRAINT %s;`
} }
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec(ctx)
} }
func (scope *Scope) removeIndex(indexName string) { func (scope *Scope) removeIndex(ctx context.Context, indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName) scope.Dialect().RemoveIndex(ctx, scope.TableName(), indexName)
} }
func (scope *Scope) autoMigrate() *Scope { func (scope *Scope) autoMigrate(ctx context.Context) *Scope {
tableName := scope.TableName() tableName := scope.TableName()
quotedTableName := scope.QuotedTableName() quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(tableName) { if !scope.Dialect().HasTable(ctx, tableName) {
scope.createTable() scope.createTable(ctx)
} else { } else {
for _, field := range scope.GetModelStruct().StructFields { for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(tableName, field.DBName) { if !scope.Dialect().HasColumn(ctx, tableName, field.DBName) {
if field.IsNormal { if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field) sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec(ctx)
} }
} }
scope.createJoinTable(field) scope.createJoinTable(ctx, field)
} }
scope.autoIndex() scope.autoIndex()
} }