Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names * Add a test for deeply nested wrapped transactions
This commit is contained in:
		
							parent
							
								
									0daaf1747c
								
							
						
					
					
						commit
						7f75b12bb2
					
				| @ -4,6 +4,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"hash/maphash" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 
 | ||||
| @ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er | ||||
| 	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { | ||||
| 		// nested transaction
 | ||||
| 		if !db.DisableNestedTransaction { | ||||
| 			err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error | ||||
| 			spID := new(maphash.Hash).Sum64() | ||||
| 			err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			defer func() { | ||||
| 				// Make sure to rollback when panic, Block error or Commit error
 | ||||
| 				if panicked || err != nil { | ||||
| 					db.RollbackTo(fmt.Sprintf("sp%p", fc)) | ||||
| 					db.RollbackTo(fmt.Sprintf("sp%d", spID)) | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
|  | ||||
| @ -29,8 +29,8 @@ require ( | ||||
| 	github.com/microsoft/go-mssqldb v1.7.2 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/rogpeppe/go-internal v1.12.0 // indirect | ||||
| 	golang.org/x/crypto v0.24.0 // indirect | ||||
| 	golang.org/x/text v0.16.0 // indirect | ||||
| 	golang.org/x/crypto v0.26.0 // indirect | ||||
| 	golang.org/x/text v0.17.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
| 
 | ||||
|  | ||||
| @ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) { | ||||
| 	transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error { | ||||
| 		return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { | ||||
| 			return callback(ctx, tx) | ||||
| 		}) | ||||
| 	} | ||||
| 	var ( | ||||
| 		user  = *GetUser("transaction-nested", Config{}) | ||||
| 		user1 = *GetUser("transaction-nested-1", Config{}) | ||||
| 		user2 = *GetUser("transaction-nested-2", Config{}) | ||||
| 	) | ||||
| 
 | ||||
| 	if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error { | ||||
| 		tx.Create(&user) | ||||
| 
 | ||||
| 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 			t.Fatalf("Should find saved record") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error { | ||||
| 			tx1.Create(&user1) | ||||
| 
 | ||||
| 			if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { | ||||
| 				t.Fatalf("Should find saved record") | ||||
| 			} | ||||
| 
 | ||||
| 			if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error { | ||||
| 				tx2.Create(&user2) | ||||
| 
 | ||||
| 				if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 					t.Fatalf("Should find saved record") | ||||
| 				} | ||||
| 
 | ||||
| 				return errors.New("inner rollback") | ||||
| 			}); err == nil { | ||||
| 				t.Fatalf("nested transaction has no error") | ||||
| 			} | ||||
| 
 | ||||
| 			return errors.New("rollback") | ||||
| 		}); err == nil { | ||||
| 			t.Fatalf("nested transaction should returns error") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||
| 			t.Fatalf("Should not find rollbacked record") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 			t.Fatalf("Should find saved record") | ||||
| 		} | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		t.Fatalf("no error should return, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||
| 		t.Fatalf("Should not find rollbacked parent record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should not find rollbacked nested record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDisabledNestedTransaction(t *testing.T) { | ||||
| 	var ( | ||||
| 		user  = *GetUser("transaction-nested", Config{}) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Leo Sjöberg
						Leo Sjöberg