Add SavePoint/RollbackTo/NestedTransaction
This commit is contained in:
		
							parent
							
								
									2c1b04a2cf
								
							
						
					
					
						commit
						7dc255acfe
					
				| @ -25,4 +25,6 @@ var ( | ||||
| 	ErrorPrimaryKeyRequired = errors.New("primary key required") | ||||
| 	// ErrorModelValueRequired model value required
 | ||||
| 	ErrorModelValueRequired = errors.New("model value required") | ||||
| 	// ErrUnsupportedDriver unsupported driver
 | ||||
| 	ErrUnsupportedDriver = errors.New("unsupported driver") | ||||
| ) | ||||
|  | ||||
| @ -3,6 +3,7 @@ package gorm | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 
 | ||||
| @ -343,18 +344,33 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { | ||||
| // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | ||||
| func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { | ||||
| 	panicked := true | ||||
| 	tx := db.Begin(opts...) | ||||
| 	defer func() { | ||||
| 		// Make sure to rollback when panic, Block error or Commit error
 | ||||
| 		if panicked || err != nil { | ||||
| 			tx.Rollback() | ||||
| 
 | ||||
| 	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { | ||||
| 		// nested transaction
 | ||||
| 		db.SavePoint(fmt.Sprintf("sp%p", fc)) | ||||
| 		defer func() { | ||||
| 			// Make sure to rollback when panic, Block error or Commit error
 | ||||
| 			if panicked || err != nil { | ||||
| 				db.RollbackTo(fmt.Sprintf("sp%p", fc)) | ||||
| 			} | ||||
| 		}() | ||||
| 
 | ||||
| 		err = fc(db.Session(&Session{WithConditions: true})) | ||||
| 	} else { | ||||
| 		tx := db.Begin(opts...) | ||||
| 
 | ||||
| 		defer func() { | ||||
| 			// Make sure to rollback when panic, Block error or Commit error
 | ||||
| 			if panicked || err != nil { | ||||
| 				tx.Rollback() | ||||
| 			} | ||||
| 		}() | ||||
| 
 | ||||
| 		err = fc(tx) | ||||
| 
 | ||||
| 		if err == nil { | ||||
| 			err = tx.Commit().Error | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	err = fc(tx) | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		err = tx.Commit().Error | ||||
| 	} | ||||
| 
 | ||||
| 	panicked = false | ||||
| @ -409,6 +425,24 @@ func (db *DB) Rollback() *DB { | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| func (db *DB) SavePoint(name string) *DB { | ||||
| 	if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { | ||||
| 		savePointer.SavePoint(db, name) | ||||
| 	} else { | ||||
| 		db.AddError(ErrUnsupportedDriver) | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| func (db *DB) RollbackTo(name string) *DB { | ||||
| 	if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { | ||||
| 		savePointer.RollbackTo(db, name) | ||||
| 	} else { | ||||
| 		db.AddError(ErrUnsupportedDriver) | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| // Exec execute raw sql
 | ||||
| func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
|  | ||||
| @ -27,6 +27,11 @@ type ConnPool interface { | ||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||
| } | ||||
| 
 | ||||
| type SavePointerDialectorInterface interface { | ||||
| 	SavePoint(tx *DB, name string) error | ||||
| 	RollbackTo(tx *DB, name string) error | ||||
| } | ||||
| 
 | ||||
| type TxBeginner interface { | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||
| } | ||||
|  | ||||
							
								
								
									
										10
									
								
								tests/go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								tests/go.mod
									
									
									
									
									
								
							| @ -6,11 +6,11 @@ require ( | ||||
| 	github.com/google/uuid v1.1.1 | ||||
| 	github.com/jinzhu/now v1.1.1 | ||||
| 	github.com/lib/pq v1.6.0 | ||||
| 	gorm.io/driver/mysql v0.2.0 | ||||
| 	gorm.io/driver/postgres v0.2.0 | ||||
| 	gorm.io/driver/sqlite v1.0.2 | ||||
| 	gorm.io/driver/sqlserver v0.2.0 | ||||
| 	gorm.io/gorm v0.0.0-00010101000000-000000000000 | ||||
| 	gorm.io/driver/mysql v0.2.1 | ||||
| 	gorm.io/driver/postgres v0.2.1 | ||||
| 	gorm.io/driver/sqlite v1.0.4 | ||||
| 	gorm.io/driver/sqlserver v0.2.1 | ||||
| 	gorm.io/gorm v0.2.7 | ||||
| ) | ||||
| 
 | ||||
| replace gorm.io/gorm => ../ | ||||
|  | ||||
| @ -142,3 +142,123 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { | ||||
| 		t.Fatalf("Rollback after commit should raise error") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestTransactionWithSavePoint(t *testing.T) { | ||||
| 	tx := DB.Begin() | ||||
| 
 | ||||
| 	user := *GetUser("transaction-save-point", Config{}) | ||||
| 	tx.Create(&user) | ||||
| 
 | ||||
| 	if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.SavePoint("save_point1").Error; err != nil { | ||||
| 		t.Fatalf("Failed to save point, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user1 := *GetUser("transaction-save-point-1", Config{}) | ||||
| 	tx.Create(&user1) | ||||
| 
 | ||||
| 	if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.RollbackTo("save_point1").Error; err != nil { | ||||
| 		t.Fatalf("Failed to save point, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||
| 		t.Fatalf("Should not find rollbacked record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.SavePoint("save_point2").Error; err != nil { | ||||
| 		t.Fatalf("Failed to save point, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user2 := *GetUser("transaction-save-point-2", Config{}) | ||||
| 	tx.Create(&user2) | ||||
| 
 | ||||
| 	if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Commit().Error; err != nil { | ||||
| 		t.Fatalf("Failed to commit, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||
| 		t.Fatalf("Should not find rollbacked record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNestedTransactionWithBlock(t *testing.T) { | ||||
| 	var ( | ||||
| 		user  = *GetUser("transaction-nested", Config{}) | ||||
| 		user1 = *GetUser("transaction-nested-1", Config{}) | ||||
| 		user2 = *GetUser("transaction-nested-2", Config{}) | ||||
| 	) | ||||
| 
 | ||||
| 	if err := DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		tx.Create(&user) | ||||
| 
 | ||||
| 		if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 			t.Fatalf("Should find saved record") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Transaction(func(tx1 *gorm.DB) error { | ||||
| 			tx1.Create(&user1) | ||||
| 
 | ||||
| 			if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { | ||||
| 				t.Fatalf("Should find saved record") | ||||
| 			} | ||||
| 
 | ||||
| 			return errors.New("rollback") | ||||
| 		}); err == nil { | ||||
| 			t.Fatalf("nested transaction should returns error") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||
| 			t.Fatalf("Should not find rollbacked record") | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Transaction(func(tx2 *gorm.DB) error { | ||||
| 			tx2.Create(&user2) | ||||
| 
 | ||||
| 			if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 				t.Fatalf("Should find saved record") | ||||
| 			} | ||||
| 
 | ||||
| 			return nil | ||||
| 		}); err != nil { | ||||
| 			t.Fatalf("nested transaction returns error: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 			t.Fatalf("Should find saved record") | ||||
| 		} | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		t.Fatalf("no error should return, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { | ||||
| 		t.Fatalf("Should not find rollbacked record") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { | ||||
| 		t.Fatalf("Should find saved record") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -124,9 +124,3 @@ build: | ||||
|                 name: test mssql | ||||
|                 code: | | ||||
|                     GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh | ||||
| 
 | ||||
|         - script: | ||||
|                 name: codecov | ||||
|                 code: | | ||||
|                     go test -race -coverprofile=coverage.txt -covermode=atomic ./... | ||||
|                     bash <(curl -s https://codecov.io/bash) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu