diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..a2e53e27 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -42,6 +42,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) + createCallback.Register("gorm:after_error", AfterError) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) @@ -49,6 +50,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) + queryCallback.Register("gorm:after_error", AfterError) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Clauses = config.QueryClauses @@ -58,6 +60,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete(config)) + deleteCallback.Register("gorm:after_error", AfterError) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Clauses = config.DeleteClauses @@ -68,6 +71,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:update", Update(config)) + updateCallback.Register("gorm:after_error", AfterError) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) @@ -75,9 +79,11 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { rowCallback := db.Callback().Row() rowCallback.Register("gorm:row", RowQuery) + rowCallback.Register("gorm:after_error", AfterError) rowCallback.Clauses = config.QueryClauses rawCallback := db.Callback().Raw() rawCallback.Register("gorm:raw", RawExec) + rawCallback.Register("gorm:after_error", AfterError) rawCallback.Clauses = config.QueryClauses } diff --git a/callbacks/error.go b/callbacks/error.go new file mode 100644 index 00000000..27734aad --- /dev/null +++ b/callbacks/error.go @@ -0,0 +1,25 @@ +package callbacks + +import ( + "gorm.io/gorm" + "reflect" +) + +// AfterError after error callback executes if any error happens during main callbacks +func AfterError(db *gorm.DB) { + if db.Statement.ReflectValue.Kind() == reflect.Ptr && db.Statement.ReflectValue.IsNil() { + return + } + if db.Error != nil && db.Statement.Schema != nil && !db.Statement.SkipHooks { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if db.Statement.Schema.AfterError { + if i, ok := value.(AfterErrorInterface); ok { + db.AddError(i.AfterError(tx)) + return true + } + } + return false + }) + } + return +} diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go index 2302470f..baa23b06 100644 --- a/callbacks/interfaces.go +++ b/callbacks/interfaces.go @@ -37,3 +37,7 @@ type AfterDeleteInterface interface { type AfterFindInterface interface { AfterFind(*gorm.DB) error } + +type AfterErrorInterface interface { + AfterError(*gorm.DB) error +} diff --git a/schema/schema.go b/schema/schema.go index 3e7459ce..3ca1dda2 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -25,6 +25,7 @@ const ( callbackTypeBeforeDelete callbackType = "BeforeDelete" callbackTypeAfterDelete callbackType = "AfterDelete" callbackTypeAfterFind callbackType = "AfterFind" + callbackTypeAfterError callbackType = "AfterError" ) // ErrUnsupportedDataType unsupported data type @@ -53,6 +54,7 @@ type Schema struct { BeforeDelete, AfterDelete bool BeforeSave, AfterSave bool AfterFind bool + AfterError bool err error initialized chan struct{} namer Namer @@ -308,6 +310,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam callbackTypeBeforeSave, callbackTypeAfterSave, callbackTypeBeforeDelete, callbackTypeAfterDelete, callbackTypeAfterFind, + callbackTypeAfterError, } for _, cbName := range callbackTypes { if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { @@ -397,6 +400,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect return modelType.MethodByName(string(callbackTypeAfterDelete)) case callbackTypeAfterFind: return modelType.MethodByName(string(callbackTypeAfterFind)) + case callbackTypeAfterError: + return modelType.MethodByName(string(callbackTypeAfterError)) default: return reflect.ValueOf(nil) }