From 35efe68ba71d571e64ccd1ee62830c30a53ed967 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Wed, 2 May 2018 07:37:51 -0700 Subject: [PATCH] add simple input validation on gorm.Open function (#1855) Simply check if the passed-in database source meets the expected types and, if not, early return with error. --- main.go | 2 ++ main_test.go | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/main.go b/main.go index ffee4ec6..c8a43e8c 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,8 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { dbSQL, err = sql.Open(driver, source) case SQLCommon: dbSQL = value + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } db = &DB{ diff --git a/main_test.go b/main_test.go index 66c46af0..265e0be7 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "testing" "time" @@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) { return } +func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { + stringRef := "foo" + testCases := []interface{}{42, time.Now(), &stringRef} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + _, err := gorm.Open("postgresql", tc) + if err == nil { + t.Error("Should got error with invalid database source") + } + if !strings.HasPrefix(err.Error(), "invalid database source:") { + t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) + } + }) + } +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"`