From 8abaddf4b4bf9539f23989eb39af8592b321c34d Mon Sep 17 00:00:00 2001 From: snackmgmg <16898622+snackmgmg@users.noreply.github.com> Date: Tue, 19 Mar 2024 22:09:09 +0900 Subject: [PATCH] fix: remove callback from callbacks if Remove() called --- callbacks.go | 24 ++++++++++++++++++++- tests/callbacks_test.go | 48 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/callbacks.go b/callbacks.go index 195d1720..54a4be85 100644 --- a/callbacks.go +++ b/callbacks.go @@ -186,12 +186,23 @@ func (p *processor) Replace(name string, fn func(*DB)) error { } func (p *processor) compile() (err error) { - var callbacks []*callback + var ( + callbacks []*callback + removed []string + ) for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removed = append(removed, callback.name) + } } + + if len(removed) > 0 { + callbacks = removeCallbacks(callbacks, removed) + } + p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { @@ -339,3 +350,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, names []string) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if utils.Contains(names, callback.name) { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 4479da4c..f77209f1 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) { }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, + results: []string{"c1", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, @@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) { t.Errorf("callbacks tests failed, got %v", msg) } } + +func TestCallbacksGet(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { + t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) + } + + createCallback.Remove("c1") + if cb := createCallback.Get("c2"); cb != nil { + t.Errorf("callbacks test failed. got: %p, want: nil", cb) + } +} + +func TestCallbacksRemove(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + createCallback.After("*").Register("c2", c2) + createCallback.Before("c4").Register("c3", c3) + createCallback.After("c2").Register("c4", c4) + + // callbacks: []string{"c1", "c3", "c4", "c2"} + createCallback.Remove("c1") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c4") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c2") + if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c3") + if ok, msg := assertCallbacks(createCallback, []string{}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +}