From 87b8577ce37933f160824d66bce4b1e822826e0c Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 13:20:25 +0400 Subject: [PATCH] Fix data races in Callbacks --- callback.go | 20 ++++++++++++++-- callbacks_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/callback.go b/callback.go index a4382147..ea8c5737 100644 --- a/callback.go +++ b/callback.go @@ -1,6 +1,9 @@ package gorm -import "log" +import ( + "log" + "sync" +) // DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{} @@ -13,6 +16,7 @@ var DefaultCallback = &Callback{} // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { + sync.Mutex creates []*func(scope *Scope) updates []*func(scope *Scope) deletes []*func(scope *Scope) @@ -23,6 +27,7 @@ type Callback struct { // CallbackProcessor contains callback informations type CallbackProcessor struct { + sync.RWMutex name string // current callback's name before string // register current callback before a callback after string // register current callback after a callback @@ -100,8 +105,15 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * cp.name = callbackName cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.Lock() + cp.Lock() + processors := make([]*CallbackProcessor, 0) + copy(processors, cp.parent.processors) + processors = append(processors, cp) + cp.parent.processors = processors + cp.Unlock() cp.parent.reorder() + cp.parent.Unlock() } // Remove a registered callback @@ -200,7 +212,9 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { } for _, cp := range cps { + cp.Lock() sortCallbackProcessor(cp) + cp.Unlock() } var sortedFuncs []*func(scope *Scope) @@ -218,6 +232,7 @@ func (c *Callback) reorder() { var creates, updates, deletes, queries, rowQueries []*CallbackProcessor for _, processor := range c.processors { + processor.RLock() if processor.name != "" { switch processor.kind { case "create": @@ -232,6 +247,7 @@ func (c *Callback) reorder() { rowQueries = append(rowQueries, processor) } } + processor.RUnlock() } c.creates = sortProcessors(creates) diff --git a/callbacks_test.go b/callbacks_test.go index a58913d7..95b8518c 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -2,11 +2,14 @@ package gorm_test import ( "errors" + "io/ioutil" + "os" + "path" + "reflect" + "sync" + "testing" "github.com/jinzhu/gorm" - - "reflect" - "testing" ) func (s *Product) BeforeCreate() (err error) { @@ -175,3 +178,55 @@ func TestCallbacksWithErrors(t *testing.T) { t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") } } + +func worker() error { + tempdir, err := ioutil.TempDir("", "testdb") + if err != nil { + return err + } + defer os.RemoveAll(tempdir) + + gdb, err := gorm.Open("sqlite3", path.Join(tempdir, "gorm.db")) + if err != nil { + return err + } + defer gdb.Close() + + gdb.Callback().Create().Before("gorm:create").Register("dummy:create_before", func(s *gorm.Scope) {}) + return nil +} + +func TestWorker(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + if err := worker(); err != nil { + t.Fail() + } + }() + } + wg.Wait() +} + +func BenchmarkWorker(b *testing.B) { + b.ReportAllocs() + var wg sync.WaitGroup + for i := 0; i < b.N; i++ { + for i := 0; i < 100; i++ { + + wg.Add(1) + go func() { + defer wg.Done() + + if err := worker(); err != nil { + b.Fail() + } + }() + } + wg.Wait() + + } +}