diff --git a/gorm.go b/gorm.go index 1109e8cd..5b64fa04 100644 --- a/gorm.go +++ b/gorm.go @@ -92,6 +92,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.NamingStrategy = schema.NamingStrategy{} } + // Special case: initialize the smap if given Namer is a schema.NamingStrategy. + if v, ok := config.NamingStrategy.(schema.NamingStrategy); ok { + v.Init() + config.NamingStrategy = v + } + if config.Logger == nil { config.Logger = logger.Default } diff --git a/schema/naming.go b/schema/naming.go index 0d093b40..172c63a8 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -31,6 +31,31 @@ type NamingStrategy struct { SingularTable bool NameReplacer Replacer NoLowerCase bool + smap *safeSyncMap // Optional: gorm.Open initializes this by calling Init(). +} + +// safeSyncMap is a sync.Map that allows Load and Store to be called with a nil receiver. +type safeSyncMap sync.Map + +// Load implements a nil-safe call to sync.Map's Load. +func (smap *safeSyncMap) Load(name string) (interface{}, bool) { + if smap == nil { + return nil, false + } + return (*sync.Map)(smap).Load(name) +} + +// Store implements a nil-safe call to sync.Map's Store. +func (smap *safeSyncMap) Store(name string, value interface{}) { + if smap == nil { + return + } + (*sync.Map)(smap).Store(name, value) +} + +// Init initializes a NamingStrategy instance smap ptr. +func (ns *NamingStrategy) Init() { + ns.smap = &safeSyncMap{} } // TableName convert string to table name @@ -87,7 +112,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { } var ( - smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} commonInitialismsReplacer *strings.Replacer @@ -101,24 +125,20 @@ func init() { commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } -// reset should be called before each unit test. It clears out the cached names from smap. TODO: make smap part of NamingStrategy instead of a global singleton. -func reset() { - smap = sync.Map{} -} - func (ns NamingStrategy) toDBName(name string) string { if name == "" { return "" - } else if v, ok := smap.Load(name); ok { + } else if v, ok := ns.smap.Load(name); ok { return v.(string) } + origName := name if ns.NameReplacer != nil { name = ns.NameReplacer.Replace(name) } if ns.NoLowerCase { - smap.Store(name, name) // TODO: should store with original name, not replaced name + ns.smap.Store(origName, name) return name } @@ -159,6 +179,6 @@ func (ns NamingStrategy) toDBName(name string) string { buf.WriteByte(value[len(value)-1]) } ret := buf.String() - smap.Store(name, ret) // TODO: should store with original name, not replaced name + ns.smap.Store(origName, ret) return ret } diff --git a/schema/naming_test.go b/schema/naming_test.go index 0bada8b4..6096e11d 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -6,8 +6,6 @@ import ( ) func TestToDBName(t *testing.T) { - reset() - var maps = map[string]string{ "": "", "x": "x", @@ -38,8 +36,6 @@ func TestToDBName(t *testing.T) { } func TestNamingStrategy(t *testing.T) { - reset() - var ns = NamingStrategy{ TablePrefix: "public.", SingularTable: true, @@ -93,8 +89,6 @@ var testReplacer = CustomReplacer{ } func TestCustomReplacer(t *testing.T) { - reset() - var ns = NamingStrategy{ TablePrefix: "public.", SingularTable: true, @@ -134,8 +128,6 @@ func TestCustomReplacer(t *testing.T) { } func TestCustomReplacerWithNoLowerCase(t *testing.T) { - reset() - var ns = NamingStrategy{ TablePrefix: "public.", SingularTable: true, @@ -173,3 +165,44 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { t.Errorf("invalid column name generated, got %v", columdName) } } + +func TestNamingStrategySmapInit(t *testing.T) { + ncalls := 0 + var testReplacer = CustomReplacer{ + func(name string) string { + ncalls++ + return name + }, + } + + var ns = NamingStrategy{ + NameReplacer: testReplacer, + } + + ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap. + if ncalls != 1 { + t.Errorf("replacer function called invalid # of times, got %v", ncalls) + } + + ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap. + if ncalls != 2 { + t.Errorf("replacer function called invalid # of times, got %v", ncalls) + } + + // Now call Init() to create the smap. The next call will be cached. + ns.Init() + + ns.IndexName("public.table", "name") // This calls the Replacer: smap not populated yet. + if ncalls != 3 { + t.Errorf("replacer function called invalid # of times, got %v", ncalls) + } + ns.IndexName("public.table", "name") // This does not call the Replacer. "name" is in the smap. + if ncalls != 3 { + t.Errorf("replacer function called invalid # of times, got %v", ncalls) + } + + ns.IndexName("public.table", "name2") // This calls the Replacer, because it's a different name. + if ncalls != 4 { + t.Errorf("replacer function called invalid # of times, got %v", ncalls) + } +}