Fix data races in Callbacks

This commit is contained in:
Emir Beganovic 2019-04-14 13:20:25 +04:00
parent 7bc3561503
commit 87b8577ce3
2 changed files with 76 additions and 5 deletions

View File

@ -1,6 +1,9 @@
package gorm
import "log"
import (
"log"
"sync"
)
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
@ -13,6 +16,7 @@ var DefaultCallback = &Callback{}
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
sync.Mutex
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
@ -23,6 +27,7 @@ type Callback struct {
// CallbackProcessor contains callback informations
type CallbackProcessor struct {
sync.RWMutex
name string // current callback's name
before string // register current callback before a callback
after string // register current callback after a callback
@ -100,8 +105,15 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.Lock()
cp.Lock()
processors := make([]*CallbackProcessor, 0)
copy(processors, cp.parent.processors)
processors = append(processors, cp)
cp.parent.processors = processors
cp.Unlock()
cp.parent.reorder()
cp.parent.Unlock()
}
// Remove a registered callback
@ -200,7 +212,9 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
}
for _, cp := range cps {
cp.Lock()
sortCallbackProcessor(cp)
cp.Unlock()
}
var sortedFuncs []*func(scope *Scope)
@ -218,6 +232,7 @@ func (c *Callback) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
for _, processor := range c.processors {
processor.RLock()
if processor.name != "" {
switch processor.kind {
case "create":
@ -232,6 +247,7 @@ func (c *Callback) reorder() {
rowQueries = append(rowQueries, processor)
}
}
processor.RUnlock()
}
c.creates = sortProcessors(creates)

View File

@ -2,11 +2,14 @@ package gorm_test
import (
"errors"
"io/ioutil"
"os"
"path"
"reflect"
"sync"
"testing"
"github.com/jinzhu/gorm"
"reflect"
"testing"
)
func (s *Product) BeforeCreate() (err error) {
@ -175,3 +178,55 @@ func TestCallbacksWithErrors(t *testing.T) {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
func worker() error {
tempdir, err := ioutil.TempDir("", "testdb")
if err != nil {
return err
}
defer os.RemoveAll(tempdir)
gdb, err := gorm.Open("sqlite3", path.Join(tempdir, "gorm.db"))
if err != nil {
return err
}
defer gdb.Close()
gdb.Callback().Create().Before("gorm:create").Register("dummy:create_before", func(s *gorm.Scope) {})
return nil
}
func TestWorker(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := worker(); err != nil {
t.Fail()
}
}()
}
wg.Wait()
}
func BenchmarkWorker(b *testing.B) {
b.ReportAllocs()
var wg sync.WaitGroup
for i := 0; i < b.N; i++ {
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := worker(); err != nil {
b.Fail()
}
}()
}
wg.Wait()
}
}