diff --git a/finisher_api.go b/finisher_api.go index 39d9fca3..c52efdb8 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -617,7 +617,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction - if !db.DisableNestedTransaction { + if !db.DisableNestedTransaction && (db.PrepareStmt && !db.DisablePrepareNestedTransaction) { err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error if err != nil { return diff --git a/gorm.go b/gorm.go index 37595ddd..3ef228ec 100644 --- a/gorm.go +++ b/gorm.go @@ -41,6 +41,8 @@ type Config struct { IgnoreRelationshipsWhenMigrating bool // DisableNestedTransaction disable nested transaction DisableNestedTransaction bool + // DisablePrepareNestedTransaction disable nested transaction in prepare statement + DisablePrepareNestedTransaction bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 64baa01b..3ba3d335 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -51,6 +51,14 @@ func TestPreparedStmt(t *testing.T) { if err := tx.Find(&result3, user2.ID).Error; err != nil { t.Fatalf("no error should happen but got %v", err) } + + user3 := *GetUser("prepared_stmt_transaction", Config{}) + err := tx.Transaction(func(tx1 *gorm.DB) error { + return tx1.Transaction(func(tx2 *gorm.DB) error { + return tx2.Create(&user3).Error + }) + }) + AssertEqual(t, nil, err) } func TestPreparedStmtFromTransaction(t *testing.T) {