Merge 0b21fdde50f974580c6c19faeae898bab3cf909a into 0b8c9f29a9e28708c17e0ef1aed71487e8bd356c

This commit is contained in:
André Martins 2015-06-04 20:55:15 +00:00
commit 1f7073945c
5 changed files with 89 additions and 7 deletions

View File

@ -80,8 +80,17 @@ func Create(scope *Scope) {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
scope.db.RowsAffected, _ = results.RowsAffected()
}
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
} else {
//Using the Scan on Crate we get a sql: expected 0 destination arguments in Scan, not 1
if _, ok := scope.Dialect().(*crate); ok {
if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan()) == nil {
scope.db.RowsAffected = 1
}
} else {
if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
}
}
}
}
}

View File

@ -99,3 +99,7 @@ func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string)
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}
func (commonDialect) QueryTerminator() string {
return ";"
}

62
crate.go Normal file
View File

@ -0,0 +1,62 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type crate struct {
commonDialect
}
func (crate) SupportLastInsertId() bool {
return false
}
func (crate) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
//Crate doesn't support autoIncrease
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int8, reflect.Int16, reflect.Uint8, reflect.Uint16:
return "short"
case reflect.Int, reflect.Int32, reflect.Uint, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "double"
case reflect.Float32, reflect.Float64:
return "float"
case reflect.String:
return "string"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "timestamp NULL"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "object"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (c crate) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND schema_name = 'doc'", tableName).Row().Scan(&count)
return count > 0
}
func (c crate) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE schema_name = 'doc' AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
func (crate) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (crate) QueryTerminator() string {
return ""
}

View File

@ -17,6 +17,7 @@ type Dialect interface {
HasColumn(scope *Scope, tableName string, columnName string) bool
HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string)
QueryTerminator() string
}
func NewDialect(driver string) Dialect {
@ -32,6 +33,8 @@ func NewDialect(driver string) Dialect {
d = &sqlite3{}
case "mssql":
d = &mssql{}
case "crate":
d = &crate{}
default:
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
d = &commonDialect{}

View File

@ -479,12 +479,16 @@ func (scope *Scope) createTable() *Scope {
}
scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", scope.Quote(strings.Join(primaryKeys, "\", \"")))
}
//Workaround for a crate's bug
if _, ok := scope.Dialect().(*crate); ok {
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) with (number_of_replicas=0)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
} else {
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
return scope
}
@ -529,7 +533,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s` + scope.Dialect().QueryTerminator()
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), keyName, field, dest, onDelete, onUpdate)).Exec()
}
@ -548,7 +552,7 @@ func (scope *Scope) autoMigrate() *Scope {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v%v", quotedTableName, field.DBName, sqlTag, scope.Dialect().QueryTerminator())).Exec()
}
}
scope.createJoinTable(field)