diff --git a/callback.go b/callback.go index 1f0e3c79..4c92fb59 100644 --- a/callback.go +++ b/callback.go @@ -1,6 +1,9 @@ package gorm -import "fmt" +import ( + "context" + "fmt" +) // DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{logger: nopLogger{}} @@ -14,24 +17,24 @@ var DefaultCallback = &Callback{logger: nopLogger{}} // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { logger logger - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) + creates []*func(ctx context.Context, scope *Scope) + updates []*func(ctx context.Context, scope *Scope) + deletes []*func(ctx context.Context, scope *Scope) + queries []*func(ctx context.Context, scope *Scope) + rowQueries []*func(ctx context.Context, scope *Scope) processors []*CallbackProcessor } // CallbackProcessor contains callback informations type CallbackProcessor struct { logger logger - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler + name string // current callback's name + before string // register current callback before a callback + after string // register current callback after a callback + replace bool // replace callbacks with same name + remove bool // delete callbacks with same name + kind string // callback type: create, update, delete, query, row_query + processor *func(ctx context.Context, scope *Scope) // callback handler parent *Callback } @@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { } // Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { +func (cp *CallbackProcessor) Register(callbackName string, callback func(ctx context.Context, scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) @@ -123,7 +126,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("CreatedAt", now) // scope.SetColumn("UpdatedAt", now) // }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { +func (cp *CallbackProcessor) Replace(callbackName string, callback func(ctx context.Context, scope *Scope)) { cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback @@ -134,7 +137,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S // Get registered callback // db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { +func (cp *CallbackProcessor) Get(callbackName string) (callback func(ctx context.Context, scope *Scope)) { for _, p := range cp.parent.processors { if p.name == callbackName && p.kind == cp.kind { if p.remove { @@ -158,7 +161,7 @@ func getRIndex(strs []string, str string) int { } // sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { +func sortProcessors(cps []*CallbackProcessor) []*func(ctx context.Context, scope *Scope) { var ( allNames, sortedNames []string sortCallbackProcessor func(c *CallbackProcessor) @@ -211,7 +214,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { sortCallbackProcessor(cp) } - var sortedFuncs []*func(scope *Scope) + var sortedFuncs []*func(ctx context.Context, scope *Scope) for _, name := range sortedNames { if index := getRIndex(allNames, name); !cps[index].remove { sortedFuncs = append(sortedFuncs, cps[index].processor) diff --git a/callback_create.go b/callback_create.go index c4d25f37..f614486a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "fmt" "strings" ) @@ -19,7 +20,7 @@ func init() { } // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { +func beforeCreateCallback(_ctx context.Context, scope *Scope) { if !scope.HasError() { scope.CallMethod("BeforeSave") } @@ -29,7 +30,7 @@ func beforeCreateCallback(scope *Scope) { } // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { +func updateTimeStampForCreateCallback(_ctx context.Context, scope *Scope) { if !scope.HasError() { now := scope.db.nowFunc() @@ -48,7 +49,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { } // createCallback the callback used to insert data into database -func createCallback(scope *Scope) { +func createCallback(ctx context.Context, scope *Scope) { if !scope.HasError() { defer scope.trace(NowFunc()) @@ -100,10 +101,10 @@ func createCallback(scope *Scope) { returningColumn = scope.Quote(primaryField.DBName) } - lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) + lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(ctx, quotedTableName, returningColumn, columns) var lastInsertIDReturningSuffix string if lastInsertIDOutputInterstitial == "" { - lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(ctx, quotedTableName, returningColumn) } if len(columns) == 0 { @@ -130,7 +131,7 @@ func createCallback(scope *Scope) { // execute create sql: no primaryField if primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -146,7 +147,7 @@ func createCallback(scope *Scope) { // execute create sql: lastInsertID implemention for majority of dialects if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -162,7 +163,7 @@ func createCallback(scope *Scope) { // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + if err := scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { primaryField.IsBlank = false scope.db.RowsAffected = 1 } @@ -174,7 +175,7 @@ func createCallback(scope *Scope) { } // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { +func forceReloadAfterCreateCallback(_ctx context.Context, scope *Scope) { if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) for _, field := range scope.Fields() { @@ -187,7 +188,7 @@ func forceReloadAfterCreateCallback(scope *Scope) { } // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { +func afterCreateCallback(_ctx context.Context, scope *Scope) { if !scope.HasError() { scope.CallMethod("AfterCreate") } diff --git a/callback_delete.go b/callback_delete.go index 48b97acb..c040db78 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "errors" "fmt" ) @@ -15,7 +16,7 @@ func init() { } // beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { +func beforeDeleteCallback(_ctx context.Context, scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { scope.Err(errors.New("missing WHERE clause while deleting")) return @@ -26,7 +27,7 @@ func beforeDeleteCallback(scope *Scope) { } // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { +func deleteCallback(ctx context.Context, scope *Scope) { if !scope.HasError() { var extraOption string if str, ok := scope.Get("gorm:delete_option"); ok { @@ -43,20 +44,20 @@ func deleteCallback(scope *Scope) { scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), - )).Exec() + )).Exec(ctx) } else { scope.Raw(fmt.Sprintf( "DELETE FROM %v%v%v", scope.QuotedTableName(), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), - )).Exec() + )).Exec(ctx) } } } // afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { +func afterDeleteCallback(_ctx context.Context, scope *Scope) { if !scope.HasError() { scope.CallMethod("AfterDelete") } diff --git a/callback_query.go b/callback_query.go index 544afd63..57f377a9 100644 --- a/callback_query.go +++ b/callback_query.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "errors" "fmt" "reflect" @@ -14,7 +15,7 @@ func init() { } // queryCallback used to query data from database -func queryCallback(scope *Scope) { +func queryCallback(ctx context.Context, scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } @@ -69,7 +70,7 @@ func queryCallback(scope *Scope) { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if rows, err := scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil { defer rows.Close() columns, _ := rows.Columns() @@ -102,7 +103,7 @@ func queryCallback(scope *Scope) { } // afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { +func afterQueryCallback(_ctx context.Context, scope *Scope) { if !scope.HasError() { scope.CallMethod("AfterFind") } diff --git a/callback_query_preload.go b/callback_query_preload.go index a936180a..52b7b3a4 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "errors" "fmt" "reflect" @@ -9,7 +10,7 @@ import ( ) // preloadCallback used to preload associations -func preloadCallback(scope *Scope) { +func preloadCallback(ctx context.Context, scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } @@ -18,9 +19,9 @@ func preloadCallback(scope *Scope) { // If gorm:auto_preload IS NOT a bool then auto preload. // Else if it IS a bool, use the value if apb, ok := ap.(bool); !ok { - autoPreload(scope) + autoPreload(ctx, scope) } else if apb { - autoPreload(scope) + autoPreload(ctx, scope) } } @@ -62,13 +63,13 @@ func preloadCallback(scope *Scope) { switch field.Relationship.Kind { case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) + currentScope.handleHasOnePreload(ctx, field, currentPreloadConditions) case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) + currentScope.handleHasManyPreload(ctx, field, currentPreloadConditions) case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) + currentScope.handleBelongsToPreload(ctx, field, currentPreloadConditions) case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) + currentScope.handleManyToManyPreload(ctx, field, currentPreloadConditions) default: scope.Err(errors.New("unsupported relation")) } @@ -94,7 +95,7 @@ func preloadCallback(scope *Scope) { } } -func autoPreload(scope *Scope) { +func autoPreload(_ctx context.Context, scope *Scope) { for _, field := range scope.Fields() { if field.Relationship == nil { continue @@ -131,7 +132,7 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (* } // handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { +func (scope *Scope) handleHasOnePreload(_ctx context.Context, field *Field, conditions []interface{}) { relation := field.Relationship // get relations's primary keys @@ -183,7 +184,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } // handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { +func (scope *Scope) handleHasManyPreload(_ctx context.Context, field *Field, conditions []interface{}) { relation := field.Relationship // get relations's primary keys @@ -236,7 +237,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) } // handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { +func (scope *Scope) handleBelongsToPreload(_ctx context.Context, field *Field, conditions []interface{}) { relation := field.Relationship // preload conditions @@ -283,7 +284,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } // handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { +func (scope *Scope) handleManyToManyPreload(ctx context.Context, field *Field, conditions []interface{}) { var ( relation = field.Relationship joinTableHandler = relation.JoinTableHandler @@ -346,7 +347,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.New(elem.Addr().Interface()). InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) + callCallbacks(ctx, scope.db.parent.callbacks.queries) var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table diff --git a/callback_row_query.go b/callback_row_query.go index 323b1605..2fda6d3d 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "fmt" ) @@ -20,7 +21,7 @@ type RowsQueryResult struct { } // queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { +func rowQueryCallback(ctx context.Context, scope *Scope) { if result, ok := scope.InstanceGet("row_query_result"); ok { scope.prepareQuerySQL() @@ -33,9 +34,9 @@ func rowQueryCallback(scope *Scope) { } if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) + rowResult.Row = scope.SQLDB().QueryRowContext(ctx, scope.SQL, scope.SQLVars...) } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) + rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(ctx, scope.SQL, scope.SQLVars...) } } } diff --git a/callback_save.go b/callback_save.go index 3b4e0589..9627af2b 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,15 +1,16 @@ package gorm import ( + "context" "reflect" "strings" ) -func beginTransactionCallback(scope *Scope) { +func beginTransactionCallback(_ctx context.Context, scope *Scope) { scope.Begin() } -func commitOrRollbackTransactionCallback(scope *Scope) { +func commitOrRollbackTransactionCallback(_ctx context.Context, scope *Scope) { scope.CommitOrRollback() } @@ -64,7 +65,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea return } -func saveBeforeAssociationsCallback(scope *Scope) { +func saveBeforeAssociationsCallback(_ctx context.Context, scope *Scope) { for _, field := range scope.Fields() { autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) @@ -95,7 +96,7 @@ func saveBeforeAssociationsCallback(scope *Scope) { } } -func saveAfterAssociationsCallback(scope *Scope) { +func saveAfterAssociationsCallback(_ctx context.Context, scope *Scope) { for _, field := range scope.Fields() { autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) diff --git a/callback_system_test.go b/callback_system_test.go index 2482eda4..69886a0f 100644 --- a/callback_system_test.go +++ b/callback_system_test.go @@ -1,13 +1,14 @@ package gorm import ( + "context" "reflect" "runtime" "strings" "testing" ) -func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { +func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool { var names []string for _, f := range funcs { fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") @@ -16,11 +17,11 @@ func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { return reflect.DeepEqual(names, fnames) } -func create(s *Scope) {} -func beforeCreate1(s *Scope) {} -func beforeCreate2(s *Scope) {} -func afterCreate1(s *Scope) {} -func afterCreate2(s *Scope) {} +func create(ctx context.Context, s *Scope) {} +func beforeCreate1(ctx context.Context, s *Scope) {} +func beforeCreate2(ctx context.Context, s *Scope) {} +func afterCreate1(ctx context.Context, s *Scope) {} +func afterCreate2(ctx context.Context, s *Scope) {} func TestRegisterCallback(t *testing.T) { var callback = &Callback{logger: defaultLogger} @@ -83,7 +84,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) { } } -func replaceCreate(s *Scope) {} +func replaceCreate(ctx context.Context, s *Scope) {} func TestReplaceCallback(t *testing.T) { var callback = &Callback{logger: defaultLogger} diff --git a/callback_update.go b/callback_update.go index 699e534b..b4daefc3 100644 --- a/callback_update.go +++ b/callback_update.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "errors" "fmt" "sort" @@ -21,7 +22,7 @@ func init() { } // assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { +func assignUpdatingAttributesCallback(_ctx context.Context, scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { scope.InstanceSet("gorm:update_attrs", updateMaps) @@ -32,7 +33,7 @@ func assignUpdatingAttributesCallback(scope *Scope) { } // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { +func beforeUpdateCallback(_ctx context.Context, scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { scope.Err(errors.New("missing WHERE clause while updating")) return @@ -48,14 +49,14 @@ func beforeUpdateCallback(scope *Scope) { } // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { +func updateTimeStampForUpdateCallback(_ctx context.Context, scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } // updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { +func updateCallback(ctx context.Context, scope *Scope) { if !scope.HasError() { var sqls []string @@ -103,13 +104,13 @@ func updateCallback(scope *Scope) { strings.Join(sqls, ", "), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), - )).Exec() + )).Exec(ctx) } } } // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { +func afterUpdateCallback(_ctx context.Context, scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { scope.CallMethod("AfterUpdate") diff --git a/callbacks_test.go b/callbacks_test.go index bebd0e38..6e1f2148 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "errors" "reflect" "testing" @@ -182,22 +183,23 @@ func TestGetCallback(t *testing.T) { t.Errorf("`gorm:test_callback` should be nil") } - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) + DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) callback := DB.Callback().Create().Get("gorm:test_callback") if callback == nil { t.Errorf("`gorm:test_callback` should be non-nil") } - callback(scope) + ctx := context.Background() + callback(ctx, scope) if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) } - DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) + DB.Callback().Create().Replace("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) callback = DB.Callback().Create().Get("gorm:test_callback") if callback == nil { t.Errorf("`gorm:test_callback` should be non-nil") } - callback(scope) + callback(ctx, scope) if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) } @@ -207,12 +209,12 @@ func TestGetCallback(t *testing.T) { t.Errorf("`gorm:test_callback` should be nil") } - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) + DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) callback = DB.Callback().Create().Get("gorm:test_callback") if callback == nil { t.Errorf("`gorm:test_callback` should be non-nil") } - callback(scope) + callback(ctx, scope) if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) } @@ -220,7 +222,7 @@ func TestGetCallback(t *testing.T) { func TestUseDefaultCallback(t *testing.T) { createCallbackName := "gorm:test_use_default_callback_for_create" - gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { + gorm.DefaultCallback.Create().Register(createCallbackName, func(context.Context, *gorm.Scope) { // nop }) if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { @@ -233,16 +235,17 @@ func TestUseDefaultCallback(t *testing.T) { updateCallbackName := "gorm:test_use_default_callback_for_update" scopeValueName := "gorm:test_use_default_callback_for_update_value" - gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { + gorm.DefaultCallback.Update().Register(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) { scope.Set(scopeValueName, 1) }) - gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { + gorm.DefaultCallback.Update().Replace(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) { scope.Set(scopeValueName, 2) }) scope := DB.NewScope(nil) callback := gorm.DefaultCallback.Update().Get(updateCallbackName) - callback(scope) + ctx := context.Background() + callback(ctx, scope) if v, ok := scope.Get(scopeValueName); !ok || v != 2 { t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) } diff --git a/customize_column_test.go b/customize_column_test.go index c236ac24..529a3bcf 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "testing" "time" @@ -21,12 +22,13 @@ type CustomColumnAndIgnoredFieldClash struct { } func TestCustomizeColumn(t *testing.T) { + ctx := context.Background() col := "mapped_name" DB.DropTable(&CustomizeColumn{}) DB.AutoMigrate(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { + if !scope.Dialect().HasColumn(ctx, scope.TableName(), col) { t.Errorf("CustomizeColumn should have column %s", col) } diff --git a/dialect.go b/dialect.go index 749587f4..64496df5 100644 --- a/dialect.go +++ b/dialect.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "fmt" "reflect" @@ -24,26 +25,26 @@ type Dialect interface { DataTypeOf(field *StructField) string // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool + HasIndex(ctx context.Context, tableName string, indexName string) bool // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool + HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error + RemoveIndex(ctx context.Context, tableName string, indexName string) error // HasTable check has table or not - HasTable(tableName string) bool + HasTable(ctx context.Context, tableName string) bool // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool + HasColumn(ctx context.Context, tableName string, columnName string) bool // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error + ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) (string, error) // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` - LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string + LastInsertIDOutputInterstitial(ctx context.Context, tableName, columnName string, columns []string) string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string + LastInsertIDReturningSuffix(ctx context.Context, tableName, columnName string) string // DefaultValueStr DefaultValueStr() string @@ -54,7 +55,7 @@ type Dialect interface { NormalizeIndexAndColumn(indexName, columnName string) (string, string) // CurrentDatabase return current database name - CurrentDatabase() string + CurrentDatabase(ctx context.Context) string } var dialectsMap = map[string]Dialect{} @@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { +func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) return splitStrings[0], splitStrings[1] } - return dialect.CurrentDatabase(), tableName + return dialect.CurrentDatabase(ctx), tableName } diff --git a/dialect_common.go b/dialect_common.go index d549510c..3dfa58b6 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "fmt" "reflect" "regexp" @@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s commonDialect) HasIndex(tableName string, indexName string) bool { +func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) +func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName)) return err } -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s commonDialect) HasForeignKey(_ctx context.Context, tableName string, foreignKeyName string) bool { return false } -func (s commonDialect) HasTable(tableName string) bool { +func (s commonDialect) HasTable(ctx context.Context, tableName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } -func (s commonDialect) HasColumn(tableName string, columnName string) bool { +func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) +func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) return err } -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) +func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name) return } @@ -162,11 +163,11 @@ func (commonDialect) SelectFromDummyTable() string { return "" } -func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { +func (commonDialect) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string { return "" } -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { +func (commonDialect) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string { return "" } diff --git a/dialect_mysql.go b/dialect_mysql.go index b4467ffa..e8c27f99 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "crypto/sha1" "database/sql" "fmt" @@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) +func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) return err } -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) +func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) return err } @@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err return } -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s mysql) HasTable(tableName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) +func (s mysql) HasTable(ctx context.Context, tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) var name string // allow mysql database name with '-' character - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { if err == sql.ErrNoRows { return false } @@ -183,9 +184,9 @@ func (s mysql) HasTable(tableName string) bool { } } -func (s mysql) HasIndex(tableName string, indexName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { +func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { panic(err) } else { defer rows.Close() @@ -193,9 +194,9 @@ func (s mysql) HasIndex(tableName string, indexName string) bool { } } -func (s mysql) HasColumn(tableName string, columnName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { +func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { panic(err) } else { defer rows.Close() @@ -203,8 +204,8 @@ func (s mysql) HasColumn(tableName string, columnName string) bool { } } -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) +func (s mysql) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_postgres.go b/dialect_postgres.go index d2df3131..4dbaa2e6 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "encoding/json" "fmt" "reflect" @@ -91,40 +92,40 @@ func (s *postgres) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s postgres) HasIndex(tableName string, indexName string) bool { +func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) return count > 0 } -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s postgres) HasTable(tableName string) bool { +func (s postgres) HasTable(ctx context.Context, tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) return count > 0 } -func (s postgres) HasColumn(tableName string, columnName string) bool { +func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) return count > 0 } -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) +func (s postgres) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name) return } -func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { +func (s postgres) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, key string, columns []string) string { return "" } -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { +func (s postgres) LastInsertIDReturningSuffix(_ctx context.Context, tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 5f96c363..6db7191f 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "fmt" "reflect" "strings" @@ -70,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s sqlite3) HasIndex(tableName string, indexName string) bool { +func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) + s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) return count > 0 } -func (s sqlite3) HasTable(tableName string) bool { +func (s sqlite3) HasTable(ctx context.Context, tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) return count > 0 } -func (s sqlite3) HasColumn(tableName string, columnName string) bool { +func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) + s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) return count > 0 } -func (s sqlite3) CurrentDatabase() (name string) { +func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) @@ -97,7 +98,7 @@ func (s sqlite3) CurrentDatabase() (name string) { for i = 0; i < 3; i++ { ifaces[i] = &pointers[i] } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { + if err := s.db.QueryRowContext(ctx, "PRAGMA database_list").Scan(ifaces...); err != nil { return } if pointers[1] != nil { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a516ed4a..308f536d 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql/driver" "encoding/json" "errors" @@ -15,7 +16,7 @@ import ( "github.com/jinzhu/gorm" ) -func setIdentityInsert(scope *gorm.Scope) { +func setIdentityInsert(ctx context.Context, scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { for _, field := range scope.PrimaryFields() { if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { @@ -26,7 +27,7 @@ func setIdentityInsert(scope *gorm.Scope) { } } -func turnOffIdentityInsert(scope *gorm.Scope) { +func turnOffIdentityInsert(ctx context.Context, scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) @@ -122,49 +123,49 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { return field.IsPrimaryKey } -func (s mssql) HasIndex(tableName string, indexName string) bool { +func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) return count > 0 } -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) +func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) return err } -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow(`SELECT count(*) - FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id - inner join information_schema.tables as I on I.TABLE_NAME = T.name - WHERE F.name = ? + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, `SELECT count(*) + FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id + inner join information_schema.tables as I on I.TABLE_NAME = T.name + WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) return count > 0 } -func (s mssql) HasTable(tableName string) bool { +func (s mssql) HasTable(ctx context.Context, tableName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } -func (s mssql) HasColumn(tableName string, columnName string) bool { +func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } -func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) +func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) return err } -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) +func (s mssql) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name) return } @@ -198,7 +199,7 @@ func (mssql) SelectFromDummyTable() string { return "" } -func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { +func (mssql) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string { if len(columns) == 0 { // No OUTPUT to query return "" @@ -206,7 +207,7 @@ func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, column return fmt.Sprintf("OUTPUT Inserted.%v", columnName) } -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { +func (mssql) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string { // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id return "; SELECT SCOPE_IDENTITY()" } @@ -220,12 +221,12 @@ func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, stri return indexName, columnName } -func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { +func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) return splitStrings[0], splitStrings[1] } - return dialect.CurrentDatabase(), tableName + return dialect.CurrentDatabase(ctx), tableName } // JSON type to support easy handling of JSON data in character table fields diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 5f8ece57..e6ed3e65 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -1,6 +1,9 @@ package gorm_test -import "testing" +import ( + "context" + "testing" +) type BasePost struct { Id int64 @@ -27,9 +30,10 @@ type EngadgetPost struct { } func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { + ctx := context.Background() dialect := DB.NewScope(&EngadgetPost{}).Dialect() engadgetPostScope := DB.NewScope(&EngadgetPost{}) - if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { + if !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_email") { t.Errorf("should has prefix for embedded columns") } @@ -38,7 +42,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { } hnScope := DB.NewScope(&HNPost{}) - if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { + if !dialect.HasColumn(ctx, hnScope.TableName(), "user_id") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_name") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_email") { t.Errorf("should has prefix for embedded columns") } } diff --git a/interface.go b/interface.go index fe649231..8f119521 100644 --- a/interface.go +++ b/interface.go @@ -7,10 +7,10 @@ import ( // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } type sqlDb interface { diff --git a/main.go b/main.go index 3db87870..275ffb7b 100644 --- a/main.go +++ b/main.go @@ -327,51 +327,91 @@ func (s *DB) Assign(attrs ...interface{}) *DB { // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { + ctx := context.Background() + return s.FirstContext(ctx, out, where...) +} + +func (s *DB) FirstContext(ctx context.Context, out interface{}, where ...interface{}) *DB { newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db } // Take return a record that match given conditions, the order will depend on the database implementation func (s *DB) Take(out interface{}, where ...interface{}) *DB { + ctx := context.Background() + return s.TakeContext(ctx, out, where...) +} + +func (s *DB) TakeContext(ctx context.Context, out interface{}, where ...interface{}) *DB { newScope := s.NewScope(out) newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return newScope.inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db } // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { + ctx := context.Background() + return s.LastContext(ctx, out, where...) +} + +func (s *DB) LastContext(ctx context.Context, out interface{}, where ...interface{}) *DB { newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db } // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + ctx := context.Background() + return s.FindContext(ctx, out, where...) +} + +func (s *DB) FindContext(ctx context.Context, out interface{}, where ...interface{}) *DB { + return s.NewScope(out).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.queries).db } //Preloads preloads relations, don`t touch out func (s *DB) Preloads(out interface{}) *DB { - return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db + ctx := context.Background() + return s.PreloadsContext(ctx, out) +} + +func (s *DB) PreloadsContext(ctx context.Context, out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(ctx, s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + ctx := context.Background() + return s.ScanContext(ctx, dest) +} + +func (s *DB) ScanContext(ctx context.Context, dest interface{}) *DB { + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(ctx, s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() + ctx := context.Background() + return s.RowContext(ctx) +} + +func (s *DB) RowContext(ctx context.Context) *sql.Row { + return s.NewScope(s.Value).row(ctx) } // Rows return `*sql.Rows` with given conditions func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() + ctx := context.Background() + return s.RowsContext(ctx) +} + +func (s *DB) RowsContext(ctx context.Context) (*sql.Rows, error) { + return s.NewScope(s.Value).rows(ctx) } // ScanRows scan `*sql.Rows` to give struct @@ -393,12 +433,22 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { // var ages []int64 // db.Find(&users).Pluck("age", &ages) func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db + ctx := context.Background() + return s.PluckContext(ctx, column, value) +} + +func (s *DB) PluckContext(ctx context.Context, column string, value interface{}) *DB { + return s.NewScope(s.Value).pluck(ctx, column, value).db } // Count get how many records for a model func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db + ctx := context.Background() + return s.CountContext(ctx, value) +} + +func (s *DB) CountContext(ctx context.Context, value interface{}) *DB { + return s.NewScope(s.Value).count(ctx, value).db } // Related get related associations @@ -409,8 +459,13 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) // https://jinzhu.github.io/gorm/crud.html#firstorinit func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { + ctx := context.Background() + return s.FirstOrInitContext(ctx, out, where...) +} + +func (s *DB) FirstOrInitContext(ctx context.Context, out interface{}, where ...interface{}) *DB { c := s.clone() - if result := c.First(out, where...); result.Error != nil { + if result := c.FirstContext(ctx, out, where...); result.Error != nil { if !result.RecordNotFound() { return result } @@ -424,14 +479,19 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { // FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) // https://jinzhu.github.io/gorm/crud.html#firstorcreate func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { + ctx := context.Background() + return s.FirstOrCreateContext(ctx, out, where...) +} + +func (s *DB) FirstOrCreateContext(ctx context.Context, out interface{}, where ...interface{}) *DB { c := s.clone() if result := s.First(out, where...); result.Error != nil { if !result.RecordNotFound() { return result } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db + return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(ctx, c.parent.callbacks.creates).db } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db + return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(ctx, c.parent.callbacks.updates).db } return c } @@ -439,54 +499,89 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // WARNING when update with struct, GORM will not update fields that with zero value func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) + ctx := context.Background() + return s.UpdateContext(ctx, attrs...) +} + +func (s *DB) UpdateContext(ctx context.Context, attrs ...interface{}) *DB { + return s.UpdatesContext(ctx, toSearchableMap(attrs...), true) } // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { + ctx := context.Background() + return s.UpdatesContext(ctx, values, ignoreProtectedAttrs...) +} + +func (s *DB) UpdatesContext(ctx context.Context, values interface{}, ignoreProtectedAttrs ...bool) *DB { return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db + callCallbacks(ctx, s.parent.callbacks.updates).db } // UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) + ctx := context.Background() + return s.UpdateColumnContext(ctx, attrs...) +} + +func (s *DB) UpdateColumnContext(ctx context.Context, attrs ...interface{}) *DB { + return s.UpdateColumnsContext(ctx, toSearchableMap(attrs...)) } // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { + ctx := context.Background() + return s.UpdateColumnsContext(ctx, values) +} + +func (s *DB) UpdateColumnsContext(ctx context.Context, values interface{}) *DB { return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db + callCallbacks(ctx, s.parent.callbacks.updates).db } // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { + ctx := context.Background() + return s.SaveContext(ctx, value) +} + +func (s *DB) SaveContext(ctx context.Context, value interface{}) *DB { scope := s.NewScope(value) if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db + newDB := scope.callCallbacks(ctx, s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().Table(scope.TableName()).FirstOrCreate(value) + return s.New().Table(scope.TableName()).FirstOrCreateContext(ctx, value) } return newDB } - return scope.callCallbacks(s.parent.callbacks.creates).db + return scope.callCallbacks(ctx, s.parent.callbacks.creates).db } // Create insert the value into database func (s *DB) Create(value interface{}) *DB { + ctx := context.Background() + return s.CreateContext(ctx, value) +} + +func (s *DB) CreateContext(ctx context.Context, value interface{}) *DB { scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db + return scope.callCallbacks(ctx, s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + ctx := context.Background() + return s.DeleteContext(ctx, value, where...) +} + +func (s *DB) DeleteContext(ctx context.Context, value interface{}, where ...interface{}) *DB { + return s.NewScope(value).inlineCondition(where...).callCallbacks(ctx, s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -497,11 +592,16 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { + ctx := context.Background() + return s.ExecContext(ctx, sql, values...) +} + +func (s *DB) ExecContext(ctx context.Context, sql string, values ...interface{}) *DB { scope := s.NewScope(nil) generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) - return scope.Exec().db + return scope.Exec(ctx).db } // Model specify the model you would like to run db operations @@ -628,32 +728,47 @@ func (s *DB) RecordNotFound() bool { // CreateTable create table for models func (s *DB) CreateTable(models ...interface{}) *DB { + ctx := context.Background() + return s.CreateTableContext(ctx, models...) +} + +func (s *DB) CreateTableContext(ctx context.Context, models ...interface{}) *DB { db := s.Unscoped() for _, model := range models { - db = db.NewScope(model).createTable().db + db = db.NewScope(model).createTable(ctx).db } return db } // DropTable drop table for models func (s *DB) DropTable(values ...interface{}) *DB { + ctx := context.Background() + return s.DropTableContext(ctx, values...) +} + +func (s *DB) DropTableContext(ctx context.Context, values ...interface{}) *DB { db := s.clone() for _, value := range values { if tableName, ok := value.(string); ok { db = db.Table(tableName) } - db = db.NewScope(value).dropTable().db + db = db.NewScope(value).dropTable(ctx).db } return db } // DropTableIfExists drop table if it is exist func (s *DB) DropTableIfExists(values ...interface{}) *DB { + ctx := context.Background() + return s.DropTableIfExistsContext(ctx, values...) +} + +func (s *DB) DropTableIfExistsContext(ctx context.Context, values ...interface{}) *DB { db := s.clone() for _, value := range values { if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) + db.AddError(s.DropTableContext(ctx, value).Error) } } return db @@ -661,6 +776,11 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { + ctx := context.Background() + return s.HasTableContext(ctx, value) +} + +func (s *DB) HasTableContext(ctx context.Context, value interface{}) bool { var ( scope = s.NewScope(value) tableName string @@ -672,68 +792,108 @@ func (s *DB) HasTable(value interface{}) bool { tableName = scope.TableName() } - has := scope.Dialect().HasTable(tableName) + has := scope.Dialect().HasTable(ctx, tableName) s.AddError(scope.db.Error) return has } // AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data func (s *DB) AutoMigrate(values ...interface{}) *DB { + ctx := context.Background() + return s.AutoMigrateContext(ctx, values...) +} + +func (s *DB) AutoMigrateContext(ctx context.Context, values ...interface{}) *DB { db := s.Unscoped() for _, value := range values { - db = db.NewScope(value).autoMigrate().db + db = db.NewScope(value).autoMigrate(ctx).db } return db } // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { + ctx := context.Background() + return s.ModifyColumnContext(ctx, column, typ) +} + +func (s *DB) ModifyColumnContext(ctx context.Context, column string, typ string) *DB { scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) + scope.modifyColumn(ctx, column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { + ctx := context.Background() + return s.DropColumnContext(ctx, column) +} + +func (s *DB) DropColumnContext(ctx context.Context, column string) *DB { scope := s.NewScope(s.Value) - scope.dropColumn(column) + scope.dropColumn(ctx, column) return scope.db } // AddIndex add index for columns with given name func (s *DB) AddIndex(indexName string, columns ...string) *DB { + ctx := context.Background() + return s.AddIndexContext(ctx, indexName, columns...) +} + +func (s *DB) AddIndexContext(ctx context.Context, indexName string, columns ...string) *DB { scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) + scope.addIndex(ctx, false, indexName, columns...) return scope.db } // AddUniqueIndex add unique index for columns with given name func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { + ctx := context.Background() + return s.AddUniqueIndexContext(ctx, indexName, columns...) +} + +func (s *DB) AddUniqueIndexContext(ctx context.Context, indexName string, columns ...string) *DB { scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) + scope.addIndex(ctx, true, indexName, columns...) return scope.db } // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { + ctx := context.Background() + return s.RemoveIndexContext(ctx, indexName) +} + +func (s *DB) RemoveIndexContext(ctx context.Context, indexName string) *DB { scope := s.NewScope(s.Value) - scope.removeIndex(indexName) + scope.removeIndex(ctx, indexName) return scope.db } // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { + ctx := context.Background() + return s.AddForeignKeyContext(ctx, field, dest, onDelete, onUpdate) +} + +func (s *DB) AddForeignKeyContext(ctx context.Context, field string, dest string, onDelete string, onUpdate string) *DB { scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) + scope.addForeignKey(ctx, field, dest, onDelete, onUpdate) return scope.db } // RemoveForeignKey Remove foreign key from the given scope, e.g: // db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") func (s *DB) RemoveForeignKey(field string, dest string) *DB { + ctx := context.Background() + return s.RemoveForeignKeyContext(ctx, field, dest) +} + +func (s *DB) RemoveForeignKeyContext(ctx context.Context, field string, dest string) *DB { scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) + scope.removeForeignKey(ctx, field, dest) return scope.db } @@ -784,6 +944,11 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { // SetJoinTableHandler set a model's join table handler for a relation func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { + ctx := context.Background() + s.SetJoinTableHandlerContext(ctx, source, column, handler) +} + +func (s *DB) SetJoinTableHandlerContext(ctx context.Context, source interface{}, column string, handler JoinTableHandlerInterface) { scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { @@ -792,7 +957,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { + if table := handler.Table(s); scope.Dialect().HasTable(ctx, table) { s.Table(table).AutoMigrate(handler) } } diff --git a/migration_test.go b/migration_test.go index d94ec9ec..416a7f0a 100644 --- a/migration_test.go +++ b/migration_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -302,12 +303,14 @@ func runMigration() { } func TestIndexes(t *testing.T) { + ctx := context.Background() + if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { t.Errorf("Got error when tried to create index: %+v", err) } scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") { t.Errorf("Email should have index idx_email_email") } @@ -315,7 +318,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") { t.Errorf("Email's index idx_email_email should be deleted") } @@ -323,7 +326,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -331,7 +334,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -339,7 +342,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -361,7 +364,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -381,6 +384,8 @@ type EmailWithIdx struct { } func TestAutoMigration(t *testing.T) { + ctx := context.Background() + DB.AutoMigrate(&Address{}) DB.DropTable(&EmailWithIdx{}) if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { @@ -391,11 +396,11 @@ func TestAutoMigration(t *testing.T) { DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) scope := DB.NewScope(&EmailWithIdx{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_agent") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_email_with_idxes_registered_at") { t.Errorf("Failed to create index") } @@ -462,6 +467,8 @@ type MultipleIndexes struct { } func TestMultipleIndexes(t *testing.T) { + ctx := context.Background() + if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) } @@ -474,23 +481,23 @@ func TestMultipleIndexes(t *testing.T) { DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_name") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_email") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multiple_indexes_email") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multipleindexes_user_other") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { + if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multiple_indexes_other") { t.Errorf("Failed to create index") } @@ -540,6 +547,8 @@ func TestModifyColumnType(t *testing.T) { } func TestIndexWithPrefixLength(t *testing.T) { + ctx := context.Background() + if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { t.Skip("Skipping this because only mysql support setting an index prefix length") } @@ -571,7 +580,7 @@ func TestIndexWithPrefixLength(t *testing.T) { if err := DB.CreateTable(table).Error; err != nil { t.Errorf("Failed to create %s table: %v", tableName, err) } - if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { + if !scope.Dialect().HasIndex(ctx, tableName, "idx_index_with_prefixes_length") { t.Errorf("Failed to create %s table index:", tableName) } }) diff --git a/preload_test.go b/preload_test.go index dd29fb5e..17d94adb 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "database/sql" "encoding/json" "os" @@ -1684,7 +1685,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { called := 0 - DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(ctx context.Context, scope *gorm.Scope) { called = called + 1 }) diff --git a/scope.go b/scope.go index d82cadbc..90e2ea4f 100644 --- a/scope.go +++ b/scope.go @@ -2,6 +2,7 @@ package gorm import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -357,11 +358,11 @@ func (scope *Scope) Raw(sql string) *Scope { } // Exec perform generated SQL -func (scope *Scope) Exec() *Scope { +func (scope *Scope) Exec(ctx context.Context) *Scope { defer scope.trace(NowFunc()) if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().ExecContext(ctx, scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } @@ -856,7 +857,7 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { return scope } -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { +func (scope *Scope) callCallbacks(ctx context.Context, funcs []*func(ctx context.Context, s *Scope)) *Scope { defer func() { if err := recover(); err != nil { if db, ok := scope.db.db.(sqlTx); ok { @@ -866,7 +867,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { } }() for _, f := range funcs { - (*f)(scope) + (*f)(ctx, scope) if scope.skipLeft { break } @@ -933,22 +934,22 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin return } -func (scope *Scope) row() *sql.Row { +func (scope *Scope) row(ctx context.Context) *sql.Row { defer scope.trace(NowFunc()) result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries) return result.Row } -func (scope *Scope) rows() (*sql.Rows, error) { +func (scope *Scope) rows(ctx context.Context) (*sql.Rows, error) { defer scope.trace(NowFunc()) result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.callCallbacks(ctx, scope.db.parent.callbacks.rowQueries) return result.Rows, result.Error } @@ -979,7 +980,7 @@ func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { return false } -func (scope *Scope) pluck(column string, value interface{}) *Scope { +func (scope *Scope) pluck(ctx context.Context, column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) @@ -994,7 +995,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { scope.Search.Select(column) } - rows, err := scope.rows() + rows, err := scope.rows(ctx) if scope.Err(err) == nil { defer rows.Close() for rows.Next() { @@ -1010,7 +1011,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { return scope } -func (scope *Scope) count(value interface{}) *Scope { +func (scope *Scope) count(ctx context.Context, value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { if len(scope.Search.group) != 0 { if len(scope.Search.havingConditions) != 0 { @@ -1027,7 +1028,7 @@ func (scope *Scope) count(value interface{}) *Scope { } } scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) + scope.Err(scope.row(ctx).Scan(value)) return scope } @@ -1124,11 +1125,11 @@ func (scope *Scope) getTableOptions() string { return " " + tableOptions.(string) } -func (scope *Scope) createJoinTable(field *StructField) { +func (scope *Scope) createJoinTable(ctx context.Context, field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { + if !scope.Dialect().HasTable(ctx, joinTable) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes, primaryKeys []string @@ -1156,11 +1157,11 @@ func (scope *Scope) createJoinTable(field *StructField) { scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) + scope.NewDB().Table(joinTable).AutoMigrateContext(ctx, joinTableHandler) } } -func (scope *Scope) createTable() *Scope { +func (scope *Scope) createTable(ctx context.Context) *Scope { var tags []string var primaryKeys []string var primaryKeyInColumnType = false @@ -1181,7 +1182,7 @@ func (scope *Scope) createTable() *Scope { if field.IsPrimaryKey { primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) } - scope.createJoinTable(field) + scope.createJoinTable(ctx, field) } var primaryKeyStr string @@ -1189,27 +1190,27 @@ func (scope *Scope) createTable() *Scope { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec(ctx) scope.autoIndex() return scope } -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() +func (scope *Scope) dropTable(ctx context.Context) *Scope { + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec(ctx) return scope } -func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) +func (scope *Scope) modifyColumn(ctx context.Context, column string, typ string) { + scope.db.AddError(scope.Dialect().ModifyColumn(ctx, scope.QuotedTableName(), scope.Quote(column), typ)) } -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() +func (scope *Scope) dropColumn(ctx context.Context, column string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec(ctx) } -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { +func (scope *Scope) addIndex(ctx context.Context, unique bool, indexName string, column ...string) { + if scope.Dialect().HasIndex(ctx, scope.TableName(), indexName) { return } @@ -1223,23 +1224,23 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { sqlCreate = "CREATE UNIQUE INDEX" } - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() + scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec(ctx) } -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { +func (scope *Scope) addForeignKey(ctx context.Context, field string, dest string, onDelete string, onUpdate string) { // Compatible with old generated key keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + if scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) { return } var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec(ctx) } -func (scope *Scope) removeForeignKey(field string, dest string) { +func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + if !scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) { return } var mysql mysql @@ -1250,28 +1251,28 @@ func (scope *Scope) removeForeignKey(field string, dest string) { query = `ALTER TABLE %s DROP CONSTRAINT %s;` } - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec(ctx) } -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) +func (scope *Scope) removeIndex(ctx context.Context, indexName string) { + scope.Dialect().RemoveIndex(ctx, scope.TableName(), indexName) } -func (scope *Scope) autoMigrate() *Scope { +func (scope *Scope) autoMigrate(ctx context.Context) *Scope { tableName := scope.TableName() quotedTableName := scope.QuotedTableName() - if !scope.Dialect().HasTable(tableName) { - scope.createTable() + if !scope.Dialect().HasTable(ctx, tableName) { + scope.createTable(ctx) } else { for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { + if !scope.Dialect().HasColumn(ctx, tableName, field.DBName) { if field.IsNormal { sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec(ctx) } } - scope.createJoinTable(field) + scope.createJoinTable(ctx, field) } scope.autoIndex() }