Add callbacks
This commit is contained in:
parent
d833efe8b9
commit
728c0d4470
29
callbacks.go
29
callbacks.go
@ -9,15 +9,15 @@ import (
|
|||||||
"github.com/jinzhu/gorm/utils"
|
"github.com/jinzhu/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitializeCallbacks() *callbacks {
|
func initializeCallbacks(db *DB) *callbacks {
|
||||||
return &callbacks{
|
return &callbacks{
|
||||||
processors: map[string]*processor{
|
processors: map[string]*processor{
|
||||||
"create": &processor{},
|
"create": &processor{db: db},
|
||||||
"query": &processor{},
|
"query": &processor{db: db},
|
||||||
"update": &processor{},
|
"update": &processor{db: db},
|
||||||
"delete": &processor{},
|
"delete": &processor{db: db},
|
||||||
"row": &processor{},
|
"row": &processor{db: db},
|
||||||
"raw": &processor{},
|
"raw": &processor{db: db},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,7 +118,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
|||||||
return (&callback{processor: p}).Replace(name, fn)
|
return (&callback{processor: p}).Replace(name, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) compile(db *DB) (err error) {
|
func (p *processor) compile() (err error) {
|
||||||
|
var callbacks []*callback
|
||||||
|
for _, callback := range p.callbacks {
|
||||||
|
if callback.match == nil || callback.match(p.db) {
|
||||||
|
callbacks = append(callbacks, callback)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||||
logger.Default.Error("Got error when compile callbacks, got %v", err)
|
logger.Default.Error("Got error when compile callbacks, got %v", err)
|
||||||
}
|
}
|
||||||
@ -139,7 +146,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
|
|||||||
c.name = name
|
c.name = name
|
||||||
c.handler = fn
|
c.handler = fn
|
||||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
return c.processor.compile(c.processor.db)
|
return c.processor.compile()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Remove(name string) error {
|
func (c *callback) Remove(name string) error {
|
||||||
@ -147,7 +154,7 @@ func (c *callback) Remove(name string) error {
|
|||||||
c.name = name
|
c.name = name
|
||||||
c.remove = true
|
c.remove = true
|
||||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
return c.processor.compile(c.processor.db)
|
return c.processor.compile()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||||
@ -156,7 +163,7 @@ func (c *callback) Replace(name string, fn func(*DB)) error {
|
|||||||
c.handler = fn
|
c.handler = fn
|
||||||
c.replace = true
|
c.replace = true
|
||||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
return c.processor.compile(c.processor.db)
|
return c.processor.compile()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRIndex get right index from string slice
|
// getRIndex get right index from string slice
|
||||||
|
@ -3,10 +3,37 @@ package callbacks
|
|||||||
import "github.com/jinzhu/gorm"
|
import "github.com/jinzhu/gorm"
|
||||||
|
|
||||||
func RegisterDefaultCallbacks(db *gorm.DB) {
|
func RegisterDefaultCallbacks(db *gorm.DB) {
|
||||||
callback := db.Callback()
|
enableTransaction := func(db *gorm.DB) bool {
|
||||||
callback.Create().Register("gorm:before_create", BeforeCreate)
|
return !db.SkipDefaultTransaction
|
||||||
callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
}
|
||||||
callback.Create().Register("gorm:create", Create)
|
|
||||||
callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
|
createCallback := db.Callback().Create()
|
||||||
callback.Create().Register("gorm:after_create", AfterCreate)
|
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
|
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||||
|
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||||
|
createCallback.Register("gorm:create", Create)
|
||||||
|
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||||
|
createCallback.Register("gorm:after_create", AfterCreate)
|
||||||
|
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
|
|
||||||
|
queryCallback := db.Callback().Query()
|
||||||
|
queryCallback.Register("gorm:query", BeforeCreate)
|
||||||
|
queryCallback.Register("gorm:preload", Preload)
|
||||||
|
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||||
|
|
||||||
|
deleteCallback := db.Callback().Delete()
|
||||||
|
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
|
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||||
|
deleteCallback.Register("gorm:delete", Delete)
|
||||||
|
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||||
|
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
|
|
||||||
|
updateCallback := db.Callback().Update()
|
||||||
|
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
|
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||||
|
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||||
|
updateCallback.Register("gorm:update", Update)
|
||||||
|
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||||
|
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||||
|
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||||||
|
|
||||||
func Create(db *gorm.DB) {
|
func Create(db *gorm.DB) {
|
||||||
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
|
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
|
||||||
|
db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,17 +29,3 @@ func AfterCreate(db *gorm.DB) {
|
|||||||
// after save
|
// after save
|
||||||
// after create
|
// after create
|
||||||
}
|
}
|
||||||
|
|
||||||
func objectToFieldsMap(stmt *gorm.Statement) {
|
|
||||||
if stmt.Schema != nil {
|
|
||||||
if s, ok := stmt.Clauses["SELECT"]; ok {
|
|
||||||
s.Attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
if s, ok := stmt.Clauses["OMIT"]; ok {
|
|
||||||
s.Attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt.Schema.LookUpField(s.S)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
12
callbacks/delete.go
Normal file
12
callbacks/delete.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import "github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
func BeforeDelete(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func Delete(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func AfterDelete(db *gorm.DB) {
|
||||||
|
}
|
9
callbacks/transaction.go
Normal file
9
callbacks/transaction.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import "github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
func BeginTransaction(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||||
|
}
|
12
callbacks/update.go
Normal file
12
callbacks/update.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import "github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func Update(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func AfterUpdate(db *gorm.DB) {
|
||||||
|
}
|
@ -1,5 +0,0 @@
|
|||||||
module github.com/jinzhu/gorm/dialects/sqlite
|
|
||||||
|
|
||||||
go 1.13
|
|
||||||
|
|
||||||
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
|
@ -1,2 +0,0 @@
|
|||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
|
5
go.mod
5
go.mod
@ -2,7 +2,4 @@ module github.com/jinzhu/gorm
|
|||||||
|
|
||||||
go 1.13
|
go 1.13
|
||||||
|
|
||||||
require (
|
require github.com/jinzhu/inflection v1.0.0
|
||||||
github.com/jinzhu/inflection v1.0.0
|
|
||||||
gopkg.in/errgo.v2 v2.1.0
|
|
||||||
)
|
|
||||||
|
2
go.sum
2
go.sum
@ -1,2 +0,0 @@
|
|||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
|
3
gorm.go
3
gorm.go
@ -63,10 +63,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||||||
Config: config,
|
Config: config,
|
||||||
Dialector: dialector,
|
Dialector: dialector,
|
||||||
clone: true,
|
clone: true,
|
||||||
callbacks: InitializeCallbacks(),
|
|
||||||
cacheStore: &sync.Map{},
|
cacheStore: &sync.Map{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.callbacks = initializeCallbacks(db)
|
||||||
|
|
||||||
if dialector != nil {
|
if dialector != nil {
|
||||||
err = dialector.Initialize(db)
|
err = dialector.Initialize(db)
|
||||||
}
|
}
|
||||||
|
14
statement.go
14
statement.go
@ -21,6 +21,13 @@ type Instance struct {
|
|||||||
Statement *Statement
|
Statement *Statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
||||||
|
if len(clauses) > 0 {
|
||||||
|
instance.Statement.Build(clauses...)
|
||||||
|
}
|
||||||
|
return instance.Statement.SQL.String(), instance.Statement.Vars
|
||||||
|
}
|
||||||
|
|
||||||
// AddError add error to instance
|
// AddError add error to instance
|
||||||
func (inst Instance) AddError(err error) {
|
func (inst Instance) AddError(err error) {
|
||||||
if inst.Error == nil {
|
if inst.Error == nil {
|
||||||
@ -205,16 +212,17 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
|||||||
|
|
||||||
// Build build sql with clauses names
|
// Build build sql with clauses names
|
||||||
func (stmt Statement) Build(clauses ...string) {
|
func (stmt Statement) Build(clauses ...string) {
|
||||||
var includeSpace bool
|
var firstClauseWritten bool
|
||||||
|
|
||||||
for _, name := range clauses {
|
for _, name := range clauses {
|
||||||
if c, ok := stmt.Clauses[name]; ok {
|
if c, ok := stmt.Clauses[name]; ok {
|
||||||
if includeSpace {
|
if firstClauseWritten {
|
||||||
stmt.WriteByte(' ')
|
stmt.WriteByte(' ')
|
||||||
}
|
}
|
||||||
|
|
||||||
includeSpace = true
|
firstClauseWritten = true
|
||||||
c.Build(stmt)
|
c.Build(stmt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// TODO handle named vars
|
||||||
}
|
}
|
||||||
|
@ -99,8 +99,8 @@ func TestCallbacks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for idx, data := range datas {
|
for idx, data := range datas {
|
||||||
var err error
|
db, err := gorm.Open(nil, nil)
|
||||||
callbacks := gorm.InitializeCallbacks()
|
callbacks := db.Callback()
|
||||||
|
|
||||||
for _, c := range data.callbacks {
|
for _, c := range data.callbacks {
|
||||||
var v interface{} = callbacks.Create()
|
var v interface{} = callbacks.Create()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user