feat: support case insensitive schema fields

This commit is contained in:
John Boone 2025-05-15 12:51:02 -04:00
parent e5b867e785
commit 7e19567156
10 changed files with 149 additions and 37 deletions

View File

@ -269,7 +269,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
db.assignInterfacesToValue(exprs) db.assignInterfacesToValue(exprs)
} }
default: default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:

7
go.mod
View File

@ -5,5 +5,12 @@ go 1.18
require ( require (
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/stretchr/testify v1.10.0
golang.org/x/text v0.20.0 golang.org/x/text v0.20.0
) )
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,6 +1,16 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

37
gorm.go
View File

@ -57,6 +57,8 @@ type Config struct {
TranslateError bool TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement // PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool PropagateUnscoped bool
// CaseInsensitiveSchemaFields enabling case insensitivity for schema fields
CaseInsensitiveSchemaFields bool
// ClauseBuilders clause builder // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
@ -108,21 +110,22 @@ type DB struct {
// Session session config when create session with Session() method // Session session config when create session with Session() method
type Session struct { type Session struct {
DryRun bool DryRun bool
PrepareStmt bool PrepareStmt bool
NewDB bool NewDB bool
Initialized bool Initialized bool
SkipHooks bool SkipHooks bool
SkipDefaultTransaction bool SkipDefaultTransaction bool
DisableNestedTransaction bool DisableNestedTransaction bool
AllowGlobalUpdate bool AllowGlobalUpdate bool
FullSaveAssociations bool FullSaveAssociations bool
PropagateUnscoped bool PropagateUnscoped bool
QueryFields bool QueryFields bool
Context context.Context CaseInsensitiveSchemaFields bool
Logger logger.Interface Context context.Context
NowFunc func() time.Time Logger logger.Interface
CreateBatchSize int NowFunc func() time.Time
CreateBatchSize int
} }
// Open initialize db session based on dialector // Open initialize db session based on dialector
@ -258,6 +261,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.PropagateUnscoped = true txConfig.PropagateUnscoped = true
} }
if config.CaseInsensitiveSchemaFields {
txConfig.CaseInsensitiveSchemaFields = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks { if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone() tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx tx.Statement.DB = tx

View File

@ -210,7 +210,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if sch != nil { if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields)
} }
if len(columns) == 1 { if len(columns) == 1 {

View File

@ -397,7 +397,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{} cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true) cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil {
schema.err = err schema.err = err
} }

View File

@ -77,7 +77,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = err schema.err = err
return nil return nil
} }
@ -359,8 +359,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`, Tag: `gorm:"-"`,
}) })
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil { schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many

View File

@ -43,6 +43,7 @@ type Schema struct {
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
FieldsCaseInsensitive bool
Relationships Relationships Relationships Relationships
CreateClauses []clause.Interface CreateClauses []clause.Interface
QueryClauses []clause.Interface QueryClauses []clause.Interface
@ -78,9 +79,24 @@ func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok { if field, ok := schema.FieldsByDBName[name]; ok {
return field return field
} }
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByDBName {
if strings.EqualFold(key, name) {
return field
}
}
}
if field, ok := schema.FieldsByName[name]; ok { if field, ok := schema.FieldsByName[name]; ok {
return field return field
} }
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByName {
if strings.EqualFold(key, name) {
return field
}
}
}
return nil return nil
} }
@ -101,6 +117,13 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie
if field, ok := schema.FieldsByBindName[find]; ok { if field, ok := schema.FieldsByBindName[find]; ok {
return field return field
} }
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByBindName {
if strings.EqualFold(key, find) {
return field
}
}
}
} }
return nil return nil
} }
@ -115,11 +138,15 @@ type TablerWithNamer interface {
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "") return ParseWithCaseInsensitivity(dest, cacheStore, namer, false)
}
func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "")
} }
// ParseWithSpecialTableName get data type from dialector with extra schema table // ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) {
if dest == nil { if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
@ -178,16 +205,17 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
ModelType: modelType, ModelType: modelType,
Table: tableName, Table: tableName,
FieldsByName: map[string]*Field{}, FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{}, FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{}, FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}}, FieldsCaseInsensitive: caseInsensitive,
cacheStore: cacheStore, Relationships: Relationships{Relations: map[string]*Relationship{}},
namer: namer, cacheStore: cacheStore,
initialized: make(chan struct{}), namer: namer,
initialized: make(chan struct{}),
} }
// When the schema initialization is completed, the channel will be closed // When the schema initialization is completed, the channel will be closed
defer close(schema.initialized) defer close(schema.initialized)
@ -403,7 +431,7 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
} }
} }
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type() modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem() modelType = modelType.Elem()
@ -420,5 +448,5 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
return v.(*Schema), nil return v.(*Schema), nil
} }
return Parse(dest, cacheStore, namer) return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive)
} }

View File

@ -5,6 +5,7 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/assert"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests" "gorm.io/gorm/utils/tests"
@ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
} }
} }
func TestLookupField(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
Code string `gorm:"column:product_code"`
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
field := product.LookUpField("ProductID")
assert.NotNil(t, field)
field = product.LookUpField("productid")
assert.Nil(t, field)
field = product.LookUpField("product_code")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.Nil(t, field)
// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpField("productid")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.NotNil(t, field)
}
func TestLookupFieldByBindName(t *testing.T) {
type Product struct {
ID uint `gorm:"primaryKey;autoIncrement"`
}
type Sellable struct {
Name string
Product Product `gorm:"embedded;embeddedPrefix:product_"`
}
product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err)
}
field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.Nil(t, field)
// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.NotNil(t, field)
}

View File

@ -388,7 +388,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
reflectValue = reflectValue.Elem() reflectValue = reflectValue.Elem()
} }
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil {
selectedColumns := map[string]bool{} selectedColumns := map[string]bool{}
if idx == 0 { if idx == 0 {
for _, v := range args[1:] { for _, v := range args[1:] {
@ -490,7 +490,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
} }
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1] stmt.Table = tables[1]