From 133176c3fb17787899318781be945ee82c685a71 Mon Sep 17 00:00:00 2001 From: Thomas Boerger Date: Mon, 21 Mar 2016 15:16:04 +0100 Subject: [PATCH] Added separate function to generate primary keys --- dialect.go | 2 ++ dialect_common.go | 4 ++++ dialects/mssql/mssql.go | 4 ++++ scope.go | 15 +++++++++++++-- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/dialect.go b/dialect.go index 6c9405da..1c563083 100644 --- a/dialect.go +++ b/dialect.go @@ -20,6 +20,8 @@ type Dialect interface { BindVar(i int) string // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string + // PrimaryKeys used to define the primary keys of a table + PrimaryKeys(keys []string) string // DataTypeOf return data's sql type DataTypeOf(field *StructField) string diff --git a/dialect_common.go b/dialect_common.go index f009271b..72a6dd65 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -32,6 +32,10 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (commonDialect) PrimaryKeys(keys []string) string { + return fmt.Sprintf("PRIMARY KEY (%v)", strings.Join(keys, ",")) +} + func (commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 5b994f9d..e2b822d8 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -42,6 +42,10 @@ func (mssql) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (mssql) PrimaryKeys(keys []string) string { + return fmt.Sprintf("PRIMARY KEY (%v)", strings.Join(keys, ",")) +} + func (mssql) DataTypeOf(field *gorm.StructField) string { var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field) diff --git a/scope.go b/scope.go index da5f7ff3..0ed70d94 100644 --- a/scope.go +++ b/scope.go @@ -80,6 +80,17 @@ func (scope *Scope) Quote(str string) string { return scope.Dialect().Quote(str) } +// PrimaryKeys used to define the primary keys of a table +func (scope *Scope) PrimaryKeys(keys []string) string { + res := scope.Dialect().PrimaryKeys(keys) + + if res != "" { + return fmt.Sprintf(", %v", res) + } + + return "" +} + // Err add error to Scope func (scope *Scope) Err(err error) error { if err != nil { @@ -1027,7 +1038,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.PrimaryKeys(primaryKeys), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -1059,7 +1070,7 @@ func (scope *Scope) createTable() *Scope { var primaryKeyStr string if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + primaryKeyStr = scope.PrimaryKeys(primaryKeys) } scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()