From 4db3fde9c568f1486b066ada8d829e3b01a8a159 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 23 May 2025 18:06:34 +0800 Subject: [PATCH] Add default transaction timeout support --- finisher_api.go | 12 +++++-- generics.go | 6 +++- gorm.go | 6 ++-- tests/generics_test.go | 70 +++++++++++++++++++++++++++++++++++++-- tests/transaction_test.go | 16 +++++++-- 5 files changed, 101 insertions(+), 9 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 6802945c..57809d17 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } + ctx := tx.Statement.Context + if _, ok := ctx.Deadline(); !ok { + if db.Config.DefaultTransactionTimeout > 0 { + ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + } + } + switch beginner := tx.Statement.ConnPool.(type) { case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) default: err = ErrInvalidTransaction } diff --git a/generics.go b/generics.go index 52492c8f..ad2d063f 100644 --- a/generics.go +++ b/generics.go @@ -127,7 +127,11 @@ type g[T any] struct { } func (g *g[T]) apply(ctx context.Context) *DB { - db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + db := g.db + if !db.DryRun { + db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + } + for _, op := range g.ops { db = op(db) } diff --git a/gorm.go b/gorm.go index 63a28b37..27e4caa0 100644 --- a/gorm.go +++ b/gorm.go @@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt" type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can disable it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool + SkipDefaultTransaction bool + DefaultTransactionTimeout time.Duration + // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer // FullSaveAssociations full save associations @@ -519,7 +521,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) diff --git a/tests/generics_test.go b/tests/generics_test.go index f89678b9..39decb3f 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "sort" "strconv" "strings" @@ -593,15 +594,37 @@ func TestGenericsNestedPreloads(t *testing.T) { } user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { - db.LimitPerRecord(3) return nil }).Where(user.ID).Take(ctx) if err != nil { t.Errorf("failed to nested preload user") } CheckUser(t, user2, user) + if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 { + t.Fatalf("failed to nested preload") + } - if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { + if DB.Dialector.Name() == "mysql" { + // mysql 5.7 doesn't support row_number() + if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { + return + } + } + if DB.Dialector.Name() == "sqlserver" { + // sqlserver doesn't support order by in subquery + return + } + + user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { + db.LimitPerRecord(3) + return nil + }).Where(user.ID).Take(ctx) + if err != nil { + t.Errorf("failed to nested preload user") + } + CheckUser(t, user3, user) + + if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 { t.Errorf("failed to nested preload with limit per record") } } @@ -784,3 +807,46 @@ func TestGenericsReuse(t *testing.T) { } sg.Wait() } + +func TestGenericsWithTransaction(t *testing.T) { + ctx := context.Background() + tx := DB.Begin() + if tx.Error != nil { + t.Fatalf("failed to begin transaction: %v", tx.Error) + } + + users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}} + err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2) + + count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count != 2 { + t.Errorf("expected 2 records, got %d", count) + } + + if err := tx.Rollback().Error; err != nil { + t.Fatalf("failed to rollback transaction: %v", err) + } + + count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count2 != 0 { + t.Errorf("expected 0 records after rollback, got %d", count2) + } +} + +func TestGenericsToSQL(t *testing.T) { + ctx := context.Background() + sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + gorm.G[User](tx).Limit(10).Find(ctx) + return tx + }) + + if !regexp.MustCompile("SELECT \\* FROM `users`.* 10").MatchString(sql) { + t.Errorf("ToSQL: got wrong sql with Generics API %v", sql) + } +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 9f0f067c..80d3a7fc 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) { return tx2.Scan(&User{}).Error }) }) - if err != nil { t.Error(err) } @@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) { return tx3.Where("user_id", user.ID).Delete(&Account{}).Error }) }) - if err != nil { t.Error(err) } } + +func TestTransactionWithDefaultTimeout(t *testing.T) { + db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + tx := db.Begin() + time.Sleep(3 * time.Second) + if err = tx.Find(&User{}).Error; err == nil { + t.Errorf("should return error when transaction timeout, got error %v", err) + } +}