From fa9817b8057182b6f2a021919a331ef6ba2acd59 Mon Sep 17 00:00:00 2001 From: Glenn Nagel Date: Sun, 9 Oct 2016 13:59:29 -0400 Subject: [PATCH 1/2] Added redshift support & added Dialect.SupportLastInsertID interface method --- .gitignore | 1 + callback_create.go | 28 +++++++- dialect.go | 1 + dialect_common.go | 4 ++ dialect_mysql.go | 4 ++ dialect_redshift.go | 123 ++++++++++++++++++++++++++++++++++ dialect_sqlite3.go | 4 ++ dialects/mssql/mssql.go | 4 ++ dialects/redshift/redshift.go | 7 ++ 9 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 dialect_redshift.go create mode 100644 dialects/redshift/redshift.go diff --git a/.gitignore b/.gitignore index 01dc5ce0..71274210 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ documents _book +.DS_Store diff --git a/callback_create.go b/callback_create.go index 7a6dea94..1325b432 100644 --- a/callback_create.go +++ b/callback_create.go @@ -109,8 +109,32 @@ func createCallback(scope *Scope) { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { + // Set primary value to primary field + // * If `LastInsertID` isn't supported, then pull the last inserted row from the db + // * Else, use the `LastInsertID` field stored in `result` + if !scope.Dialect().SupportLastInsertID() && primaryField != nil { + // Build the WHERE query for each inserted column + where_filter := make([]string, len(columns)) + for i, column := range columns { + where_filter[i] = fmt.Sprintf("%v = $%v", column, i + 1) + } + // Store the query in scope.SQL + scope.Raw(fmt.Sprintf( + "SELECT \"%v\" FROM %v WHERE %v ORDER BY \"%v\" DESC LIMIT 1", + primaryField.Name, + scope.QuotedTableName(), + strings.Join(where_filter, " AND "), + primaryField.Name, + )) + // Execute the query and store the results in primaryField & scope.Err + var id int64 + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(&id); err == nil { + primaryField.Set(id) + scope.Err(nil) + } else { + scope.Err(primaryField.Set(err.Error())) + } + } else if primaryField != nil && primaryField.IsBlank { if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { scope.Err(primaryField.Set(primaryValue)) } diff --git a/dialect.go b/dialect.go index facde0d0..2ed65cb8 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,7 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + SupportLastInsertID() bool // BuildForeignKeyName returns a foreign key name for the given table, field and reference BuildForeignKeyName(tableName, field, dest string) string diff --git a/dialect_common.go b/dialect_common.go index 5b5682c5..e3752a97 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -145,6 +145,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +func (commonDialect) SupportLastInsertID() bool { + return true +} + func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") diff --git a/dialect_mysql.go b/dialect_mysql.go index 11b894b3..51855258 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -144,3 +144,7 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { return fmt.Sprintf("%s%x", string(destRunes), bs) } + +func (mysql) SupportLastInsertID() bool { + return true +} diff --git a/dialect_redshift.go b/dialect_redshift.go new file mode 100644 index 00000000..9fa096d5 --- /dev/null +++ b/dialect_redshift.go @@ -0,0 +1,123 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type redshift struct { + commonDialect +} + +func init() { + RegisterDialect("redshift", &redshift{}) +} + +func (redshift) GetName() string { + return "redshift" +} + +func (redshift) BindVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (redshift) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Float32: + sqlType = "float4" + case reflect.Float64: + sqlType = "float8" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if field.IsPrimaryKey { + delete(field.TagSettings, "AUTO_INCREMENT") + delete(field.TagSettings, "IDENTITY(1, 1)") + sqlType = "integer IDENTITY(1, 1)" + } else if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + delete(field.TagSettings, "AUTO_INCREMENT") + delete(field.TagSettings, "IDENTITY(1, 1)") + sqlType = "integer IDENTITY(1, 1)" + } else if _, ok := field.TagSettings["IDENTITY(1, 1)"]; ok { + delete(field.TagSettings, "AUTO_INCREMENT") + delete(field.TagSettings, "IDENTITY(1, 1)") + sqlType = "integer IDENTITY(1, 1)" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uintptr, reflect.Uint64: + if _, ok := field.TagSettings["IDENTITY(1, 1)"]; ok || field.IsPrimaryKey { + field.TagSettings["IDENTITY(1, 1)"] = "IDENTITY(1, 1)" + sqlType = "bigint" + } else { + sqlType = "bigint" + } + case reflect.String: + if _, ok := field.TagSettings["SIZE"]; !ok { + size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different + } + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "timestamp with time zone" + } + default: + sqlType = "" + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for redshift", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s redshift) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + return count > 0 +} + +func (s redshift) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s redshift) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + return count > 0 +} + +func (s redshift) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + return count > 0 +} + +func (s redshift) CurrentDatabase() (name string) { + s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) + return +} + +func (s redshift) LastInsertIDReturningSuffix(tableName, key string) string { + return "" +} + +func (redshift) SupportLastInsertID() bool { + return false +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 2abcefa5..d812c0c9 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -106,3 +106,7 @@ func (s sqlite3) CurrentDatabase() (name string) { } return } + +func (sqlite3) SupportLastInsertID() bool { + return true +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a7bca6b8..7aafc7d5 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -149,3 +149,7 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func (mssql) SupportLastInsertID() bool { + return true +} diff --git a/dialects/redshift/redshift.go b/dialects/redshift/redshift.go new file mode 100644 index 00000000..09727bbb --- /dev/null +++ b/dialects/redshift/redshift.go @@ -0,0 +1,7 @@ +package redshift + +import ( + _ "database/sql" + _ "database/sql/driver" + _ "github.com/lib/pq" +) From ae1ad864b667cfb46145e6aa7d0fafe3c91fe049 Mon Sep 17 00:00:00 2001 From: Glenn Nagel Date: Sat, 15 Oct 2016 19:44:29 -0400 Subject: [PATCH 2/2] Updated redshift dialect to key off of `gorm:"auto_increment"` key --- dialect_redshift.go | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/dialect_redshift.go b/dialect_redshift.go index 9fa096d5..9e2857ae 100644 --- a/dialect_redshift.go +++ b/dialect_redshift.go @@ -35,25 +35,14 @@ func (redshift) DataTypeOf(field *StructField) string { case reflect.Float64: sqlType = "float8" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - if field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") - delete(field.TagSettings, "IDENTITY(1, 1)") - sqlType = "integer IDENTITY(1, 1)" - } else if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - delete(field.TagSettings, "AUTO_INCREMENT") - delete(field.TagSettings, "IDENTITY(1, 1)") - sqlType = "integer IDENTITY(1, 1)" - } else if _, ok := field.TagSettings["IDENTITY(1, 1)"]; ok { - delete(field.TagSettings, "AUTO_INCREMENT") - delete(field.TagSettings, "IDENTITY(1, 1)") - sqlType = "integer IDENTITY(1, 1)" + if _, ok := field.TagSettings["AUTO_INCREMENT"]; field.IsPrimaryKey || ok { + sqlType = "integer DISTKEY IDENTITY(1, 1)" } else { sqlType = "integer" } case reflect.Int64, reflect.Uintptr, reflect.Uint64: - if _, ok := field.TagSettings["IDENTITY(1, 1)"]; ok || field.IsPrimaryKey { - field.TagSettings["IDENTITY(1, 1)"] = "IDENTITY(1, 1)" - sqlType = "bigint" + if _, ok := field.TagSettings["AUTO_INCREMENT"]; field.IsPrimaryKey || ok { + sqlType = "bigint DISTKEY IDENTITY(1, 1)" } else { sqlType = "bigint" }