Move sync.Map from global variable into member of NamingStrategy.
This maintains backward compatibility by making the smap optional - the NamingStrategy still works if it is nil. gorm.Open activates it by calling Init() if the given Namer is a schema.NamingStrategy. Also, this changes the key stored in the smap to be the original name, instead of the replaced name.
This commit is contained in:
parent
96fe56218b
commit
4d33b2f584
6
gorm.go
6
gorm.go
@ -92,6 +92,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
||||
config.NamingStrategy = schema.NamingStrategy{}
|
||||
}
|
||||
|
||||
// Special case: initialize the smap if given Namer is a schema.NamingStrategy.
|
||||
if v, ok := config.NamingStrategy.(schema.NamingStrategy); ok {
|
||||
v.Init()
|
||||
config.NamingStrategy = v
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = logger.Default
|
||||
}
|
||||
|
@ -31,6 +31,31 @@ type NamingStrategy struct {
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
smap *safeSyncMap // Optional: gorm.Open initializes this by calling Init().
|
||||
}
|
||||
|
||||
// safeSyncMap is a sync.Map that allows Load and Store to be called with a nil receiver.
|
||||
type safeSyncMap sync.Map
|
||||
|
||||
// Load implements a nil-safe call to sync.Map's Load.
|
||||
func (smap *safeSyncMap) Load(name string) (interface{}, bool) {
|
||||
if smap == nil {
|
||||
return nil, false
|
||||
}
|
||||
return (*sync.Map)(smap).Load(name)
|
||||
}
|
||||
|
||||
// Store implements a nil-safe call to sync.Map's Store.
|
||||
func (smap *safeSyncMap) Store(name string, value interface{}) {
|
||||
if smap == nil {
|
||||
return
|
||||
}
|
||||
(*sync.Map)(smap).Store(name, value)
|
||||
}
|
||||
|
||||
// Init initializes a NamingStrategy instance smap ptr.
|
||||
func (ns *NamingStrategy) Init() {
|
||||
ns.smap = &safeSyncMap{}
|
||||
}
|
||||
|
||||
// TableName convert string to table name
|
||||
@ -87,7 +112,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
}
|
||||
|
||||
var (
|
||||
smap sync.Map
|
||||
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
commonInitialismsReplacer *strings.Replacer
|
||||
@ -101,24 +125,20 @@ func init() {
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
|
||||
// reset should be called before each unit test. It clears out the cached names from smap. TODO: make smap part of NamingStrategy instead of a global singleton.
|
||||
func reset() {
|
||||
smap = sync.Map{}
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toDBName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
} else if v, ok := smap.Load(name); ok {
|
||||
} else if v, ok := ns.smap.Load(name); ok {
|
||||
return v.(string)
|
||||
}
|
||||
|
||||
origName := name
|
||||
if ns.NameReplacer != nil {
|
||||
name = ns.NameReplacer.Replace(name)
|
||||
}
|
||||
|
||||
if ns.NoLowerCase {
|
||||
smap.Store(name, name) // TODO: should store with original name, not replaced name
|
||||
ns.smap.Store(origName, name)
|
||||
return name
|
||||
}
|
||||
|
||||
@ -159,6 +179,6 @@ func (ns NamingStrategy) toDBName(name string) string {
|
||||
buf.WriteByte(value[len(value)-1])
|
||||
}
|
||||
ret := buf.String()
|
||||
smap.Store(name, ret) // TODO: should store with original name, not replaced name
|
||||
ns.smap.Store(origName, ret)
|
||||
return ret
|
||||
}
|
||||
|
@ -6,8 +6,6 @@ import (
|
||||
)
|
||||
|
||||
func TestToDBName(t *testing.T) {
|
||||
reset()
|
||||
|
||||
var maps = map[string]string{
|
||||
"": "",
|
||||
"x": "x",
|
||||
@ -38,8 +36,6 @@ func TestToDBName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNamingStrategy(t *testing.T) {
|
||||
reset()
|
||||
|
||||
var ns = NamingStrategy{
|
||||
TablePrefix: "public.",
|
||||
SingularTable: true,
|
||||
@ -93,8 +89,6 @@ var testReplacer = CustomReplacer{
|
||||
}
|
||||
|
||||
func TestCustomReplacer(t *testing.T) {
|
||||
reset()
|
||||
|
||||
var ns = NamingStrategy{
|
||||
TablePrefix: "public.",
|
||||
SingularTable: true,
|
||||
@ -134,8 +128,6 @@ func TestCustomReplacer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||
reset()
|
||||
|
||||
var ns = NamingStrategy{
|
||||
TablePrefix: "public.",
|
||||
SingularTable: true,
|
||||
@ -173,3 +165,44 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||
t.Errorf("invalid column name generated, got %v", columdName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamingStrategySmapInit(t *testing.T) {
|
||||
ncalls := 0
|
||||
var testReplacer = CustomReplacer{
|
||||
func(name string) string {
|
||||
ncalls++
|
||||
return name
|
||||
},
|
||||
}
|
||||
|
||||
var ns = NamingStrategy{
|
||||
NameReplacer: testReplacer,
|
||||
}
|
||||
|
||||
ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap.
|
||||
if ncalls != 1 {
|
||||
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||
}
|
||||
|
||||
ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap.
|
||||
if ncalls != 2 {
|
||||
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||
}
|
||||
|
||||
// Now call Init() to create the smap. The next call will be cached.
|
||||
ns.Init()
|
||||
|
||||
ns.IndexName("public.table", "name") // This calls the Replacer: smap not populated yet.
|
||||
if ncalls != 3 {
|
||||
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||
}
|
||||
ns.IndexName("public.table", "name") // This does not call the Replacer. "name" is in the smap.
|
||||
if ncalls != 3 {
|
||||
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||
}
|
||||
|
||||
ns.IndexName("public.table", "name2") // This calls the Replacer, because it's a different name.
|
||||
if ncalls != 4 {
|
||||
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user