diff --git a/dialect.go b/dialect.go index facde0d0..28ee8f82 100644 --- a/dialect.go +++ b/dialect.go @@ -20,6 +20,8 @@ type Dialect interface { BindVar(i int) string // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string + // PrimaryKeys used to define the primary keys of a table + PrimaryKeys(keys []string) string // DataTypeOf return data's sql type DataTypeOf(field *StructField) string diff --git a/dialect_common.go b/dialect_common.go index 5b5682c5..ac43c1b7 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,6 +39,10 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (commonDialect) PrimaryKeys(keys []string) string { + return fmt.Sprintf("PRIMARY KEY (%v)", strings.Join(keys, ",")) +} + func (commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/dialect_ql.go b/dialect_ql.go new file mode 100644 index 00000000..886c62a3 --- /dev/null +++ b/dialect_ql.go @@ -0,0 +1,110 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type ql struct { + commonDialect +} + +func init() { + RegisterDialect("ql", &ql{}) + RegisterDialect("ql-mem", &ql{}) +} + +func (ql) GetName() string { + return "ql" +} + +func (ql) Quote(key string) string { + return fmt.Sprintf(`%s`, key) +} + +func (ql) PrimaryKeys(keys []string) string { + return "" +} + +// Get Data Type for Sqlite Dialect +func (ql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, _, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bool" + case reflect.Int: + sqlType = "int" + case reflect.Int8: + sqlType = "int8" + case reflect.Int16: + sqlType = "int16" + case reflect.Int32: + sqlType = "int32" + case reflect.Int64: + sqlType = "int64" + case reflect.Uint: + sqlType = "uint" + case reflect.Uint8: + sqlType = "uint8" + case reflect.Uint16: + sqlType = "uint16" + case reflect.Uint32: + sqlType = "uint32" + case reflect.Uint64: + sqlType = "uint64" + case reflect.Uintptr: + sqlType = "uint" + case reflect.Float32: + sqlType = "float32" + case reflect.Float64: + sqlType = "float64" + case reflect.String: + sqlType = "string" + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "time" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + sqlType = "blob" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for ql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s ql) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM __Table WHERE Name = ?", tableName).Scan(&count) + return count > 0 +} + +func (s ql) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM __Column WHERE TableName = ? AND Name = ?", tableName, columnName).Scan(&count) + return count > 0 +} + +func (s ql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM __Index WHERE TableName = ? AND Name = ?", tableName, indexName).Scan(&count) + return count > 0 +} + +func (s ql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("BEGIN TRANSACTION; DROP INDEX %v; COMMIT;", indexName)) + return err +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a7bca6b8..bbcc70d5 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -44,6 +44,10 @@ func (mssql) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (mssql) PrimaryKeys(keys []string) string { + return fmt.Sprintf("PRIMARY KEY (%v)", strings.Join(keys, ",")) +} + func (mssql) DataTypeOf(field *gorm.StructField) string { var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field) diff --git a/dialects/ql/ql.go b/dialects/ql/ql.go new file mode 100644 index 00000000..3db1b986 --- /dev/null +++ b/dialects/ql/ql.go @@ -0,0 +1,3 @@ +package ql + +import _ "github.com/cznic/ql/driver" diff --git a/main_test.go b/main_test.go index 9869a7ad..20594460 100644 --- a/main_test.go +++ b/main_test.go @@ -17,6 +17,7 @@ import ( _ "github.com/jinzhu/gorm/dialects/mysql" "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" + _ "github.com/jinzhu/gorm/dialects/ql" "github.com/jinzhu/now" ) @@ -60,6 +61,9 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mssql": fmt.Println("testing mssql...") db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") + case "ql": + fmt.Println("testing ql...") + db, err = gorm.Open("ql", "/tmp/gorm.ql") default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) diff --git a/scope.go b/scope.go index 0a3d6e6f..a063070e 100644 --- a/scope.go +++ b/scope.go @@ -80,6 +80,17 @@ func (scope *Scope) Quote(str string) string { return scope.Dialect().Quote(str) } +// PrimaryKeys used to define the primary keys of a table +func (scope *Scope) PrimaryKeys(keys []string) string { + res := scope.Dialect().PrimaryKeys(keys) + + if res != "" { + return fmt.Sprintf(", %v", res) + } + + return "" +} + // Err add error to Scope func (scope *Scope) Err(err error) error { if err != nil { @@ -1071,7 +1082,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.PrimaryKeys(primaryKeys), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -1103,7 +1114,7 @@ func (scope *Scope) createTable() *Scope { var primaryKeyStr string if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + primaryKeyStr = scope.PrimaryKeys(primaryKeys) } scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() diff --git a/test_all.sh b/test_all.sh index 6c5593b3..1e2a4715 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "mysql" "sqlite" "ql") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test