Implement callbacks
This commit is contained in:
		
							parent
							
								
									9d5b9834d9
								
							
						
					
					
						commit
						e509b3100d
					
				
							
								
								
									
										211
									
								
								callbacks.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										211
									
								
								callbacks.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,211 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/logger" | ||||
| 	"github.com/jinzhu/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| // Callbacks gorm callbacks manager
 | ||||
| type Callbacks struct { | ||||
| 	creates    []func(*DB) | ||||
| 	queries    []func(*DB) | ||||
| 	updates    []func(*DB) | ||||
| 	deletes    []func(*DB) | ||||
| 	row        []func(*DB) | ||||
| 	raw        []func(*DB) | ||||
| 	db         *DB | ||||
| 	processors []*processor | ||||
| } | ||||
| 
 | ||||
| type processor struct { | ||||
| 	kind      string | ||||
| 	name      string | ||||
| 	before    string | ||||
| 	after     string | ||||
| 	remove    bool | ||||
| 	replace   bool | ||||
| 	match     func(*DB) bool | ||||
| 	handler   func(*DB) | ||||
| 	callbacks *Callbacks | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Create() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "create"} | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Query() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "query"} | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Update() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "update"} | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Delete() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "delete"} | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Row() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "row"} | ||||
| } | ||||
| 
 | ||||
| func (cs *Callbacks) Raw() *processor { | ||||
| 	return &processor{callbacks: cs, kind: "raw"} | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Before(name string) *processor { | ||||
| 	p.before = name | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| func (p *processor) After(name string) *processor { | ||||
| 	p.after = name | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Match(fc func(*DB) bool) *processor { | ||||
| 	p.match = fc | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Get(name string) func(*DB) { | ||||
| 	for i := len(p.callbacks.processors) - 1; i >= 0; i-- { | ||||
| 		if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { | ||||
| 			return v.handler | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Register(name string, fn func(*DB)) { | ||||
| 	p.name = name | ||||
| 	p.handler = fn | ||||
| 	p.callbacks.processors = append(p.callbacks.processors, p) | ||||
| 	p.callbacks.compile(p.callbacks.db) | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Remove(name string) { | ||||
| 	logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) | ||||
| 	p.name = name | ||||
| 	p.remove = true | ||||
| 	p.callbacks.processors = append(p.callbacks.processors, p) | ||||
| 	p.callbacks.compile(p.callbacks.db) | ||||
| } | ||||
| 
 | ||||
| func (p *processor) Replace(name string, fn func(*DB)) { | ||||
| 	logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) | ||||
| 	p.name = name | ||||
| 	p.handler = fn | ||||
| 	p.replace = true | ||||
| 	p.callbacks.processors = append(p.callbacks.processors, p) | ||||
| 	p.callbacks.compile(p.callbacks.db) | ||||
| } | ||||
| 
 | ||||
| // getRIndex get right index from string slice
 | ||||
| func getRIndex(strs []string, str string) int { | ||||
| 	for i := len(strs) - 1; i >= 0; i-- { | ||||
| 		if strs[i] == str { | ||||
| 			return i | ||||
| 		} | ||||
| 	} | ||||
| 	return -1 | ||||
| } | ||||
| 
 | ||||
| func sortProcessors(ps []*processor) []func(*DB) { | ||||
| 	var ( | ||||
| 		allNames, sortedNames []string | ||||
| 		sortProcessor         func(*processor) error | ||||
| 	) | ||||
| 
 | ||||
| 	for _, p := range ps { | ||||
| 		// show warning message the callback name already exists
 | ||||
| 		if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { | ||||
| 			log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) | ||||
| 		} | ||||
| 		allNames = append(allNames, p.name) | ||||
| 	} | ||||
| 
 | ||||
| 	sortProcessor = func(p *processor) error { | ||||
| 		if getRIndex(sortedNames, p.name) == -1 { // if not sorted
 | ||||
| 			if p.before != "" { // if defined before callback
 | ||||
| 				if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { | ||||
| 					if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true { | ||||
| 						// if before callback already sorted, append current callback just after it
 | ||||
| 						sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) | ||||
| 					} else if curIdx > sortedIdx { | ||||
| 						return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) | ||||
| 					} | ||||
| 				} else if idx := getRIndex(allNames, p.before); idx != -1 { | ||||
| 					// if before callback exists
 | ||||
| 					ps[idx].after = p.name | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if p.after != "" { // if defined after callback
 | ||||
| 				if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { | ||||
| 					// if after callback sorted, append current callback to last
 | ||||
| 					sortedNames = append(sortedNames, p.name) | ||||
| 				} else if idx := getRIndex(allNames, p.after); idx != -1 { | ||||
| 					// if after callback exists but haven't sorted
 | ||||
| 					// set after callback's before callback to current callback
 | ||||
| 					if after := ps[idx]; after.before == "" { | ||||
| 						after.before = p.name | ||||
| 						sortProcessor(after) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// if current callback haven't been sorted, append it to last
 | ||||
| 			if getRIndex(sortedNames, p.name) == -1 { | ||||
| 				sortedNames = append(sortedNames, p.name) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	for _, p := range ps { | ||||
| 		sortProcessor(p) | ||||
| 	} | ||||
| 
 | ||||
| 	var fns []func(*DB) | ||||
| 	for _, name := range sortedNames { | ||||
| 		if idx := getRIndex(allNames, name); !ps[idx].remove { | ||||
| 			fns = append(fns, ps[idx].handler) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return fns | ||||
| } | ||||
| 
 | ||||
| // compile processors
 | ||||
| func (cs *Callbacks) compile(db *DB) { | ||||
| 	processors := map[string][]*processor{} | ||||
| 	for _, p := range cs.processors { | ||||
| 		if p.name != "" { | ||||
| 			if p.match == nil || p.match(db) { | ||||
| 				processors[p.kind] = append(processors[p.kind], p) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for name, ps := range processors { | ||||
| 		switch name { | ||||
| 		case "create": | ||||
| 			cs.creates = sortProcessors(ps) | ||||
| 		case "query": | ||||
| 			cs.queries = sortProcessors(ps) | ||||
| 		case "update": | ||||
| 			cs.updates = sortProcessors(ps) | ||||
| 		case "delete": | ||||
| 			cs.deletes = sortProcessors(ps) | ||||
| 		case "row": | ||||
| 			cs.row = sortProcessors(ps) | ||||
| 		case "raw": | ||||
| 			cs.raw = sortProcessors(ps) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										131
									
								
								callbacks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								callbacks_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,131 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) { | ||||
| 	var got []string | ||||
| 
 | ||||
| 	for _, f := range funcs { | ||||
| 		got = append(got, getFuncName(f)) | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) | ||||
| } | ||||
| 
 | ||||
| func getFuncName(fc func(*DB)) string { | ||||
| 	fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".") | ||||
| 	return fnames[len(fnames)-1] | ||||
| } | ||||
| 
 | ||||
| func c1(*DB) {} | ||||
| func c2(*DB) {} | ||||
| func c3(*DB) {} | ||||
| func c4(*DB) {} | ||||
| func c5(*DB) {} | ||||
| 
 | ||||
| func TestCallbacks(t *testing.T) { | ||||
| 	type callback struct { | ||||
| 		name    string | ||||
| 		before  string | ||||
| 		after   string | ||||
| 		remove  bool | ||||
| 		replace bool | ||||
| 		err     error | ||||
| 		match   func(*DB) bool | ||||
| 		h       func(*DB) | ||||
| 	} | ||||
| 
 | ||||
| 	datas := []struct { | ||||
| 		callbacks []callback | ||||
| 		results   []string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c4", "c5"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, | ||||
| 			results:   []string{"c1", "c2", "c3", "c5", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c5", "c2", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, | ||||
| 			results:   []string{"c1", "c3", "c5", "c2", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, | ||||
| 			results:   []string{"c1", "c5", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, | ||||
| 			results:   []string{"c1", "c4", "c3"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | ||||
| 	// 	var callback2 = &Callback{logger: defaultLogger}
 | ||||
| 
 | ||||
| 	// 	callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
 | ||||
| 	// 	callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
 | ||||
| 	// 	callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
 | ||||
| 	// 	callback2.Delete().Register("after_create1", afterCreate1)
 | ||||
| 	// 	callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
 | ||||
| 
 | ||||
| 	// 	if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
 | ||||
| 	// 		t.Errorf("register callback with order")
 | ||||
| 	// 	}
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	for idx, data := range datas { | ||||
| 		callbacks := &Callbacks{} | ||||
| 
 | ||||
| 		for _, c := range data.callbacks { | ||||
| 			p := callbacks.Create() | ||||
| 
 | ||||
| 			if c.name == "" { | ||||
| 				c.name = getFuncName(c.h) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.before != "" { | ||||
| 				p = p.Before(c.before) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.after != "" { | ||||
| 				p = p.After(c.after) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.match != nil { | ||||
| 				p = p.Match(c.match) | ||||
| 			} | ||||
| 
 | ||||
| 			if c.remove { | ||||
| 				p.Remove(c.name) | ||||
| 			} else if c.replace { | ||||
| 				p.Replace(c.name, c.h) | ||||
| 			} else { | ||||
| 				p.Register(c.name, c.h) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok { | ||||
| 			t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -1,6 +1,9 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import "errors" | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	// ErrRecordNotFound record not found error
 | ||||
| @ -13,10 +16,14 @@ var ( | ||||
| 	ErrUnaddressable = errors.New("using unaddressable value") | ||||
| ) | ||||
| 
 | ||||
| type Error struct { | ||||
| 	Err error | ||||
| } | ||||
| 
 | ||||
| func (e Error) Unwrap() error { | ||||
| 	return e.Err | ||||
| // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
 | ||||
| // It may be embeded into your model or you may build your own model without it
 | ||||
| //    type User struct {
 | ||||
| //      gorm.Model
 | ||||
| //    }
 | ||||
| type Model struct { | ||||
| 	ID        uint `gorm:"primary_key"` | ||||
| 	CreatedAt time.Time | ||||
| 	UpdatedAt time.Time | ||||
| 	DeletedAt *time.Time `gorm:"index"` | ||||
| } | ||||
| @ -1,7 +1,15 @@ | ||||
| package logger | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| ) | ||||
| 
 | ||||
| type LogLevel int | ||||
| 
 | ||||
| var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)} | ||||
| 
 | ||||
| const ( | ||||
| 	Info LogLevel = iota + 1 | ||||
| 	Warn | ||||
| @ -11,4 +19,42 @@ const ( | ||||
| // Interface logger interface
 | ||||
| type Interface interface { | ||||
| 	LogMode(LogLevel) Interface | ||||
| 	Info(string, ...interface{}) | ||||
| 	Warn(string, ...interface{}) | ||||
| 	Error(string, ...interface{}) | ||||
| } | ||||
| 
 | ||||
| // Writer log writer interface
 | ||||
| type Writer interface { | ||||
| 	Print(...interface{}) | ||||
| } | ||||
| 
 | ||||
| type Logger struct { | ||||
| 	Writer | ||||
| 	logLevel LogLevel | ||||
| } | ||||
| 
 | ||||
| func (logger Logger) LogMode(level LogLevel) Interface { | ||||
| 	return Logger{Writer: logger.Writer, logLevel: level} | ||||
| } | ||||
| 
 | ||||
| // Info print info
 | ||||
| func (logger Logger) Info(msg string, data ...interface{}) { | ||||
| 	if logger.logLevel >= Info { | ||||
| 		logger.Print("[info] " + fmt.Sprintf(msg, data...)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Warn print warn messages
 | ||||
| func (logger Logger) Warn(msg string, data ...interface{}) { | ||||
| 	if logger.logLevel >= Warn { | ||||
| 		logger.Print("[warn] " + fmt.Sprintf(msg, data...)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Error print error messages
 | ||||
| func (logger Logger) Error(msg string, data ...interface{}) { | ||||
| 	if logger.logLevel >= Error { | ||||
| 		logger.Print("[error] " + fmt.Sprintf(msg, data...)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										15
									
								
								model.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								model.go
									
									
									
									
									
								
							| @ -1,15 +0,0 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import "time" | ||||
| 
 | ||||
| // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
 | ||||
| // It may be embeded into your model or you may build your own model without it
 | ||||
| //    type User struct {
 | ||||
| //      gorm.Model
 | ||||
| //    }
 | ||||
| type Model struct { | ||||
| 	ID        uint `gorm:"primary_key"` | ||||
| 	CreatedAt time.Time | ||||
| 	UpdatedAt time.Time | ||||
| 	DeletedAt *time.Time `gorm:"index"` | ||||
| } | ||||
							
								
								
									
										20
									
								
								utils/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								utils/utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | ||||
| package utils | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"runtime" | ||||
| ) | ||||
| 
 | ||||
| var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) | ||||
| var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) | ||||
| 
 | ||||
| func FileWithLineNum() string { | ||||
| 	for i := 2; i < 15; i++ { | ||||
| 		_, file, line, ok := runtime.Caller(i) | ||||
| 		if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { | ||||
| 			return fmt.Sprintf("%v:%v", file, line) | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu