Don't execute SET IDENTITY_INSERT if dialect is not mssql

This commit is contained in:
Jinzhu 2016-03-05 21:24:54 +08:00
parent 2522f03c1f
commit b6a2710a15
7 changed files with 26 additions and 2 deletions

View File

@ -10,6 +10,9 @@ import (
// Dialect interface contains behaviors that differ across SQL database // Dialect interface contains behaviors that differ across SQL database
type Dialect interface { type Dialect interface {
// GetName get dialect's name
GetName() string
// SetDB set db for dialect // SetDB set db for dialect
SetDB(db *sql.DB) SetDB(db *sql.DB)

View File

@ -16,6 +16,10 @@ func init() {
RegisterDialect("common", &commonDialect{}) RegisterDialect("common", &commonDialect{})
} }
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db *sql.DB) { func (s *commonDialect) SetDB(db *sql.DB) {
s.db = db s.db = db
} }

View File

@ -15,6 +15,10 @@ func init() {
RegisterDialect("mssql", &mssql{}) RegisterDialect("mssql", &mssql{})
} }
func (mssql) GetName() string {
return "mssql"
}
func (mssql) DataTypeOf(field *StructField) string { func (mssql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)

View File

@ -15,6 +15,10 @@ func init() {
RegisterDialect("mysql", &mysql{}) RegisterDialect("mysql", &mysql{})
} }
func (mysql) GetName() string {
return "mysql"
}
func (mysql) Quote(key string) string { func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }

View File

@ -15,6 +15,10 @@ func init() {
RegisterDialect("postgres", &postgres{}) RegisterDialect("postgres", &postgres{})
} }
func (postgres) GetName() string {
return "postgres"
}
func (postgres) BindVar(i int) string { func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i) return fmt.Sprintf("$%v", i)
} }

View File

@ -16,6 +16,10 @@ func init() {
RegisterDialect("sqlite3", &sqlite3{}) RegisterDialect("sqlite3", &sqlite3{})
} }
func (sqlite3) GetName() string {
return "sqlite3"
}
// Get Data Type for Sqlite Dialect // Get Data Type for Sqlite Dialect
func (sqlite3) DataTypeOf(field *StructField) string { func (sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)

View File

@ -8,10 +8,11 @@ import (
) )
func setIdentityInsert(scope *gorm.Scope) { func setIdentityInsert(scope *gorm.Scope) {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) if scope.Dialect().GetName() == "mssql" {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
}
} }
func init() { func init() {
gorm.DefaultCallback.Update().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
} }