diff --git a/dialect_oracommon.go b/dialect_oracommon.go new file mode 100644 index 00000000..095562cb --- /dev/null +++ b/dialect_oracommon.go @@ -0,0 +1,486 @@ +package gorm + +import ( + "database/sql" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" +) + +/* + +Understanding Oracle + +It's a bit different than the other RDBMS databases and I'll just try to +hightlight a few of the important ones the dialect has to deal with: + +1. Oracle upper cases all non-quoted identifiers. That means the dialect +has to decide what to do: + 1. quote all identifiers which would require developers to quote every + identifer they passed in a string to gorm. + 2. only quote identifers that conflict with reserved words and leave all + other identifiers unquoted, which means Oracle will automatically + upper case them. This would allow developers to pass unquoted + identifiers in strings they passed to gorm and make the experience + align better with the other dialects. +We chose option #2. + +This design decision has the following side affects: + a. you must be case insensitive when matching column names, like in + the Scope.scan function + b. Devs will have to escape reserved words when they reference them + in things like: First(&CreditCard{}, `"number" = ?`) + + +2. Oracle handles last inserted id a bit differently, and requires a sql.Out +parameter to return the value in the oci8 driver. Since Oracle parameters +are positional, you need to know how many other bind variables there are before +adding the returning clause. We've implemented the +OraDialect.CreateWithReturningInto(*Scope) to handle this. + +3. Oracle doesn't let you specify "AS " when selecting a count + from a dynamic table, so you just omit it. (see Scope.count() ) + +4. Oracle handles foreign keys a bit differently: + A. REFERENCES is implicit + B. ON UPDATE is not supported + (see scope.addForeignKey() ) + +5. Searching a blob requires using a function from the dbms_lob package like + instr() and specifying the offset and number of matches. + (see oci8.SearchBlob() ) + +6 Trailing semicolons are not allowed at the end of Oracle sql statements + (so they were removed in the unit tests) + +*/ + +// OraCommon implements the Dialect and OraDialect interfaces in a common way for all Ora drivers... but many will need +// to override the OraDialect.CreateWithReturningInto. We've OraCommon public so dialect implementers can use it in modules +// outside of the gorm module/repo +type OraCommon struct { + db SQLCommon + DefaultForeignKeyNamer +} + +// GetName will panic when called, since ALL driver specific dialects must override GetName() to return the driver name. +func (OraCommon) GetName() string { + panic("driver specific dialect must override this to return the driver name.") +} + +func (*OraCommon) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + return strings.ToLower(value) != "false" + } + return field.IsPrimaryKey +} + +func (OraCommon) BindVar(i int) string { + return fmt.Sprintf(":%v", i) +} + +// Quote will only quote ora reserved words, which means all other identifiers will be upper cased automatically +func (OraCommon) Quote(key string) string { + if IsOraReservedWord(key) { + return fmt.Sprintf(`"%s"`, key) + } + return key +} + +func (s OraCommon) CurrentDatabase() string { + var name string + if err := s.db.QueryRow("SELECT ORA_DATABASE_NAME as \"Current Database\" FROM DUAL").Scan(&name); err != nil { + return "" // just return "", since the Dialect interface doesn't support returning an error for this func + } + return name +} + +func (OraCommon) DefaultValueStr() string { + return "VALUES (DEFAULT)" +} + +func (s OraCommon) HasColumn(tableName string, columnName string) bool { + var count int + _, tableName = s.oraCurrentDatabaseAndTable(tableName) + tableName = strings.ToUpper(tableName) + columnName = strings.ToUpper(columnName) + if err := s.db.QueryRow("SELECT count(*) FROM ALL_TAB_COLUMNS WHERE TABLE_NAME = :1 AND COLUMN_NAME = :2", tableName, columnName).Scan(&count); err == nil { + return count > 0 + } + return false +} + +func (s OraCommon) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + tableName = strings.ToUpper(tableName) + foreignKeyName = strings.ToUpper(foreignKeyName) + + if err := s.db.QueryRow(`SELECT count(*) FROM USER_CONSTRAINTS WHERE CONSTRAINT_NAME = :1 AND constraint_type = 'R' AND table_name = :2`, foreignKeyName, tableName).Scan(&count); err == nil { + return count > 0 + } + return false +} + +func (s OraCommon) HasIndex(tableName string, indexName string) bool { + var count int + tableName = strings.ToUpper(tableName) + indexName = strings.ToUpper(indexName) + if err := s.db.QueryRow("SELECT count(*) FROM ALL_INDEXES WHERE INDEX_NAME = :1 AND TABLE_NAME = :2", indexName, tableName).Scan(&count); err == nil { + return count > 0 + } + return false +} + +func (s OraCommon) HasTable(tableName string) bool { + var count int + _, tableName = s.oraCurrentDatabaseAndTable(tableName) + tableName = strings.ToUpper(tableName) + if err := s.db.QueryRow("select count(*) from user_tables where table_name = :1", tableName).Scan(&count); err == nil { + return count > 0 + } + return false +} + +func (OraCommon) LastInsertIDReturningSuffix(tableName, columnName string) string { + return "" +} + +func (OraCommon) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + return "" +} + +func (s OraCommon) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", tableName, columnName, typ)) + return err +} + +func (s OraCommon) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) + return err +} + +func (OraCommon) SelectFromDummyTable() string { + return "FROM DUAL" +} + +func (s *OraCommon) SetDB(db SQLCommon) { + s.db = db +} + +func (s OraCommon) oraCurrentDatabaseAndTable(tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return s.CurrentDatabase(), tableName +} + +func (s *OraCommon) DataTypeOf(field *StructField) string { + if _, found := field.TagSettingsGet("RESTRICT"); found { + field.TagSettingsDelete("RESTRICT") + } + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, + reflect.Uint16, reflect.Uintptr, reflect.Int64, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + if s.fieldCanAutoIncrement(field) { + sqlType = "NUMBER GENERATED BY DEFAULT AS IDENTITY" + } else { + switch dataValue.Kind() { + case reflect.Int8, + reflect.Uint8, + reflect.Uintptr: + sqlType = "SHORTINTEGER" + case reflect.Int, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint16, reflect.Uint32: + sqlType = "INTEGER" + case reflect.Int64, + reflect.Uint64: + sqlType = "INTEGER" + default: + sqlType = "NUMBER" + } + } + case reflect.Bool: + sqlType = "INTEGER" + case reflect.String: + if _, ok := field.TagSettingsGet("SIZE"); !ok { + size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different + } + switch { + case size > 0 && size < 4000: + sqlType = fmt.Sprintf("VARCHAR2(%d)", size) + case size == 0: + sqlType = "VARCHAR2 (1000)" // no size specified, so default to something that can be indexed + default: + sqlType = "CLOB" + } + + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "TIMESTAMP WITH TIME ZONE" + } + default: + if IsByteArrayOrSlice(dataValue) { + sqlType = "BLOB" + } + } + } + if strings.EqualFold(sqlType, "text") { + sqlType = "CLOB" + } + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + if strings.EqualFold(sqlType, "json") { + sqlType = "VARCHAR2 (4000)" + } + + // For oracle, we have to redo the order of the Default type from tag setting + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { + additionalType = fmt.Sprintf("%s %s %s", "DEFAULT", value, additionalType) + } + + if value, ok := field.TagSettingsGet("COMMENT"); ok { + additionalType = additionalType + " COMMENT " + value + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} +func (s OraCommon) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { + if limit != nil { + if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + // when only Limit() is called on a query, the offset is set to -1 for some reason + if offset != nil && offset != -1 { + if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + sql += fmt.Sprintf(" OFFSET %d ROWS ", parsedOffset) + } else { + return "", err + } + } + sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) + } else { + return "", err + } + } + return +} + +// NormalizeIndexAndColumn returns argument's index name and column name without doing anything +func (OraCommon) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + return indexName, columnName +} + +// CreateWithReturningInto to implement the OraDialect interface +func (OraCommon) CreateWithReturningInto(scope *Scope) { + var stringId string + var intId uint32 + primaryField := scope.PrimaryField() + + primaryIsString := false + out := sql.Out{ + Dest: &intId, + } + if primaryField.Field.Kind() == reflect.String { + out = sql.Out{ + Dest: &stringId, + } + primaryIsString = true + } + scope.SQLVars = append(scope.SQLVars, out) + scope.SQL = fmt.Sprintf("%s returning %s into :%d", scope.SQL, scope.Quote(primaryField.DBName), len(scope.SQLVars)) + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + scope.DB().RowsAffected, _ = result.RowsAffected() + if primaryIsString { + scope.Err(primaryField.Set(stringId)) + } else { + scope.Err(primaryField.Set(intId)) + } + } + // this should raise an error, but the gorm.createCallback() which calls it simply doesn't support returning an error +} + +// OraSearchBlob returns a where clause substring for searching fieldName and will require you to pass a parameter for the search value +func OraSearchBlob(fieldName string) string { + // oracle requires some hoop jumping to search []byte stored as BLOB + + const lobSearch = ` dbms_lob.instr (%s, -- the blob + utl_raw.cast_to_raw (?), -- the search string cast to raw + 1, -- where to start. i.e. offset + 1 -- Which occurrance i.e. 1=first + ) > 0 ` + return fmt.Sprintf(lobSearch, fieldName) +} + +var setupReserved sync.Once +var reservedWords map[string]struct{} + +// IsOraReservedWord is public intentionally so other DEVs can access it for implementation +func IsOraReservedWord(w string) bool { + setupReserved.Do( + func() { + words := strings.Split(reserved, "\n") + reservedWords = make(map[string]struct{}, len(words)) + for _, s := range words { + reservedWords[s] = struct{}{} + } + }, + ) + _, ok := reservedWords[strings.ToUpper(w)] + return ok +} + +const reserved = `AGGREGATE +AGGREGATES +ALL +ALLOW +ANALYZE +ANCESTOR +AND +ANY +AS +ASC +AT +AVG +BETWEEN +BINARY_DOUBLE +BINARY_FLOAT +BLOB +BRANCH +BUILD +BY +BYTE +CASE +CAST +CHAR +CHILD +CLEAR +CLOB +COMMIT +COMPILE +CONSIDER +COUNT +DATATYPE +DATE +DATE_MEASURE +DAY +DECIMAL +DELETE +DESC +DESCENDANT +DIMENSION +DISALLOW +DIVISION +DML +ELSE +END +ESCAPE +EXECUTE +FIRST +FLOAT +FOR +FROM +HIERARCHIES +HIERARCHY +HOUR +IGNORE +IN +INFINITE +INSERT +INTEGER +INTERVAL +INTO +IS +LAST +LEAF_DESCENDANT +LEAVES +LEVEL +LIKE +LIKEC +LIKE2 +LIKE4 +LOAD +LOCAL +LOG_SPEC +LONG +MAINTAIN +MAX +MEASURE +MEASURES +MEMBER +MEMBERS +MERGE +MLSLABEL +MIN +MINUTE +MODEL +MONTH +NAN +NCHAR +NCLOB +NO +NONE +NOT +NULL +NULLS +NUMBER +NVARCHAR2 +OF +OLAP +OLAP_DML_EXPRESSION +ON +ONLY +OPERATOR +OR +ORDER +OVER +OVERFLOW +PARALLEL +PARENT +PLSQL +PRUNE +RAW +RELATIVE +ROOT_ANCESTOR +ROWID +SCN +SECOND +SELF +SERIAL +SET +SOLVE +SOME +SORT +SPEC +SUM +SYNCH +TEXT_MEASURE +THEN +TIME +TIMESTAMP +TO +UNBRANCH +UPDATE +USING +VALIDATE +VALUES +VARCHAR2 +WHEN +WHERE +WITHIN +WITH +YEAR +ZERO +ZONE`