Merge 6e51fe2b4a833134dc961511de0add7a36ef4667 into 0fd395ab37aefd2d50854f0556a4311dccc6f45a
This commit is contained in:
commit
44f72bb47b
24
dialect.go
24
dialect.go
@ -78,6 +78,30 @@ func GetDialect(name string) (dialect Dialect, ok bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// GetAllDialects returns a map of registered dialects keyed by their names
|
||||
func GetAllDialects() (dialects map[string]Dialect) {
|
||||
// Copy dialectsMap to protect it from being accidentally modified by clients
|
||||
dialects = make(map[string]Dialect)
|
||||
for k, v := range dialectsMap {
|
||||
dialects[k] = v
|
||||
}
|
||||
return dialects
|
||||
}
|
||||
|
||||
// UnregisterDialect removes a registered dialect, if present
|
||||
func UnregisterDialect(name string) {
|
||||
delete(dialectsMap, name)
|
||||
}
|
||||
|
||||
// GetAllDialectNames returns a list of registered dialect names
|
||||
func GetAllDialectNames() []string {
|
||||
names := []string{}
|
||||
for n := range dialectsMap {
|
||||
names = append(names, n)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ParseFieldStructForDialect get field's sql data type
|
||||
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||
// Get redirected field type
|
||||
|
157
dialect_test.go
Normal file
157
dialect_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
)
|
||||
|
||||
func TestRegisterGetAndUnregisterDialect(t *testing.T) {
|
||||
commonDialect, ok := gorm.GetDialect("common")
|
||||
if !ok {
|
||||
t.Error("Expected to find dialect 'common' registered, but it is missing")
|
||||
}
|
||||
gorm.RegisterDialect("dialect_for_TestUnregisterDialect", commonDialect)
|
||||
|
||||
// Check the test's dialect is there.
|
||||
testDialect, ok := gorm.GetDialect("dialect_for_TestUnregisterDialect")
|
||||
if !ok {
|
||||
t.Error("Expected to find the test dialect registered, but it is missing")
|
||||
}
|
||||
if testDialect != commonDialect {
|
||||
t.Error("Unexpected dialect returned by GetDialect")
|
||||
}
|
||||
|
||||
// Remove the test dialect.
|
||||
gorm.UnregisterDialect("dialect_for_TestUnregisterDialect")
|
||||
|
||||
// Check the test's dialect is now gone.
|
||||
testDialect, ok = gorm.GetDialect("dialect_for_TestUnregisterDialect")
|
||||
if ok {
|
||||
t.Errorf("Expected the test dialect to be removed, but it is still registered: %v", testDialect)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllDialectsEmpty(t *testing.T) {
|
||||
// Clear the old dialects map, and reset it after the test ends.
|
||||
oldDialectsMap := gorm.GetAllDialects()
|
||||
for name := range oldDialectsMap {
|
||||
gorm.UnregisterDialect(name)
|
||||
}
|
||||
defer func() {
|
||||
for name, dialect := range oldDialectsMap {
|
||||
gorm.RegisterDialect(name, dialect)
|
||||
}
|
||||
}()
|
||||
|
||||
// Empty case.
|
||||
dialects := gorm.GetAllDialects()
|
||||
if len(dialects) != 0 {
|
||||
t.Errorf("There should be no dialects registered when dialectsMap is empty, instead found: %v", dialects)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllDialectNamesEmpty(t *testing.T) {
|
||||
// Clear the old dialects map, and reset it after the test ends.
|
||||
oldDialectsMap := gorm.GetAllDialects()
|
||||
for name := range oldDialectsMap {
|
||||
gorm.UnregisterDialect(name)
|
||||
}
|
||||
defer func() {
|
||||
for name, dialect := range oldDialectsMap {
|
||||
gorm.RegisterDialect(name, dialect)
|
||||
}
|
||||
}()
|
||||
|
||||
// Empty case.
|
||||
allNames := gorm.GetAllDialectNames()
|
||||
if len(allNames) != 0 {
|
||||
t.Errorf("There should be no registered dialects when dialectsMap is empty, instead found: %v", allNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllDialects(t *testing.T) {
|
||||
// Clear the old dialects map, and reset it after the test ends.
|
||||
oldDialectsMap := gorm.GetAllDialects()
|
||||
for name := range oldDialectsMap {
|
||||
gorm.UnregisterDialect(name)
|
||||
}
|
||||
defer func() {
|
||||
for name, dialect := range oldDialectsMap {
|
||||
gorm.RegisterDialect(name, dialect)
|
||||
}
|
||||
}()
|
||||
|
||||
// Register some dialects.
|
||||
dialectNames := []string{
|
||||
"common",
|
||||
"mysql",
|
||||
"mssql",
|
||||
"postgres",
|
||||
"cloudsqlpostgres",
|
||||
"sqlite3",
|
||||
}
|
||||
for _, name := range dialectNames {
|
||||
oldDialect, ok := oldDialectsMap[name]
|
||||
if !ok {
|
||||
t.Errorf("Expected imports to register dialect '%s', but it is missing. Full map is: %v", name, oldDialectsMap)
|
||||
}
|
||||
gorm.RegisterDialect(name, oldDialect)
|
||||
}
|
||||
|
||||
// Check the returned map
|
||||
dialects := gorm.GetAllDialects()
|
||||
if len(dialects) != 6 {
|
||||
t.Errorf("Expected to find 6 dialects registered, instead found %d. Full map is: %v", len(dialects), dialects)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllDialectNames(t *testing.T) {
|
||||
// Clear the old dialects map, and reset it after the test ends.
|
||||
oldDialectsMap := gorm.GetAllDialects()
|
||||
for name := range oldDialectsMap {
|
||||
gorm.UnregisterDialect(name)
|
||||
}
|
||||
defer func() {
|
||||
for name, dialect := range oldDialectsMap {
|
||||
gorm.RegisterDialect(name, dialect)
|
||||
}
|
||||
}()
|
||||
|
||||
// Register some dialects.
|
||||
dialectNames := []string{
|
||||
"cloudsqlpostgres",
|
||||
"common",
|
||||
"mssql",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"sqlite3",
|
||||
}
|
||||
for _, name := range dialectNames {
|
||||
oldDialect, ok := oldDialectsMap[name]
|
||||
if !ok {
|
||||
t.Errorf("Expected imports to register dialect '%s', but it is missing. Full map is: %v", name, oldDialectsMap)
|
||||
}
|
||||
gorm.RegisterDialect(name, oldDialect)
|
||||
}
|
||||
|
||||
// Check the returned map
|
||||
allNames := gorm.GetAllDialectNames()
|
||||
if len(allNames) != 6 {
|
||||
t.Errorf("Expected to find 6 dialect names, instead found %d. Full list is: %v", len(allNames), allNames)
|
||||
}
|
||||
|
||||
// Sort both dialectNames and allNames
|
||||
sort.Strings(dialectNames)
|
||||
sort.Strings(allNames)
|
||||
if !reflect.DeepEqual(dialectNames, allNames) {
|
||||
t.Errorf("Unexpected list of dialects returned. Expected: %v, instead found: %v", dialectNames, allNames)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user