Add default transaction timeout support

This commit is contained in:
Jinzhu 2025-05-23 18:06:34 +08:00
parent 4ee59e1d87
commit 4db3fde9c5
5 changed files with 101 additions and 9 deletions

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0] opt = opts[0]
} }
ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
}
}
switch beginner := tx.Statement.ConnPool.(type) { switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner: case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner: case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default: default:
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }

View File

@ -127,7 +127,11 @@ type g[T any] struct {
} }
func (g *g[T]) apply(ctx context.Context) *DB { func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance() db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops { for _, op := range g.ops {
db = op(db) db = op(db)
} }

View File

@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
type Config struct { type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true // You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
// FullSaveAssociations full save associations // FullSaveAssociations full save associations
@ -519,7 +521,7 @@ func (db *DB) Use(plugin Plugin) error {
// .First(&User{}) // .First(&User{})
// }) // })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -593,15 +594,37 @@ func TestGenericsNestedPreloads(t *testing.T) {
} }
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(3)
return nil return nil
}).Where(user.ID).Take(ctx) }).Where(user.ID).Take(ctx)
if err != nil { if err != nil {
t.Errorf("failed to nested preload user") t.Errorf("failed to nested preload user")
} }
CheckUser(t, user2, user) CheckUser(t, user2, user)
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
t.Fatalf("failed to nested preload")
}
if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { if DB.Dialector.Name() == "mysql" {
// mysql 5.7 doesn't support row_number()
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
return
}
}
if DB.Dialector.Name() == "sqlserver" {
// sqlserver doesn't support order by in subquery
return
}
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(3)
return nil
}).Where(user.ID).Take(ctx)
if err != nil {
t.Errorf("failed to nested preload user")
}
CheckUser(t, user3, user)
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
t.Errorf("failed to nested preload with limit per record") t.Errorf("failed to nested preload with limit per record")
} }
} }
@ -784,3 +807,46 @@ func TestGenericsReuse(t *testing.T) {
} }
sg.Wait() sg.Wait()
} }
func TestGenericsWithTransaction(t *testing.T) {
ctx := context.Background()
tx := DB.Begin()
if tx.Error != nil {
t.Fatalf("failed to begin transaction: %v", tx.Error)
}
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count != 2 {
t.Errorf("expected 2 records, got %d", count)
}
if err := tx.Rollback().Error; err != nil {
t.Fatalf("failed to rollback transaction: %v", err)
}
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count2 != 0 {
t.Errorf("expected 0 records after rollback, got %d", count2)
}
}
func TestGenericsToSQL(t *testing.T) {
ctx := context.Background()
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
gorm.G[User](tx).Limit(10).Find(ctx)
return tx
})
if !regexp.MustCompile("SELECT \\* FROM `users`.* 10").MatchString(sql) {
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
}
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"testing" "testing"
"time"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) {
return tx2.Scan(&User{}).Error return tx2.Scan(&User{}).Error
}) })
}) })
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) {
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
}) })
}) })
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestTransactionWithDefaultTimeout(t *testing.T) {
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
tx := db.Begin()
time.Sleep(3 * time.Second)
if err = tx.Find(&User{}).Error; err == nil {
t.Errorf("should return error when transaction timeout, got error %v", err)
}
}