diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 3ba3d335..34ab7cb0 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -52,13 +52,31 @@ func TestPreparedStmt(t *testing.T) { t.Fatalf("no error should happen but got %v", err) } - user3 := *GetUser("prepared_stmt_transaction", Config{}) + users := []User{ + *GetUser("prepared_stmt_transaction_1", Config{}), + *GetUser("prepared_stmt_transaction_2", Config{}), + *GetUser("prepared_stmt_transaction_3", Config{}), + } err := tx.Transaction(func(tx1 *gorm.DB) error { - return tx1.Transaction(func(tx2 *gorm.DB) error { - return tx2.Create(&user3).Error + tx1.Create(&users[0]) + + tx1.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&users[1]) + return errors.New("rollback user2") // Rollback user3 }) + + tx1.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&users[2]) + return nil + }) + return nil }) AssertEqual(t, nil, err) + + var psUsers []User + err = tx.Where("name like ?", "prepared_stmt_transaction%").Find(&psUsers).Error + AssertEqual(t, nil, err) + AssertEqual(t, 2, len(psUsers)) } func TestPreparedStmtFromTransaction(t *testing.T) {