From 0b21fdde50f974580c6c19faeae898bab3cf909a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Martins?= Date: Thu, 14 May 2015 18:48:29 +0100 Subject: [PATCH] Implementing crate dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: André Martins --- callback_create.go | 13 ++++++++-- common_dialect.go | 4 +++ crate.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++ dialect.go | 3 +++ scope_private.go | 14 +++++++---- 5 files changed, 89 insertions(+), 7 deletions(-) create mode 100644 crate.go diff --git a/callback_create.go b/callback_create.go index b21df08b..81caba39 100644 --- a/callback_create.go +++ b/callback_create.go @@ -80,8 +80,17 @@ func Create(scope *Scope) { if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil { scope.db.RowsAffected, _ = results.RowsAffected() } - } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { - scope.db.RowsAffected = 1 + } else { + //Using the Scan on Crate we get a sql: expected 0 destination arguments in Scan, not 1 + if _, ok := scope.Dialect().(*crate); ok { + if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan()) == nil { + scope.db.RowsAffected = 1 + } + } else { + if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { + scope.db.RowsAffected = 1 + } + } } } } diff --git a/common_dialect.go b/common_dialect.go index 281df8a7..26a1085f 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -99,3 +99,7 @@ func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) func (commonDialect) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } + +func (commonDialect) QueryTerminator() string { + return ";" +} diff --git a/crate.go b/crate.go new file mode 100644 index 00000000..118a5200 --- /dev/null +++ b/crate.go @@ -0,0 +1,62 @@ +package gorm + +import ( + "fmt" + "reflect" + "time" +) + +type crate struct { + commonDialect +} + +func (crate) SupportLastInsertId() bool { + return false +} + +func (crate) SqlTag(value reflect.Value, size int, autoIncrease bool) string { + //Crate doesn't support autoIncrease + switch value.Kind() { + case reflect.Bool: + return "boolean" + case reflect.Int8, reflect.Int16, reflect.Uint8, reflect.Uint16: + return "short" + case reflect.Int, reflect.Int32, reflect.Uint, reflect.Uint32, reflect.Uintptr: + return "integer" + case reflect.Int64, reflect.Uint64: + return "double" + case reflect.Float32, reflect.Float64: + return "float" + case reflect.String: + return "string" + case reflect.Struct: + if _, ok := value.Interface().(time.Time); ok { + return "timestamp NULL" + } + default: + if _, ok := value.Interface().([]byte); ok { + return "object" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) +} + +func (c crate) HasTable(scope *Scope, tableName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND schema_name = 'doc'", tableName).Row().Scan(&count) + return count > 0 +} + +func (c crate) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE schema_name = 'doc' AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + return count > 0 +} + +func (crate) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) +} + +func (crate) QueryTerminator() string { + return "" +} diff --git a/dialect.go b/dialect.go index f3221075..fb1563e0 100644 --- a/dialect.go +++ b/dialect.go @@ -17,6 +17,7 @@ type Dialect interface { HasColumn(scope *Scope, tableName string, columnName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool RemoveIndex(scope *Scope, indexName string) + QueryTerminator() string } func NewDialect(driver string) Dialect { @@ -32,6 +33,8 @@ func NewDialect(driver string) Dialect { d = &sqlite3{} case "mssql": d = &mssql{} + case "crate": + d = &crate{} default: fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) d = &commonDialect{} diff --git a/scope_private.go b/scope_private.go index 4ecefe3a..77d0694f 100644 --- a/scope_private.go +++ b/scope_private.go @@ -479,12 +479,16 @@ func (scope *Scope) createTable() *Scope { } scope.createJoinTable(field) } - var primaryKeyStr string if len(primaryKeys) > 0 { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", scope.Quote(strings.Join(primaryKeys, "\", \""))) + } + //Workaround for a crate's bug + if _, ok := scope.Dialect().(*crate); ok { + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) with (number_of_replicas=0)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() + } else { + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() return scope } @@ -529,7 +533,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { var table = scope.TableName() var keyName = fmt.Sprintf("%s_%s_foreign", table, field) - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` + var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s` + scope.Dialect().QueryTerminator() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), keyName, field, dest, onDelete, onUpdate)).Exec() } @@ -548,7 +552,7 @@ func (scope *Scope) autoMigrate() *Scope { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if field.IsNormal { sqlTag := scope.generateSqlTag(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v%v", quotedTableName, field.DBName, sqlTag, scope.Dialect().QueryTerminator())).Exec() } } scope.createJoinTable(field)