diff --git a/dialects/oracle/oracle.go b/dialects/oracle/oracle.go index 9483ebc4..6bbef519 100644 --- a/dialects/oracle/oracle.go +++ b/dialects/oracle/oracle.go @@ -10,19 +10,33 @@ import ( _ "github.com/godror/godror" "github.com/jinzhu/gorm" + _ "github.com/mattn/go-oci8" + _ "gopkg.in/rana/ora.v4" ) +//const dialectName = "godror" +// const dialectName = "ora" +const dialectName = "oci8" + type oracle struct { db gorm.SQLCommon gorm.DefaultForeignKeyNamer } func init() { - gorm.RegisterDialect("godror", &oracle{}) + gorm.RegisterDialect(dialectName, &oracle{}) + +} + +func (s *oracle) fieldCanAutoIncrement(field *gorm.StructField) bool { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + return strings.ToLower(value) != "false" + } + return field.IsPrimaryKey } func (oracle) GetName() string { - return "goracle" + return dialectName } func (oracle) BindVar(i int) string { @@ -30,12 +44,13 @@ func (oracle) BindVar(i int) string { } func (oracle) Quote(key string) string { - return fmt.Sprintf("%s", key) + return key } -func (s oracle) CurrentDatabase() (name string) { +func (s oracle) CurrentDatabase() string { + var name string s.db.QueryRow("SELECT ORA_DATABASE_NAME as \"Current Database\" FROM DUAL").Scan(&name) - return + return name } func (oracle) DefaultValueStr() string { @@ -44,40 +59,53 @@ func (oracle) DefaultValueStr() string { func (s oracle) HasColumn(tableName string, columnName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - fmt.Printf(currentDatabase) - s.db.QueryRow("SELECT count(*) FROM ALL_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, columnName).Scan(&count) - return count > 0 + _, tableName = currentDatabaseAndTable(&s, tableName) + tableName = strings.ToUpper(tableName) + columnName = strings.ToUpper(columnName) + if err := s.db.QueryRow("SELECT count(*) FROM ALL_TAB_COLUMNS WHERE TABLE_NAME = :1 AND COLUMN_NAME = :2", tableName, columnName).Scan(&count); err == nil { + return count > 0 + } else { + fmt.Println(err) + } + return false } func (s oracle) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - fmt.Printf(currentDatabase) - s.db.QueryRow(`SELECT count(*) + _, tableName = currentDatabaseAndTable(&s, tableName) + tableName = strings.ToUpper(tableName) + if err := s.db.QueryRow(`SELECT count(*) FROM all_cons_columns a JOIN all_constraints c ON a.owner = c.owner AND a.constraint_name = c.constraint_name - AND a.constraint_name = ? + AND a.constraint_name = :1 JOIN all_constraints c_pk ON c.r_owner = c_pk.owner AND c.r_constraint_name = c_pk.constraint_name WHERE c.constraint_type = 'R' - AND a.table_name = ?;`, foreignKeyName, tableName).Scan(&count) - return count > 0 + AND a.table_name = :2 ;`, foreignKeyName, tableName).Scan(&count); err == nil { + return count > 0 + } + return false } func (s oracle) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM ALL_INDEXES WHERE INDEX_NAME = ? AND TABLE_NAME = ?", indexName, tableName).Scan(&count) - return count > 0 + tableName = strings.ToUpper(tableName) + indexName = strings.ToUpper(indexName) + if err := s.db.QueryRow("SELECT count(*) FROM ALL_INDEXES WHERE INDEX_NAME = :1 AND TABLE_NAME = :2", indexName, tableName).Scan(&count); err == nil { + return count > 0 + } + return false } func (s oracle) HasTable(tableName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - fmt.Printf(currentDatabase) - s.db.QueryRow("SELECT DISTINCT OBJECT_NAME FROM USER_OBJECTS WHERE OBJECT_TYPE = 'TABLE' AND OBJECT_NAME = ?", tableName).Scan(&count) - return count > 0 + _, tableName = currentDatabaseAndTable(&s, tableName) + tableName = strings.ToUpper(tableName) + if err := s.db.QueryRow("select count(*) from user_tables where table_name = :1", tableName).Scan(&count); err == nil { + return count > 0 + } + return false } func (oracle) LastInsertIDReturningSuffix(tableName, columnName string) string { @@ -99,7 +127,7 @@ func (s oracle) RemoveIndex(tableName string, indexName string) error { } func (oracle) SelectFromDummyTable() string { - return "SELECT * FROM DUAL" + return "FROM DUAL" } func (s *oracle) SetDB(db gorm.SQLCommon) { @@ -122,17 +150,24 @@ func (s *oracle) DataTypeOf(field *gorm.StructField) string { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - sqlType = "NUMBER" + if s.fieldCanAutoIncrement(field) { + sqlType = "NUMBER GENERATED BY DEFAULT AS IDENTITY" + } else { + sqlType = "NUMBER" + } case reflect.String: if _, ok := field.TagSettingsGet("SIZE"); !ok { size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different } - - if size > 0 && size < 4000 { + switch { + case size > 0 && size < 4000: sqlType = fmt.Sprintf("VARCHAR2(%d)", size) - } else { + case size == 0: + sqlType = "VARCHAR2 (4000)" + default: sqlType = "CLOB" } + case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { sqlType = "TIMESTAMP WITH TIME ZONE" @@ -143,15 +178,7 @@ func (s *oracle) DataTypeOf(field *gorm.StructField) string { } default: if gorm.IsByteArrayOrSlice(dataValue) { - sqlType = "VARCHAR2" - - /*if isUUID(dataValue) { - sqlType = "VARCHAR2" - } - - if isJSON(dataValue) { - sqlType = "CLOB" - }*/ + sqlType = "VARCHAR2 (4000)" } } } @@ -163,6 +190,22 @@ func (s *oracle) DataTypeOf(field *gorm.StructField) string { if strings.TrimSpace(additionalType) == "" { return sqlType } + if strings.EqualFold(sqlType, "json") { + sqlType = "VARCHAR2 (4000)" + } + + // For oracle, we have to redo the order of the Default type from tag setting + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { + additionalType = fmt.Sprintf("%s %s %s", "DEFAULT", value, additionalType) + // additionalType = additionalType + " DEFAULT " + value + } + + if value, ok := field.TagSettingsGet("COMMENT"); ok { + additionalType = additionalType + " COMMENT " + value + } return fmt.Sprintf("%v %v", sqlType, additionalType) } func (s oracle) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {