diff --git a/scope.go b/scope.go index eb7525b8..8e33820b 100644 --- a/scope.go +++ b/scope.go @@ -401,10 +401,12 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { // Begin start a transaction func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); scope.Err(err) == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) + if _, ok := scope.Get("xa"); !ok { + if db, ok := scope.SQLDB().(sqlDb); ok { + if tx, err := db.Begin(); scope.Err(err) == nil { + scope.db.db = interface{}(tx).(SQLCommon) + scope.InstanceSet("gorm:started_transaction", true) + } } } return scope @@ -412,14 +414,16 @@ func (scope *Scope) Begin() *Scope { // CommitOrRollback commit current transaction if no error happened, otherwise will rollback it func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) + if _, ok := scope.Get("xa"); !ok { + if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { + if db, ok := scope.db.db.(sqlTx); ok { + if scope.HasError() { + db.Rollback() + } else { + scope.Err(db.Commit()) + } + scope.db.db = scope.db.parent.db } - scope.db.db = scope.db.parent.db } } return scope @@ -857,8 +861,10 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { defer func() { if err := recover(); err != nil { - if db, ok := scope.db.db.(sqlTx); ok { - db.Rollback() + if _, ok := scope.Get("xa"); !ok { + if db, ok := scope.db.db.(sqlTx); ok { + db.Rollback() + } } panic(err) }