feat: support case insensitive schema fields

This commit is contained in:
John Boone 2025-05-15 12:51:02 -04:00
parent e5b867e785
commit 7e19567156
10 changed files with 149 additions and 37 deletions

View File

@ -269,7 +269,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:

7
go.mod
View File

@ -5,5 +5,12 @@ go 1.18
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5
github.com/stretchr/testify v1.10.0
golang.org/x/text v0.20.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,6 +1,16 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -57,6 +57,8 @@ type Config struct {
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// CaseInsensitiveSchemaFields enabling case insensitivity for schema fields
CaseInsensitiveSchemaFields bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -119,6 +121,7 @@ type Session struct {
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
CaseInsensitiveSchemaFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
@ -258,6 +261,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.PropagateUnscoped = true
}
if config.CaseInsensitiveSchemaFields {
txConfig.CaseInsensitiveSchemaFields = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx

View File

@ -210,7 +210,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields)
}
if len(columns) == 1 {

View File

@ -397,7 +397,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
}

View File

@ -77,7 +77,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
return nil
}
@ -359,8 +359,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many

View File

@ -43,6 +43,7 @@ type Schema struct {
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
FieldsCaseInsensitive bool
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
@ -78,9 +79,24 @@ func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByDBName {
if strings.EqualFold(key, name) {
return field
}
}
}
if field, ok := schema.FieldsByName[name]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByName {
if strings.EqualFold(key, name) {
return field
}
}
}
return nil
}
@ -101,6 +117,13 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByBindName {
if strings.EqualFold(key, find) {
return field
}
}
}
}
return nil
}
@ -115,11 +138,15 @@ type TablerWithNamer interface {
// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
return ParseWithCaseInsensitivity(dest, cacheStore, namer, false)
}
func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
@ -184,6 +211,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
FieldsCaseInsensitive: caseInsensitive,
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
@ -403,7 +431,7 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
@ -420,5 +448,5 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive)
}

View File

@ -5,6 +5,7 @@ import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
@ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
}
}
func TestLookupField(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
Code string `gorm:"column:product_code"`
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
field := product.LookUpField("ProductID")
assert.NotNil(t, field)
field = product.LookUpField("productid")
assert.Nil(t, field)
field = product.LookUpField("product_code")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.Nil(t, field)
// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpField("productid")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.NotNil(t, field)
}
func TestLookupFieldByBindName(t *testing.T) {
type Product struct {
ID uint `gorm:"primaryKey;autoIncrement"`
}
type Sellable struct {
Name string
Product Product `gorm:"embedded;embeddedPrefix:product_"`
}
product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err)
}
field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.Nil(t, field)
// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.NotNil(t, field)
}

View File

@ -388,7 +388,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
reflectValue = reflectValue.Elem()
}
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
@ -490,7 +490,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]