Add default transaction timeout support
This commit is contained in:
		
							parent
							
								
									4ee59e1d87
								
							
						
					
					
						commit
						4db3fde9c5
					
				| @ -1,6 +1,7 @@ | |||||||
| package gorm | package gorm | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | |||||||
| 		opt = opts[0] | 		opt = opts[0] | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	ctx := tx.Statement.Context | ||||||
|  | 	if _, ok := ctx.Deadline(); !ok { | ||||||
|  | 		if db.Config.DefaultTransactionTimeout > 0 { | ||||||
|  | 			ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	switch beginner := tx.Statement.ConnPool.(type) { | 	switch beginner := tx.Statement.ConnPool.(type) { | ||||||
| 	case TxBeginner: | 	case TxBeginner: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) | ||||||
| 	case ConnPoolBeginner: | 	case ConnPoolBeginner: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) | ||||||
| 	default: | 	default: | ||||||
| 		err = ErrInvalidTransaction | 		err = ErrInvalidTransaction | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -127,7 +127,11 @@ type g[T any] struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (g *g[T]) apply(ctx context.Context) *DB { | func (g *g[T]) apply(ctx context.Context) *DB { | ||||||
| 	db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance() | 	db := g.db | ||||||
|  | 	if !db.DryRun { | ||||||
|  | 		db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for _, op := range g.ops { | 	for _, op := range g.ops { | ||||||
| 		db = op(db) | 		db = op(db) | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								gorm.go
									
									
									
									
									
								
							| @ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt" | |||||||
| type Config struct { | type Config struct { | ||||||
| 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | ||||||
| 	// You can disable it by setting `SkipDefaultTransaction` to true
 | 	// You can disable it by setting `SkipDefaultTransaction` to true
 | ||||||
| 	SkipDefaultTransaction bool | 	SkipDefaultTransaction    bool | ||||||
|  | 	DefaultTransactionTimeout time.Duration | ||||||
|  | 
 | ||||||
| 	// NamingStrategy tables, columns naming strategy
 | 	// NamingStrategy tables, columns naming strategy
 | ||||||
| 	NamingStrategy schema.Namer | 	NamingStrategy schema.Namer | ||||||
| 	// FullSaveAssociations full save associations
 | 	// FullSaveAssociations full save associations
 | ||||||
| @ -519,7 +521,7 @@ func (db *DB) Use(plugin Plugin) error { | |||||||
| //				.First(&User{})
 | //				.First(&User{})
 | ||||||
| //	})
 | //	})
 | ||||||
| func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { | func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { | ||||||
| 	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) | 	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) | ||||||
| 	stmt := tx.Statement | 	stmt := tx.Statement | ||||||
| 
 | 
 | ||||||
| 	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | 	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"regexp" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| @ -593,15 +594,37 @@ func TestGenericsNestedPreloads(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | 	user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||||
| 		db.LimitPerRecord(3) |  | ||||||
| 		return nil | 		return nil | ||||||
| 	}).Where(user.ID).Take(ctx) | 	}).Where(user.ID).Take(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("failed to nested preload user") | 		t.Errorf("failed to nested preload user") | ||||||
| 	} | 	} | ||||||
| 	CheckUser(t, user2, user) | 	CheckUser(t, user2, user) | ||||||
|  | 	if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 { | ||||||
|  | 		t.Fatalf("failed to nested preload") | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { | 	if DB.Dialector.Name() == "mysql" { | ||||||
|  | 		// mysql 5.7 doesn't support row_number()
 | ||||||
|  | 		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if DB.Dialector.Name() == "sqlserver" { | ||||||
|  | 		// sqlserver doesn't support order by in subquery
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.LimitPerRecord(3) | ||||||
|  | 		return nil | ||||||
|  | 	}).Where(user.ID).Take(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to nested preload user") | ||||||
|  | 	} | ||||||
|  | 	CheckUser(t, user3, user) | ||||||
|  | 
 | ||||||
|  | 	if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 { | ||||||
| 		t.Errorf("failed to nested preload with limit per record") | 		t.Errorf("failed to nested preload with limit per record") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -784,3 +807,46 @@ func TestGenericsReuse(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	sg.Wait() | 	sg.Wait() | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsWithTransaction(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	tx := DB.Begin() | ||||||
|  | 	if tx.Error != nil { | ||||||
|  | 		t.Fatalf("failed to begin transaction: %v", tx.Error) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}} | ||||||
|  | 	err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2) | ||||||
|  | 
 | ||||||
|  | 	count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Count failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if count != 2 { | ||||||
|  | 		t.Errorf("expected 2 records, got %d", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := tx.Rollback().Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to rollback transaction: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Count failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if count2 != 0 { | ||||||
|  | 		t.Errorf("expected 0 records after rollback, got %d", count2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsToSQL(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { | ||||||
|  | 		gorm.G[User](tx).Limit(10).Find(ctx) | ||||||
|  | 		return tx | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if !regexp.MustCompile("SELECT \\* FROM `users`.* 10").MatchString(sql) { | ||||||
|  | 		t.Errorf("ToSQL: got wrong sql with Generics API %v", sql) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| @ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) { | |||||||
| 			return tx2.Scan(&User{}).Error | 			return tx2.Scan(&User{}).Error | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Error(err) | 		t.Error(err) | ||||||
| 	} | 	} | ||||||
| @ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) { | |||||||
| 			return tx3.Where("user_id", user.ID).Delete(&Account{}).Error | 			return tx3.Where("user_id", user.ID).Delete(&Account{}).Error | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Error(err) | 		t.Error(err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestTransactionWithDefaultTimeout(t *testing.T) { | ||||||
|  | 	db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to connect database, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tx := db.Begin() | ||||||
|  | 	time.Sleep(3 * time.Second) | ||||||
|  | 	if err = tx.Find(&User{}).Error; err == nil { | ||||||
|  | 		t.Errorf("should return error when transaction timeout, got error %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu