From 7e195671565f74b2558a45477dc56f32c4961f89 Mon Sep 17 00:00:00 2001 From: John Boone Date: Thu, 15 May 2025 12:51:02 -0400 Subject: [PATCH] feat: support case insensitive schema fields --- finisher_api.go | 2 +- go.mod | 7 +++++ go.sum | 10 +++++++ gorm.go | 37 +++++++++++++++----------- scan.go | 2 +- schema/field.go | 2 +- schema/relationship.go | 6 ++--- schema/schema.go | 56 +++++++++++++++++++++++++++++---------- schema/schema_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++ statement.go | 4 +-- 10 files changed, 149 insertions(+), 37 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 6802945c..bd04e6af 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -269,7 +269,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { db.assignInterfacesToValue(exprs) } default: - if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { + if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Struct: diff --git a/go.mod b/go.mod index 3060fc8f..aa679fd3 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,12 @@ go 1.18 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.5 + github.com/stretchr/testify v1.10.0 golang.org/x/text v0.20.0 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 9af11572..3eeeaad0 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,16 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gorm.go b/gorm.go index 63a28b37..f4a18cb4 100644 --- a/gorm.go +++ b/gorm.go @@ -57,6 +57,8 @@ type Config struct { TranslateError bool // PropagateUnscoped propagate Unscoped to every other nested statement PropagateUnscoped bool + // CaseInsensitiveSchemaFields enabling case insensitivity for schema fields + CaseInsensitiveSchemaFields bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -108,21 +110,22 @@ type DB struct { // Session session config when create session with Session() method type Session struct { - DryRun bool - PrepareStmt bool - NewDB bool - Initialized bool - SkipHooks bool - SkipDefaultTransaction bool - DisableNestedTransaction bool - AllowGlobalUpdate bool - FullSaveAssociations bool - PropagateUnscoped bool - QueryFields bool - Context context.Context - Logger logger.Interface - NowFunc func() time.Time - CreateBatchSize int + DryRun bool + PrepareStmt bool + NewDB bool + Initialized bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + PropagateUnscoped bool + QueryFields bool + CaseInsensitiveSchemaFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int } // Open initialize db session based on dialector @@ -258,6 +261,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.PropagateUnscoped = true } + if config.CaseInsensitiveSchemaFields { + txConfig.CaseInsensitiveSchemaFields = true + } + if config.Context != nil || config.PrepareStmt || config.SkipHooks { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/scan.go b/scan.go index 6dc55f62..4e58c971 100644 --- a/scan.go +++ b/scan.go @@ -210,7 +210,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if sch != nil { if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { - sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields) } if len(columns) == 1 { diff --git a/schema/field.go b/schema/field.go index d1a633ce..697a3005 100644 --- a/schema/field.go +++ b/schema/field.go @@ -397,7 +397,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil { schema.err = err } diff --git a/schema/relationship.go b/schema/relationship.go index def4a595..78eb8d08 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -77,7 +77,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { cacheStore := schema.cacheStore - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil { schema.err = err return nil } @@ -359,8 +359,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, - schema.namer); err != nil { + if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, + schema.namer, schema.FieldsCaseInsensitive); err != nil { schema.err = err } relation.JoinTable.Name = many2many diff --git a/schema/schema.go b/schema/schema.go index db236797..6387cd97 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -43,6 +43,7 @@ type Schema struct { FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database + FieldsCaseInsensitive bool Relationships Relationships CreateClauses []clause.Interface QueryClauses []clause.Interface @@ -78,9 +79,24 @@ func (schema Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } + if schema.FieldsCaseInsensitive { + for key, field := range schema.FieldsByDBName { + if strings.EqualFold(key, name) { + return field + } + } + } if field, ok := schema.FieldsByName[name]; ok { return field } + if schema.FieldsCaseInsensitive { + for key, field := range schema.FieldsByName { + if strings.EqualFold(key, name) { + return field + } + } + } + return nil } @@ -101,6 +117,13 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie if field, ok := schema.FieldsByBindName[find]; ok { return field } + if schema.FieldsCaseInsensitive { + for key, field := range schema.FieldsByBindName { + if strings.EqualFold(key, find) { + return field + } + } + } } return nil } @@ -115,11 +138,15 @@ type TablerWithNamer interface { // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { - return ParseWithSpecialTableName(dest, cacheStore, namer, "") + return ParseWithCaseInsensitivity(dest, cacheStore, namer, false) +} + +func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) { + return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "") } // ParseWithSpecialTableName get data type from dialector with extra schema table -func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -178,16 +205,17 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema := &Schema{ - Name: modelType.Name(), - ModelType: modelType, - Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, - Relationships: Relationships{Relations: map[string]*Relationship{}}, - cacheStore: cacheStore, - namer: namer, - initialized: make(chan struct{}), + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + FieldsCaseInsensitive: caseInsensitive, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), } // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) @@ -403,7 +431,7 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect } } -func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() @@ -420,5 +448,5 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e return v.(*Schema), nil } - return Parse(dest, cacheStore, namer) + return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive) } diff --git a/schema/schema_test.go b/schema/schema_test.go index a7115f60..55763331 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -5,6 +5,7 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" @@ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") } } + +func TestLookupField(t *testing.T) { + type Product struct { + ProductID uint `gorm:"primaryKey;autoIncrement"` + Code string `gorm:"column:product_code"` + Name string + } + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + field := product.LookUpField("ProductID") + assert.NotNil(t, field) + field = product.LookUpField("productid") + assert.Nil(t, field) + field = product.LookUpField("product_code") + assert.NotNil(t, field) + field = product.LookUpField("PRODUCT_CODE") + assert.Nil(t, field) + + // Check case insensitivity + product.FieldsCaseInsensitive = true + field = product.LookUpField("productid") + assert.NotNil(t, field) + field = product.LookUpField("PRODUCT_CODE") + assert.NotNil(t, field) +} + +func TestLookupFieldByBindName(t *testing.T) { + type Product struct { + ID uint `gorm:"primaryKey;autoIncrement"` + } + type Sellable struct { + Name string + Product Product `gorm:"embedded;embeddedPrefix:product_"` + } + + product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err) + } + field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID") + assert.NotNil(t, field) + field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id") + assert.Nil(t, field) + field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id") + assert.Nil(t, field) + field = product.LookUpFieldByBindName([]string{"product", "id"}, "id") + assert.Nil(t, field) + + // Check case insensitivity + product.FieldsCaseInsensitive = true + field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id") + assert.NotNil(t, field) + field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id") + assert.NotNil(t, field) + field = product.LookUpFieldByBindName([]string{"product", "id"}, "id") + assert.NotNil(t, field) +} diff --git a/statement.go b/statement.go index 39e05d09..e393adbb 100644 --- a/statement.go +++ b/statement.go @@ -388,7 +388,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] reflectValue = reflectValue.Elem() } - if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil { selectedColumns := map[string]bool{} if idx == 0 { for _, v := range args[1:] { @@ -490,7 +490,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { } func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { - if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1]