diff --git a/.gitignore b/.gitignore index 01dc5ce0..71274210 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ documents _book +.DS_Store diff --git a/callback_create.go b/callback_create.go index f0709880..18b13c0c 100644 --- a/callback_create.go +++ b/callback_create.go @@ -109,8 +109,32 @@ func createCallback(scope *Scope) { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { + // Set primary value to primary field + // * If `LastInsertID` isn't supported, then pull the last inserted row from the db + // * Else, use the `LastInsertID` field stored in `result` + if !scope.Dialect().SupportLastInsertID() && primaryField != nil { + // Build the WHERE query for each inserted column + where_filter := make([]string, len(columns)) + for i, column := range columns { + where_filter[i] = fmt.Sprintf("%v = $%v", column, i + 1) + } + // Store the query in scope.SQL + scope.Raw(fmt.Sprintf( + "SELECT \"%v\" FROM %v WHERE %v ORDER BY \"%v\" DESC LIMIT 1", + primaryField.Name, + scope.QuotedTableName(), + strings.Join(where_filter, " AND "), + primaryField.Name, + )) + // Execute the query and store the results in primaryField & scope.Err + var id int64 + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(&id); err == nil { + primaryField.Set(id) + scope.Err(nil) + } else { + scope.Err(primaryField.Set(err.Error())) + } + } else if primaryField != nil && primaryField.IsBlank { if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { scope.Err(primaryField.Set(primaryValue)) } diff --git a/dialect.go b/dialect.go index e879588b..e4b5ab02 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,7 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + SupportLastInsertID() bool // BuildForeignKeyName returns a foreign key name for the given table, field and reference BuildForeignKeyName(tableName, field, dest string) string diff --git a/dialect_common.go b/dialect_common.go index a99627f2..f8619462 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -144,6 +144,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +func (commonDialect) SupportLastInsertID() bool { + return true +} + func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") diff --git a/dialect_mysql.go b/dialect_mysql.go index 271670b8..c5a5fdb8 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -158,3 +158,7 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { return fmt.Sprintf("%s%x", string(destRunes), bs) } + +func (mysql) SupportLastInsertID() bool { + return true +} diff --git a/dialect_redshift.go b/dialect_redshift.go new file mode 100644 index 00000000..9e2857ae --- /dev/null +++ b/dialect_redshift.go @@ -0,0 +1,112 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type redshift struct { + commonDialect +} + +func init() { + RegisterDialect("redshift", &redshift{}) +} + +func (redshift) GetName() string { + return "redshift" +} + +func (redshift) BindVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (redshift) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Float32: + sqlType = "float4" + case reflect.Float64: + sqlType = "float8" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; field.IsPrimaryKey || ok { + sqlType = "integer DISTKEY IDENTITY(1, 1)" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uintptr, reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; field.IsPrimaryKey || ok { + sqlType = "bigint DISTKEY IDENTITY(1, 1)" + } else { + sqlType = "bigint" + } + case reflect.String: + if _, ok := field.TagSettings["SIZE"]; !ok { + size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different + } + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "timestamp with time zone" + } + default: + sqlType = "" + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for redshift", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s redshift) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + return count > 0 +} + +func (s redshift) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s redshift) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + return count > 0 +} + +func (s redshift) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + return count > 0 +} + +func (s redshift) CurrentDatabase() (name string) { + s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) + return +} + +func (s redshift) LastInsertIDReturningSuffix(tableName, key string) string { + return "" +} + +func (redshift) SupportLastInsertID() bool { + return false +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..66fcb086 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -105,3 +105,7 @@ func (s sqlite3) CurrentDatabase() (name string) { } return } + +func (sqlite3) SupportLastInsertID() bool { + return true +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 46b5ec9c..233eec8a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -168,3 +168,7 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func (mssql) SupportLastInsertID() bool { + return true +} diff --git a/dialects/redshift/redshift.go b/dialects/redshift/redshift.go new file mode 100644 index 00000000..09727bbb --- /dev/null +++ b/dialects/redshift/redshift.go @@ -0,0 +1,7 @@ +package redshift + +import ( + _ "database/sql" + _ "database/sql/driver" + _ "github.com/lib/pq" +)