Merge aa0e8e721cbc5f0dcc4521d6a3fc1f7e214d9442 into 6cc2c01268c24fdf7b30cfe4b1a8fb5a10922c65
This commit is contained in:
		
						commit
						256f236244
					
				| @ -270,7 +270,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { | ||||
| 				db.assignInterfacesToValue(exprs) | ||||
| 			} | ||||
| 		default: | ||||
| 			if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { | ||||
| 			if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil { | ||||
| 				reflectValue := reflect.Indirect(reflect.ValueOf(value)) | ||||
| 				switch reflectValue.Kind() { | ||||
| 				case reflect.Struct: | ||||
|  | ||||
							
								
								
									
										7
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								go.mod
									
									
									
									
									
								
							| @ -5,5 +5,12 @@ go 1.18 | ||||
| require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
| 	github.com/jinzhu/now v1.1.5 | ||||
| 	github.com/stretchr/testify v1.10.0 | ||||
| 	golang.org/x/text v0.20.0 | ||||
| ) | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|  | ||||
							
								
								
									
										10
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.sum
									
									
									
									
									
								
							| @ -1,6 +1,16 @@ | ||||
| github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||||
| github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||
| github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= | ||||
| github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= | ||||
| golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= | ||||
| golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
| gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
|  | ||||
							
								
								
									
										37
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								gorm.go
									
									
									
									
									
								
							| @ -59,6 +59,8 @@ type Config struct { | ||||
| 	TranslateError bool | ||||
| 	// PropagateUnscoped propagate Unscoped to every other nested statement
 | ||||
| 	PropagateUnscoped bool | ||||
| 	// CaseInsensitiveSchemaFields enabling case insensitivity for schema fields
 | ||||
| 	CaseInsensitiveSchemaFields bool | ||||
| 
 | ||||
| 	// ClauseBuilders clause builder
 | ||||
| 	ClauseBuilders map[string]clause.ClauseBuilder | ||||
| @ -110,21 +112,22 @@ type DB struct { | ||||
| 
 | ||||
| // Session session config when create session with Session() method
 | ||||
| type Session struct { | ||||
| 	DryRun                   bool | ||||
| 	PrepareStmt              bool | ||||
| 	NewDB                    bool | ||||
| 	Initialized              bool | ||||
| 	SkipHooks                bool | ||||
| 	SkipDefaultTransaction   bool | ||||
| 	DisableNestedTransaction bool | ||||
| 	AllowGlobalUpdate        bool | ||||
| 	FullSaveAssociations     bool | ||||
| 	PropagateUnscoped        bool | ||||
| 	QueryFields              bool | ||||
| 	Context                  context.Context | ||||
| 	Logger                   logger.Interface | ||||
| 	NowFunc                  func() time.Time | ||||
| 	CreateBatchSize          int | ||||
| 	DryRun                      bool | ||||
| 	PrepareStmt                 bool | ||||
| 	NewDB                       bool | ||||
| 	Initialized                 bool | ||||
| 	SkipHooks                   bool | ||||
| 	SkipDefaultTransaction      bool | ||||
| 	DisableNestedTransaction    bool | ||||
| 	AllowGlobalUpdate           bool | ||||
| 	FullSaveAssociations        bool | ||||
| 	PropagateUnscoped           bool | ||||
| 	QueryFields                 bool | ||||
| 	CaseInsensitiveSchemaFields bool | ||||
| 	Context                     context.Context | ||||
| 	Logger                      logger.Interface | ||||
| 	NowFunc                     func() time.Time | ||||
| 	CreateBatchSize             int | ||||
| } | ||||
| 
 | ||||
| // Open initialize db session based on dialector
 | ||||
| @ -276,6 +279,10 @@ func (db *DB) Session(config *Session) *DB { | ||||
| 		txConfig.PropagateUnscoped = true | ||||
| 	} | ||||
| 
 | ||||
| 	if config.CaseInsensitiveSchemaFields { | ||||
| 		txConfig.CaseInsensitiveSchemaFields = true | ||||
| 	} | ||||
| 
 | ||||
| 	if config.Context != nil || config.PrepareStmt || config.SkipHooks { | ||||
| 		tx.Statement = tx.Statement.clone() | ||||
| 		tx.Statement.DB = tx | ||||
|  | ||||
							
								
								
									
										2
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								scan.go
									
									
									
									
									
								
							| @ -211,7 +211,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 
 | ||||
| 		if sch != nil { | ||||
| 			if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { | ||||
| 				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) | ||||
| 				sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields) | ||||
| 			} | ||||
| 
 | ||||
| 			if len(columns) == 1 { | ||||
|  | ||||
| @ -398,7 +398,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 
 | ||||
| 			cacheStore := &sync.Map{} | ||||
| 			cacheStore.Store(embeddedCacheKey, true) | ||||
| 			if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { | ||||
| 			if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil { | ||||
| 				schema.err = err | ||||
| 			} | ||||
| 
 | ||||
|  | ||||
| @ -77,7 +77,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { | ||||
| 
 | ||||
| 	cacheStore := schema.cacheStore | ||||
| 
 | ||||
| 	if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { | ||||
| 	if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil { | ||||
| 		schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) | ||||
| 		return nil | ||||
| 	} | ||||
| @ -359,8 +359,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | ||||
| 		Tag:  `gorm:"-"`, | ||||
| 	}) | ||||
| 
 | ||||
| 	if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, | ||||
| 		schema.namer); err != nil { | ||||
| 	if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, | ||||
| 		schema.namer, schema.FieldsCaseInsensitive); err != nil { | ||||
| 		schema.err = err | ||||
| 	} | ||||
| 	relation.JoinTable.Name = many2many | ||||
|  | ||||
| @ -44,6 +44,7 @@ type Schema struct { | ||||
| 	FieldsByBindName          map[string]*Field // embedded fields is 'Embed.Field'
 | ||||
| 	FieldsByDBName            map[string]*Field | ||||
| 	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
 | ||||
| 	FieldsCaseInsensitive     bool | ||||
| 	Relationships             Relationships | ||||
| 	CreateClauses             []clause.Interface | ||||
| 	QueryClauses              []clause.Interface | ||||
| @ -79,9 +80,24 @@ func (schema Schema) LookUpField(name string) *Field { | ||||
| 	if field, ok := schema.FieldsByDBName[name]; ok { | ||||
| 		return field | ||||
| 	} | ||||
| 	if schema.FieldsCaseInsensitive { | ||||
| 		for key, field := range schema.FieldsByDBName { | ||||
| 			if strings.EqualFold(key, name) { | ||||
| 				return field | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if field, ok := schema.FieldsByName[name]; ok { | ||||
| 		return field | ||||
| 	} | ||||
| 	if schema.FieldsCaseInsensitive { | ||||
| 		for key, field := range schema.FieldsByName { | ||||
| 			if strings.EqualFold(key, name) { | ||||
| 				return field | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| @ -102,6 +118,13 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie | ||||
| 		if field, ok := schema.FieldsByBindName[find]; ok { | ||||
| 			return field | ||||
| 		} | ||||
| 		if schema.FieldsCaseInsensitive { | ||||
| 			for key, field := range schema.FieldsByBindName { | ||||
| 				if strings.EqualFold(key, find) { | ||||
| 					return field | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @ -116,11 +139,15 @@ type TablerWithNamer interface { | ||||
| 
 | ||||
| // Parse get data type from dialector
 | ||||
| func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { | ||||
| 	return ParseWithSpecialTableName(dest, cacheStore, namer, "") | ||||
| 	return ParseWithCaseInsensitivity(dest, cacheStore, namer, false) | ||||
| } | ||||
| 
 | ||||
| func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) { | ||||
| 	return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "") | ||||
| } | ||||
| 
 | ||||
| // ParseWithSpecialTableName get data type from dialector with extra schema table
 | ||||
| func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { | ||||
| func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) { | ||||
| 	if dest == nil { | ||||
| 		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||
| 	} | ||||
| @ -179,16 +206,17 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | ||||
| 	} | ||||
| 
 | ||||
| 	schema := &Schema{ | ||||
| 		Name:             modelType.Name(), | ||||
| 		ModelType:        modelType, | ||||
| 		Table:            tableName, | ||||
| 		FieldsByName:     map[string]*Field{}, | ||||
| 		FieldsByBindName: map[string]*Field{}, | ||||
| 		FieldsByDBName:   map[string]*Field{}, | ||||
| 		Relationships:    Relationships{Relations: map[string]*Relationship{}}, | ||||
| 		cacheStore:       cacheStore, | ||||
| 		namer:            namer, | ||||
| 		initialized:      make(chan struct{}), | ||||
| 		Name:                  modelType.Name(), | ||||
| 		ModelType:             modelType, | ||||
| 		Table:                 tableName, | ||||
| 		FieldsByName:          map[string]*Field{}, | ||||
| 		FieldsByBindName:      map[string]*Field{}, | ||||
| 		FieldsByDBName:        map[string]*Field{}, | ||||
| 		FieldsCaseInsensitive: caseInsensitive, | ||||
| 		Relationships:         Relationships{Relations: map[string]*Relationship{}}, | ||||
| 		cacheStore:            cacheStore, | ||||
| 		namer:                 namer, | ||||
| 		initialized:           make(chan struct{}), | ||||
| 	} | ||||
| 	// When the schema initialization is completed, the channel will be closed
 | ||||
| 	defer close(schema.initialized) | ||||
| @ -410,7 +438,7 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { | ||||
| func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) { | ||||
| 	modelType := reflect.ValueOf(dest).Type() | ||||
| 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { | ||||
| 		modelType = modelType.Elem() | ||||
| @ -427,5 +455,5 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e | ||||
| 		return v.(*Schema), nil | ||||
| 	} | ||||
| 
 | ||||
| 	return Parse(dest, cacheStore, namer) | ||||
| 	return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive) | ||||
| } | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils/tests" | ||||
| @ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { | ||||
| 		t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestLookupField(t *testing.T) { | ||||
| 	type Product struct { | ||||
| 		ProductID uint   `gorm:"primaryKey;autoIncrement"` | ||||
| 		Code      string `gorm:"column:product_code"` | ||||
| 		Name      string | ||||
| 	} | ||||
| 	product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) | ||||
| 	} | ||||
| 	field := product.LookUpField("ProductID") | ||||
| 	assert.NotNil(t, field) | ||||
| 	field = product.LookUpField("productid") | ||||
| 	assert.Nil(t, field) | ||||
| 	field = product.LookUpField("product_code") | ||||
| 	assert.NotNil(t, field) | ||||
| 	field = product.LookUpField("PRODUCT_CODE") | ||||
| 	assert.Nil(t, field) | ||||
| 
 | ||||
| 	// Check case insensitivity
 | ||||
| 	product.FieldsCaseInsensitive = true | ||||
| 	field = product.LookUpField("productid") | ||||
| 	assert.NotNil(t, field) | ||||
| 	field = product.LookUpField("PRODUCT_CODE") | ||||
| 	assert.NotNil(t, field) | ||||
| } | ||||
| 
 | ||||
| func TestLookupFieldByBindName(t *testing.T) { | ||||
| 	type Product struct { | ||||
| 		ID uint `gorm:"primaryKey;autoIncrement"` | ||||
| 	} | ||||
| 	type Sellable struct { | ||||
| 		Name    string | ||||
| 		Product Product `gorm:"embedded;embeddedPrefix:product_"` | ||||
| 	} | ||||
| 
 | ||||
| 	product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err) | ||||
| 	} | ||||
| 	field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID") | ||||
| 	assert.NotNil(t, field) | ||||
| 	field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id") | ||||
| 	assert.Nil(t, field) | ||||
| 	field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id") | ||||
| 	assert.Nil(t, field) | ||||
| 	field = product.LookUpFieldByBindName([]string{"product", "id"}, "id") | ||||
| 	assert.Nil(t, field) | ||||
| 
 | ||||
| 	// Check case insensitivity
 | ||||
| 	product.FieldsCaseInsensitive = true | ||||
| 	field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id") | ||||
| 	assert.NotNil(t, field) | ||||
| 	field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id") | ||||
| 	assert.NotNil(t, field) | ||||
| 	field = product.LookUpFieldByBindName([]string{"product", "id"}, "id") | ||||
| 	assert.NotNil(t, field) | ||||
| } | ||||
|  | ||||
| @ -408,7 +408,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 				reflectValue = reflectValue.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { | ||||
| 			if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil { | ||||
| 				selectedColumns := map[string]bool{} | ||||
| 				if idx == 0 { | ||||
| 					for _, v := range args[1:] { | ||||
| @ -510,7 +510,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { | ||||
| } | ||||
| 
 | ||||
| func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { | ||||
| 	if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { | ||||
| 	if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" { | ||||
| 		if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { | ||||
| 			stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} | ||||
| 			stmt.Table = tables[1] | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 John Boone
						John Boone