feat: adds AfterError callback

This commit is contained in:
guilhermefbarbosa 2023-10-19 09:45:04 -03:00
parent 6bef318891
commit 2e00b2bd7d
4 changed files with 40 additions and 0 deletions

View File

@ -42,6 +42,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:after_error", AfterError)
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
@ -49,6 +50,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:after_error", AfterError)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
@ -58,6 +60,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_error", AfterError)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
@ -68,6 +71,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:after_error", AfterError)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
@ -75,9 +79,11 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Register("gorm:after_error", AfterError)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Register("gorm:after_error", AfterError)
rawCallback.Clauses = config.QueryClauses
}

25
callbacks/error.go Normal file
View File

@ -0,0 +1,25 @@
package callbacks
import (
"gorm.io/gorm"
"reflect"
)
// AfterError after error callback executes if any error happens during main callbacks
func AfterError(db *gorm.DB) {
if db.Statement.ReflectValue.Kind() == reflect.Ptr && db.Statement.ReflectValue.IsNil() {
return
}
if db.Error != nil && db.Statement.Schema != nil && !db.Statement.SkipHooks {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if db.Statement.Schema.AfterError {
if i, ok := value.(AfterErrorInterface); ok {
db.AddError(i.AfterError(tx))
return true
}
}
return false
})
}
return
}

View File

@ -37,3 +37,7 @@ type AfterDeleteInterface interface {
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}
type AfterErrorInterface interface {
AfterError(*gorm.DB) error
}

View File

@ -25,6 +25,7 @@ const (
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
callbackTypeAfterError callbackType = "AfterError"
)
// ErrUnsupportedDataType unsupported data type
@ -53,6 +54,7 @@ type Schema struct {
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
AfterError bool
err error
initialized chan struct{}
namer Namer
@ -308,6 +310,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
callbackTypeAfterError,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
@ -397,6 +400,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
case callbackTypeAfterError:
return modelType.MethodByName(string(callbackTypeAfterError))
default:
return reflect.ValueOf(nil)
}