diff --git a/finisher_api.go b/finisher_api.go index 8a3d4199..19534460 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -445,7 +445,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // Commit commit a transaction func (db *DB) Commit() *DB { - if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -456,7 +456,9 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { - db.AddError(committer.Rollback()) + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } } else { db.AddError(ErrInvalidTransaction) } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c101388a..aea151d9 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "errors" "testing" @@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) { } } +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() {