Refactor callbacks
This commit is contained in:
		
							parent
							
								
									e509b3100d
								
							
						
					
					
						commit
						5959c81be6
					
				
							
								
								
									
										285
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										285
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -2,26 +2,36 @@ package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| // Callbacks gorm callbacks manager
 | ||||
| type Callbacks struct { | ||||
| 	creates    []func(*DB) | ||||
| 	queries    []func(*DB) | ||||
| 	updates    []func(*DB) | ||||
| 	deletes    []func(*DB) | ||||
| 	row        []func(*DB) | ||||
| 	raw        []func(*DB) | ||||
| 	db         *DB | ||||
| 	processors []*processor | ||||
| func InitializeCallbacks() *callbacks { | ||||
| 	return &callbacks{ | ||||
| 		processors: map[string]*processor{ | ||||
| 			"create": &processor{}, | ||||
| 			"query":  &processor{}, | ||||
| 			"update": &processor{}, | ||||
| 			"delete": &processor{}, | ||||
| 			"row":    &processor{}, | ||||
| 			"raw":    &processor{}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // callbacks gorm callbacks manager
 | ||||
| type callbacks struct { | ||||
| 	processors map[string]*processor | ||||
| } | ||||
| 
 | ||||
| type processor struct { | ||||
| 	kind      string | ||||
| 	db        *DB | ||||
| 	fns       []func(*DB) | ||||
| 	callbacks []*callback | ||||
| } | ||||
| 
 | ||||
| type callback struct { | ||||
| 	name      string | ||||
| 	before    string | ||||
| 	after     string | ||||
| @ -29,79 +39,111 @@ type processor struct { | ||||
| 	replace   bool | ||||
| 	match     func(*DB) bool | ||||
| 	handler   func(*DB) | ||||
| 	callbacks *Callbacks | ||||
| 	processor *processor | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Create() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "create"} | ||||
| func (cs *callbacks) Create() *processor { | ||||
| 	return cs.processors["create"] | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Query() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "query"} | ||||
| func (cs *callbacks) Query() *processor { | ||||
| 	return cs.processors["query"] | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Update() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "update"} | ||||
| func (cs *callbacks) Update() *processor { | ||||
| 	return cs.processors["update"] | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Delete() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "delete"} | ||||
| func (cs *callbacks) Delete() *processor { | ||||
| 	return cs.processors["delete"] | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Row() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "row"} | ||||
| func (cs *callbacks) Row() *processor { | ||||
| 	return cs.processors["row"] | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Raw() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "raw"} | ||||
| func (cs *callbacks) Raw() *processor { | ||||
| 	return cs.processors["raw"] | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Before(name string) *processor { | ||||
| 	p.before = name | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| func (p *processor) After(name string) *processor { | ||||
| 	p.after = name | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Match(fc func(*DB) bool) *processor { | ||||
| 	p.match = fc | ||||
| 	return p | ||||
| func (p *processor) Execute(db *DB) { | ||||
| 	for _, f := range p.fns { | ||||
| 		f(db) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Get(name string) func(*DB) { | ||||
| 	for i := len(p.callbacks.processors) - 1; i >= 0; i-- { | ||||
| 		if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { | ||||
| 	for i := len(p.callbacks) - 1; i >= 0; i-- { | ||||
| 		if v := p.callbacks[i]; v.name == name && !v.remove { | ||||
| 			return v.handler | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Register(name string, fn func(*DB)) { | ||||
| 	p.name = name | ||||
| 	p.handler = fn | ||||
| 	p.callbacks.processors = append(p.callbacks.processors, p) | ||||
| 	p.callbacks.compile(p.callbacks.db) | ||||
| func (p *processor) Before(name string) *callback { | ||||
| 	return &callback{before: name, processor: p} | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Remove(name string) { | ||||
| 	logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) | ||||
| 	p.name = name | ||||
| 	p.remove = true | ||||
| 	p.callbacks.processors = append(p.callbacks.processors, p) | ||||
| 	p.callbacks.compile(p.callbacks.db) | ||||
| func (p *processor) After(name string) *callback { | ||||
| 	return &callback{after: name, processor: p} | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Replace(name string, fn func(*DB)) { | ||||
| 	logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) | ||||
| 	p.name = name | ||||
| 	p.handler = fn | ||||
| 	p.replace = true | ||||
| 	p.callbacks.processors = append(p.callbacks.processors, p) | ||||
| 	p.callbacks.compile(p.callbacks.db) | ||||
| func (p *processor) Match(fc func(*DB) bool) *callback { | ||||
| 	return &callback{match: fc, processor: p} | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Register(name string, fn func(*DB)) error { | ||||
| 	return (&callback{processor: p}).Register(name, fn) | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Remove(name string) error { | ||||
| 	return (&callback{processor: p}).Remove(name) | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Replace(name string, fn func(*DB)) error { | ||||
| 	return (&callback{processor: p}).Replace(name, fn) | ||||
| } | ||||
| 
 | ||||
| func (p *processor) compile(db *DB) (err error) { | ||||
| 	if p.fns, err = sortCallbacks(p.callbacks); err != nil { | ||||
| 		logger.Default.Error("Got error when compile callbacks, got %v", err) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (c *callback) Before(name string) *callback { | ||||
| 	c.before = name | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| func (c *callback) After(name string) *callback { | ||||
| 	c.after = name | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| func (c *callback) Register(name string, fn func(*DB)) error { | ||||
| 	c.name = name | ||||
| 	c.handler = fn | ||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||
| 	return c.processor.compile(c.processor.db) | ||||
| } | ||||
| 
 | ||||
| func (c *callback) Remove(name string) error { | ||||
| 	logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) | ||||
| 	c.name = name | ||||
| 	c.remove = true | ||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||
| 	return c.processor.compile(c.processor.db) | ||||
| } | ||||
| 
 | ||||
| func (c *callback) Replace(name string, fn func(*DB)) error { | ||||
| 	logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) | ||||
| 	c.name = name | ||||
| 	c.handler = fn | ||||
| 	c.replace = true | ||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||
| 	return c.processor.compile(c.processor.db) | ||||
| } | ||||
| 
 | ||||
| // getRIndex get right index from string slice
 | ||||
| @ -114,98 +156,81 @@ func getRIndex(strs []string, str string) int { | ||||
| 	return -1 | ||||
| } | ||||
| 
 | ||||
| func sortProcessors(ps []*processor) []func(*DB) { | ||||
| func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | ||||
| 	var ( | ||||
| 		allNames, sortedNames []string | ||||
| 		sortProcessor         func(*processor) error | ||||
| 		names, sorted []string | ||||
| 		sortCallback  func(*callback) error | ||||
| 	) | ||||
| 
 | ||||
| 	for _, p := range ps { | ||||
| 	for _, c := range cs { | ||||
| 		// show warning message the callback name already exists
 | ||||
| 		if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { | ||||
| 			log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) | ||||
| 		if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { | ||||
| 			logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) | ||||
| 		} | ||||
| 		allNames = append(allNames, p.name) | ||||
| 		names = append(names, c.name) | ||||
| 	} | ||||
| 
 | ||||
| 	sortProcessor = func(p *processor) error { | ||||
| 		if getRIndex(sortedNames, p.name) == -1 { // if not sorted
 | ||||
| 			if p.before != "" { // if defined before callback
 | ||||
| 				if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { | ||||
| 					if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true { | ||||
| 						// if before callback already sorted, append current callback just after it
 | ||||
| 						sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) | ||||
| 					} else if curIdx > sortedIdx { | ||||
| 						return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) | ||||
| 					} | ||||
| 				} else if idx := getRIndex(allNames, p.before); idx != -1 { | ||||
| 					// if before callback exists
 | ||||
| 					ps[idx].after = p.name | ||||
| 	sortCallback = func(c *callback) error { | ||||
| 		if c.before != "" { // if defined before callback
 | ||||
| 			if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { | ||||
| 				if curIdx := getRIndex(sorted, c.name); curIdx == -1 { | ||||
| 					// if before callback already sorted, append current callback just after it
 | ||||
| 					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) | ||||
| 				} else if curIdx > sortedIdx { | ||||
| 					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) | ||||
| 				} | ||||
| 			} else if idx := getRIndex(names, c.before); idx != -1 { | ||||
| 				// if before callback exists
 | ||||
| 				cs[idx].after = c.name | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 			if p.after != "" { // if defined after callback
 | ||||
| 				if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { | ||||
| 		if c.after != "" { // if defined after callback
 | ||||
| 			if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { | ||||
| 				if curIdx := getRIndex(sorted, c.name); curIdx == -1 { | ||||
| 					// if after callback sorted, append current callback to last
 | ||||
| 					sortedNames = append(sortedNames, p.name) | ||||
| 				} else if idx := getRIndex(allNames, p.after); idx != -1 { | ||||
| 					// if after callback exists but haven't sorted
 | ||||
| 					// set after callback's before callback to current callback
 | ||||
| 					if after := ps[idx]; after.before == "" { | ||||
| 						after.before = p.name | ||||
| 						sortProcessor(after) | ||||
| 					} | ||||
| 					sorted = append(sorted, c.name) | ||||
| 				} else if curIdx < sortedIdx { | ||||
| 					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) | ||||
| 				} | ||||
| 			} else if idx := getRIndex(names, c.after); idx != -1 { | ||||
| 				// if after callback exists but haven't sorted
 | ||||
| 				// set after callback's before callback to current callback
 | ||||
| 				after := cs[idx] | ||||
| 
 | ||||
| 				if after.before == "" { | ||||
| 					after.before = c.name | ||||
| 				} | ||||
| 
 | ||||
| 				if err := sortCallback(after); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				if err := sortCallback(c); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 			// if current callback haven't been sorted, append it to last
 | ||||
| 			if getRIndex(sortedNames, p.name) == -1 { | ||||
| 				sortedNames = append(sortedNames, p.name) | ||||
| 			} | ||||
| 		// if current callback haven't been sorted, append it to last
 | ||||
| 		if getRIndex(sorted, c.name) == -1 { | ||||
| 			sorted = append(sorted, c.name) | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	for _, p := range ps { | ||||
| 		sortProcessor(p) | ||||
| 	} | ||||
| 
 | ||||
| 	var fns []func(*DB) | ||||
| 	for _, name := range sortedNames { | ||||
| 		if idx := getRIndex(allNames, name); !ps[idx].remove { | ||||
| 			fns = append(fns, ps[idx].handler) | ||||
| 	for _, c := range cs { | ||||
| 		if err = sortCallback(c); err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return fns | ||||
| } | ||||
| 
 | ||||
| // compile processors
 | ||||
| func (cs *Callbacks) compile(db *DB) { | ||||
| 	processors := map[string][]*processor{} | ||||
| 	for _, p := range cs.processors { | ||||
| 		if p.name != "" { | ||||
| 			if p.match == nil || p.match(db) { | ||||
| 				processors[p.kind] = append(processors[p.kind], p) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for name, ps := range processors { | ||||
| 		switch name { | ||||
| 		case "create": | ||||
| 			cs.creates = sortProcessors(ps) | ||||
| 		case "query": | ||||
| 			cs.queries = sortProcessors(ps) | ||||
| 		case "update": | ||||
| 			cs.updates = sortProcessors(ps) | ||||
| 		case "delete": | ||||
| 			cs.deletes = sortProcessors(ps) | ||||
| 		case "row": | ||||
| 			cs.row = sortProcessors(ps) | ||||
| 		case "raw": | ||||
| 			cs.raw = sortProcessors(ps) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, name := range sorted { | ||||
| 		if idx := getRIndex(names, name); !cs[idx].remove { | ||||
| 			fns = append(fns, cs[idx].handler) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -1,131 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) { | ||||
| 	var got []string | ||||
| 
 | ||||
| 	for _, f := range funcs { | ||||
| 		got = append(got, getFuncName(f)) | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) | ||||
| } | ||||
| 
 | ||||
| func getFuncName(fc func(*DB)) string { | ||||
| 	fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".") | ||||
| 	return fnames[len(fnames)-1] | ||||
| } | ||||
| 
 | ||||
| func c1(*DB) {} | ||||
| func c2(*DB) {} | ||||
| func c3(*DB) {} | ||||
| func c4(*DB) {} | ||||
| func c5(*DB) {} | ||||
| 
 | ||||
| func TestCallbacks(t *testing.T) { | ||||
| 	type callback struct { | ||||
| 		name    string | ||||
| 		before  string | ||||
| 		after   string | ||||
| 		remove  bool | ||||
| 		replace bool | ||||
| 		err     error | ||||
| 		match   func(*DB) bool | ||||
| 		h       func(*DB) | ||||
| 	} | ||||
| 
 | ||||
| 	datas := []struct { | ||||
| 		callbacks []callback | ||||
| 		results   []string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c4", "c5"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c5", "c2", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c3", "c5", "c2", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, | ||||
| 			results:   []string{"c1", "c5", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, | ||||
| 			results:   []string{"c1", "c4", "c3"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | ||||
| 	// 	var callback2 = &Callback{logger: defaultLogger}
 | ||||
| 
 | ||||
| 	// 	callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
 | ||||
| 	// 	callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
 | ||||
| 	// 	callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
 | ||||
| 	// 	callback2.Delete().Register("after_create1", afterCreate1)
 | ||||
| 	// 	callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
 | ||||
| 
 | ||||
| 	// 	if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
 | ||||
| 	// 		t.Errorf("register callback with order")
 | ||||
| 	// 	}
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	for idx, data := range datas { | ||||
| 		callbacks := &Callbacks{} | ||||
| 
 | ||||
| 		for _, c := range data.callbacks { | ||||
| 			p := callbacks.Create() | ||||
| 
 | ||||
| 			if c.name == "" { | ||||
| 				c.name = getFuncName(c.h) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.before != "" { | ||||
| 				p = p.Before(c.before) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.after != "" { | ||||
| 				p = p.After(c.after) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.match != nil { | ||||
| 				p = p.Match(c.match) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.remove { | ||||
| 				p.Remove(c.name) | ||||
| 			} else if c.replace { | ||||
| 				p.Replace(c.name, c.h) | ||||
| 			} else { | ||||
| 				p.Register(c.name, c.h) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok { | ||||
| 			t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										158
									
								
								tests/callbacks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								tests/callbacks_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,158 @@ | ||||
| package gorm_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { | ||||
| 	var ( | ||||
| 		got   []string | ||||
| 		funcs = reflect.ValueOf(v).Elem().FieldByName("fns") | ||||
| 	) | ||||
| 
 | ||||
| 	for i := 0; i < funcs.Len(); i++ { | ||||
| 		got = append(got, getFuncName(funcs.Index(i))) | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) | ||||
| } | ||||
| 
 | ||||
| func getFuncName(fc interface{}) string { | ||||
| 	reflectValue, ok := fc.(reflect.Value) | ||||
| 	if !ok { | ||||
| 		reflectValue = reflect.ValueOf(fc) | ||||
| 	} | ||||
| 
 | ||||
| 	fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".") | ||||
| 	return fnames[len(fnames)-1] | ||||
| } | ||||
| 
 | ||||
| func c1(*gorm.DB) {} | ||||
| func c2(*gorm.DB) {} | ||||
| func c3(*gorm.DB) {} | ||||
| func c4(*gorm.DB) {} | ||||
| func c5(*gorm.DB) {} | ||||
| 
 | ||||
| func TestCallbacks(t *testing.T) { | ||||
| 	type callback struct { | ||||
| 		name    string | ||||
| 		before  string | ||||
| 		after   string | ||||
| 		remove  bool | ||||
| 		replace bool | ||||
| 		err     string | ||||
| 		match   func(*gorm.DB) bool | ||||
| 		h       func(*gorm.DB) | ||||
| 	} | ||||
| 
 | ||||
| 	datas := []struct { | ||||
| 		callbacks []callback | ||||
| 		err       string | ||||
| 		results   []string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c4", "c5"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c5", "c2", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c3", "c1", "c5", "c2", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c3", "c1", "c5", "c2", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, | ||||
| 			err:       "conflicting", | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, | ||||
| 			results:   []string{"c1", "c5", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, | ||||
| 			results:   []string{"c1", "c4", "c3"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, data := range datas { | ||||
| 		var err error | ||||
| 		callbacks := gorm.InitializeCallbacks() | ||||
| 
 | ||||
| 		for _, c := range data.callbacks { | ||||
| 			var v interface{} = callbacks.Create() | ||||
| 			callMethod := func(s interface{}, name string, args ...interface{}) { | ||||
| 				var argValues []reflect.Value | ||||
| 				for _, arg := range args { | ||||
| 					argValues = append(argValues, reflect.ValueOf(arg)) | ||||
| 				} | ||||
| 
 | ||||
| 				results := reflect.ValueOf(s).MethodByName(name).Call(argValues) | ||||
| 				if len(results) > 0 { | ||||
| 					v = results[0].Interface() | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if c.name == "" { | ||||
| 				c.name = getFuncName(c.h) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.before != "" { | ||||
| 				callMethod(v, "Before", c.before) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.after != "" { | ||||
| 				callMethod(v, "After", c.after) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.match != nil { | ||||
| 				callMethod(v, "Match", c.match) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.remove { | ||||
| 				callMethod(v, "Remove", c.name) | ||||
| 			} else if c.replace { | ||||
| 				callMethod(v, "Replace", c.name, c.h) | ||||
| 			} else { | ||||
| 				callMethod(v, "Register", c.name, c.h) | ||||
| 			} | ||||
| 
 | ||||
| 			if e, ok := v.(error); !ok || e != nil { | ||||
| 				err = e | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if len(data.err) > 0 && err == nil { | ||||
| 			t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err) | ||||
| 		} else if len(data.err) == 0 && err != nil { | ||||
| 			t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err) | ||||
| 		} | ||||
| 
 | ||||
| 		if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok { | ||||
| 			t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu