add context funcs

This commit is contained in:
Daniel Gatis 2020-01-03 22:41:09 -03:00
parent 79a77d771d
commit c7f9cb552b
23 changed files with 474 additions and 272 deletions

View File

@ -1,6 +1,9 @@
package gorm
import "fmt"
import (
"context"
"fmt"
)
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{logger: nopLogger{}}
@ -14,11 +17,11 @@ var DefaultCallback = &Callback{logger: nopLogger{}}
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
logger logger
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
creates []*func(ctx context.Context, scope *Scope)
updates []*func(ctx context.Context, scope *Scope)
deletes []*func(ctx context.Context, scope *Scope)
queries []*func(ctx context.Context, scope *Scope)
rowQueries []*func(ctx context.Context, scope *Scope)
processors []*CallbackProcessor
}
@ -31,7 +34,7 @@ type CallbackProcessor struct {
replace bool // replace callbacks with same name
remove bool // delete callbacks with same name
kind string // callback type: create, update, delete, query, row_query
processor *func(scope *Scope) // callback handler
processor *func(ctx context.Context, scope *Scope) // callback handler
parent *Callback
}
@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
}
// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
func (cp *CallbackProcessor) Register(callbackName string, callback func(ctx context.Context, scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
@ -123,7 +126,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// scope.SetColumn("CreatedAt", now)
// scope.SetColumn("UpdatedAt", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
func (cp *CallbackProcessor) Replace(callbackName string, callback func(ctx context.Context, scope *Scope)) {
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
@ -134,7 +137,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
func (cp *CallbackProcessor) Get(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
@ -158,7 +161,7 @@ func getRIndex(strs []string, str string) int {
}
// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
func sortProcessors(cps []*CallbackProcessor) []*func(ctx context.Context, scope *Scope) {
var (
allNames, sortedNames []string
sortCallbackProcessor func(c *CallbackProcessor)
@ -211,7 +214,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
sortCallbackProcessor(cp)
}
var sortedFuncs []*func(scope *Scope)
var sortedFuncs []*func(ctx context.Context, scope *Scope)
for _, name := range sortedNames {
if index := getRIndex(allNames, name); !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"fmt"
"strings"
)
@ -19,7 +20,7 @@ func init() {
}
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
func beforeCreateCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
@ -29,7 +30,7 @@ func beforeCreateCallback(scope *Scope) {
}
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
func updateTimeStampForCreateCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() {
now := scope.db.nowFunc()
@ -48,7 +49,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
}
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
func createCallback(ctx context.Context, scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
@ -100,10 +101,10 @@ func createCallback(scope *Scope) {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(ctx, quotedTableName, returningColumn, columns)
var lastInsertIDReturningSuffix string
if lastInsertIDOutputInterstitial == "" {
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(ctx, quotedTableName, returningColumn)
}
if len(columns) == 0 {
@ -130,7 +131,7 @@ func createCallback(scope *Scope) {
// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@ -146,7 +147,7 @@ func createCallback(scope *Scope) {
// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@ -162,7 +163,7 @@ func createCallback(scope *Scope) {
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if err := scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
@ -174,7 +175,7 @@ func createCallback(scope *Scope) {
}
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) {
func forceReloadAfterCreateCallback(_ctx context.Context, scope *Scope) {
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
for _, field := range scope.Fields() {
@ -187,7 +188,7 @@ func forceReloadAfterCreateCallback(scope *Scope) {
}
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
func afterCreateCallback(scope *Scope) {
func afterCreateCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterCreate")
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"errors"
"fmt"
)
@ -15,7 +16,7 @@ func init() {
}
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
func beforeDeleteCallback(_ctx context.Context, scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("missing WHERE clause while deleting"))
return
@ -26,7 +27,7 @@ func beforeDeleteCallback(scope *Scope) {
}
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) {
func deleteCallback(ctx context.Context, scope *Scope) {
if !scope.HasError() {
var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok {
@ -43,20 +44,20 @@ func deleteCallback(scope *Scope) {
scope.AddToVars(scope.db.nowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
)).Exec(ctx)
} else {
scope.Raw(fmt.Sprintf(
"DELETE FROM %v%v%v",
scope.QuotedTableName(),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
)).Exec(ctx)
}
}
}
// afterDeleteCallback will invoke `AfterDelete` method after deleting
func afterDeleteCallback(scope *Scope) {
func afterDeleteCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterDelete")
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"errors"
"fmt"
"reflect"
@ -14,7 +15,7 @@ func init() {
}
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
func queryCallback(ctx context.Context, scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
@ -69,7 +70,7 @@ func queryCallback(scope *Scope) {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if rows, err := scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()
columns, _ := rows.Columns()
@ -102,7 +103,7 @@ func queryCallback(scope *Scope) {
}
// afterQueryCallback will invoke `AfterFind` method after querying
func afterQueryCallback(scope *Scope) {
func afterQueryCallback(_ctx context.Context, scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterFind")
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"errors"
"fmt"
"reflect"
@ -9,7 +10,7 @@ import (
)
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
func preloadCallback(ctx context.Context, scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
@ -18,9 +19,9 @@ func preloadCallback(scope *Scope) {
// If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value
if apb, ok := ap.(bool); !ok {
autoPreload(scope)
autoPreload(ctx, scope)
} else if apb {
autoPreload(scope)
autoPreload(ctx, scope)
}
}
@ -62,13 +63,13 @@ func preloadCallback(scope *Scope) {
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions)
currentScope.handleHasOnePreload(ctx, field, currentPreloadConditions)
case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions)
currentScope.handleHasManyPreload(ctx, field, currentPreloadConditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
currentScope.handleBelongsToPreload(ctx, field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
currentScope.handleManyToManyPreload(ctx, field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
@ -94,7 +95,7 @@ func preloadCallback(scope *Scope) {
}
}
func autoPreload(scope *Scope) {
func autoPreload(_ctx context.Context, scope *Scope) {
for _, field := range scope.Fields() {
if field.Relationship == nil {
continue
@ -131,7 +132,7 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*
}
// handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
func (scope *Scope) handleHasOnePreload(_ctx context.Context, field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
@ -183,7 +184,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
}
// handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
func (scope *Scope) handleHasManyPreload(_ctx context.Context, field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
@ -236,7 +237,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
}
// handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
func (scope *Scope) handleBelongsToPreload(_ctx context.Context, field *Field, conditions []interface{}) {
relation := field.Relationship
// preload conditions
@ -283,7 +284,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
}
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
func (scope *Scope) handleManyToManyPreload(ctx context.Context, field *Field, conditions []interface{}) {
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
@ -346,7 +347,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries)
callCallbacks(ctx, scope.db.parent.callbacks.queries)
var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"database/sql"
"fmt"
)
@ -20,7 +21,7 @@ type RowsQueryResult struct {
}
// queryCallback used to query data from database
func rowQueryCallback(scope *Scope) {
func rowQueryCallback(ctx context.Context, scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()
@ -33,9 +34,9 @@ func rowQueryCallback(scope *Scope) {
}
if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
rowResult.Row = scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...)
}
}
}

View File

@ -1,15 +1,16 @@
package gorm
import (
"context"
"reflect"
"strings"
)
func beginTransactionCallback(scope *Scope) {
func beginTransactionCallback(_ctx context.Context, scope *Scope) {
scope.Begin()
}
func commitOrRollbackTransactionCallback(scope *Scope) {
func commitOrRollbackTransactionCallback(_ctx context.Context, scope *Scope) {
scope.CommitOrRollback()
}
@ -64,7 +65,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
return
}
func saveBeforeAssociationsCallback(scope *Scope) {
func saveBeforeAssociationsCallback(_ctx context.Context, scope *Scope) {
for _, field := range scope.Fields() {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
@ -95,7 +96,7 @@ func saveBeforeAssociationsCallback(scope *Scope) {
}
}
func saveAfterAssociationsCallback(scope *Scope) {
func saveAfterAssociationsCallback(_ctx context.Context, scope *Scope) {
for _, field := range scope.Fields() {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)

View File

@ -1,13 +1,14 @@
package gorm
import (
"context"
"reflect"
"runtime"
"strings"
"testing"
)
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
@ -16,11 +17,11 @@ func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
return reflect.DeepEqual(names, fnames)
}
func create(s *Scope) {}
func beforeCreate1(s *Scope) {}
func beforeCreate2(s *Scope) {}
func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {}
func create(ctx context.Context, s *Scope) {}
func beforeCreate1(ctx context.Context, s *Scope) {}
func beforeCreate2(ctx context.Context, s *Scope) {}
func afterCreate1(ctx context.Context, s *Scope) {}
func afterCreate2(ctx context.Context, s *Scope) {}
func TestRegisterCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger}
@ -83,7 +84,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
}
}
func replaceCreate(s *Scope) {}
func replaceCreate(ctx context.Context, s *Scope) {}
func TestReplaceCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"errors"
"fmt"
"sort"
@ -21,7 +22,7 @@ func init() {
}
// assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) {
func assignUpdatingAttributesCallback(_ctx context.Context, scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps)
@ -32,7 +33,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
}
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
func beforeUpdateCallback(_ctx context.Context, scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("missing WHERE clause while updating"))
return
@ -48,14 +49,14 @@ func beforeUpdateCallback(scope *Scope) {
}
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) {
func updateTimeStampForUpdateCallback(_ctx context.Context, scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
}
}
// updateCallback the callback used to update data to database
func updateCallback(scope *Scope) {
func updateCallback(ctx context.Context, scope *Scope) {
if !scope.HasError() {
var sqls []string
@ -103,13 +104,13 @@ func updateCallback(scope *Scope) {
strings.Join(sqls, ", "),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
)).Exec(ctx)
}
}
}
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
func afterUpdateCallback(scope *Scope) {
func afterUpdateCallback(_ctx context.Context, scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("AfterUpdate")

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"context"
"errors"
"reflect"
"testing"
@ -182,22 +183,23 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
ctx := context.Background()
callback(ctx, scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
DB.Callback().Create().Replace("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
callback(ctx, scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}
@ -207,12 +209,12 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
callback(ctx, scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
@ -220,7 +222,7 @@ func TestGetCallback(t *testing.T) {
func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
gorm.DefaultCallback.Create().Register(createCallbackName, func(context.Context, *gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
@ -233,16 +235,17 @@ func TestUseDefaultCallback(t *testing.T) {
updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
gorm.DefaultCallback.Update().Register(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})
scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
ctx := context.Background()
callback(ctx, scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"context"
"testing"
"time"
@ -21,12 +22,13 @@ type CustomColumnAndIgnoredFieldClash struct {
}
func TestCustomizeColumn(t *testing.T) {
ctx := context.Background()
col := "mapped_name"
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) {
if !scope.Dialect().HasColumn(ctx, scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"reflect"
@ -24,26 +25,26 @@ type Dialect interface {
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
HasIndex(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
RemoveIndex(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
HasTable(ctx context.Context, tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
HasColumn(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
LastInsertIDOutputInterstitial(ctx context.Context, tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
LastInsertIDReturningSuffix(ctx context.Context, tableName, columnName string) string
// DefaultValueStr
DefaultValueStr() string
@ -54,7 +55,7 @@ type Dialect interface {
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
// CurrentDatabase return current database name
CurrentDatabase() string
CurrentDatabase(ctx context.Context) string
}
var dialectsMap = map[string]Dialect{}
@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
return dialect.CurrentDatabase(ctx), tableName
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"fmt"
"reflect"
"regexp"
@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s commonDialect) HasForeignKey(_ctx context.Context, tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
func (s commonDialect) HasTable(ctx context.Context, tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return
}
@ -162,11 +163,11 @@ func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
func (commonDialect) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
func (commonDialect) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
return ""
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"crypto/sha1"
"database/sql"
"fmt"
@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err
}
@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err
return
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) HasTable(tableName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
func (s mysql) HasTable(ctx context.Context, tableName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
var name string
// allow mysql database name with '-' character
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err == sql.ErrNoRows {
return false
}
@ -183,9 +184,9 @@ func (s mysql) HasTable(tableName string) bool {
}
}
func (s mysql) HasIndex(tableName string, indexName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
panic(err)
} else {
defer rows.Close()
@ -193,9 +194,9 @@ func (s mysql) HasIndex(tableName string, indexName string) bool {
}
}
func (s mysql) HasColumn(tableName string, columnName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
panic(err)
} else {
defer rows.Close()
@ -203,8 +204,8 @@ func (s mysql) HasColumn(tableName string, columnName string) bool {
}
}
func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
func (s mysql) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"encoding/json"
"fmt"
"reflect"
@ -91,40 +92,40 @@ func (s *postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s postgres) HasIndex(tableName string, indexName string) bool {
func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
func (s postgres) HasTable(ctx context.Context, tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
func (s postgres) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
func (s postgres) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, key string, columns []string) string {
return ""
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
func (s postgres) LastInsertIDReturningSuffix(_ctx context.Context, tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"fmt"
"reflect"
"strings"
@ -70,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
func (s sqlite3) HasTable(ctx context.Context, tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) CurrentDatabase() (name string) {
func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
@ -97,7 +98,7 @@ func (s sqlite3) CurrentDatabase() (name string) {
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
if err := s.db.QueryRowContext(ctx, "PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {

View File

@ -1,6 +1,7 @@
package mssql
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
@ -15,7 +16,7 @@ import (
"github.com/jinzhu/gorm"
)
func setIdentityInsert(scope *gorm.Scope) {
func setIdentityInsert(ctx context.Context, scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
@ -26,7 +27,7 @@ func setIdentityInsert(scope *gorm.Scope) {
}
}
func turnOffIdentityInsert(scope *gorm.Scope) {
func turnOffIdentityInsert(ctx context.Context, scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
@ -122,21 +123,21 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
return field.IsPrimaryKey
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow(`SELECT count(*)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, `SELECT count(*)
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
inner join information_schema.tables as I on I.TABLE_NAME = T.name
WHERE F.name = ?
@ -144,27 +145,27 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return count > 0
}
func (s mssql) HasTable(tableName string) bool {
func (s mssql) HasTable(ctx context.Context, tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
func (s mssql) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
@ -198,7 +199,7 @@ func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
func (mssql) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
if len(columns) == 0 {
// No OUTPUT to query
return ""
@ -206,7 +207,7 @@ func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, column
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
func (mssql) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
return "; SELECT SCOPE_IDENTITY()"
}
@ -220,12 +221,12 @@ func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, stri
return indexName, columnName
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
return dialect.CurrentDatabase(ctx), tableName
}
// JSON type to support easy handling of JSON data in character table fields

View File

@ -1,6 +1,9 @@
package gorm_test
import "testing"
import (
"context"
"testing"
)
type BasePost struct {
Id int64
@ -27,9 +30,10 @@ type EngadgetPost struct {
}
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
ctx := context.Background()
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{})
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
if !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns")
}
@ -38,7 +42,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
}
hnScope := DB.NewScope(&HNPost{})
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
if !dialect.HasColumn(ctx, hnScope.TableName(), "user_id") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_name") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns")
}
}

View File

@ -7,10 +7,10 @@ import (
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type sqlDb interface {

237
main.go
View File

@ -327,51 +327,91 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
// First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.FirstContext(ctx, out, where...)
}
func (s *DB) FirstContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
}
// Take return a record that match given conditions, the order will depend on the database implementation
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.TakeContext(ctx, out, where...)
}
func (s *DB) TakeContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
return newScope.inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
}
// Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.LastContext(ctx, out, where...)
}
func (s *DB) LastContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
}
// Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
ctx := context.Background()
return s.FindContext(ctx, out, where...)
}
func (s *DB) FindContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
return s.NewScope(out).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db
}
//Preloads preloads relations, don`t touch out
func (s *DB) Preloads(out interface{}) *DB {
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
ctx := context.Background()
return s.PreloadsContext(ctx, out)
}
func (s *DB) PreloadsContext(ctx context.Context, out interface{}) *DB {
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(ctx, s.parent.callbacks.queries).db
}
// Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB {
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
ctx := context.Background()
return s.ScanContext(ctx, dest)
}
func (s *DB) ScanContext(ctx context.Context, dest interface{}) *DB {
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(ctx, s.parent.callbacks.queries).db
}
// Row return `*sql.Row` with given conditions
func (s *DB) Row() *sql.Row {
return s.NewScope(s.Value).row()
ctx := context.Background()
return s.RowContext(ctx)
}
func (s *DB) RowContext(ctx context.Context) *sql.Row {
return s.NewScope(s.Value).row(ctx)
}
// Rows return `*sql.Rows` with given conditions
func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows()
ctx := context.Background()
return s.RowsContext(ctx)
}
func (s *DB) RowsContext(ctx context.Context) (*sql.Rows, error) {
return s.NewScope(s.Value).rows(ctx)
}
// ScanRows scan `*sql.Rows` to give struct
@ -393,12 +433,22 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db
ctx := context.Background()
return s.PluckContext(ctx, column, value)
}
func (s *DB) PluckContext(ctx context.Context, column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(ctx, column, value).db
}
// Count get how many records for a model
func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db
ctx := context.Background()
return s.CountContext(ctx, value)
}
func (s *DB) CountContext(ctx context.Context, value interface{}) *DB {
return s.NewScope(s.Value).count(ctx, value).db
}
// Related get related associations
@ -409,8 +459,13 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/crud.html#firstorinit
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.FirstOrInitContext(ctx, out, where...)
}
func (s *DB) FirstOrInitContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
c := s.clone()
if result := c.First(out, where...); result.Error != nil {
if result := c.FirstContext(ctx, out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
@ -424,14 +479,19 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/crud.html#firstorcreate
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
ctx := context.Background()
return s.FirstOrCreateContext(ctx, out, where...)
}
func (s *DB) FirstOrCreateContext(ctx context.Context, out interface{}, where ...interface{}) *DB {
c := s.clone()
if result := s.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(ctx, c.parent.callbacks.creates).db
} else if len(c.search.assignAttrs) > 0 {
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(ctx, c.parent.callbacks.updates).db
}
return c
}
@ -439,54 +499,89 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
// WARNING when update with struct, GORM will not update fields that with zero value
func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true)
ctx := context.Background()
return s.UpdateContext(ctx, attrs...)
}
func (s *DB) UpdateContext(ctx context.Context, attrs ...interface{}) *DB {
return s.UpdatesContext(ctx, toSearchableMap(attrs...), true)
}
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
ctx := context.Background()
return s.UpdatesContext(ctx, values, ignoreProtectedAttrs...)
}
func (s *DB) UpdatesContext(ctx context.Context, values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db
callCallbacks(ctx, s.parent.callbacks.updates).db
}
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
return s.UpdateColumns(toSearchableMap(attrs...))
ctx := context.Background()
return s.UpdateColumnContext(ctx, attrs...)
}
func (s *DB) UpdateColumnContext(ctx context.Context, attrs ...interface{}) *DB {
return s.UpdateColumnsContext(ctx, toSearchableMap(attrs...))
}
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumns(values interface{}) *DB {
ctx := context.Background()
return s.UpdateColumnsContext(ctx, values)
}
func (s *DB) UpdateColumnsContext(ctx context.Context, values interface{}) *DB {
return s.NewScope(s.Value).
Set("gorm:update_column", true).
Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db
callCallbacks(ctx, s.parent.callbacks.updates).db
}
// Save update value in database, if the value doesn't have primary key, will insert it
func (s *DB) Save(value interface{}) *DB {
ctx := context.Background()
return s.SaveContext(ctx, value)
}
func (s *DB) SaveContext(ctx context.Context, value interface{}) *DB {
scope := s.NewScope(value)
if !scope.PrimaryKeyZero() {
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
newDB := scope.callCallbacks(ctx, s.parent.callbacks.updates).db
if newDB.Error == nil && newDB.RowsAffected == 0 {
return s.New().Table(scope.TableName()).FirstOrCreate(value)
return s.New().Table(scope.TableName()).FirstOrCreateContext(ctx, value)
}
return newDB
}
return scope.callCallbacks(s.parent.callbacks.creates).db
return scope.callCallbacks(ctx, s.parent.callbacks.creates).db
}
// Create insert the value into database
func (s *DB) Create(value interface{}) *DB {
ctx := context.Background()
return s.CreateContext(ctx, value)
}
func (s *DB) CreateContext(ctx context.Context, value interface{}) *DB {
scope := s.NewScope(value)
return scope.callCallbacks(s.parent.callbacks.creates).db
return scope.callCallbacks(ctx, s.parent.callbacks.creates).db
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
ctx := context.Background()
return s.DeleteContext(ctx, value, where...)
}
func (s *DB) DeleteContext(ctx context.Context, value interface{}, where ...interface{}) *DB {
return s.NewScope(value).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.deletes).db
}
// Raw use raw sql as conditions, won't run it unless invoked by other methods
@ -497,11 +592,16 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
// Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB {
ctx := context.Background()
return s.ExecContext(ctx, sql, values...)
}
func (s *DB) ExecContext(ctx context.Context, sql string, values ...interface{}) *DB {
scope := s.NewScope(nil)
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSQL)
return scope.Exec().db
return scope.Exec(ctx).db
}
// Model specify the model you would like to run db operations
@ -628,32 +728,47 @@ func (s *DB) RecordNotFound() bool {
// CreateTable create table for models
func (s *DB) CreateTable(models ...interface{}) *DB {
ctx := context.Background()
return s.CreateTableContext(ctx, models...)
}
func (s *DB) CreateTableContext(ctx context.Context, models ...interface{}) *DB {
db := s.Unscoped()
for _, model := range models {
db = db.NewScope(model).createTable().db
db = db.NewScope(model).createTable(ctx).db
}
return db
}
// DropTable drop table for models
func (s *DB) DropTable(values ...interface{}) *DB {
ctx := context.Background()
return s.DropTableContext(ctx, values...)
}
func (s *DB) DropTableContext(ctx context.Context, values ...interface{}) *DB {
db := s.clone()
for _, value := range values {
if tableName, ok := value.(string); ok {
db = db.Table(tableName)
}
db = db.NewScope(value).dropTable().db
db = db.NewScope(value).dropTable(ctx).db
}
return db
}
// DropTableIfExists drop table if it is exist
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
ctx := context.Background()
return s.DropTableIfExistsContext(ctx, values...)
}
func (s *DB) DropTableIfExistsContext(ctx context.Context, values ...interface{}) *DB {
db := s.clone()
for _, value := range values {
if s.HasTable(value) {
db.AddError(s.DropTable(value).Error)
db.AddError(s.DropTableContext(ctx, value).Error)
}
}
return db
@ -661,6 +776,11 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB {
// HasTable check has table or not
func (s *DB) HasTable(value interface{}) bool {
ctx := context.Background()
return s.HasTableContext(ctx, value)
}
func (s *DB) HasTableContext(ctx context.Context, value interface{}) bool {
var (
scope = s.NewScope(value)
tableName string
@ -672,68 +792,108 @@ func (s *DB) HasTable(value interface{}) bool {
tableName = scope.TableName()
}
has := scope.Dialect().HasTable(tableName)
has := scope.Dialect().HasTable(ctx, tableName)
s.AddError(scope.db.Error)
return has
}
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB {
ctx := context.Background()
return s.AutoMigrateContext(ctx, values...)
}
func (s *DB) AutoMigrateContext(ctx context.Context, values ...interface{}) *DB {
db := s.Unscoped()
for _, value := range values {
db = db.NewScope(value).autoMigrate().db
db = db.NewScope(value).autoMigrate(ctx).db
}
return db
}
// ModifyColumn modify column to type
func (s *DB) ModifyColumn(column string, typ string) *DB {
ctx := context.Background()
return s.ModifyColumnContext(ctx, column, typ)
}
func (s *DB) ModifyColumnContext(ctx context.Context, column string, typ string) *DB {
scope := s.NewScope(s.Value)
scope.modifyColumn(column, typ)
scope.modifyColumn(ctx, column, typ)
return scope.db
}
// DropColumn drop a column
func (s *DB) DropColumn(column string) *DB {
ctx := context.Background()
return s.DropColumnContext(ctx, column)
}
func (s *DB) DropColumnContext(ctx context.Context, column string) *DB {
scope := s.NewScope(s.Value)
scope.dropColumn(column)
scope.dropColumn(ctx, column)
return scope.db
}
// AddIndex add index for columns with given name
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
ctx := context.Background()
return s.AddIndexContext(ctx, indexName, columns...)
}
func (s *DB) AddIndexContext(ctx context.Context, indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(false, indexName, columns...)
scope.addIndex(ctx, false, indexName, columns...)
return scope.db
}
// AddUniqueIndex add unique index for columns with given name
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
ctx := context.Background()
return s.AddUniqueIndexContext(ctx, indexName, columns...)
}
func (s *DB) AddUniqueIndexContext(ctx context.Context, indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(true, indexName, columns...)
scope.addIndex(ctx, true, indexName, columns...)
return scope.db
}
// RemoveIndex remove index with name
func (s *DB) RemoveIndex(indexName string) *DB {
ctx := context.Background()
return s.RemoveIndexContext(ctx, indexName)
}
func (s *DB) RemoveIndexContext(ctx context.Context, indexName string) *DB {
scope := s.NewScope(s.Value)
scope.removeIndex(indexName)
scope.removeIndex(ctx, indexName)
return scope.db
}
// AddForeignKey Add foreign key to the given scope, e.g:
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
ctx := context.Background()
return s.AddForeignKeyContext(ctx, field, dest, onDelete, onUpdate)
}
func (s *DB) AddForeignKeyContext(ctx context.Context, field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
scope.addForeignKey(ctx, field, dest, onDelete, onUpdate)
return scope.db
}
// RemoveForeignKey Remove foreign key from the given scope, e.g:
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
ctx := context.Background()
return s.RemoveForeignKeyContext(ctx, field, dest)
}
func (s *DB) RemoveForeignKeyContext(ctx context.Context, field string, dest string) *DB {
scope := s.clone().NewScope(s.Value)
scope.removeForeignKey(field, dest)
scope.removeForeignKey(ctx, field, dest)
return scope.db
}
@ -784,6 +944,11 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
// SetJoinTableHandler set a model's join table handler for a relation
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
ctx := context.Background()
s.SetJoinTableHandlerContext(ctx, source, column, handler)
}
func (s *DB) SetJoinTableHandlerContext(ctx context.Context, source interface{}, column string, handler JoinTableHandlerInterface) {
scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column {
@ -792,7 +957,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(table) {
if table := handler.Table(s); scope.Dialect().HasTable(ctx, table) {
s.Table(table).AutoMigrate(handler)
}
}

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@ -302,12 +303,14 @@ func runMigration() {
}
func TestIndexes(t *testing.T) {
ctx := context.Background()
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
@ -315,7 +318,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
@ -323,7 +326,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
@ -331,7 +334,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
@ -339,7 +342,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
@ -361,7 +364,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
@ -381,6 +384,8 @@ type EmailWithIdx struct {
}
func TestAutoMigration(t *testing.T) {
ctx := context.Background()
DB.AutoMigrate(&Address{})
DB.DropTable(&EmailWithIdx{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
@ -391,11 +396,11 @@ func TestAutoMigration(t *testing.T) {
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
scope := DB.NewScope(&EmailWithIdx{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index")
}
@ -462,6 +467,8 @@ type MultipleIndexes struct {
}
func TestMultipleIndexes(t *testing.T) {
ctx := context.Background()
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
}
@ -474,23 +481,23 @@ func TestMultipleIndexes(t *testing.T) {
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
scope := DB.NewScope(&MultipleIndexes{})
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_name") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_email") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multiple_indexes_email") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multipleindexes_user_other") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multiple_indexes_other") {
t.Errorf("Failed to create index")
}
@ -540,6 +547,8 @@ func TestModifyColumnType(t *testing.T) {
}
func TestIndexWithPrefixLength(t *testing.T) {
ctx := context.Background()
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
t.Skip("Skipping this because only mysql support setting an index prefix length")
}
@ -571,7 +580,7 @@ func TestIndexWithPrefixLength(t *testing.T) {
if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err)
}
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
if !scope.Dialect().HasIndex(ctx, tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName)
}
})

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"context"
"database/sql"
"encoding/json"
"os"
@ -1684,7 +1685,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
called := 0
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(ctx context.Context, scope *gorm.Scope) {
called = called + 1
})

View File

@ -2,6 +2,7 @@ package gorm
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"errors"
@ -357,11 +358,11 @@ func (scope *Scope) Raw(sql string) *Scope {
}
// Exec perform generated SQL
func (scope *Scope) Exec() *Scope {
func (scope *Scope) Exec(ctx context.Context) *Scope {
defer scope.trace(NowFunc())
if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count
}
@ -856,7 +857,7 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
return scope
}
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
func (scope *Scope) callCallbacks(ctx context.Context, funcs []*func(ctx context.Context, s *Scope)) *Scope {
defer func() {
if err := recover(); err != nil {
if db, ok := scope.db.db.(sqlTx); ok {
@ -866,7 +867,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
}
}()
for _, f := range funcs {
(*f)(scope)
(*f)(ctx, scope)
if scope.skipLeft {
break
}
@ -933,22 +934,22 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
return
}
func (scope *Scope) row() *sql.Row {
func (scope *Scope) row(ctx context.Context) *sql.Row {
defer scope.trace(NowFunc())
result := &RowQueryResult{}
scope.InstanceSet("row_query_result", result)
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries)
return result.Row
}
func (scope *Scope) rows() (*sql.Rows, error) {
func (scope *Scope) rows(ctx context.Context) (*sql.Rows, error) {
defer scope.trace(NowFunc())
result := &RowsQueryResult{}
scope.InstanceSet("row_query_result", result)
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries)
return result.Rows, result.Error
}
@ -979,7 +980,7 @@ func (scope *Scope) isQueryForColumn(query interface{}, column string) bool {
return false
}
func (scope *Scope) pluck(column string, value interface{}) *Scope {
func (scope *Scope) pluck(ctx context.Context, column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
if dest.Kind() != reflect.Slice {
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
@ -994,7 +995,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
scope.Search.Select(column)
}
rows, err := scope.rows()
rows, err := scope.rows(ctx)
if scope.Err(err) == nil {
defer rows.Close()
for rows.Next() {
@ -1010,7 +1011,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
return scope
}
func (scope *Scope) count(value interface{}) *Scope {
func (scope *Scope) count(ctx context.Context, value interface{}) *Scope {
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
if len(scope.Search.group) != 0 {
if len(scope.Search.havingConditions) != 0 {
@ -1027,7 +1028,7 @@ func (scope *Scope) count(value interface{}) *Scope {
}
}
scope.Search.ignoreOrderQuery = true
scope.Err(scope.row().Scan(value))
scope.Err(scope.row(ctx).Scan(value))
return scope
}
@ -1124,11 +1125,11 @@ func (scope *Scope) getTableOptions() string {
return " " + tableOptions.(string)
}
func (scope *Scope) createJoinTable(field *StructField) {
func (scope *Scope) createJoinTable(ctx context.Context, field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(joinTable) {
if !scope.Dialect().HasTable(ctx, joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string
@ -1156,11 +1157,11 @@ func (scope *Scope) createJoinTable(field *StructField) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
}
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
scope.NewDB().Table(joinTable).AutoMigrateContext(ctx, joinTableHandler)
}
}
func (scope *Scope) createTable() *Scope {
func (scope *Scope) createTable(ctx context.Context) *Scope {
var tags []string
var primaryKeys []string
var primaryKeyInColumnType = false
@ -1181,7 +1182,7 @@ func (scope *Scope) createTable() *Scope {
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
scope.createJoinTable(ctx, field)
}
var primaryKeyStr string
@ -1189,27 +1190,27 @@ func (scope *Scope) createTable() *Scope {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec(ctx)
scope.autoIndex()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
func (scope *Scope) dropTable(ctx context.Context) *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec(ctx)
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
func (scope *Scope) modifyColumn(ctx context.Context, column string, typ string) {
scope.db.AddError(scope.Dialect().ModifyColumn(ctx, scope.QuotedTableName(), scope.Quote(column), typ))
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
func (scope *Scope) dropColumn(ctx context.Context, column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec(ctx)
}
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
func (scope *Scope) addIndex(ctx context.Context, unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(ctx, scope.TableName(), indexName) {
return
}
@ -1223,23 +1224,23 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
sqlCreate = "CREATE UNIQUE INDEX"
}
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec(ctx)
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
func (scope *Scope) addForeignKey(ctx context.Context, field string, dest string, onDelete string, onUpdate string) {
// Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
if scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec(ctx)
}
func (scope *Scope) removeForeignKey(field string, dest string) {
func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
if !scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
return
}
var mysql mysql
@ -1250,28 +1251,28 @@ func (scope *Scope) removeForeignKey(field string, dest string) {
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
}
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec(ctx)
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
func (scope *Scope) removeIndex(ctx context.Context, indexName string) {
scope.Dialect().RemoveIndex(ctx, scope.TableName(), indexName)
}
func (scope *Scope) autoMigrate() *Scope {
func (scope *Scope) autoMigrate(ctx context.Context) *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(tableName) {
scope.createTable()
if !scope.Dialect().HasTable(ctx, tableName) {
scope.createTable(ctx)
} else {
for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(tableName, field.DBName) {
if !scope.Dialect().HasColumn(ctx, tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec(ctx)
}
}
scope.createJoinTable(field)
scope.createJoinTable(ctx, field)
}
scope.autoIndex()
}