From 728c0d4470ec02629483fe90b11f7a0dec17bded Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 19:32:27 +0800 Subject: [PATCH] Add callbacks --- callbacks.go | 29 ++++++++++++++++++----------- callbacks/callbacks.go | 39 +++++++++++++++++++++++++++++++++------ callbacks/create.go | 16 +--------------- callbacks/delete.go | 12 ++++++++++++ callbacks/transaction.go | 9 +++++++++ callbacks/update.go | 12 ++++++++++++ dialects/sqlite/go.mod | 5 ----- dialects/sqlite/go.sum | 2 -- go.mod | 5 +---- go.sum | 2 -- gorm.go | 3 ++- statement.go | 14 +++++++++++--- tests/callbacks_test.go | 4 ++-- 13 files changed, 101 insertions(+), 51 deletions(-) create mode 100644 callbacks/delete.go create mode 100644 callbacks/transaction.go create mode 100644 callbacks/update.go delete mode 100644 dialects/sqlite/go.mod delete mode 100644 dialects/sqlite/go.sum delete mode 100644 go.sum diff --git a/callbacks.go b/callbacks.go index 22d2eda3..51ee150f 100644 --- a/callbacks.go +++ b/callbacks.go @@ -9,15 +9,15 @@ import ( "github.com/jinzhu/gorm/utils" ) -func InitializeCallbacks() *callbacks { +func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": &processor{}, - "query": &processor{}, - "update": &processor{}, - "delete": &processor{}, - "row": &processor{}, - "raw": &processor{}, + "create": &processor{db: db}, + "query": &processor{db: db}, + "update": &processor{db: db}, + "delete": &processor{db: db}, + "row": &processor{db: db}, + "raw": &processor{db: db}, }, } } @@ -118,7 +118,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error { return (&callback{processor: p}).Replace(name, fn) } -func (p *processor) compile(db *DB) (err error) { +func (p *processor) compile() (err error) { + var callbacks []*callback + for _, callback := range p.callbacks { + if callback.match == nil || callback.match(p.db) { + callbacks = append(callbacks, callback) + } + } + if p.fns, err = sortCallbacks(p.callbacks); err != nil { logger.Default.Error("Got error when compile callbacks, got %v", err) } @@ -139,7 +146,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { c.name = name c.handler = fn c.processor.callbacks = append(c.processor.callbacks, c) - return c.processor.compile(c.processor.db) + return c.processor.compile() } func (c *callback) Remove(name string) error { @@ -147,7 +154,7 @@ func (c *callback) Remove(name string) error { c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) - return c.processor.compile(c.processor.db) + return c.processor.compile() } func (c *callback) Replace(name string, fn func(*DB)) error { @@ -156,7 +163,7 @@ func (c *callback) Replace(name string, fn func(*DB)) error { c.handler = fn c.replace = true c.processor.callbacks = append(c.processor.callbacks, c) - return c.processor.compile(c.processor.db) + return c.processor.compile() } // getRIndex get right index from string slice diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 7fd12cb7..a3e5245b 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -3,10 +3,37 @@ package callbacks import "github.com/jinzhu/gorm" func RegisterDefaultCallbacks(db *gorm.DB) { - callback := db.Callback() - callback.Create().Register("gorm:before_create", BeforeCreate) - callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) - callback.Create().Register("gorm:create", Create) - callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) - callback.Create().Register("gorm:after_create", AfterCreate) + enableTransaction := func(db *gorm.DB) bool { + return !db.SkipDefaultTransaction + } + + createCallback := db.Callback().Create() + createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + createCallback.Register("gorm:before_create", BeforeCreate) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + createCallback.Register("gorm:create", Create) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + queryCallback := db.Callback().Query() + queryCallback.Register("gorm:query", BeforeCreate) + queryCallback.Register("gorm:preload", Preload) + queryCallback.Register("gorm:after_query", AfterQuery) + + deleteCallback := db.Callback().Delete() + deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:after_delete", AfterDelete) + deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + updateCallback := db.Callback().Update() + updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:before_update", BeforeUpdate) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + updateCallback.Register("gorm:update", Update) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/callbacks/create.go b/callbacks/create.go index 5a3aaa24..028cdbc4 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -18,7 +18,7 @@ func SaveBeforeAssociations(db *gorm.DB) { func Create(db *gorm.DB) { db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") - + db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } @@ -29,17 +29,3 @@ func AfterCreate(db *gorm.DB) { // after save // after create } - -func objectToFieldsMap(stmt *gorm.Statement) { - if stmt.Schema != nil { - if s, ok := stmt.Clauses["SELECT"]; ok { - s.Attrs - } - - if s, ok := stmt.Clauses["OMIT"]; ok { - s.Attrs - } - - stmt.Schema.LookUpField(s.S) - } -} diff --git a/callbacks/delete.go b/callbacks/delete.go new file mode 100644 index 00000000..96c392f2 --- /dev/null +++ b/callbacks/delete.go @@ -0,0 +1,12 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeforeDelete(db *gorm.DB) { +} + +func Delete(db *gorm.DB) { +} + +func AfterDelete(db *gorm.DB) { +} diff --git a/callbacks/transaction.go b/callbacks/transaction.go new file mode 100644 index 00000000..253c4e82 --- /dev/null +++ b/callbacks/transaction.go @@ -0,0 +1,9 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeginTransaction(db *gorm.DB) { +} + +func CommitOrRollbackTransaction(db *gorm.DB) { +} diff --git a/callbacks/update.go b/callbacks/update.go new file mode 100644 index 00000000..8e504403 --- /dev/null +++ b/callbacks/update.go @@ -0,0 +1,12 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeforeUpdate(db *gorm.DB) { +} + +func Update(db *gorm.DB) { +} + +func AfterUpdate(db *gorm.DB) { +} diff --git a/dialects/sqlite/go.mod b/dialects/sqlite/go.mod deleted file mode 100644 index 79d48da8..00000000 --- a/dialects/sqlite/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/jinzhu/gorm/dialects/sqlite - -go 1.13 - -require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/dialects/sqlite/go.sum b/dialects/sqlite/go.sum deleted file mode 100644 index d6744290..00000000 --- a/dialects/sqlite/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= -github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/go.mod b/go.mod index 820046ba..516a9759 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module github.com/jinzhu/gorm go 1.13 -require ( - github.com/jinzhu/inflection v1.0.0 - gopkg.in/errgo.v2 v2.1.0 -) +require github.com/jinzhu/inflection v1.0.0 diff --git a/go.sum b/go.sum deleted file mode 100644 index a310b071..00000000 --- a/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= diff --git a/gorm.go b/gorm.go index 2264b9ae..8ac7e057 100644 --- a/gorm.go +++ b/gorm.go @@ -63,10 +63,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { Config: config, Dialector: dialector, clone: true, - callbacks: InitializeCallbacks(), cacheStore: &sync.Map{}, } + db.callbacks = initializeCallbacks(db) + if dialector != nil { err = dialector.Initialize(db) } diff --git a/statement.go b/statement.go index 86359177..4d959cbb 100644 --- a/statement.go +++ b/statement.go @@ -21,6 +21,13 @@ type Instance struct { Statement *Statement } +func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { + if len(clauses) > 0 { + instance.Statement.Build(clauses...) + } + return instance.Statement.SQL.String(), instance.Statement.Vars +} + // AddError add error to instance func (inst Instance) AddError(err error) { if inst.Error == nil { @@ -205,16 +212,17 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con // Build build sql with clauses names func (stmt Statement) Build(clauses ...string) { - var includeSpace bool + var firstClauseWritten bool for _, name := range clauses { if c, ok := stmt.Clauses[name]; ok { - if includeSpace { + if firstClauseWritten { stmt.WriteByte(' ') } - includeSpace = true + firstClauseWritten = true c.Build(stmt) } } + // TODO handle named vars } diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index af975a55..f8dc3e81 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -99,8 +99,8 @@ func TestCallbacks(t *testing.T) { } for idx, data := range datas { - var err error - callbacks := gorm.InitializeCallbacks() + db, err := gorm.Open(nil, nil) + callbacks := db.Callback() for _, c := range data.callbacks { var v interface{} = callbacks.Create()