gorm/schema/naming.go
Michael Nussbaum 45ef1da7e4
Fix naming longer then 64 chars with dots in table (#5045)
Ensures that foreign key relationships and indexes are given
syntactically valid names when their name length exceeds 64 characters
and they contained dot characters within the name. This is most often
relevant when a Postgres table name is fully qualified by including its schema
as part of its name
2022-02-24 10:10:20 +08:00

182 lines
4.8 KiB
Go

package schema
import (
"crypto/sha1"
"encoding/hex"
"regexp"
"strings"
"unicode/utf8"
"github.com/jinzhu/inflection"
)
// Namer namer interface
type Namer interface {
TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string
JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
type Replacer interface {
Replace(name string) string
}
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
}
// TableName convert string to table name
func (ns NamingStrategy) TableName(str string) string {
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string {
return ns.toDBName(column)
}
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
if !ns.NoLowerCase && strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return ns.formatName("chk", table, column)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.Replace(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_", -1)
if utf8.RuneCountInString(formattedName) > 64 {
h := sha1.New()
h.Write([]byte(formattedName))
bs := h.Sum(nil)
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
}
return formattedName
}
var (
// https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer
)
func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
func (ns NamingStrategy) toDBName(name string) string {
if name == "" {
return ""
}
if ns.NameReplacer != nil {
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
return name
}
var (
value = commonInitialismsReplacer.Replace(name)
buf strings.Builder
lastCase, nextCase, nextNumber bool // upper case == true
curCase = value[0] <= 'Z' && value[0] >= 'A'
)
for i, v := range value[:len(value)-1] {
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
if curCase {
if lastCase && (nextCase || nextNumber) {
buf.WriteRune(v + 32)
} else {
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
buf.WriteByte('_')
}
buf.WriteRune(v + 32)
}
} else {
buf.WriteRune(v)
}
lastCase = curCase
curCase = nextCase
}
if curCase {
if !lastCase && len(value) > 1 {
buf.WriteByte('_')
}
buf.WriteByte(value[len(value)-1] + 32)
} else {
buf.WriteByte(value[len(value)-1])
}
ret := buf.String()
return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
}