Fix data races in Callbacks
This commit is contained in:
parent
7bc3561503
commit
87b8577ce3
20
callback.go
20
callback.go
@ -1,6 +1,9 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "log"
|
import (
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
// DefaultCallback default callbacks defined by gorm
|
// DefaultCallback default callbacks defined by gorm
|
||||||
var DefaultCallback = &Callback{}
|
var DefaultCallback = &Callback{}
|
||||||
@ -13,6 +16,7 @@ var DefaultCallback = &Callback{}
|
|||||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
type Callback struct {
|
type Callback struct {
|
||||||
|
sync.Mutex
|
||||||
creates []*func(scope *Scope)
|
creates []*func(scope *Scope)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(scope *Scope)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(scope *Scope)
|
||||||
@ -23,6 +27,7 @@ type Callback struct {
|
|||||||
|
|
||||||
// CallbackProcessor contains callback informations
|
// CallbackProcessor contains callback informations
|
||||||
type CallbackProcessor struct {
|
type CallbackProcessor struct {
|
||||||
|
sync.RWMutex
|
||||||
name string // current callback's name
|
name string // current callback's name
|
||||||
before string // register current callback before a callback
|
before string // register current callback before a callback
|
||||||
after string // register current callback after a callback
|
after string // register current callback after a callback
|
||||||
@ -100,8 +105,15 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
|||||||
|
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.Lock()
|
||||||
|
cp.Lock()
|
||||||
|
processors := make([]*CallbackProcessor, 0)
|
||||||
|
copy(processors, cp.parent.processors)
|
||||||
|
processors = append(processors, cp)
|
||||||
|
cp.parent.processors = processors
|
||||||
|
cp.Unlock()
|
||||||
cp.parent.reorder()
|
cp.parent.reorder()
|
||||||
|
cp.parent.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove a registered callback
|
// Remove a registered callback
|
||||||
@ -200,7 +212,9 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
|
cp.Lock()
|
||||||
sortCallbackProcessor(cp)
|
sortCallbackProcessor(cp)
|
||||||
|
cp.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
var sortedFuncs []*func(scope *Scope)
|
var sortedFuncs []*func(scope *Scope)
|
||||||
@ -218,6 +232,7 @@ func (c *Callback) reorder() {
|
|||||||
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||||
|
|
||||||
for _, processor := range c.processors {
|
for _, processor := range c.processors {
|
||||||
|
processor.RLock()
|
||||||
if processor.name != "" {
|
if processor.name != "" {
|
||||||
switch processor.kind {
|
switch processor.kind {
|
||||||
case "create":
|
case "create":
|
||||||
@ -232,6 +247,7 @@ func (c *Callback) reorder() {
|
|||||||
rowQueries = append(rowQueries, processor)
|
rowQueries = append(rowQueries, processor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
processor.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
c.creates = sortProcessors(creates)
|
c.creates = sortProcessors(creates)
|
||||||
|
@ -2,11 +2,14 @@ package gorm_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Product) BeforeCreate() (err error) {
|
func (s *Product) BeforeCreate() (err error) {
|
||||||
@ -175,3 +178,55 @@ func TestCallbacksWithErrors(t *testing.T) {
|
|||||||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func worker() error {
|
||||||
|
tempdir, err := ioutil.TempDir("", "testdb")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempdir)
|
||||||
|
|
||||||
|
gdb, err := gorm.Open("sqlite3", path.Join(tempdir, "gorm.db"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer gdb.Close()
|
||||||
|
|
||||||
|
gdb.Callback().Create().Before("gorm:create").Register("dummy:create_before", func(s *gorm.Scope) {})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorker(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if err := worker(); err != nil {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWorker(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if err := worker(); err != nil {
|
||||||
|
b.Fail()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user