Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names * Add a test for deeply nested wrapped transactions
This commit is contained in:
		
							parent
							
								
									0daaf1747c
								
							
						
					
					
						commit
						7f75b12bb2
					
				| @ -4,6 +4,7 @@ import ( | |||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"hash/maphash" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| @ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er | |||||||
| 	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { | 	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { | ||||||
| 		// nested transaction
 | 		// nested transaction
 | ||||||
| 		if !db.DisableNestedTransaction { | 		if !db.DisableNestedTransaction { | ||||||
| 			err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error | 			spID := new(maphash.Hash).Sum64() | ||||||
|  | 			err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			defer func() { | 			defer func() { | ||||||
| 				// Make sure to rollback when panic, Block error or Commit error
 | 				// Make sure to rollback when panic, Block error or Commit error
 | ||||||
| 				if panicked || err != nil { | 				if panicked || err != nil { | ||||||
| 					db.RollbackTo(fmt.Sprintf("sp%p", fc)) | 					db.RollbackTo(fmt.Sprintf("sp%d", spID)) | ||||||
| 				} | 				} | ||||||
| 			}() | 			}() | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -29,8 +29,8 @@ require ( | |||||||
| 	github.com/microsoft/go-mssqldb v1.7.2 // indirect | 	github.com/microsoft/go-mssqldb v1.7.2 // indirect | ||||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||||
| 	github.com/rogpeppe/go-internal v1.12.0 // indirect | 	github.com/rogpeppe/go-internal v1.12.0 // indirect | ||||||
| 	golang.org/x/crypto v0.24.0 // indirect | 	golang.org/x/crypto v0.26.0 // indirect | ||||||
| 	golang.org/x/text v0.16.0 // indirect | 	golang.org/x/text v0.17.0 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) { | ||||||
|  | 	transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error { | ||||||
|  | 		return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { | ||||||
|  | 			return callback(ctx, tx) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	var ( | ||||||
|  | 		user  = *GetUser("transaction-nested", Config{}) | ||||||
|  | 		user1 = *GetUser("transaction-nested-1", Config{}) | ||||||
|  | 		user2 = *GetUser("transaction-nested-2", Config{}) | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error { | ||||||
|  | 		tx.Create(&user) | ||||||
|  | 
 | ||||||
|  | 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||||
|  | 			t.Fatalf("Should find saved record") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error { | ||||||
|  | 			tx1.Create(&user1) | ||||||
|  | 
 | ||||||
|  | 			if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { | ||||||
|  | 				t.Fatalf("Should find saved record") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error { | ||||||
|  | 				tx2.Create(&user2) | ||||||
|  | 
 | ||||||
|  | 				if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||||
|  | 					t.Fatalf("Should find saved record") | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				return errors.New("inner rollback") | ||||||
|  | 			}); err == nil { | ||||||
|  | 				t.Fatalf("nested transaction has no error") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return errors.New("rollback") | ||||||
|  | 		}); err == nil { | ||||||
|  | 			t.Fatalf("nested transaction should returns error") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||||
|  | 			t.Fatalf("Should not find rollbacked record") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||||
|  | 			t.Fatalf("Should find saved record") | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}); err != nil { | ||||||
|  | 		t.Fatalf("no error should return, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||||
|  | 		t.Fatalf("Should find saved record") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||||
|  | 		t.Fatalf("Should not find rollbacked parent record") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||||
|  | 		t.Fatalf("Should not find rollbacked nested record") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestDisabledNestedTransaction(t *testing.T) { | func TestDisabledNestedTransaction(t *testing.T) { | ||||||
| 	var ( | 	var ( | ||||||
| 		user  = *GetUser("transaction-nested", Config{}) | 		user  = *GetUser("transaction-nested", Config{}) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Leo Sjöberg
						Leo Sjöberg