Added redshift support & added Dialect.SupportLastInsertID interface method

This commit is contained in:
Glenn Nagel 2016-10-09 13:59:29 -04:00
parent 39165d4980
commit fa9817b805
9 changed files with 174 additions and 2 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
documents
_book
.DS_Store

View File

@ -109,8 +109,32 @@ func createCallback(scope *Scope) {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
// Set primary value to primary field
// * If `LastInsertID` isn't supported, then pull the last inserted row from the db
// * Else, use the `LastInsertID` field stored in `result`
if !scope.Dialect().SupportLastInsertID() && primaryField != nil {
// Build the WHERE query for each inserted column
where_filter := make([]string, len(columns))
for i, column := range columns {
where_filter[i] = fmt.Sprintf("%v = $%v", column, i + 1)
}
// Store the query in scope.SQL
scope.Raw(fmt.Sprintf(
"SELECT \"%v\" FROM %v WHERE %v ORDER BY \"%v\" DESC LIMIT 1",
primaryField.Name,
scope.QuotedTableName(),
strings.Join(where_filter, " AND "),
primaryField.Name,
))
// Execute the query and store the results in primaryField & scope.Err
var id int64
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(&id); err == nil {
primaryField.Set(id)
scope.Err(nil)
} else {
scope.Err(primaryField.Set(err.Error()))
}
} else if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}

View File

@ -40,6 +40,7 @@ type Dialect interface {
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
SupportLastInsertID() bool
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string

View File

@ -145,6 +145,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s
return ""
}
func (commonDialect) SupportLastInsertID() bool {
return true
}
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")

View File

@ -144,3 +144,7 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
return fmt.Sprintf("%s%x", string(destRunes), bs)
}
func (mysql) SupportLastInsertID() bool {
return true
}

123
dialect_redshift.go Normal file
View File

@ -0,0 +1,123 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type redshift struct {
commonDialect
}
func init() {
RegisterDialect("redshift", &redshift{})
}
func (redshift) GetName() string {
return "redshift"
}
func (redshift) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (redshift) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Float32:
sqlType = "float4"
case reflect.Float64:
sqlType = "float8"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if field.IsPrimaryKey {
delete(field.TagSettings, "AUTO_INCREMENT")
delete(field.TagSettings, "IDENTITY(1, 1)")
sqlType = "integer IDENTITY(1, 1)"
} else if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
delete(field.TagSettings, "AUTO_INCREMENT")
delete(field.TagSettings, "IDENTITY(1, 1)")
sqlType = "integer IDENTITY(1, 1)"
} else if _, ok := field.TagSettings["IDENTITY(1, 1)"]; ok {
delete(field.TagSettings, "AUTO_INCREMENT")
delete(field.TagSettings, "IDENTITY(1, 1)")
sqlType = "integer IDENTITY(1, 1)"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uintptr, reflect.Uint64:
if _, ok := field.TagSettings["IDENTITY(1, 1)"]; ok || field.IsPrimaryKey {
field.TagSettings["IDENTITY(1, 1)"] = "IDENTITY(1, 1)"
sqlType = "bigint"
} else {
sqlType = "bigint"
}
case reflect.String:
if _, ok := field.TagSettings["SIZE"]; !ok {
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
}
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "timestamp with time zone"
}
default:
sqlType = ""
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for redshift", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s redshift) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
return count > 0
}
func (s redshift) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s redshift) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
return count > 0
}
func (s redshift) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
return count > 0
}
func (s redshift) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (s redshift) LastInsertIDReturningSuffix(tableName, key string) string {
return ""
}
func (redshift) SupportLastInsertID() bool {
return false
}

View File

@ -106,3 +106,7 @@ func (s sqlite3) CurrentDatabase() (name string) {
}
return
}
func (sqlite3) SupportLastInsertID() bool {
return true
}

View File

@ -149,3 +149,7 @@ func (mssql) SelectFromDummyTable() string {
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (mssql) SupportLastInsertID() bool {
return true
}

View File

@ -0,0 +1,7 @@
package redshift
import (
_ "database/sql"
_ "database/sql/driver"
_ "github.com/lib/pq"
)