From 725aa5b5ff4c0687b06d9a01096b8e4cf96b6c9e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Aug 2025 17:19:20 +0800 Subject: [PATCH] Fix data race, close #7287 #7110 #7539 #7108 --- schema/relationship.go | 15 ++- schema/relationship_test.go | 45 ++++++++ schema/schema.go | 210 ++++++++++++++++-------------------- tests/go.mod | 6 +- 4 files changed, 152 insertions(+), 124 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index f1ace924..0535bba4 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/schema/relationship_test.go b/schema/relationship_test.go index f1acf2d9..c706ac84 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -3,9 +3,11 @@ package schema_test import ( "sync" "testing" + "time" "gorm.io/gorm" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { @@ -996,3 +998,46 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { ) } } + +type InfoRelation struct { + ID int + Code string + Info1 []*Info1 `gorm:"foreignkey:Code;references:Code"` + Info2 []*Info2 `gorm:"foreignkey:Code;references:Code"` +} + +type Info1 struct { + CreatedAt time.Time + UpdatedAt time.Time + Code string + Relation []*InfoRelation `gorm:"foreignkey:Code;references:Code"` +} + +type Info2 struct { + CreatedAt time.Time + UpdatedAt time.Time + Code string + Relation []*InfoRelation `gorm:"foreignkey:Code;references:Code"` +} + +func TestDataRace(t *testing.T) { + syncMap := &sync.Map{} + for i := 0; i < 10; i++ { + go func() { + schema.Parse(&Info1{}, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) + }() + + go func() { + schema.Parse(&Info2{}, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) + }() + + go func() { + var result User + schema.Parse(&result, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) + }() + go func() { + var result tests.Account + schema.Parse(&result, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) + }() + } +} diff --git a/schema/schema.go b/schema/schema.go index 2a5c28e2..be8d8f07 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -60,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -75,7 +75,7 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } @@ -93,7 +93,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { if len(bindNames) == 0 { return nil } @@ -114,6 +114,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -129,30 +137,31 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if value.Kind() == reflect.Ptr && value.IsNil() { value = reflect.New(value.Type().Elem()) } + modelType := reflect.Indirect(value).Type() - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -163,28 +172,27 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { - tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { - tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } if specialTableName != "" && specialTableName != tableName { tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -284,10 +292,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -304,30 +339,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": - expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) - if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - } else { - logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) - // PASS - } - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -343,84 +354,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/tests/go.mod b/tests/go.mod index 30b0c16e..b490e423 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -27,13 +27,13 @@ require ( github.com/jackc/pgx/v5 v5.7.5 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.28 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/microsoft/go-mssqldb v1.9.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect - golang.org/x/crypto v0.40.0 // indirect + golang.org/x/crypto v0.41.0 // indirect golang.org/x/sync v0.16.0 // indirect - golang.org/x/text v0.27.0 // indirect + golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect )