Move sync.Map from global variable into member of NamingStrategy.
This maintains backward compatibility by making the smap optional - the NamingStrategy still works if it is nil. gorm.Open activates it by calling Init() if the given Namer is a schema.NamingStrategy. Also, this changes the key stored in the smap to be the original name, instead of the replaced name.
This commit is contained in:
		
							parent
							
								
									96fe56218b
								
							
						
					
					
						commit
						4d33b2f584
					
				
							
								
								
									
										6
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								gorm.go
									
									
									
									
									
								
							@ -92,6 +92,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
 | 
			
		||||
		config.NamingStrategy = schema.NamingStrategy{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Special case: initialize the smap if given Namer is a schema.NamingStrategy.
 | 
			
		||||
	if v, ok := config.NamingStrategy.(schema.NamingStrategy); ok {
 | 
			
		||||
		v.Init()
 | 
			
		||||
		config.NamingStrategy = v
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.Logger == nil {
 | 
			
		||||
		config.Logger = logger.Default
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -31,6 +31,31 @@ type NamingStrategy struct {
 | 
			
		||||
	SingularTable bool
 | 
			
		||||
	NameReplacer  Replacer
 | 
			
		||||
	NoLowerCase   bool
 | 
			
		||||
	smap          *safeSyncMap // Optional: gorm.Open initializes this by calling Init().
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// safeSyncMap is a sync.Map that allows Load and Store to be called with a nil receiver.
 | 
			
		||||
type safeSyncMap sync.Map
 | 
			
		||||
 | 
			
		||||
// Load implements a nil-safe call to sync.Map's Load.
 | 
			
		||||
func (smap *safeSyncMap) Load(name string) (interface{}, bool) {
 | 
			
		||||
	if smap == nil {
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
	return (*sync.Map)(smap).Load(name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Store implements a nil-safe call to sync.Map's Store.
 | 
			
		||||
func (smap *safeSyncMap) Store(name string, value interface{}) {
 | 
			
		||||
	if smap == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	(*sync.Map)(smap).Store(name, value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Init initializes a NamingStrategy instance smap ptr.
 | 
			
		||||
func (ns *NamingStrategy) Init() {
 | 
			
		||||
	ns.smap = &safeSyncMap{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TableName convert string to table name
 | 
			
		||||
@ -87,7 +112,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	smap sync.Map
 | 
			
		||||
	// https://github.com/golang/lint/blob/master/lint.go#L770
 | 
			
		||||
	commonInitialisms         = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
 | 
			
		||||
	commonInitialismsReplacer *strings.Replacer
 | 
			
		||||
@ -101,24 +125,20 @@ func init() {
 | 
			
		||||
	commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// reset should be called before each unit test. It clears out the cached names from smap. TODO: make smap part of NamingStrategy instead of a global singleton.
 | 
			
		||||
func reset() {
 | 
			
		||||
	smap = sync.Map{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ns NamingStrategy) toDBName(name string) string {
 | 
			
		||||
	if name == "" {
 | 
			
		||||
		return ""
 | 
			
		||||
	} else if v, ok := smap.Load(name); ok {
 | 
			
		||||
	} else if v, ok := ns.smap.Load(name); ok {
 | 
			
		||||
		return v.(string)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	origName := name
 | 
			
		||||
	if ns.NameReplacer != nil {
 | 
			
		||||
		name = ns.NameReplacer.Replace(name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ns.NoLowerCase {
 | 
			
		||||
		smap.Store(name, name) // TODO: should store with original name, not replaced name
 | 
			
		||||
		ns.smap.Store(origName, name)
 | 
			
		||||
		return name
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -159,6 +179,6 @@ func (ns NamingStrategy) toDBName(name string) string {
 | 
			
		||||
		buf.WriteByte(value[len(value)-1])
 | 
			
		||||
	}
 | 
			
		||||
	ret := buf.String()
 | 
			
		||||
	smap.Store(name, ret) // TODO: should store with original name, not replaced name
 | 
			
		||||
	ns.smap.Store(origName, ret)
 | 
			
		||||
	return ret
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,6 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestToDBName(t *testing.T) {
 | 
			
		||||
	reset()
 | 
			
		||||
 | 
			
		||||
	var maps = map[string]string{
 | 
			
		||||
		"":                          "",
 | 
			
		||||
		"x":                         "x",
 | 
			
		||||
@ -38,8 +36,6 @@ func TestToDBName(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNamingStrategy(t *testing.T) {
 | 
			
		||||
	reset()
 | 
			
		||||
 | 
			
		||||
	var ns = NamingStrategy{
 | 
			
		||||
		TablePrefix:   "public.",
 | 
			
		||||
		SingularTable: true,
 | 
			
		||||
@ -93,8 +89,6 @@ var testReplacer = CustomReplacer{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomReplacer(t *testing.T) {
 | 
			
		||||
	reset()
 | 
			
		||||
 | 
			
		||||
	var ns = NamingStrategy{
 | 
			
		||||
		TablePrefix:   "public.",
 | 
			
		||||
		SingularTable: true,
 | 
			
		||||
@ -134,8 +128,6 @@ func TestCustomReplacer(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
 | 
			
		||||
	reset()
 | 
			
		||||
 | 
			
		||||
	var ns = NamingStrategy{
 | 
			
		||||
		TablePrefix:   "public.",
 | 
			
		||||
		SingularTable: true,
 | 
			
		||||
@ -173,3 +165,44 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
 | 
			
		||||
		t.Errorf("invalid column name generated, got %v", columdName)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNamingStrategySmapInit(t *testing.T) {
 | 
			
		||||
	ncalls := 0
 | 
			
		||||
	var testReplacer = CustomReplacer{
 | 
			
		||||
		func(name string) string {
 | 
			
		||||
			ncalls++
 | 
			
		||||
			return name
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ns = NamingStrategy{
 | 
			
		||||
		NameReplacer: testReplacer,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap.
 | 
			
		||||
	if ncalls != 1 {
 | 
			
		||||
		t.Errorf("replacer function called invalid # of times, got %v", ncalls)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ns.IndexName("public.table", "name") // This calls the Replacer: there is no smap.
 | 
			
		||||
	if ncalls != 2 {
 | 
			
		||||
		t.Errorf("replacer function called invalid # of times, got %v", ncalls)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Now call Init() to create the smap. The next call will be cached.
 | 
			
		||||
	ns.Init()
 | 
			
		||||
 | 
			
		||||
	ns.IndexName("public.table", "name") // This calls the Replacer: smap not populated yet.
 | 
			
		||||
	if ncalls != 3 {
 | 
			
		||||
		t.Errorf("replacer function called invalid # of times, got %v", ncalls)
 | 
			
		||||
	}
 | 
			
		||||
	ns.IndexName("public.table", "name") // This does not call the Replacer. "name" is in the smap.
 | 
			
		||||
	if ncalls != 3 {
 | 
			
		||||
		t.Errorf("replacer function called invalid # of times, got %v", ncalls)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ns.IndexName("public.table", "name2") // This calls the Replacer, because it's a different name.
 | 
			
		||||
	if ncalls != 4 {
 | 
			
		||||
		t.Errorf("replacer function called invalid # of times, got %v", ncalls)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user