refactor to use isOra()

This commit is contained in:
Jim Lambert 2020-02-15 16:12:09 -05:00
parent 81b1478ecf
commit 582cc3cc49
4 changed files with 11 additions and 11 deletions

View File

@ -50,7 +50,7 @@ func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
table := pa.Table(db) table := pa.Table(db)
where := fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= ", table, table) where := fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= ", table, table)
if isOracle(db) { if isOra(db) {
where = where + "DATE '0001-01-02'" where = where + "DATE '0001-01-02'"
} else { } else {
where = where + "'0001-01-02'" where = where + "'0001-01-02'"

View File

@ -574,7 +574,7 @@ func TestTransactionReadonly(t *testing.T) {
if dialect == "" { if dialect == "" {
dialect = "sqlite" dialect = "sqlite"
} }
if isOracle(DB) { if isOra(DB) {
dialect = "oracle" dialect = "oracle"
} }
switch dialect { switch dialect {
@ -1352,7 +1352,7 @@ func TestQueryHint1(t *testing.T) {
db := DB.New() db := DB.New()
q := "select 1" q := "select 1"
if isOracle(db) { if isOra(db) {
q = q + " from dual" q = q + " from dual"
} }
_, err := db.Model(User{}).Raw(q).Rows() _, err := db.Model(User{}).Raw(q).Rows()
@ -1462,8 +1462,8 @@ func parseTime(str string) *time.Time {
return &t return &t
} }
// isOracle gives tests an easy way to determine if the dialect uses Oracle for its RMDBS, since oracle could have multiple dialects for different drivers. // isOra gives tests an easy way to determine if the dialect uses Oracle for its RMDBS, since oracle could have multiple dialects for different drivers.
// and since tests run in their own namespace func isOra(db *gorm.DB) bool {
func isOracle(db *gorm.DB) bool { _, ok := db.Dialect().(gorm.OraDialect)
return db.Dialect().GetName() == "oci8" return ok
} }

View File

@ -185,7 +185,7 @@ func TestSearchWithPlainSQL(t *testing.T) {
} }
param := "?" param := "?"
if isOracle(scopedb) { if isOra(scopedb) {
param = "to_date(?, 'YYYY-MM-DD')" param = "to_date(?, 'YYYY-MM-DD')"
} }
scopedb.Where("birthday > "+param, "2002-10-10").Find(&users) scopedb.Where("birthday > "+param, "2002-10-10").Find(&users)
@ -773,7 +773,7 @@ func TestSelectWithEscapedFieldName(t *testing.T) {
DB.Save(&user1).Save(&user2).Save(&user3) DB.Save(&user1).Save(&user2).Save(&user3)
colName := "name" colName := "name"
if isOracle(DB) { if isOra(DB) {
colName = "NAME" // oracle upper cases all identifiers that aren't explicitly escaped when the table is created colName = "NAME" // oracle upper cases all identifiers that aren't explicitly escaped when the table is created
} }
var names []string var names []string
@ -794,7 +794,7 @@ func TestSelectWithVariables(t *testing.T) {
} else { } else {
columns, _ := rows.Columns() columns, _ := rows.Columns()
colName := "fake" colName := "fake"
if isOracle(DB) { if isOra(DB) {
colName = "FAKE" // oracle upper cases all identifiers that aren't explicitly escaped when the table is created colName = "FAKE" // oracle upper cases all identifiers that aren't explicitly escaped when the table is created
} }
if !reflect.DeepEqual(columns, []string{colName}) { if !reflect.DeepEqual(columns, []string{colName}) {

View File

@ -65,7 +65,7 @@ func TestValuer(t *testing.T) {
var user2 User var user2 User
if isOracle(DB) { if isOra(DB) {
where := fmt.Sprintf("name = ? AND %s AND %s", oracle.SearchBlob("password_hash"), oracle.SearchBlob("password")) where := fmt.Sprintf("name = ? AND %s AND %s", oracle.SearchBlob("password_hash"), oracle.SearchBlob("password"))
if err := DB.Where(where, name, "abc", "***pass1").First(&user2).Error; err != nil { if err := DB.Where(where, name, "abc", "***pass1").First(&user2).Error; err != nil {