fix: remove callback from callbacks if Remove() called (#6916)
				
					
				
			* fix: remove callback from callbacks if Remove() called * reduce number of loops * remove unnecessary blank line
This commit is contained in:
		
							parent
							
								
									956f7ce843
								
							
						
					
					
						commit
						26195e6d16
					
				
							
								
								
									
										19
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { | |||||||
| 
 | 
 | ||||||
| func (p *processor) compile() (err error) { | func (p *processor) compile() (err error) { | ||||||
| 	var callbacks []*callback | 	var callbacks []*callback | ||||||
|  | 	removedMap := map[string]bool{} | ||||||
| 	for _, callback := range p.callbacks { | 	for _, callback := range p.callbacks { | ||||||
| 		if callback.match == nil || callback.match(p.db) { | 		if callback.match == nil || callback.match(p.db) { | ||||||
| 			callbacks = append(callbacks, callback) | 			callbacks = append(callbacks, callback) | ||||||
| 		} | 		} | ||||||
|  | 		if callback.remove { | ||||||
|  | 			removedMap[callback.name] = true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(removedMap) > 0 { | ||||||
|  | 		callbacks = removeCallbacks(callbacks, removedMap) | ||||||
| 	} | 	} | ||||||
| 	p.callbacks = callbacks | 	p.callbacks = callbacks | ||||||
| 
 | 
 | ||||||
| @ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | |||||||
| 
 | 
 | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { | ||||||
|  | 	callbacks := make([]*callback, 0, len(cs)) | ||||||
|  | 	for _, callback := range cs { | ||||||
|  | 		if nameMap[callback.name] { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		callbacks = append(callbacks, callback) | ||||||
|  | 	} | ||||||
|  | 	return callbacks | ||||||
|  | } | ||||||
|  | |||||||
| @ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, | 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, | ||||||
| 			results:   []string{"c1", "c5", "c3", "c4"}, | 			results:   []string{"c1", "c3", "c4", "c5"}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, | 			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, | ||||||
| @ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) { | |||||||
| 		t.Errorf("callbacks tests failed, got %v", msg) | 		t.Errorf("callbacks tests failed, got %v", msg) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestCallbacksGet(t *testing.T) { | ||||||
|  | 	db, _ := gorm.Open(nil, nil) | ||||||
|  | 	createCallback := db.Callback().Create() | ||||||
|  | 
 | ||||||
|  | 	createCallback.Before("*").Register("c1", c1) | ||||||
|  | 	if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { | ||||||
|  | 		t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	createCallback.Remove("c1") | ||||||
|  | 	if cb := createCallback.Get("c2"); cb != nil { | ||||||
|  | 		t.Errorf("callbacks test failed. got: %p, want: nil", cb) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCallbacksRemove(t *testing.T) { | ||||||
|  | 	db, _ := gorm.Open(nil, nil) | ||||||
|  | 	createCallback := db.Callback().Create() | ||||||
|  | 
 | ||||||
|  | 	createCallback.Before("*").Register("c1", c1) | ||||||
|  | 	createCallback.After("*").Register("c2", c2) | ||||||
|  | 	createCallback.Before("c4").Register("c3", c3) | ||||||
|  | 	createCallback.After("c2").Register("c4", c4) | ||||||
|  | 
 | ||||||
|  | 	// callbacks: []string{"c1", "c3", "c4", "c2"}
 | ||||||
|  | 	createCallback.Remove("c1") | ||||||
|  | 	if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { | ||||||
|  | 		t.Errorf("callbacks tests failed, got %v", msg) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	createCallback.Remove("c4") | ||||||
|  | 	if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { | ||||||
|  | 		t.Errorf("callbacks tests failed, got %v", msg) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	createCallback.Remove("c2") | ||||||
|  | 	if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { | ||||||
|  | 		t.Errorf("callbacks tests failed, got %v", msg) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	createCallback.Remove("c3") | ||||||
|  | 	if ok, msg := assertCallbacks(createCallback, []string{}); !ok { | ||||||
|  | 		t.Errorf("callbacks tests failed, got %v", msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 snackmgmg
						snackmgmg