Merge branch 'master' into master
This commit is contained in:
commit
53c0fcfe7d
25
callback.go
25
callback.go
@ -1,6 +1,6 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "log"
|
import "fmt"
|
||||||
|
|
||||||
// DefaultCallback default callbacks defined by gorm
|
// DefaultCallback default callbacks defined by gorm
|
||||||
var DefaultCallback = &Callback{}
|
var DefaultCallback = &Callback{}
|
||||||
@ -13,6 +13,7 @@ var DefaultCallback = &Callback{}
|
|||||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
type Callback struct {
|
type Callback struct {
|
||||||
|
logger logger
|
||||||
creates []*func(scope *Scope)
|
creates []*func(scope *Scope)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(scope *Scope)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(scope *Scope)
|
||||||
@ -23,6 +24,7 @@ type Callback struct {
|
|||||||
|
|
||||||
// CallbackProcessor contains callback informations
|
// CallbackProcessor contains callback informations
|
||||||
type CallbackProcessor struct {
|
type CallbackProcessor struct {
|
||||||
|
logger logger
|
||||||
name string // current callback's name
|
name string // current callback's name
|
||||||
before string // register current callback before a callback
|
before string // register current callback before a callback
|
||||||
after string // register current callback after a callback
|
after string // register current callback after a callback
|
||||||
@ -33,8 +35,9 @@ type CallbackProcessor struct {
|
|||||||
parent *Callback
|
parent *Callback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Callback) clone() *Callback {
|
func (c *Callback) clone(logger logger) *Callback {
|
||||||
return &Callback{
|
return &Callback{
|
||||||
|
logger: logger,
|
||||||
creates: c.creates,
|
creates: c.creates,
|
||||||
updates: c.updates,
|
updates: c.updates,
|
||||||
deletes: c.deletes,
|
deletes: c.deletes,
|
||||||
@ -53,28 +56,28 @@ func (c *Callback) clone() *Callback {
|
|||||||
// scope.Err(errors.New("error"))
|
// scope.Err(errors.New("error"))
|
||||||
// })
|
// })
|
||||||
func (c *Callback) Create() *CallbackProcessor {
|
func (c *Callback) Create() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "create", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||||
func (c *Callback) Update() *CallbackProcessor {
|
func (c *Callback) Update() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "update", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||||
func (c *Callback) Delete() *CallbackProcessor {
|
func (c *Callback) Delete() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "delete", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||||
// Refer `Create` for usage
|
// Refer `Create` for usage
|
||||||
func (c *Callback) Query() *CallbackProcessor {
|
func (c *Callback) Query() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "query", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||||
func (c *Callback) RowQuery() *CallbackProcessor {
|
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "row_query", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||||
@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
|||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
|
||||||
cp.before = "gorm:row_query"
|
cp.before = "gorm:row_query"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,7 +110,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
|||||||
// Remove a registered callback
|
// Remove a registered callback
|
||||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.remove = true
|
cp.remove = true
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
@ -120,7 +123,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
// scope.SetColumn("Updated", now)
|
// scope.SetColumn("Updated", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
@ -159,7 +162,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
|||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||||
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
|
||||||
}
|
}
|
||||||
allNames = append(allNames, cp.name)
|
allNames = append(allNames, cp.name)
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {}
|
|||||||
func afterCreate2(s *Scope) {}
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &Callback{}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("before_create2", beforeCreate2)
|
callback.Create().Register("before_create2", beforeCreate2)
|
||||||
@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
var callback1 = &Callback{}
|
var callback1 = &Callback{logger: defaultLogger}
|
||||||
callback1.Create().Register("before_create1", beforeCreate1)
|
callback1.Create().Register("before_create1", beforeCreate1)
|
||||||
callback1.Create().Register("create", create)
|
callback1.Create().Register("create", create)
|
||||||
callback1.Create().Register("after_create1", afterCreate1)
|
callback1.Create().Register("after_create1", afterCreate1)
|
||||||
@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
|||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &Callback{}
|
var callback2 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback2.Update().Register("create", create)
|
callback2.Update().Register("create", create)
|
||||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||||
@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
var callback1 = &Callback{}
|
var callback1 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback1.Query().Register("before_create1", beforeCreate1)
|
callback1.Query().Register("before_create1", beforeCreate1)
|
||||||
@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &Callback{}
|
var callback2 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||||
@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||||||
func replaceCreate(s *Scope) {}
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &Callback{}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveCallback(t *testing.T) {
|
func TestRemoveCallback(t *testing.T) {
|
||||||
var callback = &Callback{}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
6
main.go
6
main.go
@ -138,7 +138,7 @@ func (s *DB) Dialect() Dialect {
|
|||||||
// db.Callback().Create().Register("update_created_at", updateCreated)
|
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||||
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
||||||
func (s *DB) Callback() *Callback {
|
func (s *DB) Callback() *Callback {
|
||||||
s.parent.callbacks = s.parent.callbacks.clone()
|
s.parent.callbacks = s.parent.callbacks.clone(s.logger)
|
||||||
return s.parent.callbacks
|
return s.parent.callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -533,7 +533,9 @@ func (s *DB) Commit() *DB {
|
|||||||
func (s *DB) Rollback() *DB {
|
func (s *DB) Rollback() *DB {
|
||||||
var emptySQLTx *sql.Tx
|
var emptySQLTx *sql.Tx
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
s.AddError(db.Rollback())
|
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||||
|
s.AddError(err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
|
16
main_test.go
16
main_test.go
@ -421,6 +421,22 @@ func TestTransaction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
u := User{Name: "transcation"}
|
||||||
|
if err := tx.Save(&u).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
t.Errorf("Commit should not raise error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Rollback().Error; err != nil {
|
||||||
|
t.Errorf("Rollback should not raise error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRow(t *testing.T) {
|
func TestRow(t *testing.T) {
|
||||||
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
4
scope.go
4
scope.go
@ -402,7 +402,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
|||||||
// Begin start a transaction
|
// Begin start a transaction
|
||||||
func (scope *Scope) Begin() *Scope {
|
func (scope *Scope) Begin() *Scope {
|
||||||
if db, ok := scope.SQLDB().(sqlDb); ok {
|
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); scope.Err(err) == nil {
|
||||||
scope.db.db = interface{}(tx).(SQLCommon)
|
scope.db.db = interface{}(tx).(SQLCommon)
|
||||||
scope.InstanceSet("gorm:started_transaction", true)
|
scope.InstanceSet("gorm:started_transaction", true)
|
||||||
}
|
}
|
||||||
@ -1194,7 +1194,7 @@ func (scope *Scope) createTable() *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropTable() *Scope {
|
func (scope *Scope) dropTable() *Scope {
|
||||||
scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec()
|
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,3 +78,16 @@ func TestFailedValuer(t *testing.T) {
|
|||||||
t.Errorf("The error should be returned from Valuer, but get %v", err)
|
t.Errorf("The error should be returned from Valuer, but get %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDropTableWithTableOptions(t *testing.T) {
|
||||||
|
type UserWithOptions struct {
|
||||||
|
gorm.Model
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&UserWithOptions{})
|
||||||
|
|
||||||
|
DB = DB.Set("gorm:table_options", "CHARSET=utf8")
|
||||||
|
err := DB.DropTable(&UserWithOptions{}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Table must be dropped, got error %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user