From 1082cd1fe044471648312326d928cbf1538414c6 Mon Sep 17 00:00:00 2001 From: Thomas Boerger Date: Mon, 21 Mar 2016 13:43:41 +0100 Subject: [PATCH 1/5] Integrated ql dialect --- dialect_ql.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++ dialects/ql/ql.go | 3 ++ 2 files changed, 92 insertions(+) create mode 100644 dialect_ql.go create mode 100644 dialects/ql/ql.go diff --git a/dialect_ql.go b/dialect_ql.go new file mode 100644 index 00000000..fc9a4855 --- /dev/null +++ b/dialect_ql.go @@ -0,0 +1,89 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type ql struct { + commonDialect +} + +func init() { + RegisterDialect("ql", &ql{}) + RegisterDialect("ql-mem", &ql{}) +} + +func (ql) GetName() string { + return "ql" +} + +// Get Data Type for Sqlite Dialect +func (ql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uint64: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "real" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetime" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + sqlType = "blob" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for ql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s ql) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM __Table WHERE Name = ?", tableName).Scan(&count) + return count > 0 +} + +func (s ql) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM __Column WHERE TableName = ? AND Name = ?", tableName, columnName).Scan(&count) + return count > 0 +} + +func (s ql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM __Index WHERE TableName = ? AND Name = ?", tableName, indexName).Scan(&count) + return count > 0 +} diff --git a/dialects/ql/ql.go b/dialects/ql/ql.go new file mode 100644 index 00000000..3db1b986 --- /dev/null +++ b/dialects/ql/ql.go @@ -0,0 +1,3 @@ +package ql + +import _ "github.com/cznic/ql/driver" From 97b28cbfeca3beab232b0b8e58e17c51dcca558b Mon Sep 17 00:00:00 2001 From: Thomas Boerger Date: Mon, 21 Mar 2016 13:58:13 +0100 Subject: [PATCH 2/5] Prepare the test suite --- main_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/main_test.go b/main_test.go index e9bffd0f..c8b6ad77 100644 --- a/main_test.go +++ b/main_test.go @@ -16,6 +16,7 @@ import ( _ "github.com/jinzhu/gorm/dialects/mysql" "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" + _ "github.com/jinzhu/gorm/dialects/ql" "github.com/jinzhu/now" ) @@ -59,6 +60,9 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mssql": fmt.Println("testing mssql...") db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") + case "ql": + fmt.Println("testing ql...") + db, err = gorm.Open("ql", "/tmp/gorm.ql") default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", "/tmp/gorm.db") From e84e58d3c02ed64de370b122707f2875a32bf1dd Mon Sep 17 00:00:00 2001 From: Thomas Boerger Date: Mon, 21 Mar 2016 15:15:22 +0100 Subject: [PATCH 3/5] Added ql to the testing dialects --- test_all.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_all.sh b/test_all.sh index 6c5593b3..1e2a4715 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "mysql" "sqlite" "ql") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test From 133176c3fb17787899318781be945ee82c685a71 Mon Sep 17 00:00:00 2001 From: Thomas Boerger Date: Mon, 21 Mar 2016 15:16:04 +0100 Subject: [PATCH 4/5] Added separate function to generate primary keys --- dialect.go | 2 ++ dialect_common.go | 4 ++++ dialects/mssql/mssql.go | 4 ++++ scope.go | 15 +++++++++++++-- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/dialect.go b/dialect.go index 6c9405da..1c563083 100644 --- a/dialect.go +++ b/dialect.go @@ -20,6 +20,8 @@ type Dialect interface { BindVar(i int) string // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string + // PrimaryKeys used to define the primary keys of a table + PrimaryKeys(keys []string) string // DataTypeOf return data's sql type DataTypeOf(field *StructField) string diff --git a/dialect_common.go b/dialect_common.go index f009271b..72a6dd65 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -32,6 +32,10 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (commonDialect) PrimaryKeys(keys []string) string { + return fmt.Sprintf("PRIMARY KEY (%v)", strings.Join(keys, ",")) +} + func (commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 5b994f9d..e2b822d8 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -42,6 +42,10 @@ func (mssql) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (mssql) PrimaryKeys(keys []string) string { + return fmt.Sprintf("PRIMARY KEY (%v)", strings.Join(keys, ",")) +} + func (mssql) DataTypeOf(field *gorm.StructField) string { var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field) diff --git a/scope.go b/scope.go index da5f7ff3..0ed70d94 100644 --- a/scope.go +++ b/scope.go @@ -80,6 +80,17 @@ func (scope *Scope) Quote(str string) string { return scope.Dialect().Quote(str) } +// PrimaryKeys used to define the primary keys of a table +func (scope *Scope) PrimaryKeys(keys []string) string { + res := scope.Dialect().PrimaryKeys(keys) + + if res != "" { + return fmt.Sprintf(", %v", res) + } + + return "" +} + // Err add error to Scope func (scope *Scope) Err(err error) error { if err != nil { @@ -1027,7 +1038,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.PrimaryKeys(primaryKeys), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -1059,7 +1070,7 @@ func (scope *Scope) createTable() *Scope { var primaryKeyStr string if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + primaryKeyStr = scope.PrimaryKeys(primaryKeys) } scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() From 1d806d662c474caa72d86b21c9d1ba4166e34657 Mon Sep 17 00:00:00 2001 From: Thomas Boerger Date: Mon, 21 Mar 2016 15:16:39 +0100 Subject: [PATCH 5/5] Fixed ql types and integrated primary keys function --- dialect_ql.go | 63 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/dialect_ql.go b/dialect_ql.go index fc9a4855..886c62a3 100644 --- a/dialect_ql.go +++ b/dialect_ql.go @@ -20,37 +20,53 @@ func (ql) GetName() string { return "ql" } +func (ql) Quote(key string) string { + return fmt.Sprintf(`%s`, key) +} + +func (ql) PrimaryKeys(keys []string) string { + return "" +} + // Get Data Type for Sqlite Dialect func (ql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + var dataValue, sqlType, _, additionalType = ParseFieldStructForDialect(field) if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" + case reflect.Int: + sqlType = "int" + case reflect.Int8: + sqlType = "int8" + case reflect.Int16: + sqlType = "int16" + case reflect.Int32: + sqlType = "int32" + case reflect.Int64: + sqlType = "int64" + case reflect.Uint: + sqlType = "uint" + case reflect.Uint8: + sqlType = "uint8" + case reflect.Uint16: + sqlType = "uint16" + case reflect.Uint32: + sqlType = "uint32" + case reflect.Uint64: + sqlType = "uint64" + case reflect.Uintptr: + sqlType = "uint" + case reflect.Float32: + sqlType = "float32" + case reflect.Float64: + sqlType = "float64" case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } + sqlType = "string" case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" + sqlType = "time" } default: if _, ok := dataValue.Interface().([]byte); ok { @@ -87,3 +103,8 @@ func (s ql) HasIndex(tableName string, indexName string) bool { s.db.QueryRow("SELECT COUNT(*) FROM __Index WHERE TableName = ? AND Name = ?", tableName, indexName).Scan(&count) return count > 0 } + +func (s ql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("BEGIN TRANSACTION; DROP INDEX %v; COMMIT;", indexName)) + return err +}