Add callbacks
This commit is contained in:
		
							parent
							
								
									d833efe8b9
								
							
						
					
					
						commit
						728c0d4470
					
				
							
								
								
									
										29
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -9,15 +9,15 @@ import ( | |||||||
| 	"github.com/jinzhu/gorm/utils" | 	"github.com/jinzhu/gorm/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func InitializeCallbacks() *callbacks { | func initializeCallbacks(db *DB) *callbacks { | ||||||
| 	return &callbacks{ | 	return &callbacks{ | ||||||
| 		processors: map[string]*processor{ | 		processors: map[string]*processor{ | ||||||
| 			"create": &processor{}, | 			"create": &processor{db: db}, | ||||||
| 			"query":  &processor{}, | 			"query":  &processor{db: db}, | ||||||
| 			"update": &processor{}, | 			"update": &processor{db: db}, | ||||||
| 			"delete": &processor{}, | 			"delete": &processor{db: db}, | ||||||
| 			"row":    &processor{}, | 			"row":    &processor{db: db}, | ||||||
| 			"raw":    &processor{}, | 			"raw":    &processor{db: db}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -118,7 +118,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error { | |||||||
| 	return (&callback{processor: p}).Replace(name, fn) | 	return (&callback{processor: p}).Replace(name, fn) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *processor) compile(db *DB) (err error) { | func (p *processor) compile() (err error) { | ||||||
|  | 	var callbacks []*callback | ||||||
|  | 	for _, callback := range p.callbacks { | ||||||
|  | 		if callback.match == nil || callback.match(p.db) { | ||||||
|  | 			callbacks = append(callbacks, callback) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if p.fns, err = sortCallbacks(p.callbacks); err != nil { | 	if p.fns, err = sortCallbacks(p.callbacks); err != nil { | ||||||
| 		logger.Default.Error("Got error when compile callbacks, got %v", err) | 		logger.Default.Error("Got error when compile callbacks, got %v", err) | ||||||
| 	} | 	} | ||||||
| @ -139,7 +146,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { | |||||||
| 	c.name = name | 	c.name = name | ||||||
| 	c.handler = fn | 	c.handler = fn | ||||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||||
| 	return c.processor.compile(c.processor.db) | 	return c.processor.compile() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *callback) Remove(name string) error { | func (c *callback) Remove(name string) error { | ||||||
| @ -147,7 +154,7 @@ func (c *callback) Remove(name string) error { | |||||||
| 	c.name = name | 	c.name = name | ||||||
| 	c.remove = true | 	c.remove = true | ||||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||||
| 	return c.processor.compile(c.processor.db) | 	return c.processor.compile() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *callback) Replace(name string, fn func(*DB)) error { | func (c *callback) Replace(name string, fn func(*DB)) error { | ||||||
| @ -156,7 +163,7 @@ func (c *callback) Replace(name string, fn func(*DB)) error { | |||||||
| 	c.handler = fn | 	c.handler = fn | ||||||
| 	c.replace = true | 	c.replace = true | ||||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||||
| 	return c.processor.compile(c.processor.db) | 	return c.processor.compile() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getRIndex get right index from string slice
 | // getRIndex get right index from string slice
 | ||||||
|  | |||||||
| @ -3,10 +3,37 @@ package callbacks | |||||||
| import "github.com/jinzhu/gorm" | import "github.com/jinzhu/gorm" | ||||||
| 
 | 
 | ||||||
| func RegisterDefaultCallbacks(db *gorm.DB) { | func RegisterDefaultCallbacks(db *gorm.DB) { | ||||||
| 	callback := db.Callback() | 	enableTransaction := func(db *gorm.DB) bool { | ||||||
| 	callback.Create().Register("gorm:before_create", BeforeCreate) | 		return !db.SkipDefaultTransaction | ||||||
| 	callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) | 	} | ||||||
| 	callback.Create().Register("gorm:create", Create) | 
 | ||||||
| 	callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) | 	createCallback := db.Callback().Create() | ||||||
| 	callback.Create().Register("gorm:after_create", AfterCreate) | 	createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||||
|  | 	createCallback.Register("gorm:before_create", BeforeCreate) | ||||||
|  | 	createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) | ||||||
|  | 	createCallback.Register("gorm:create", Create) | ||||||
|  | 	createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) | ||||||
|  | 	createCallback.Register("gorm:after_create", AfterCreate) | ||||||
|  | 	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
|  | 
 | ||||||
|  | 	queryCallback := db.Callback().Query() | ||||||
|  | 	queryCallback.Register("gorm:query", BeforeCreate) | ||||||
|  | 	queryCallback.Register("gorm:preload", Preload) | ||||||
|  | 	queryCallback.Register("gorm:after_query", AfterQuery) | ||||||
|  | 
 | ||||||
|  | 	deleteCallback := db.Callback().Delete() | ||||||
|  | 	deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||||
|  | 	deleteCallback.Register("gorm:before_delete", BeforeDelete) | ||||||
|  | 	deleteCallback.Register("gorm:delete", Delete) | ||||||
|  | 	deleteCallback.Register("gorm:after_delete", AfterDelete) | ||||||
|  | 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
|  | 
 | ||||||
|  | 	updateCallback := db.Callback().Update() | ||||||
|  | 	updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||||
|  | 	updateCallback.Register("gorm:before_update", BeforeUpdate) | ||||||
|  | 	updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) | ||||||
|  | 	updateCallback.Register("gorm:update", Update) | ||||||
|  | 	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) | ||||||
|  | 	updateCallback.Register("gorm:after_update", AfterUpdate) | ||||||
|  | 	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
| } | } | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| func Create(db *gorm.DB) { | func Create(db *gorm.DB) { | ||||||
| 	db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") | 	db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") | ||||||
| 
 | 	db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -29,17 +29,3 @@ func AfterCreate(db *gorm.DB) { | |||||||
| 	// after save
 | 	// after save
 | ||||||
| 	// after create
 | 	// after create
 | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func objectToFieldsMap(stmt *gorm.Statement) { |  | ||||||
| 	if stmt.Schema != nil { |  | ||||||
| 		if s, ok := stmt.Clauses["SELECT"]; ok { |  | ||||||
| 			s.Attrs |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if s, ok := stmt.Clauses["OMIT"]; ok { |  | ||||||
| 			s.Attrs |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		stmt.Schema.LookUpField(s.S) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								callbacks/delete.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								callbacks/delete.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import "github.com/jinzhu/gorm" | ||||||
|  | 
 | ||||||
|  | func BeforeDelete(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Delete(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func AfterDelete(db *gorm.DB) { | ||||||
|  | } | ||||||
							
								
								
									
										9
									
								
								callbacks/transaction.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								callbacks/transaction.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import "github.com/jinzhu/gorm" | ||||||
|  | 
 | ||||||
|  | func BeginTransaction(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func CommitOrRollbackTransaction(db *gorm.DB) { | ||||||
|  | } | ||||||
							
								
								
									
										12
									
								
								callbacks/update.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								callbacks/update.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | |||||||
|  | package callbacks | ||||||
|  | 
 | ||||||
|  | import "github.com/jinzhu/gorm" | ||||||
|  | 
 | ||||||
|  | func BeforeUpdate(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Update(db *gorm.DB) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func AfterUpdate(db *gorm.DB) { | ||||||
|  | } | ||||||
| @ -1,5 +0,0 @@ | |||||||
| module github.com/jinzhu/gorm/dialects/sqlite |  | ||||||
| 
 |  | ||||||
| go 1.13 |  | ||||||
| 
 |  | ||||||
| require github.com/mattn/go-sqlite3 v2.0.3+incompatible |  | ||||||
| @ -1,2 +0,0 @@ | |||||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= |  | ||||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= |  | ||||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @ -2,7 +2,4 @@ module github.com/jinzhu/gorm | |||||||
| 
 | 
 | ||||||
| go 1.13 | go 1.13 | ||||||
| 
 | 
 | ||||||
| require ( | require github.com/jinzhu/inflection v1.0.0 | ||||||
| 	github.com/jinzhu/inflection v1.0.0 |  | ||||||
| 	gopkg.in/errgo.v2 v2.1.0 |  | ||||||
| ) |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @ -1,2 +0,0 @@ | |||||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= |  | ||||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= |  | ||||||
							
								
								
									
										3
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								gorm.go
									
									
									
									
									
								
							| @ -63,10 +63,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | |||||||
| 		Config:     config, | 		Config:     config, | ||||||
| 		Dialector:  dialector, | 		Dialector:  dialector, | ||||||
| 		clone:      true, | 		clone:      true, | ||||||
| 		callbacks:  InitializeCallbacks(), |  | ||||||
| 		cacheStore: &sync.Map{}, | 		cacheStore: &sync.Map{}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	db.callbacks = initializeCallbacks(db) | ||||||
|  | 
 | ||||||
| 	if dialector != nil { | 	if dialector != nil { | ||||||
| 		err = dialector.Initialize(db) | 		err = dialector.Initialize(db) | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								statement.go
									
									
									
									
									
								
							| @ -21,6 +21,13 @@ type Instance struct { | |||||||
| 	Statement    *Statement | 	Statement    *Statement | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { | ||||||
|  | 	if len(clauses) > 0 { | ||||||
|  | 		instance.Statement.Build(clauses...) | ||||||
|  | 	} | ||||||
|  | 	return instance.Statement.SQL.String(), instance.Statement.Vars | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // AddError add error to instance
 | // AddError add error to instance
 | ||||||
| func (inst Instance) AddError(err error) { | func (inst Instance) AddError(err error) { | ||||||
| 	if inst.Error == nil { | 	if inst.Error == nil { | ||||||
| @ -205,16 +212,17 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con | |||||||
| 
 | 
 | ||||||
| // Build build sql with clauses names
 | // Build build sql with clauses names
 | ||||||
| func (stmt Statement) Build(clauses ...string) { | func (stmt Statement) Build(clauses ...string) { | ||||||
| 	var includeSpace bool | 	var firstClauseWritten bool | ||||||
| 
 | 
 | ||||||
| 	for _, name := range clauses { | 	for _, name := range clauses { | ||||||
| 		if c, ok := stmt.Clauses[name]; ok { | 		if c, ok := stmt.Clauses[name]; ok { | ||||||
| 			if includeSpace { | 			if firstClauseWritten { | ||||||
| 				stmt.WriteByte(' ') | 				stmt.WriteByte(' ') | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			includeSpace = true | 			firstClauseWritten = true | ||||||
| 			c.Build(stmt) | 			c.Build(stmt) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	// TODO handle named vars
 | ||||||
| } | } | ||||||
|  | |||||||
| @ -99,8 +99,8 @@ func TestCallbacks(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for idx, data := range datas { | 	for idx, data := range datas { | ||||||
| 		var err error | 		db, err := gorm.Open(nil, nil) | ||||||
| 		callbacks := gorm.InitializeCallbacks() | 		callbacks := db.Callback() | ||||||
| 
 | 
 | ||||||
| 		for _, c := range data.callbacks { | 		for _, c := range data.callbacks { | ||||||
| 			var v interface{} = callbacks.Create() | 			var v interface{} = callbacks.Create() | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu