diff --git a/callbacks/preload.go b/callbacks/preload.go index 225cda28..bf35e3f2 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -138,14 +138,14 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati } tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } } case reflect.Struct, reflect.Pointer: reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } default: diff --git a/callbacks/query.go b/callbacks/query.go index 548bf709..a46a44c7 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -291,7 +291,7 @@ func Preload(db *gorm.DB) { return } - db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) + db.AddError(preloadEntryPoint(tx, joins, tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } diff --git a/schema/relationship.go b/schema/relationship.go index f1ace924..a0bade79 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -160,7 +160,7 @@ func (schema *Schema) setRelation(relation *Relationship) { if len(relation.Field.EmbeddedBindNames) <= 1 { return } - relationships := &schema.Relationships + relationships := schema.Relationships for i, name := range relation.Field.EmbeddedBindNames { if i < len(relation.Field.EmbeddedBindNames)-1 { if relationships.EmbeddedRelations == nil { diff --git a/schema/schema.go b/schema/schema.go index 203feec9..a0655eaa 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -44,7 +44,7 @@ type Schema struct { FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database - Relationships Relationships + Relationships *Relationships CreateClauses []clause.Interface QueryClauses []clause.Interface UpdateClauses []clause.Interface @@ -58,7 +58,7 @@ type Schema struct { initialized chan struct{} namer Namer cacheStore *sync.Map - rwmu sync.RWMutex + mux sync.RWMutex } func (schema Schema) String() string { @@ -77,8 +77,8 @@ func (schema Schema) MakeSlice() reflect.Value { } func (schema Schema) LookUpField(name string) *Field { - schema.rwmu.RLock() - defer schema.rwmu.RUnlock() + schema.mux.RLock() + defer schema.mux.RUnlock() if field, ok := schema.FieldsByDBName[name]; ok { return field } @@ -102,12 +102,12 @@ func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Fie } for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name - schema.rwmu.RLock() + schema.mux.RLock() if field, ok := schema.FieldsByBindName[find]; ok { - schema.rwmu.RUnlock() + schema.mux.RUnlock() return field } - schema.rwmu.RUnlock() + schema.mux.RUnlock() } return nil } @@ -191,7 +191,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam FieldsByName: map[string]*Field{}, FieldsByBindName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, - Relationships: Relationships{Relations: map[string]*Relationship{}}, + Relationships: &Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, initialized: make(chan struct{}), diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index bc326686..24b57b8b 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -221,7 +221,7 @@ func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationship t.Errorf("failed to find relation by name %s", n) } else { checkSchemaRelation(t, &schema.Schema{ - Relationships: schema.Relationships{ + Relationships: &schema.Relationships{ Relations: map[string]*schema.Relationship{n: rel}, }, }, r)