Merge aa0e8e721cbc5f0dcc4521d6a3fc1f7e214d9442 into 6cc2c01268c24fdf7b30cfe4b1a8fb5a10922c65
This commit is contained in:
commit
256f236244
@ -270,7 +270,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
db.assignInterfacesToValue(exprs)
|
||||
}
|
||||
default:
|
||||
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
|
||||
if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
|
7
go.mod
7
go.mod
@ -5,5 +5,12 @@ go 1.18
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/text v0.20.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -1,6 +1,16 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
37
gorm.go
37
gorm.go
@ -59,6 +59,8 @@ type Config struct {
|
||||
TranslateError bool
|
||||
// PropagateUnscoped propagate Unscoped to every other nested statement
|
||||
PropagateUnscoped bool
|
||||
// CaseInsensitiveSchemaFields enabling case insensitivity for schema fields
|
||||
CaseInsensitiveSchemaFields bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
@ -110,21 +112,22 @@ type DB struct {
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
NewDB bool
|
||||
Initialized bool
|
||||
SkipHooks bool
|
||||
SkipDefaultTransaction bool
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
PropagateUnscoped bool
|
||||
QueryFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
CreateBatchSize int
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
NewDB bool
|
||||
Initialized bool
|
||||
SkipHooks bool
|
||||
SkipDefaultTransaction bool
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
PropagateUnscoped bool
|
||||
QueryFields bool
|
||||
CaseInsensitiveSchemaFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
CreateBatchSize int
|
||||
}
|
||||
|
||||
// Open initialize db session based on dialector
|
||||
@ -276,6 +279,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||
txConfig.PropagateUnscoped = true
|
||||
}
|
||||
|
||||
if config.CaseInsensitiveSchemaFields {
|
||||
txConfig.CaseInsensitiveSchemaFields = true
|
||||
}
|
||||
|
||||
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
|
2
scan.go
2
scan.go
@ -211,7 +211,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
|
||||
if sch != nil {
|
||||
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields)
|
||||
}
|
||||
|
||||
if len(columns) == 1 {
|
||||
|
@ -398,7 +398,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
|
||||
cacheStore := &sync.Map{}
|
||||
cacheStore.Store(embeddedCacheKey, true)
|
||||
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
|
||||
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
|
||||
cacheStore := schema.cacheStore
|
||||
|
||||
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
|
||||
return nil
|
||||
}
|
||||
@ -359,8 +359,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer, schema.FieldsCaseInsensitive); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
|
@ -44,6 +44,7 @@ type Schema struct {
|
||||
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
FieldsCaseInsensitive bool
|
||||
Relationships Relationships
|
||||
CreateClauses []clause.Interface
|
||||
QueryClauses []clause.Interface
|
||||
@ -79,9 +80,24 @@ func (schema Schema) LookUpField(name string) *Field {
|
||||
if field, ok := schema.FieldsByDBName[name]; ok {
|
||||
return field
|
||||
}
|
||||
if schema.FieldsCaseInsensitive {
|
||||
for key, field := range schema.FieldsByDBName {
|
||||
if strings.EqualFold(key, name) {
|
||||
return field
|
||||
}
|
||||
}
|
||||
}
|
||||
if field, ok := schema.FieldsByName[name]; ok {
|
||||
return field
|
||||
}
|
||||
if schema.FieldsCaseInsensitive {
|
||||
for key, field := range schema.FieldsByName {
|
||||
if strings.EqualFold(key, name) {
|
||||
return field
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -102,6 +118,13 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie
|
||||
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||
return field
|
||||
}
|
||||
if schema.FieldsCaseInsensitive {
|
||||
for key, field := range schema.FieldsByBindName {
|
||||
if strings.EqualFold(key, find) {
|
||||
return field
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -116,11 +139,15 @@ type TablerWithNamer interface {
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
return ParseWithCaseInsensitivity(dest, cacheStore, namer, false)
|
||||
}
|
||||
|
||||
func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "")
|
||||
}
|
||||
|
||||
// ParseWithSpecialTableName get data type from dialector with extra schema table
|
||||
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
|
||||
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) {
|
||||
if dest == nil {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
@ -179,16 +206,17 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
FieldsCaseInsensitive: caseInsensitive,
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
// When the schema initialization is completed, the channel will be closed
|
||||
defer close(schema.initialized)
|
||||
@ -410,7 +438,7 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
|
||||
}
|
||||
}
|
||||
|
||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
@ -427,5 +455,5 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
|
||||
return v.(*Schema), nil
|
||||
}
|
||||
|
||||
return Parse(dest, cacheStore, namer)
|
||||
return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
@ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
|
||||
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupField(t *testing.T) {
|
||||
type Product struct {
|
||||
ProductID uint `gorm:"primaryKey;autoIncrement"`
|
||||
Code string `gorm:"column:product_code"`
|
||||
Name string
|
||||
}
|
||||
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
|
||||
}
|
||||
field := product.LookUpField("ProductID")
|
||||
assert.NotNil(t, field)
|
||||
field = product.LookUpField("productid")
|
||||
assert.Nil(t, field)
|
||||
field = product.LookUpField("product_code")
|
||||
assert.NotNil(t, field)
|
||||
field = product.LookUpField("PRODUCT_CODE")
|
||||
assert.Nil(t, field)
|
||||
|
||||
// Check case insensitivity
|
||||
product.FieldsCaseInsensitive = true
|
||||
field = product.LookUpField("productid")
|
||||
assert.NotNil(t, field)
|
||||
field = product.LookUpField("PRODUCT_CODE")
|
||||
assert.NotNil(t, field)
|
||||
}
|
||||
|
||||
func TestLookupFieldByBindName(t *testing.T) {
|
||||
type Product struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
}
|
||||
type Sellable struct {
|
||||
Name string
|
||||
Product Product `gorm:"embedded;embeddedPrefix:product_"`
|
||||
}
|
||||
|
||||
product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err)
|
||||
}
|
||||
field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID")
|
||||
assert.NotNil(t, field)
|
||||
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
|
||||
assert.Nil(t, field)
|
||||
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
|
||||
assert.Nil(t, field)
|
||||
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
|
||||
assert.Nil(t, field)
|
||||
|
||||
// Check case insensitivity
|
||||
product.FieldsCaseInsensitive = true
|
||||
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
|
||||
assert.NotNil(t, field)
|
||||
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
|
||||
assert.NotNil(t, field)
|
||||
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
|
||||
assert.NotNil(t, field)
|
||||
}
|
||||
|
@ -408,7 +408,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil {
|
||||
selectedColumns := map[string]bool{}
|
||||
if idx == 0 {
|
||||
for _, v := range args[1:] {
|
||||
@ -510,7 +510,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
}
|
||||
|
||||
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
|
||||
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
|
||||
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" {
|
||||
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
||||
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
||||
stmt.Table = tables[1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user