From 6e51fe2b4a833134dc961511de0add7a36ef4667 Mon Sep 17 00:00:00 2001 From: Marwan Al Jubeh Date: Tue, 24 Jul 2018 12:21:29 +0100 Subject: [PATCH] Add some functions for managing dialects Adds the following functions: - GetAllDialects: Returns a map of all registered dialects. - UnregisterDialect: Removes a registered dialect. - GetAllDialectNames: Returns a list of all registered dialect names. --- dialect.go | 24 ++++++++ dialect_test.go | 157 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 dialect_test.go diff --git a/dialect.go b/dialect.go index 506a6e86..b685fd28 100644 --- a/dialect.go +++ b/dialect.go @@ -78,6 +78,30 @@ func GetDialect(name string) (dialect Dialect, ok bool) { return } +// GetAllDialects returns a map of registered dialects keyed by their names +func GetAllDialects() (dialects map[string]Dialect) { + // Copy dialectsMap to protect it from being accidentally modified by clients + dialects = make(map[string]Dialect) + for k, v := range dialectsMap { + dialects[k] = v + } + return dialects +} + +// UnregisterDialect removes a registered dialect, if present +func UnregisterDialect(name string) { + delete(dialectsMap, name) +} + +// GetAllDialectNames returns a list of registered dialect names +func GetAllDialectNames() []string { + names := []string{} + for n := range dialectsMap { + names = append(names, n) + } + return names +} + // ParseFieldStructForDialect get field's sql data type var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type diff --git a/dialect_test.go b/dialect_test.go new file mode 100644 index 00000000..563a53a9 --- /dev/null +++ b/dialect_test.go @@ -0,0 +1,157 @@ +package gorm_test + +import ( + "reflect" + "sort" + "testing" + + "github.com/jinzhu/gorm" + + _ "github.com/jinzhu/gorm/dialects/mssql" + _ "github.com/jinzhu/gorm/dialects/mysql" + _ "github.com/jinzhu/gorm/dialects/postgres" + _ "github.com/jinzhu/gorm/dialects/sqlite" +) + +func TestRegisterGetAndUnregisterDialect(t *testing.T) { + commonDialect, ok := gorm.GetDialect("common") + if !ok { + t.Error("Expected to find dialect 'common' registered, but it is missing") + } + gorm.RegisterDialect("dialect_for_TestUnregisterDialect", commonDialect) + + // Check the test's dialect is there. + testDialect, ok := gorm.GetDialect("dialect_for_TestUnregisterDialect") + if !ok { + t.Error("Expected to find the test dialect registered, but it is missing") + } + if testDialect != commonDialect { + t.Error("Unexpected dialect returned by GetDialect") + } + + // Remove the test dialect. + gorm.UnregisterDialect("dialect_for_TestUnregisterDialect") + + // Check the test's dialect is now gone. + testDialect, ok = gorm.GetDialect("dialect_for_TestUnregisterDialect") + if ok { + t.Errorf("Expected the test dialect to be removed, but it is still registered: %v", testDialect) + } +} + +func TestGetAllDialectsEmpty(t *testing.T) { + // Clear the old dialects map, and reset it after the test ends. + oldDialectsMap := gorm.GetAllDialects() + for name := range oldDialectsMap { + gorm.UnregisterDialect(name) + } + defer func() { + for name, dialect := range oldDialectsMap { + gorm.RegisterDialect(name, dialect) + } + }() + + // Empty case. + dialects := gorm.GetAllDialects() + if len(dialects) != 0 { + t.Errorf("There should be no dialects registered when dialectsMap is empty, instead found: %v", dialects) + } +} + +func TestGetAllDialectNamesEmpty(t *testing.T) { + // Clear the old dialects map, and reset it after the test ends. + oldDialectsMap := gorm.GetAllDialects() + for name := range oldDialectsMap { + gorm.UnregisterDialect(name) + } + defer func() { + for name, dialect := range oldDialectsMap { + gorm.RegisterDialect(name, dialect) + } + }() + + // Empty case. + allNames := gorm.GetAllDialectNames() + if len(allNames) != 0 { + t.Errorf("There should be no registered dialects when dialectsMap is empty, instead found: %v", allNames) + } +} + +func TestGetAllDialects(t *testing.T) { + // Clear the old dialects map, and reset it after the test ends. + oldDialectsMap := gorm.GetAllDialects() + for name := range oldDialectsMap { + gorm.UnregisterDialect(name) + } + defer func() { + for name, dialect := range oldDialectsMap { + gorm.RegisterDialect(name, dialect) + } + }() + + // Register some dialects. + dialectNames := []string{ + "common", + "mysql", + "mssql", + "postgres", + "cloudsqlpostgres", + "sqlite3", + } + for _, name := range dialectNames { + oldDialect, ok := oldDialectsMap[name] + if !ok { + t.Errorf("Expected imports to register dialect '%s', but it is missing. Full map is: %v", name, oldDialectsMap) + } + gorm.RegisterDialect(name, oldDialect) + } + + // Check the returned map + dialects := gorm.GetAllDialects() + if len(dialects) != 6 { + t.Errorf("Expected to find 6 dialects registered, instead found %d. Full map is: %v", len(dialects), dialects) + } +} + +func TestGetAllDialectNames(t *testing.T) { + // Clear the old dialects map, and reset it after the test ends. + oldDialectsMap := gorm.GetAllDialects() + for name := range oldDialectsMap { + gorm.UnregisterDialect(name) + } + defer func() { + for name, dialect := range oldDialectsMap { + gorm.RegisterDialect(name, dialect) + } + }() + + // Register some dialects. + dialectNames := []string{ + "cloudsqlpostgres", + "common", + "mssql", + "mysql", + "postgres", + "sqlite3", + } + for _, name := range dialectNames { + oldDialect, ok := oldDialectsMap[name] + if !ok { + t.Errorf("Expected imports to register dialect '%s', but it is missing. Full map is: %v", name, oldDialectsMap) + } + gorm.RegisterDialect(name, oldDialect) + } + + // Check the returned map + allNames := gorm.GetAllDialectNames() + if len(allNames) != 6 { + t.Errorf("Expected to find 6 dialect names, instead found %d. Full list is: %v", len(allNames), allNames) + } + + // Sort both dialectNames and allNames + sort.Strings(dialectNames) + sort.Strings(allNames) + if !reflect.DeepEqual(dialectNames, allNames) { + t.Errorf("Unexpected list of dialects returned. Expected: %v, instead found: %v", dialectNames, allNames) + } +}