From b6a2710a15ee29bbacea10e17c67e63f4fa78608 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 5 Mar 2016 21:24:54 +0800 Subject: [PATCH] Don't execute SET IDENTITY_INSERT if dialect is not mssql --- dialect.go | 3 +++ dialect_common.go | 4 ++++ dialect_mssql.go | 4 ++++ dialect_mysql.go | 4 ++++ dialect_postgres.go | 4 ++++ dialect_sqlite3.go | 4 ++++ dialects/mssql/mssql.go | 5 +++-- 7 files changed, 26 insertions(+), 2 deletions(-) diff --git a/dialect.go b/dialect.go index cce68789..897e2570 100644 --- a/dialect.go +++ b/dialect.go @@ -10,6 +10,9 @@ import ( // Dialect interface contains behaviors that differ across SQL database type Dialect interface { + // GetName get dialect's name + GetName() string + // SetDB set db for dialect SetDB(db *sql.DB) diff --git a/dialect_common.go b/dialect_common.go index a1c8ff5c..7afc5c14 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -16,6 +16,10 @@ func init() { RegisterDialect("common", &commonDialect{}) } +func (commonDialect) GetName() string { + return "common" +} + func (s *commonDialect) SetDB(db *sql.DB) { s.db = db } diff --git a/dialect_mssql.go b/dialect_mssql.go index 2ecc27cc..7a59bb30 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -15,6 +15,10 @@ func init() { RegisterDialect("mssql", &mssql{}) } +func (mssql) GetName() string { + return "mssql" +} + func (mssql) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/dialect_mysql.go b/dialect_mysql.go index 9e530a9a..d98c33a3 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -15,6 +15,10 @@ func init() { RegisterDialect("mysql", &mysql{}) } +func (mysql) GetName() string { + return "mysql" +} + func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } diff --git a/dialect_postgres.go b/dialect_postgres.go index 3d188a65..96627d92 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -15,6 +15,10 @@ func init() { RegisterDialect("postgres", &postgres{}) } +func (postgres) GetName() string { + return "postgres" +} + func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 41e45517..56c847b5 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -16,6 +16,10 @@ func init() { RegisterDialect("sqlite3", &sqlite3{}) } +func (sqlite3) GetName() string { + return "sqlite3" +} + // Get Data Type for Sqlite Dialect func (sqlite3) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 03ad60fb..d628ef7c 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -8,10 +8,11 @@ import ( ) func setIdentityInsert(scope *gorm.Scope) { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) + if scope.Dialect().GetName() == "mssql" { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) + } } func init() { - gorm.DefaultCallback.Update().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) }