Merge aa0e8e721cbc5f0dcc4521d6a3fc1f7e214d9442 into 6cc2c01268c24fdf7b30cfe4b1a8fb5a10922c65

This commit is contained in:
John Boone 2025-08-04 13:24:03 -04:00 committed by GitHub
commit 256f236244
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 149 additions and 37 deletions

View File

@ -270,7 +270,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:

7
go.mod
View File

@ -5,5 +5,12 @@ go 1.18
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5
github.com/stretchr/testify v1.10.0
golang.org/x/text v0.20.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,6 +1,16 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

37
gorm.go
View File

@ -59,6 +59,8 @@ type Config struct {
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// CaseInsensitiveSchemaFields enabling case insensitivity for schema fields
CaseInsensitiveSchemaFields bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -110,21 +112,22 @@ type DB struct {
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
CaseInsensitiveSchemaFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
}
// Open initialize db session based on dialector
@ -276,6 +279,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.PropagateUnscoped = true
}
if config.CaseInsensitiveSchemaFields {
txConfig.CaseInsensitiveSchemaFields = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx

View File

@ -211,7 +211,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields)
}
if len(columns) == 1 {

View File

@ -398,7 +398,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
}

View File

@ -77,7 +77,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil
}
@ -359,8 +359,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many

View File

@ -44,6 +44,7 @@ type Schema struct {
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
FieldsCaseInsensitive bool
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
@ -79,9 +80,24 @@ func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByDBName {
if strings.EqualFold(key, name) {
return field
}
}
}
if field, ok := schema.FieldsByName[name]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByName {
if strings.EqualFold(key, name) {
return field
}
}
}
return nil
}
@ -102,6 +118,13 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByBindName {
if strings.EqualFold(key, find) {
return field
}
}
}
}
return nil
}
@ -116,11 +139,15 @@ type TablerWithNamer interface {
// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
return ParseWithCaseInsensitivity(dest, cacheStore, namer, false)
}
func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
@ -179,16 +206,17 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
}
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
FieldsCaseInsensitive: caseInsensitive,
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
@ -410,7 +438,7 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
@ -427,5 +455,5 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive)
}

View File

@ -5,6 +5,7 @@ import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
@ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
}
}
func TestLookupField(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
Code string `gorm:"column:product_code"`
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
field := product.LookUpField("ProductID")
assert.NotNil(t, field)
field = product.LookUpField("productid")
assert.Nil(t, field)
field = product.LookUpField("product_code")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.Nil(t, field)
// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpField("productid")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.NotNil(t, field)
}
func TestLookupFieldByBindName(t *testing.T) {
type Product struct {
ID uint `gorm:"primaryKey;autoIncrement"`
}
type Sellable struct {
Name string
Product Product `gorm:"embedded;embeddedPrefix:product_"`
}
product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err)
}
field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.Nil(t, field)
// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.NotNil(t, field)
}

View File

@ -408,7 +408,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
reflectValue = reflectValue.Elem()
}
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
@ -510,7 +510,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]