From e509b3100daa35df7b7e80e8928bcf74aacf3a9e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jan 2020 06:35:25 +0800 Subject: [PATCH] Implement callbacks --- callbacks.go | 211 ++++++++++++++++++++++++++++++++++++++++ callbacks_test.go | 131 +++++++++++++++++++++++++ errors.go => helpers.go | 21 ++-- logger/logger.go | 46 +++++++++ model.go | 15 --- utils/utils.go | 20 ++++ 6 files changed, 422 insertions(+), 22 deletions(-) create mode 100644 callbacks.go create mode 100644 callbacks_test.go rename errors.go => helpers.go (55%) delete mode 100644 model.go create mode 100644 utils/utils.go diff --git a/callbacks.go b/callbacks.go new file mode 100644 index 00000000..d53e8049 --- /dev/null +++ b/callbacks.go @@ -0,0 +1,211 @@ +package gorm + +import ( + "fmt" + "log" + + "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/gorm/utils" +) + +// Callbacks gorm callbacks manager +type Callbacks struct { + creates []func(*DB) + queries []func(*DB) + updates []func(*DB) + deletes []func(*DB) + row []func(*DB) + raw []func(*DB) + db *DB + processors []*processor +} + +type processor struct { + kind string + name string + before string + after string + remove bool + replace bool + match func(*DB) bool + handler func(*DB) + callbacks *Callbacks +} + +func (cs *Callbacks) Create() *processor { + return &processor{callbacks: cs, kind: "create"} +} + +func (cs *Callbacks) Query() *processor { + return &processor{callbacks: cs, kind: "query"} +} + +func (cs *Callbacks) Update() *processor { + return &processor{callbacks: cs, kind: "update"} +} + +func (cs *Callbacks) Delete() *processor { + return &processor{callbacks: cs, kind: "delete"} +} + +func (cs *Callbacks) Row() *processor { + return &processor{callbacks: cs, kind: "row"} +} + +func (cs *Callbacks) Raw() *processor { + return &processor{callbacks: cs, kind: "raw"} +} + +func (p *processor) Before(name string) *processor { + p.before = name + return p +} + +func (p *processor) After(name string) *processor { + p.after = name + return p +} + +func (p *processor) Match(fc func(*DB) bool) *processor { + p.match = fc + return p +} + +func (p *processor) Get(name string) func(*DB) { + for i := len(p.callbacks.processors) - 1; i >= 0; i-- { + if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { + return v.handler + } + } + return nil +} + +func (p *processor) Register(name string, fn func(*DB)) { + p.name = name + p.handler = fn + p.callbacks.processors = append(p.callbacks.processors, p) + p.callbacks.compile(p.callbacks.db) +} + +func (p *processor) Remove(name string) { + logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + p.name = name + p.remove = true + p.callbacks.processors = append(p.callbacks.processors, p) + p.callbacks.compile(p.callbacks.db) +} + +func (p *processor) Replace(name string, fn func(*DB)) { + logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + p.name = name + p.handler = fn + p.replace = true + p.callbacks.processors = append(p.callbacks.processors, p) + p.callbacks.compile(p.callbacks.db) +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +func sortProcessors(ps []*processor) []func(*DB) { + var ( + allNames, sortedNames []string + sortProcessor func(*processor) error + ) + + for _, p := range ps { + // show warning message the callback name already exists + if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { + log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) + } + allNames = append(allNames, p.name) + } + + sortProcessor = func(p *processor) error { + if getRIndex(sortedNames, p.name) == -1 { // if not sorted + if p.before != "" { // if defined before callback + if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { + if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true { + // if before callback already sorted, append current callback just after it + sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) + } + } else if idx := getRIndex(allNames, p.before); idx != -1 { + // if before callback exists + ps[idx].after = p.name + } + } + + if p.after != "" { // if defined after callback + if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { + // if after callback sorted, append current callback to last + sortedNames = append(sortedNames, p.name) + } else if idx := getRIndex(allNames, p.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + if after := ps[idx]; after.before == "" { + after.before = p.name + sortProcessor(after) + } + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sortedNames, p.name) == -1 { + sortedNames = append(sortedNames, p.name) + } + } + + return nil + } + + for _, p := range ps { + sortProcessor(p) + } + + var fns []func(*DB) + for _, name := range sortedNames { + if idx := getRIndex(allNames, name); !ps[idx].remove { + fns = append(fns, ps[idx].handler) + } + } + + return fns +} + +// compile processors +func (cs *Callbacks) compile(db *DB) { + processors := map[string][]*processor{} + for _, p := range cs.processors { + if p.name != "" { + if p.match == nil || p.match(db) { + processors[p.kind] = append(processors[p.kind], p) + } + } + } + + for name, ps := range processors { + switch name { + case "create": + cs.creates = sortProcessors(ps) + case "query": + cs.queries = sortProcessors(ps) + case "update": + cs.updates = sortProcessors(ps) + case "delete": + cs.deletes = sortProcessors(ps) + case "row": + cs.row = sortProcessors(ps) + case "raw": + cs.raw = sortProcessors(ps) + } + } +} diff --git a/callbacks_test.go b/callbacks_test.go new file mode 100644 index 00000000..547cdca1 --- /dev/null +++ b/callbacks_test.go @@ -0,0 +1,131 @@ +package gorm + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "testing" +) + +func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) { + var got []string + + for _, f := range funcs { + got = append(got, getFuncName(f)) + } + + return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) +} + +func getFuncName(fc func(*DB)) string { + fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".") + return fnames[len(fnames)-1] +} + +func c1(*DB) {} +func c2(*DB) {} +func c3(*DB) {} +func c4(*DB) {} +func c5(*DB) {} + +func TestCallbacks(t *testing.T) { + type callback struct { + name string + before string + after string + remove bool + replace bool + err error + match func(*DB) bool + h func(*DB) + } + + datas := []struct { + callbacks []callback + results []string + }{ + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c4", "c5"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c5", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c1", "c3", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, + results: []string{"c1", "c5", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, + results: []string{"c1", "c4", "c3"}, + }, + } + + // func TestRegisterCallbackWithComplexOrder(t *testing.T) { + // var callback2 = &Callback{logger: defaultLogger} + + // callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) + // callback2.Delete().Before("create").Register("before_create1", beforeCreate1) + // callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) + // callback2.Delete().Register("after_create1", afterCreate1) + // callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) + + // if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + // t.Errorf("register callback with order") + // } + // } + + for idx, data := range datas { + callbacks := &Callbacks{} + + for _, c := range data.callbacks { + p := callbacks.Create() + + if c.name == "" { + c.name = getFuncName(c.h) + } + + if c.before != "" { + p = p.Before(c.before) + } + + if c.after != "" { + p = p.After(c.after) + } + + if c.match != nil { + p = p.Match(c.match) + } + + if c.remove { + p.Remove(c.name) + } else if c.replace { + p.Replace(c.name, c.h) + } else { + p.Register(c.name, c.h) + } + } + + if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok { + t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) + } + } +} diff --git a/errors.go b/helpers.go similarity index 55% rename from errors.go rename to helpers.go index c66408be..8f9df009 100644 --- a/errors.go +++ b/helpers.go @@ -1,6 +1,9 @@ package gorm -import "errors" +import ( + "errors" + "time" +) var ( // ErrRecordNotFound record not found error @@ -13,10 +16,14 @@ var ( ErrUnaddressable = errors.New("using unaddressable value") ) -type Error struct { - Err error -} - -func (e Error) Unwrap() error { - return e.Err +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embeded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index"` } diff --git a/logger/logger.go b/logger/logger.go index 389a6763..9d6e70bf 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,7 +1,15 @@ package logger +import ( + "fmt" + "log" + "os" +) + type LogLevel int +var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)} + const ( Info LogLevel = iota + 1 Warn @@ -11,4 +19,42 @@ const ( // Interface logger interface type Interface interface { LogMode(LogLevel) Interface + Info(string, ...interface{}) + Warn(string, ...interface{}) + Error(string, ...interface{}) +} + +// Writer log writer interface +type Writer interface { + Print(...interface{}) +} + +type Logger struct { + Writer + logLevel LogLevel +} + +func (logger Logger) LogMode(level LogLevel) Interface { + return Logger{Writer: logger.Writer, logLevel: level} +} + +// Info print info +func (logger Logger) Info(msg string, data ...interface{}) { + if logger.logLevel >= Info { + logger.Print("[info] " + fmt.Sprintf(msg, data...)) + } +} + +// Warn print warn messages +func (logger Logger) Warn(msg string, data ...interface{}) { + if logger.logLevel >= Warn { + logger.Print("[warn] " + fmt.Sprintf(msg, data...)) + } +} + +// Error print error messages +func (logger Logger) Error(msg string, data ...interface{}) { + if logger.logLevel >= Error { + logger.Print("[error] " + fmt.Sprintf(msg, data...)) + } } diff --git a/model.go b/model.go deleted file mode 100644 index 118d8f14..00000000 --- a/model.go +++ /dev/null @@ -1,15 +0,0 @@ -package gorm - -import "time" - -// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embeded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` -} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..81ac8b30 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,20 @@ +package utils + +import ( + "fmt" + "regexp" + "runtime" +) + +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) + +func FileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { + return fmt.Sprintf("%v:%v", file, line) + } + } + return "" +}