keep old public api
This commit is contained in:
parent
88e0495956
commit
75d6dc912c
35
callback.go
35
callback.go
@ -96,7 +96,16 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register a new callback, refer `Callbacks.Create`
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
|
callbackContext := func(ctx context.Context, scope *Scope) {
|
||||||
|
callback(scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
cp.RegisterContext(callbackName, callbackContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterContext same as Register
|
||||||
|
func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
||||||
@ -126,7 +135,16 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
// scope.SetColumn("CreatedAt", now)
|
// scope.SetColumn("CreatedAt", now)
|
||||||
// scope.SetColumn("UpdatedAt", now)
|
// scope.SetColumn("UpdatedAt", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
|
callbackContext := func(ctx context.Context, scope *Scope) {
|
||||||
|
callback(scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
cp.ReplaceContext(callbackName, callbackContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceContext same as Replace
|
||||||
|
func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
||||||
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
@ -137,7 +155,18 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(ctx cont
|
|||||||
|
|
||||||
// Get registered callback
|
// Get registered callback
|
||||||
// db.Callback().Create().Get("gorm:create")
|
// db.Callback().Create().Get("gorm:create")
|
||||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||||
|
c := cp.GetContext(callbackName)
|
||||||
|
|
||||||
|
callback = func(scope *Scope) {
|
||||||
|
ctx := context.Background()
|
||||||
|
c(ctx, scope)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContext same as Get
|
||||||
|
func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
|
||||||
for _, p := range cp.parent.processors {
|
for _, p := range cp.parent.processors {
|
||||||
if p.name == callbackName && p.kind == cp.kind {
|
if p.name == callbackName && p.kind == cp.kind {
|
||||||
if p.remove {
|
if p.remove {
|
||||||
|
@ -8,15 +8,15 @@ import (
|
|||||||
|
|
||||||
// Define callbacks for creating
|
// Define callbacks for creating
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
DefaultCallback.Create().RegisterContext("gorm:begin_transaction", beginTransactionCallback)
|
||||||
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
DefaultCallback.Create().RegisterContext("gorm:before_create", beforeCreateCallback)
|
||||||
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
DefaultCallback.Create().RegisterContext("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||||
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
DefaultCallback.Create().RegisterContext("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||||
DefaultCallback.Create().Register("gorm:create", createCallback)
|
DefaultCallback.Create().RegisterContext("gorm:create", createCallback)
|
||||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
DefaultCallback.Create().RegisterContext("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||||
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
DefaultCallback.Create().RegisterContext("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||||
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
DefaultCallback.Create().RegisterContext("gorm:after_create", afterCreateCallback)
|
||||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
DefaultCallback.Create().RegisterContext("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||||
@ -101,10 +101,10 @@ func createCallback(ctx context.Context, scope *Scope) {
|
|||||||
returningColumn = scope.Quote(primaryField.DBName)
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(ctx, quotedTableName, returningColumn, columns)
|
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitialContext(ctx, quotedTableName, returningColumn, columns)
|
||||||
var lastInsertIDReturningSuffix string
|
var lastInsertIDReturningSuffix string
|
||||||
if lastInsertIDOutputInterstitial == "" {
|
if lastInsertIDOutputInterstitial == "" {
|
||||||
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(ctx, quotedTableName, returningColumn)
|
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffixContext(ctx, quotedTableName, returningColumn)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
|
@ -8,11 +8,11 @@ import (
|
|||||||
|
|
||||||
// Define callbacks for deleting
|
// Define callbacks for deleting
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
DefaultCallback.Delete().RegisterContext("gorm:begin_transaction", beginTransactionCallback)
|
||||||
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
DefaultCallback.Delete().RegisterContext("gorm:before_delete", beforeDeleteCallback)
|
||||||
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
DefaultCallback.Delete().RegisterContext("gorm:delete", deleteCallback)
|
||||||
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
DefaultCallback.Delete().RegisterContext("gorm:after_delete", afterDeleteCallback)
|
||||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
DefaultCallback.Delete().RegisterContext("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||||
|
@ -9,9 +9,9 @@ import (
|
|||||||
|
|
||||||
// Define callbacks for querying
|
// Define callbacks for querying
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
DefaultCallback.Query().RegisterContext("gorm:query", queryCallback)
|
||||||
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
DefaultCallback.Query().RegisterContext("gorm:preload", preloadCallback)
|
||||||
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
DefaultCallback.Query().RegisterContext("gorm:after_query", afterQueryCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryCallback used to query data from database
|
// queryCallback used to query data from database
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
// Define callbacks for row query
|
// Define callbacks for row query
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
|
DefaultCallback.RowQuery().RegisterContext("gorm:row_query", rowQueryCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
type RowQueryResult struct {
|
type RowQueryResult struct {
|
||||||
|
@ -17,11 +17,11 @@ func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) b
|
|||||||
return reflect.DeepEqual(names, fnames)
|
return reflect.DeepEqual(names, fnames)
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(ctx context.Context, s *Scope) {}
|
func create(s *Scope) {}
|
||||||
func beforeCreate1(ctx context.Context, s *Scope) {}
|
func beforeCreate1(s *Scope) {}
|
||||||
func beforeCreate2(ctx context.Context, s *Scope) {}
|
func beforeCreate2(s *Scope) {}
|
||||||
func afterCreate1(ctx context.Context, s *Scope) {}
|
func afterCreate1(s *Scope) {}
|
||||||
func afterCreate2(ctx context.Context, s *Scope) {}
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
@ -84,7 +84,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceCreate(ctx context.Context, s *Scope) {}
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
@ -10,15 +10,15 @@ import (
|
|||||||
|
|
||||||
// Define callbacks for updating
|
// Define callbacks for updating
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
DefaultCallback.Update().RegisterContext("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||||
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
DefaultCallback.Update().RegisterContext("gorm:begin_transaction", beginTransactionCallback)
|
||||||
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
DefaultCallback.Update().RegisterContext("gorm:before_update", beforeUpdateCallback)
|
||||||
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
DefaultCallback.Update().RegisterContext("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||||
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
DefaultCallback.Update().RegisterContext("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||||
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
DefaultCallback.Update().RegisterContext("gorm:update", updateCallback)
|
||||||
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
DefaultCallback.Update().RegisterContext("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||||
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
DefaultCallback.Update().RegisterContext("gorm:after_update", afterUpdateCallback)
|
||||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
DefaultCallback.Update().RegisterContext("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@ -183,23 +182,22 @@ func TestGetCallback(t *testing.T) {
|
|||||||
t.Errorf("`gorm:test_callback` should be nil")
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||||
callback := DB.Callback().Create().Get("gorm:test_callback")
|
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
callback(scope)
|
||||||
callback(ctx, scope)
|
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Replace("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(ctx, scope)
|
callback(scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
@ -209,12 +207,12 @@ func TestGetCallback(t *testing.T) {
|
|||||||
t.Errorf("`gorm:test_callback` should be nil")
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Register("gorm:test_callback", func(ctx context.Context, scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(ctx, scope)
|
callback(scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
@ -222,7 +220,7 @@ func TestGetCallback(t *testing.T) {
|
|||||||
|
|
||||||
func TestUseDefaultCallback(t *testing.T) {
|
func TestUseDefaultCallback(t *testing.T) {
|
||||||
createCallbackName := "gorm:test_use_default_callback_for_create"
|
createCallbackName := "gorm:test_use_default_callback_for_create"
|
||||||
gorm.DefaultCallback.Create().Register(createCallbackName, func(context.Context, *gorm.Scope) {
|
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
|
||||||
// nop
|
// nop
|
||||||
})
|
})
|
||||||
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
||||||
@ -235,17 +233,16 @@ func TestUseDefaultCallback(t *testing.T) {
|
|||||||
|
|
||||||
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
||||||
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
||||||
gorm.DefaultCallback.Update().Register(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
|
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
|
||||||
scope.Set(scopeValueName, 1)
|
scope.Set(scopeValueName, 1)
|
||||||
})
|
})
|
||||||
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(ctx context.Context, scope *gorm.Scope) {
|
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
|
||||||
scope.Set(scopeValueName, 2)
|
scope.Set(scopeValueName, 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
scope := DB.NewScope(nil)
|
scope := DB.NewScope(nil)
|
||||||
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
||||||
ctx := context.Background()
|
callback(scope)
|
||||||
callback(ctx, scope)
|
|
||||||
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
||||||
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -22,13 +21,12 @@ type CustomColumnAndIgnoredFieldClash struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomizeColumn(t *testing.T) {
|
func TestCustomizeColumn(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
col := "mapped_name"
|
col := "mapped_name"
|
||||||
DB.DropTable(&CustomizeColumn{})
|
DB.DropTable(&CustomizeColumn{})
|
||||||
DB.AutoMigrate(&CustomizeColumn{})
|
DB.AutoMigrate(&CustomizeColumn{})
|
||||||
|
|
||||||
scope := DB.NewScope(&CustomizeColumn{})
|
scope := DB.NewScope(&CustomizeColumn{})
|
||||||
if !scope.Dialect().HasColumn(ctx, scope.TableName(), col) {
|
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||||
t.Errorf("CustomizeColumn should have column %s", col)
|
t.Errorf("CustomizeColumn should have column %s", col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
39
dialect.go
39
dialect.go
@ -23,28 +23,43 @@ type Dialect interface {
|
|||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
// DataTypeOf return data's sql type
|
// DataTypeOf return data's sql type
|
||||||
DataTypeOf(field *StructField) string
|
DataTypeOf(field *StructField) string
|
||||||
|
|
||||||
// HasIndex check has index or not
|
// HasIndex check has index or not
|
||||||
HasIndex(ctx context.Context, tableName string, indexName string) bool
|
HasIndex(tableName string, indexName string) bool
|
||||||
|
// HasIndexContext same as HasIndex
|
||||||
|
HasIndexContext(ctx context.Context, tableName string, indexName string) bool
|
||||||
// HasForeignKey check has foreign key or not
|
// HasForeignKey check has foreign key or not
|
||||||
HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool
|
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||||
|
// HasForeignKeyContext same as HasForeignKey
|
||||||
|
HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool
|
||||||
// RemoveIndex remove index
|
// RemoveIndex remove index
|
||||||
RemoveIndex(ctx context.Context, tableName string, indexName string) error
|
RemoveIndex(tableName string, indexName string) error
|
||||||
|
// RemoveIndexContext same as RemoveIndex
|
||||||
|
RemoveIndexContext(ctx context.Context, tableName string, indexName string) error
|
||||||
// HasTable check has table or not
|
// HasTable check has table or not
|
||||||
HasTable(ctx context.Context, tableName string) bool
|
HasTable(tableName string) bool
|
||||||
|
// HasTableContext same as HasTable
|
||||||
|
HasTableContext(ctx context.Context, tableName string) bool
|
||||||
// HasColumn check has column or not
|
// HasColumn check has column or not
|
||||||
HasColumn(ctx context.Context, tableName string, columnName string) bool
|
HasColumn(tableName string, columnName string) bool
|
||||||
|
// HasColumnContext same as HasColumn
|
||||||
|
HasColumnContext(ctx context.Context, tableName string, columnName string) bool
|
||||||
// ModifyColumn modify column's type
|
// ModifyColumn modify column's type
|
||||||
ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error
|
ModifyColumn(tableName string, columnName string, typ string) error
|
||||||
|
// ModifyColumnContext same as ModifyColumn
|
||||||
|
ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error
|
||||||
|
|
||||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||||
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
||||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||||
LastInsertIDOutputInterstitial(ctx context.Context, tableName, columnName string, columns []string) string
|
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
||||||
|
// LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial
|
||||||
|
LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIDReturningSuffix(ctx context.Context, tableName, columnName string) string
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
|
// LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix
|
||||||
|
LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string
|
||||||
// DefaultValueStr
|
// DefaultValueStr
|
||||||
DefaultValueStr() string
|
DefaultValueStr() string
|
||||||
|
|
||||||
@ -55,7 +70,9 @@ type Dialect interface {
|
|||||||
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase(ctx context.Context) string
|
CurrentDatabase() string
|
||||||
|
// CurrentDatabaseContext same as CurrentDatabase
|
||||||
|
CurrentDatabaseContext(ctx context.Context) string
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
@ -144,5 +161,5 @@ func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName str
|
|||||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
return splitStrings[0], splitStrings[1]
|
return splitStrings[0], splitStrings[1]
|
||||||
}
|
}
|
||||||
return dialect.CurrentDatabase(ctx), tableName
|
return dialect.CurrentDatabaseContext(ctx), tableName
|
||||||
}
|
}
|
||||||
|
@ -100,42 +100,77 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasIndexContext(ctx, tableName, indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
|
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.RemoveIndexContext(ctx, tableName, indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
||||||
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasForeignKey(_ctx context.Context, tableName string, foreignKeyName string) bool {
|
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasTable(ctx context.Context, tableName string) bool {
|
func (s commonDialect) HasTable(tableName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasTableContext(ctx, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasColumnContext(ctx, tableName, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
|
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) {
|
func (s commonDialect) CurrentDatabase() (name string) {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.CurrentDatabaseContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -163,11 +198,21 @@ func (commonDialect) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
|
func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
|
func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,12 +130,12 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
|
func (s mysql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
||||||
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
|
func (s mysql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -163,14 +163,14 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
|
func (s mysql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasTable(ctx context.Context, tableName string) bool {
|
func (s mysql) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
var name string
|
var name string
|
||||||
// allow mysql database name with '-' character
|
// allow mysql database name with '-' character
|
||||||
@ -184,7 +184,7 @@ func (s mysql) HasTable(ctx context.Context, tableName string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
func (s mysql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
|
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -194,7 +194,7 @@ func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
func (s mysql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
|
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -204,7 +204,7 @@ func (s mysql) HasColumn(ctx context.Context, tableName string, columnName strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) CurrentDatabase(ctx context.Context) (name string) {
|
func (s mysql) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -92,40 +92,40 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
func (s postgres) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
|
func (s postgres) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasTable(ctx context.Context, tableName string) bool {
|
func (s postgres) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
func (s postgres) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) CurrentDatabase(ctx context.Context) (name string) {
|
func (s postgres) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, key string, columns []string) string {
|
func (s postgres) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, key string, columns []string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIDReturningSuffix(_ctx context.Context, tableName, key string) string {
|
func (s postgres) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
func (s sqlite3) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasTable(ctx context.Context, tableName string) bool {
|
func (s sqlite3) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
func (s sqlite3) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) {
|
func (s sqlite3) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
var (
|
var (
|
||||||
ifaces = make([]interface{}, 3)
|
ifaces = make([]interface{}, 3)
|
||||||
pointers = make([]*string, 3)
|
pointers = make([]*string, 3)
|
||||||
|
@ -36,8 +36,8 @@ func turnOffIdentityInsert(ctx context.Context, scope *gorm.Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
|
gorm.DefaultCallback.Create().After("gorm:begin_transaction").RegisterContext("mssql:set_identity_insert", setIdentityInsert)
|
||||||
gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
|
gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").RegisterContext("mssql:turn_off_identity_insert", turnOffIdentityInsert)
|
||||||
gorm.RegisterDialect("mssql", &mssql{})
|
gorm.RegisterDialect("mssql", &mssql{})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,18 +123,33 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
|||||||
return field.IsPrimaryKey
|
return field.IsPrimaryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasIndexContext(ctx, tableName, indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
|
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.RemoveIndexContext(ctx, tableName, indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
||||||
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, `SELECT count(*)
|
s.db.QueryRowContext(ctx, `SELECT count(*)
|
||||||
@ -145,26 +160,47 @@ func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyNa
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTable(ctx context.Context, tableName string) bool {
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasTableContext(ctx, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.HasColumnContext(ctx, tableName, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
|
||||||
s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
|
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
|
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) CurrentDatabase(ctx context.Context) (name string) {
|
func (s mssql) CurrentDatabase() (name string) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s.CurrentDatabaseContext(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -199,7 +235,12 @@ func (mssql) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, columnName string, columns []string) string {
|
func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
// No OUTPUT to query
|
// No OUTPUT to query
|
||||||
return ""
|
return ""
|
||||||
@ -207,7 +248,12 @@ func (mssql) LastInsertIDOutputInterstitial(_ctx context.Context, tableName, col
|
|||||||
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDReturningSuffix(_ctx context.Context, tableName, columnName string) string {
|
func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
|
ctx := context.Background()
|
||||||
|
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {
|
||||||
// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
|
// https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id
|
||||||
return "; SELECT SCOPE_IDENTITY()"
|
return "; SELECT SCOPE_IDENTITY()"
|
||||||
}
|
}
|
||||||
@ -226,7 +272,7 @@ func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableNam
|
|||||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
return splitStrings[0], splitStrings[1]
|
return splitStrings[0], splitStrings[1]
|
||||||
}
|
}
|
||||||
return dialect.CurrentDatabase(ctx), tableName
|
return dialect.CurrentDatabaseContext(ctx), tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
// JSON type to support easy handling of JSON data in character table fields
|
// JSON type to support easy handling of JSON data in character table fields
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import "testing"
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BasePost struct {
|
type BasePost struct {
|
||||||
Id int64
|
Id int64
|
||||||
@ -30,10 +27,9 @@ type EngadgetPost struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
||||||
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
||||||
if !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(ctx, engadgetPostScope.TableName(), "author_email") {
|
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
|
||||||
t.Errorf("should has prefix for embedded columns")
|
t.Errorf("should has prefix for embedded columns")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,7 +38,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hnScope := DB.NewScope(&HNPost{})
|
hnScope := DB.NewScope(&HNPost{})
|
||||||
if !dialect.HasColumn(ctx, hnScope.TableName(), "user_id") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_name") || !dialect.HasColumn(ctx, hnScope.TableName(), "user_email") {
|
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
|
||||||
t.Errorf("should has prefix for embedded columns")
|
t.Errorf("should has prefix for embedded columns")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,10 @@ import (
|
|||||||
|
|
||||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||||
type SQLCommon interface {
|
type SQLCommon interface {
|
||||||
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
Prepare(query string) (*sql.Stmt, error)
|
||||||
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
4
main.go
4
main.go
@ -817,7 +817,7 @@ func (s *DB) HasTableContext(ctx context.Context, value interface{}) bool {
|
|||||||
tableName = scope.TableName()
|
tableName = scope.TableName()
|
||||||
}
|
}
|
||||||
|
|
||||||
has := scope.Dialect().HasTable(ctx, tableName)
|
has := scope.Dialect().HasTableContext(ctx, tableName)
|
||||||
s.AddError(scope.db.Error)
|
s.AddError(scope.db.Error)
|
||||||
return has
|
return has
|
||||||
}
|
}
|
||||||
@ -991,7 +991,7 @@ func (s *DB) SetJoinTableHandlerContext(ctx context.Context, source interface{},
|
|||||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
handler.Setup(field.Relationship, many2many, source, destination)
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
field.Relationship.JoinTableHandler = handler
|
field.Relationship.JoinTableHandler = handler
|
||||||
if table := handler.Table(s); scope.Dialect().HasTable(ctx, table) {
|
if table := handler.Table(s); scope.Dialect().HasTableContext(ctx, table) {
|
||||||
s.Table(table).AutoMigrate(handler)
|
s.Table(table).AutoMigrate(handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
@ -303,14 +302,12 @@ func runMigration() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexes(t *testing.T) {
|
func TestIndexes(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scope := DB.NewScope(&Email{})
|
scope := DB.NewScope(&Email{})
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email should have index idx_email_email")
|
t.Errorf("Email should have index idx_email_email")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,7 +315,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email's index idx_email_email should be deleted")
|
t.Errorf("Email's index idx_email_email should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,7 +323,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,7 +331,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,7 +339,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -364,7 +361,7 @@ func TestIndexes(t *testing.T) {
|
|||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -384,8 +381,6 @@ type EmailWithIdx struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoMigration(t *testing.T) {
|
func TestAutoMigration(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
DB.AutoMigrate(&Address{})
|
DB.AutoMigrate(&Address{})
|
||||||
DB.DropTable(&EmailWithIdx{})
|
DB.DropTable(&EmailWithIdx{})
|
||||||
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
||||||
@ -396,11 +391,11 @@ func TestAutoMigration(t *testing.T) {
|
|||||||
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
||||||
|
|
||||||
scope := DB.NewScope(&EmailWithIdx{})
|
scope := DB.NewScope(&EmailWithIdx{})
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_email_agent") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_email_with_idxes_registered_at") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -467,8 +462,6 @@ type MultipleIndexes struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMultipleIndexes(t *testing.T) {
|
func TestMultipleIndexes(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
||||||
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
||||||
}
|
}
|
||||||
@ -481,23 +474,23 @@ func TestMultipleIndexes(t *testing.T) {
|
|||||||
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
||||||
|
|
||||||
scope := DB.NewScope(&MultipleIndexes{})
|
scope := DB.NewScope(&MultipleIndexes{})
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_name") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multipleindexes_user_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "uix_multiple_indexes_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multipleindexes_user_other") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(ctx, scope.TableName(), "idx_multiple_indexes_other") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -547,8 +540,6 @@ func TestModifyColumnType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexWithPrefixLength(t *testing.T) {
|
func TestIndexWithPrefixLength(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
||||||
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
||||||
}
|
}
|
||||||
@ -580,7 +571,7 @@ func TestIndexWithPrefixLength(t *testing.T) {
|
|||||||
if err := DB.CreateTable(table).Error; err != nil {
|
if err := DB.CreateTable(table).Error; err != nil {
|
||||||
t.Errorf("Failed to create %s table: %v", tableName, err)
|
t.Errorf("Failed to create %s table: %v", tableName, err)
|
||||||
}
|
}
|
||||||
if !scope.Dialect().HasIndex(ctx, tableName, "idx_index_with_prefixes_length") {
|
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
|
||||||
t.Errorf("Failed to create %s table index:", tableName)
|
t.Errorf("Failed to create %s table index:", tableName)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
@ -1685,7 +1684,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
|
|||||||
|
|
||||||
called := 0
|
called := 0
|
||||||
|
|
||||||
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(ctx context.Context, scope *gorm.Scope) {
|
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
|
||||||
called = called + 1
|
called = called + 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
16
scope.go
16
scope.go
@ -1129,7 +1129,7 @@ func (scope *Scope) createJoinTable(ctx context.Context, field *StructField) {
|
|||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
joinTable := joinTableHandler.Table(scope.db)
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
if !scope.Dialect().HasTable(ctx, joinTable) {
|
if !scope.Dialect().HasTableContext(ctx, joinTable) {
|
||||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||||
|
|
||||||
var sqlTypes, primaryKeys []string
|
var sqlTypes, primaryKeys []string
|
||||||
@ -1202,7 +1202,7 @@ func (scope *Scope) dropTable(ctx context.Context) *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) modifyColumn(ctx context.Context, column string, typ string) {
|
func (scope *Scope) modifyColumn(ctx context.Context, column string, typ string) {
|
||||||
scope.db.AddError(scope.Dialect().ModifyColumn(ctx, scope.QuotedTableName(), scope.Quote(column), typ))
|
scope.db.AddError(scope.Dialect().ModifyColumnContext(ctx, scope.QuotedTableName(), scope.Quote(column), typ))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropColumn(ctx context.Context, column string) {
|
func (scope *Scope) dropColumn(ctx context.Context, column string) {
|
||||||
@ -1210,7 +1210,7 @@ func (scope *Scope) dropColumn(ctx context.Context, column string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addIndex(ctx context.Context, unique bool, indexName string, column ...string) {
|
func (scope *Scope) addIndex(ctx context.Context, unique bool, indexName string, column ...string) {
|
||||||
if scope.Dialect().HasIndex(ctx, scope.TableName(), indexName) {
|
if scope.Dialect().HasIndexContext(ctx, scope.TableName(), indexName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1231,7 +1231,7 @@ func (scope *Scope) addForeignKey(ctx context.Context, field string, dest string
|
|||||||
// Compatible with old generated key
|
// Compatible with old generated key
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
|
|
||||||
if scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
|
if scope.Dialect().HasForeignKeyContext(ctx, scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||||
@ -1240,7 +1240,7 @@ func (scope *Scope) addForeignKey(ctx context.Context, field string, dest string
|
|||||||
|
|
||||||
func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest string) {
|
func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest string) {
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
if !scope.Dialect().HasForeignKey(ctx, scope.TableName(), keyName) {
|
if !scope.Dialect().HasForeignKeyContext(ctx, scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var mysql mysql
|
var mysql mysql
|
||||||
@ -1255,18 +1255,18 @@ func (scope *Scope) removeForeignKey(ctx context.Context, field string, dest str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeIndex(ctx context.Context, indexName string) {
|
func (scope *Scope) removeIndex(ctx context.Context, indexName string) {
|
||||||
scope.Dialect().RemoveIndex(ctx, scope.TableName(), indexName)
|
scope.Dialect().RemoveIndexContext(ctx, scope.TableName(), indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) autoMigrate(ctx context.Context) *Scope {
|
func (scope *Scope) autoMigrate(ctx context.Context) *Scope {
|
||||||
tableName := scope.TableName()
|
tableName := scope.TableName()
|
||||||
quotedTableName := scope.QuotedTableName()
|
quotedTableName := scope.QuotedTableName()
|
||||||
|
|
||||||
if !scope.Dialect().HasTable(ctx, tableName) {
|
if !scope.Dialect().HasTableContext(ctx, tableName) {
|
||||||
scope.createTable(ctx)
|
scope.createTable(ctx)
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if !scope.Dialect().HasColumn(ctx, tableName, field.DBName) {
|
if !scope.Dialect().HasColumnContext(ctx, tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec(ctx)
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec(ctx)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user