fix: remove callback from callbacks if Remove() called
This commit is contained in:
parent
1b0aa802df
commit
8abaddf4b4
24
callbacks.go
24
callbacks.go
@ -186,12 +186,23 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) compile() (err error) {
|
func (p *processor) compile() (err error) {
|
||||||
var callbacks []*callback
|
var (
|
||||||
|
callbacks []*callback
|
||||||
|
removed []string
|
||||||
|
)
|
||||||
for _, callback := range p.callbacks {
|
for _, callback := range p.callbacks {
|
||||||
if callback.match == nil || callback.match(p.db) {
|
if callback.match == nil || callback.match(p.db) {
|
||||||
callbacks = append(callbacks, callback)
|
callbacks = append(callbacks, callback)
|
||||||
}
|
}
|
||||||
|
if callback.remove {
|
||||||
|
removed = append(removed, callback.name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(removed) > 0 {
|
||||||
|
callbacks = removeCallbacks(callbacks, removed)
|
||||||
|
}
|
||||||
|
|
||||||
p.callbacks = callbacks
|
p.callbacks = callbacks
|
||||||
|
|
||||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||||
@ -339,3 +350,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func removeCallbacks(cs []*callback, names []string) []*callback {
|
||||||
|
callbacks := make([]*callback, 0, len(cs))
|
||||||
|
for _, callback := range cs {
|
||||||
|
if utils.Contains(names, callback.name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callbacks = append(callbacks, callback)
|
||||||
|
}
|
||||||
|
return callbacks
|
||||||
|
}
|
||||||
|
@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||||
results: []string{"c1", "c5", "c3", "c4"},
|
results: []string{"c1", "c3", "c4", "c5"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||||
@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) {
|
|||||||
t.Errorf("callbacks tests failed, got %v", msg)
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCallbacksGet(t *testing.T) {
|
||||||
|
db, _ := gorm.Open(nil, nil)
|
||||||
|
createCallback := db.Callback().Create()
|
||||||
|
|
||||||
|
createCallback.Before("*").Register("c1", c1)
|
||||||
|
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
|
||||||
|
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c1")
|
||||||
|
if cb := createCallback.Get("c2"); cb != nil {
|
||||||
|
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbacksRemove(t *testing.T) {
|
||||||
|
db, _ := gorm.Open(nil, nil)
|
||||||
|
createCallback := db.Callback().Create()
|
||||||
|
|
||||||
|
createCallback.Before("*").Register("c1", c1)
|
||||||
|
createCallback.After("*").Register("c2", c2)
|
||||||
|
createCallback.Before("c4").Register("c3", c3)
|
||||||
|
createCallback.After("c2").Register("c4", c4)
|
||||||
|
|
||||||
|
// callbacks: []string{"c1", "c3", "c4", "c2"}
|
||||||
|
createCallback.Remove("c1")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c4")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c2")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c3")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user