gorm/dialect_test.go
Marwan Al Jubeh 6e51fe2b4a Add some functions for managing dialects
Adds the following functions:
- GetAllDialects: Returns a map of all registered dialects.
- UnregisterDialect: Removes a registered dialect.
- GetAllDialectNames: Returns a list of all registered dialect names.
2018-07-24 12:21:29 +01:00

158 lines
4.2 KiB
Go

package gorm_test
import (
"reflect"
"sort"
"testing"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
func TestRegisterGetAndUnregisterDialect(t *testing.T) {
commonDialect, ok := gorm.GetDialect("common")
if !ok {
t.Error("Expected to find dialect 'common' registered, but it is missing")
}
gorm.RegisterDialect("dialect_for_TestUnregisterDialect", commonDialect)
// Check the test's dialect is there.
testDialect, ok := gorm.GetDialect("dialect_for_TestUnregisterDialect")
if !ok {
t.Error("Expected to find the test dialect registered, but it is missing")
}
if testDialect != commonDialect {
t.Error("Unexpected dialect returned by GetDialect")
}
// Remove the test dialect.
gorm.UnregisterDialect("dialect_for_TestUnregisterDialect")
// Check the test's dialect is now gone.
testDialect, ok = gorm.GetDialect("dialect_for_TestUnregisterDialect")
if ok {
t.Errorf("Expected the test dialect to be removed, but it is still registered: %v", testDialect)
}
}
func TestGetAllDialectsEmpty(t *testing.T) {
// Clear the old dialects map, and reset it after the test ends.
oldDialectsMap := gorm.GetAllDialects()
for name := range oldDialectsMap {
gorm.UnregisterDialect(name)
}
defer func() {
for name, dialect := range oldDialectsMap {
gorm.RegisterDialect(name, dialect)
}
}()
// Empty case.
dialects := gorm.GetAllDialects()
if len(dialects) != 0 {
t.Errorf("There should be no dialects registered when dialectsMap is empty, instead found: %v", dialects)
}
}
func TestGetAllDialectNamesEmpty(t *testing.T) {
// Clear the old dialects map, and reset it after the test ends.
oldDialectsMap := gorm.GetAllDialects()
for name := range oldDialectsMap {
gorm.UnregisterDialect(name)
}
defer func() {
for name, dialect := range oldDialectsMap {
gorm.RegisterDialect(name, dialect)
}
}()
// Empty case.
allNames := gorm.GetAllDialectNames()
if len(allNames) != 0 {
t.Errorf("There should be no registered dialects when dialectsMap is empty, instead found: %v", allNames)
}
}
func TestGetAllDialects(t *testing.T) {
// Clear the old dialects map, and reset it after the test ends.
oldDialectsMap := gorm.GetAllDialects()
for name := range oldDialectsMap {
gorm.UnregisterDialect(name)
}
defer func() {
for name, dialect := range oldDialectsMap {
gorm.RegisterDialect(name, dialect)
}
}()
// Register some dialects.
dialectNames := []string{
"common",
"mysql",
"mssql",
"postgres",
"cloudsqlpostgres",
"sqlite3",
}
for _, name := range dialectNames {
oldDialect, ok := oldDialectsMap[name]
if !ok {
t.Errorf("Expected imports to register dialect '%s', but it is missing. Full map is: %v", name, oldDialectsMap)
}
gorm.RegisterDialect(name, oldDialect)
}
// Check the returned map
dialects := gorm.GetAllDialects()
if len(dialects) != 6 {
t.Errorf("Expected to find 6 dialects registered, instead found %d. Full map is: %v", len(dialects), dialects)
}
}
func TestGetAllDialectNames(t *testing.T) {
// Clear the old dialects map, and reset it after the test ends.
oldDialectsMap := gorm.GetAllDialects()
for name := range oldDialectsMap {
gorm.UnregisterDialect(name)
}
defer func() {
for name, dialect := range oldDialectsMap {
gorm.RegisterDialect(name, dialect)
}
}()
// Register some dialects.
dialectNames := []string{
"cloudsqlpostgres",
"common",
"mssql",
"mysql",
"postgres",
"sqlite3",
}
for _, name := range dialectNames {
oldDialect, ok := oldDialectsMap[name]
if !ok {
t.Errorf("Expected imports to register dialect '%s', but it is missing. Full map is: %v", name, oldDialectsMap)
}
gorm.RegisterDialect(name, oldDialect)
}
// Check the returned map
allNames := gorm.GetAllDialectNames()
if len(allNames) != 6 {
t.Errorf("Expected to find 6 dialect names, instead found %d. Full list is: %v", len(allNames), allNames)
}
// Sort both dialectNames and allNames
sort.Strings(dialectNames)
sort.Strings(allNames)
if !reflect.DeepEqual(dialectNames, allNames) {
t.Errorf("Unexpected list of dialects returned. Expected: %v, instead found: %v", dialectNames, allNames)
}
}