Move sync.Map from global variable into member of NamingStrategy.
This maintains backward compatibility by making the smap optional - the NamingStrategy still works if it is nil. gorm.Open activates it by calling Init() if the given Namer is a schema.NamingStrategy. Also, this changes the key stored in the smap to be the original name, instead of the replaced name.
This commit is contained in:
parent
96fe56218b
commit
4d33b2f584
6
gorm.go
6
gorm.go
@ -92,6 +92,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||||||
config.NamingStrategy = schema.NamingStrategy{}
|
config.NamingStrategy = schema.NamingStrategy{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Special case: initialize the smap if given Namer is a schema.NamingStrategy.
|
||||||
|
if v, ok := config.NamingStrategy.(schema.NamingStrategy); ok {
|
||||||
|
v.Init()
|
||||||
|
config.NamingStrategy = v
|
||||||
|
}
|
||||||
|
|
||||||
if config.Logger == nil {
|
if config.Logger == nil {
|
||||||
config.Logger = logger.Default
|
config.Logger = logger.Default
|
||||||
}
|
}
|
||||||
|
@ -31,6 +31,31 @@ type NamingStrategy struct {
|
|||||||
SingularTable bool
|
SingularTable bool
|
||||||
NameReplacer Replacer
|
NameReplacer Replacer
|
||||||
NoLowerCase bool
|
NoLowerCase bool
|
||||||
|
smap *safeSyncMap // Optional: gorm.Open initializes this by calling Init().
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeSyncMap is a sync.Map that allows Load and Store to be called with a nil receiver.
|
||||||
|
type safeSyncMap sync.Map
|
||||||
|
|
||||||
|
// Load implements a nil-safe call to sync.Map's Load.
|
||||||
|
func (smap *safeSyncMap) Load(name string) (interface{}, bool) {
|
||||||
|
if smap == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return (*sync.Map)(smap).Load(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store implements a nil-safe call to sync.Map's Store.
|
||||||
|
func (smap *safeSyncMap) Store(name string, value interface{}) {
|
||||||
|
if smap == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
(*sync.Map)(smap).Store(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes a NamingStrategy instance smap ptr.
|
||||||
|
func (ns *NamingStrategy) Init() {
|
||||||
|
ns.smap = &safeSyncMap{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName convert string to table name
|
// TableName convert string to table name
|
||||||
@ -87,7 +112,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
smap sync.Map
|
|
||||||
// https://github.com/golang/lint/blob/master/lint.go#L770
|
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||||
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
commonInitialismsReplacer *strings.Replacer
|
commonInitialismsReplacer *strings.Replacer
|
||||||
@ -101,24 +125,20 @@ func init() {
|
|||||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset should be called before each unit test. It clears out the cached names from smap. TODO: make smap part of NamingStrategy instead of a global singleton.
|
|
||||||
func reset() {
|
|
||||||
smap = sync.Map{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ns NamingStrategy) toDBName(name string) string {
|
func (ns NamingStrategy) toDBName(name string) string {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return ""
|
return ""
|
||||||
} else if v, ok := smap.Load(name); ok {
|
} else if v, ok := ns.smap.Load(name); ok {
|
||||||
return v.(string)
|
return v.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
origName := name
|
||||||
if ns.NameReplacer != nil {
|
if ns.NameReplacer != nil {
|
||||||
name = ns.NameReplacer.Replace(name)
|
name = ns.NameReplacer.Replace(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ns.NoLowerCase {
|
if ns.NoLowerCase {
|
||||||
smap.Store(name, name) // TODO: should store with original name, not replaced name
|
ns.smap.Store(origName, name)
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,6 +179,6 @@ func (ns NamingStrategy) toDBName(name string) string {
|
|||||||
buf.WriteByte(value[len(value)-1])
|
buf.WriteByte(value[len(value)-1])
|
||||||
}
|
}
|
||||||
ret := buf.String()
|
ret := buf.String()
|
||||||
smap.Store(name, ret) // TODO: should store with original name, not replaced name
|
ns.smap.Store(origName, ret)
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestToDBName(t *testing.T) {
|
func TestToDBName(t *testing.T) {
|
||||||
reset()
|
|
||||||
|
|
||||||
var maps = map[string]string{
|
var maps = map[string]string{
|
||||||
"": "",
|
"": "",
|
||||||
"x": "x",
|
"x": "x",
|
||||||
@ -38,8 +36,6 @@ func TestToDBName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNamingStrategy(t *testing.T) {
|
func TestNamingStrategy(t *testing.T) {
|
||||||
reset()
|
|
||||||
|
|
||||||
var ns = NamingStrategy{
|
var ns = NamingStrategy{
|
||||||
TablePrefix: "public.",
|
TablePrefix: "public.",
|
||||||
SingularTable: true,
|
SingularTable: true,
|
||||||
@ -93,8 +89,6 @@ var testReplacer = CustomReplacer{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomReplacer(t *testing.T) {
|
func TestCustomReplacer(t *testing.T) {
|
||||||
reset()
|
|
||||||
|
|
||||||
var ns = NamingStrategy{
|
var ns = NamingStrategy{
|
||||||
TablePrefix: "public.",
|
TablePrefix: "public.",
|
||||||
SingularTable: true,
|
SingularTable: true,
|
||||||
@ -134,8 +128,6 @@ func TestCustomReplacer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||||
reset()
|
|
||||||
|
|
||||||
var ns = NamingStrategy{
|
var ns = NamingStrategy{
|
||||||
TablePrefix: "public.",
|
TablePrefix: "public.",
|
||||||
SingularTable: true,
|
SingularTable: true,
|
||||||
@ -173,3 +165,44 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
|||||||
t.Errorf("invalid column name generated, got %v", columdName)
|
t.Errorf("invalid column name generated, got %v", columdName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamingStrategySmapInit(t *testing.T) {
|
||||||
|
ncalls := 0
|
||||||
|
var testReplacer = CustomReplacer{
|
||||||
|
func(name string) string {
|
||||||
|
ncalls++
|
||||||
|
return name
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var ns = NamingStrategy{
|
||||||
|
NameReplacer: testReplacer,
|
||||||
|
}
|
||||||
|
|
||||||
|
ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap.
|
||||||
|
if ncalls != 1 {
|
||||||
|
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap.
|
||||||
|
if ncalls != 2 {
|
||||||
|
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now call Init() to create the smap. The next call will be cached.
|
||||||
|
ns.Init()
|
||||||
|
|
||||||
|
ns.IndexName("public.table", "name") // This calls the Replacer: smap not populated yet.
|
||||||
|
if ncalls != 3 {
|
||||||
|
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||||
|
}
|
||||||
|
ns.IndexName("public.table", "name") // This does not call the Replacer. "name" is in the smap.
|
||||||
|
if ncalls != 3 {
|
||||||
|
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
ns.IndexName("public.table", "name2") // This calls the Replacer, because it's a different name.
|
||||||
|
if ncalls != 4 {
|
||||||
|
t.Errorf("replacer function called invalid # of times, got %v", ncalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user