diff --git a/dialect.go b/dialect.go index 033b9555..60ba9f32 100644 --- a/dialect.go +++ b/dialect.go @@ -34,8 +34,8 @@ type Dialect interface { // HasColumn check has column or not HasColumn(tableName string, columnName string) bool - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset int) string + // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql and oracle has special case + LimitAndOffsetSQL(limit, offset int) (whereSQL, suffixSQL string) // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` diff --git a/dialect_common.go b/dialect_common.go index 6d43fb84..22f8460f 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -122,13 +122,13 @@ func (s commonDialect) currentDatabase() (name string) { return } -func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { +func (commonDialect) LimitAndOffsetSQL(limit, offset int) (whereSQL, suffixSQL string) { if limit > 0 || offset > 0 { if limit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", limit) + suffixSQL += fmt.Sprintf(" LIMIT %d", limit) } if offset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", offset) + suffixSQL += fmt.Sprintf(" OFFSET %d", offset) } } return diff --git a/dialect_oracle.go b/dialect_oracle.go new file mode 100644 index 00000000..93c0a5c2 --- /dev/null +++ b/dialect_oracle.go @@ -0,0 +1,120 @@ +package gorm + +import ( + "crypto/sha1" + "fmt" + "reflect" + "regexp" + "strings" + "time" + "unicode/utf8" +) + +type oracle struct { + commonDialect +} + +func init() { + RegisterDialect("ora", &oracle{}) +} + +func (oracle) GetName() string { + return "ora" +} + +func (oracle) Quote(key string) string { + return fmt.Sprintf("\"%s\"", strings.ToUpper(key)) +} + +func (oracle) SelectFromDummyTable() string { + return "FROM dual" +} + +func (oracle) BindVar(i int) string { + return fmt.Sprintf(":%d", i) +} + +func (oracle) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "CHAR(1)" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + sqlType = "INTEGER" + case reflect.Int64, reflect.Uint64: + sqlType = "NUMBER" + case reflect.Float32, reflect.Float64: + sqlType = "FLOAT" + case reflect.String: + if size > 0 && size < 255 { + sqlType = fmt.Sprintf("VARCHAR(%d)", size) + } else { + sqlType = "VARCHAR(255)" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "TIMESTAMP" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for ora", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s oracle) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = :1 AND INDEX_NAME = :2", strings.ToUpper(tableName), strings.ToUpper(indexName)).Scan(&count) + return count > 0 +} + +func (s oracle) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE CONSTRAINT_TYPE = 'R' AND TABLE_NAME = :1 AND CONSTRAINT_NAME = :2", strings.ToUpper(tableName), strings.ToUpper(foreignKeyName)).Scan(&count) + return count > 0 +} + +func (s oracle) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = :1", strings.ToUpper(tableName)).Scan(&count) + return count > 0 +} + +func (s oracle) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = :1 AND COLUMN_NAME = :2", strings.ToUpper(tableName), strings.ToUpper(columnName)).Scan(&count) + return count > 0 +} + +func (oracle) LimitAndOffsetSQL(limit, offset int) (whereSQL, suffixSQL string) { + if limit > 0 { + whereSQL += fmt.Sprintf("ROWNUM <= %d", limit) + } + return +} + +func (s oracle) BuildForeignKeyName(tableName, field, dest string) string { + keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) + if utf8.RuneCountInString(keyName) <= 30 { + return keyName + } + h := sha1.New() + h.Write([]byte(keyName)) + bs := h.Sum(nil) + + // sha1 is 40 digits, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + result := fmt.Sprintf("%s%x", string(destRunes), bs) + if len(result) <= 30 { + return result + } + return result[:29] +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 34eda717..1c5d5f3a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -127,16 +127,16 @@ func (s mssql) currentDatabase() (name string) { return } -func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) { +func (mssql) LimitAndOffsetSQL(limit, offset int) (whereSQL, suffixSQL string) { if limit > 0 || offset > 0 { if offset < 0 { offset = 0 } - sql += fmt.Sprintf(" OFFSET %d ROWS", offset) + suffixSQL += fmt.Sprintf(" OFFSET %d ROWS", offset) if limit >= 0 { - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit) + suffixSQL += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit) } } return diff --git a/dialects/ora/oracle.go b/dialects/ora/oracle.go new file mode 100644 index 00000000..14dfd543 --- /dev/null +++ b/dialects/ora/oracle.go @@ -0,0 +1,3 @@ +package oracle + +import _ "gopkg.in/rana/ora.v3" diff --git a/main.go b/main.go index cd445555..2fc5afda 100644 --- a/main.go +++ b/main.go @@ -561,7 +561,15 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { +func (s *DB) AddForeignKey(field string, dest string, options ...string) *DB { + var onDelete string + var onUpdate string + if len(options) >= 1 { + onDelete = options[0] + } + if len(options) >= 2 { + onUpdate = options[1] + } scope := s.clone().NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db diff --git a/scope.go b/scope.go index 1e755626..1636e8ff 100644 --- a/scope.go +++ b/scope.go @@ -329,8 +329,12 @@ func (scope *Scope) QuotedTableName() (name string) { // CombinedConditionSql return combined condition sql func (scope *Scope) CombinedConditionSql() string { + limitWhereSQL, limitSuffixSQL := scope.limitAndOffsetSQL() + if len(limitWhereSQL) > 0 { + scope.Search = scope.Search.Where(limitWhereSQL) + } return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() + scope.havingSQL() + scope.orderSQL() + limitSuffixSQL } // Raw set raw sql @@ -469,7 +473,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { } for fieldIndex, field := range selectFields { - if field.DBName == column { + if strings.EqualFold(field.DBName, column) { if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { @@ -663,7 +667,7 @@ func (scope *Scope) whereSQL() (sql string) { ) if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { - sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) + sql := fmt.Sprintf("%s.%s IS NULL", quotedTableName, scope.Dialect().Quote("deleted_at")) primaryConditions = append(primaryConditions, sql) } @@ -735,7 +739,7 @@ func (scope *Scope) orderSQL() string { return " ORDER BY " + strings.Join(orders, ",") } -func (scope *Scope) limitAndOffsetSQL() string { +func (scope *Scope) limitAndOffsetSQL() (whereSQL, suffixSQL string) { return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) } @@ -1122,8 +1126,14 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() + var query = fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s", scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest) + if len(onDelete) > 0 { + query += fmt.Sprintf(" ON DELETE %s", onDelete) + } + if len(onUpdate) > 0 { + query += fmt.Sprintf(" ON UPDATE %s", onUpdate) + } + scope.Raw(query).Exec() } func (scope *Scope) removeIndex(indexName string) {