add simple input validation on gorm.Open function
Simply check if the passed-in database source meets the expected types and, if not, early return with error.
This commit is contained in:
		
							parent
							
								
									6842b49a1a
								
							
						
					
					
						commit
						98cac0c476
					
				
							
								
								
									
										2
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								main.go
									
									
									
									
									
								
							@ -61,6 +61,8 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
 | 
			
		||||
		dbSQL, err = sql.Open(driver, source)
 | 
			
		||||
	case SQLCommon:
 | 
			
		||||
		dbSQL = value
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db = &DB{
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								main_test.go
									
									
									
									
									
								
							@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
 | 
			
		||||
	stringRef := "foo"
 | 
			
		||||
	testCases := []interface{}{42, time.Now(), &stringRef}
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
 | 
			
		||||
			_, err := gorm.Open("postgresql", tc)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				t.Error("Should got error with invalid database source")
 | 
			
		||||
			}
 | 
			
		||||
			if !strings.HasPrefix(err.Error(), "invalid database source:") {
 | 
			
		||||
				t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStringPrimaryKey(t *testing.T) {
 | 
			
		||||
	type UUIDStruct struct {
 | 
			
		||||
		ID   string `gorm:"primary_key"`
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user