diff --git a/model_struct.go b/model_struct.go index 1089a380..79df79f7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,26 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} +func newModelStructsMap() sync.Map { return sync.Map{} } var modelStructsMap = newModelStructsMap() @@ -46,15 +27,10 @@ type ModelStruct struct { StructFields []*StructField ModelType reflect.Type defaultTableName string - - l sync.Mutex } // TableName get model's table name func (s *ModelStruct) TableName(db *DB) string { - s.l.Lock() - defer s.l.Unlock() - if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { @@ -157,8 +133,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -606,7 +582,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct }