Add default transaction timeout support
This commit is contained in:
parent
4ee59e1d87
commit
4db3fde9c5
@ -1,6 +1,7 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
ctx := tx.Statement.Context
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
if db.Config.DefaultTransactionTimeout > 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
switch beginner := tx.Statement.ConnPool.(type) {
|
||||
case TxBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
case ConnPoolBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
default:
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
@ -127,7 +127,11 @@ type g[T any] struct {
|
||||
}
|
||||
|
||||
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||
db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||
db := g.db
|
||||
if !db.DryRun {
|
||||
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||
}
|
||||
|
||||
for _, op := range g.ops {
|
||||
db = op(db)
|
||||
}
|
||||
|
6
gorm.go
6
gorm.go
@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
SkipDefaultTransaction bool
|
||||
DefaultTransactionTimeout time.Duration
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
@ -519,7 +521,7 @@ func (db *DB) Use(plugin Plugin) error {
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
|
||||
stmt := tx.Statement
|
||||
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -593,15 +594,37 @@ func TestGenericsNestedPreloads(t *testing.T) {
|
||||
}
|
||||
|
||||
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(3)
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user2, user)
|
||||
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
|
||||
t.Fatalf("failed to nested preload")
|
||||
}
|
||||
|
||||
if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 {
|
||||
if DB.Dialector.Name() == "mysql" {
|
||||
// mysql 5.7 doesn't support row_number()
|
||||
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||
return
|
||||
}
|
||||
}
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver doesn't support order by in subquery
|
||||
return
|
||||
}
|
||||
|
||||
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(3)
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user3, user)
|
||||
|
||||
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
|
||||
t.Errorf("failed to nested preload with limit per record")
|
||||
}
|
||||
}
|
||||
@ -784,3 +807,46 @@ func TestGenericsReuse(t *testing.T) {
|
||||
}
|
||||
sg.Wait()
|
||||
}
|
||||
|
||||
func TestGenericsWithTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatalf("failed to begin transaction: %v", tx.Error)
|
||||
}
|
||||
|
||||
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
|
||||
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
|
||||
|
||||
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 records, got %d", count)
|
||||
}
|
||||
|
||||
if err := tx.Rollback().Error; err != nil {
|
||||
t.Fatalf("failed to rollback transaction: %v", err)
|
||||
}
|
||||
|
||||
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count2 != 0 {
|
||||
t.Errorf("expected 0 records after rollback, got %d", count2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsToSQL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
gorm.G[User](tx).Limit(10).Find(ctx)
|
||||
return tx
|
||||
})
|
||||
|
||||
if !regexp.MustCompile("SELECT \\* FROM `users`.* 10").MatchString(sql) {
|
||||
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) {
|
||||
return tx2.Scan(&User{}).Error
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) {
|
||||
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionWithDefaultTimeout(t *testing.T) {
|
||||
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
time.Sleep(3 * time.Second)
|
||||
if err = tx.Find(&User{}).Error; err == nil {
|
||||
t.Errorf("should return error when transaction timeout, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user