diff --git a/schema/naming.go b/schema/naming.go index e6fb81b2..6248bde8 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -8,6 +8,8 @@ import ( "unicode/utf8" "github.com/jinzhu/inflection" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) // Namer namer interface @@ -121,7 +123,7 @@ var ( func init() { commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism)) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } @@ -186,9 +188,9 @@ func (ns NamingStrategy) toDBName(name string) string { } func (ns NamingStrategy) toSchemaName(name string) string { - result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") + result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { - result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } return result } diff --git a/schema/relationship.go b/schema/relationship.go index c11918a5..bdce5812 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/jinzhu/inflection" + "golang.org/x/text/cases" + "golang.org/x/text/language" "gorm.io/gorm/clause" ) @@ -301,9 +303,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, ownField := range ownForeignFields { - joinFieldName := strings.Title(schema.Name) + ownField.Name + joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name if len(joinForeignKeys) > idx { - joinFieldName = strings.Title(joinForeignKeys[idx]) + joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx]) } ownFieldsMap[joinFieldName] = ownField @@ -318,7 +320,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, relField := range refForeignFields { - joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name + joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -329,7 +331,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) + joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx]) } referFieldsMap[joinFieldName] = relField @@ -347,7 +349,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } joinTableFields = append(joinTableFields, reflect.StructField{ - Name: strings.Title(schema.Name) + field.Name, + Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name, Type: schema.ModelType, Tag: `gorm:"-"`, })