feat: add Connection to execute multiple commands in a single connection;
				
					
				
			This commit is contained in:
		
							parent
							
								
									4dd2647967
								
							
						
					
					
						commit
						bf1508f3f0
					
				@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
 | 
				
			|||||||
	return tx.Error
 | 
						return tx.Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Connection  use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
 | 
				
			||||||
 | 
					func (db *DB) Connection(fc func(tx *DB) error) (err error) {
 | 
				
			||||||
 | 
						if db.Error != nil {
 | 
				
			||||||
 | 
							return db.Error
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx := db.getInstance()
 | 
				
			||||||
 | 
						sqlDB, err := tx.DB()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := sqlDB.Conn(tx.Statement.Context)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
						tx.Statement.ConnPool = conn
 | 
				
			||||||
 | 
						err = fc(tx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | 
					// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | 
				
			||||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
 | 
					func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
 | 
				
			||||||
	panicked := true
 | 
						panicked := true
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										48
									
								
								tests/connection_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tests/connection_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
				
			|||||||
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"gorm.io/driver/mysql"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWithSingleConnection(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var expectedName = "test"
 | 
				
			||||||
 | 
						var actualName string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						setSQL, getSQL := getSetSQL(DB.Dialector.Name())
 | 
				
			||||||
 | 
						if len(setSQL) == 0 || len(getSQL) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := DB.Connection(func(tx *gorm.DB) error {
 | 
				
			||||||
 | 
							if err := tx.Exec(setSQL, expectedName).Error; err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if actualName != expectedName {
 | 
				
			||||||
 | 
							t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getSetSQL(driverName string) (string, string) {
 | 
				
			||||||
 | 
						switch driverName {
 | 
				
			||||||
 | 
						case mysql.Dialector{}.Name():
 | 
				
			||||||
 | 
							return "SET @testName := ?", "SELECT @testName"
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return "", ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user