From 62d74645c957334708640e8594c4d066d3444e80 Mon Sep 17 00:00:00 2001 From: Jim Lambert Date: Sat, 15 Feb 2020 11:33:23 -0500 Subject: [PATCH] refactor oracle special cases to use isOralce() testing func --- join_table_test.go | 2 +- main_test.go | 10 ++++++++-- query_test.go | 6 +++--- scope_test.go | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/join_table_test.go b/join_table_test.go index 47fda43e..db09c8e6 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -50,7 +50,7 @@ func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) where := fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= ", table, table) - if db.Dialect().GetName() == "oci8" { + if isOracle(db) { where = where + "DATE '0001-01-02'" } else { where = where + "'0001-01-02'" diff --git a/main_test.go b/main_test.go index 5815b51d..67a9056f 100644 --- a/main_test.go +++ b/main_test.go @@ -574,7 +574,7 @@ func TestTransactionReadonly(t *testing.T) { if dialect == "" { dialect = "sqlite" } - if DB.Dialect().GetName() == "oci8" { + if isOracle(DB) { dialect = "oracle" } switch dialect { @@ -1352,7 +1352,7 @@ func TestQueryHint1(t *testing.T) { db := DB.New() q := "select 1" - if db.Dialect().GetName() == "oci8" { + if isOracle(db) { q = q + " from dual" } _, err := db.Model(User{}).Raw(q).Rows() @@ -1461,3 +1461,9 @@ func parseTime(str string) *time.Time { t := now.New(time.Now().UTC()).MustParse(str) return &t } + +// isOracle gives tests an easy way to determine if the dialect uses Oracle for its RMDBS, since oracle could have multiple dialects for different drivers. +// and since tests run in their own namespace +func isOracle(db *gorm.DB) bool { + return db.Dialect().GetName() == "oci8" +} diff --git a/query_test.go b/query_test.go index 429a4189..2c8320a6 100644 --- a/query_test.go +++ b/query_test.go @@ -185,7 +185,7 @@ func TestSearchWithPlainSQL(t *testing.T) { } param := "?" - if scopedb.Dialect().GetName() == "oci8" { + if isOracle(scopedb) { param = "to_date(?, 'YYYY-MM-DD')" } scopedb.Where("birthday > "+param, "2002-10-10").Find(&users) @@ -773,7 +773,7 @@ func TestSelectWithEscapedFieldName(t *testing.T) { DB.Save(&user1).Save(&user2).Save(&user3) colName := "name" - if DB.Dialect().GetName() == "oci8" { + if isOracle(DB) { colName = "NAME" // oracle upper cases all identifiers that aren't explicitly escaped when the table is created } var names []string @@ -794,7 +794,7 @@ func TestSelectWithVariables(t *testing.T) { } else { columns, _ := rows.Columns() colName := "fake" - if DB.Dialect().GetName() == "oci8" { + if isOracle(DB) { colName = "FAKE" // oracle upper cases all identifiers that aren't explicitly escaped when the table is created } if !reflect.DeepEqual(columns, []string{colName}) { diff --git a/scope_test.go b/scope_test.go index fea2b299..a08b5d4f 100644 --- a/scope_test.go +++ b/scope_test.go @@ -65,7 +65,7 @@ func TestValuer(t *testing.T) { var user2 User - if DB.Dialect().GetName() == "oci8" { + if isOracle(DB) { where := fmt.Sprintf("name = ? AND %s AND %s", oracle.SearchBlob("password_hash"), oracle.SearchBlob("password")) if err := DB.Where(where, name, "abc", "***pass1").First(&user2).Error; err != nil {