Add default transaction timeout support
This commit is contained in:
parent
4ee59e1d87
commit
4db3fde9c5
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
opt = opts[0]
|
opt = opts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := tx.Statement.Context
|
||||||
|
if _, ok := ctx.Deadline(); !ok {
|
||||||
|
if db.Config.DefaultTransactionTimeout > 0 {
|
||||||
|
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch beginner := tx.Statement.ConnPool.(type) {
|
switch beginner := tx.Statement.ConnPool.(type) {
|
||||||
case TxBeginner:
|
case TxBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||||
case ConnPoolBeginner:
|
case ConnPoolBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||||
default:
|
default:
|
||||||
err = ErrInvalidTransaction
|
err = ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,11 @@ type g[T any] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *g[T]) apply(ctx context.Context) *DB {
|
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||||
db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
db := g.db
|
||||||
|
if !db.DryRun {
|
||||||
|
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||||
|
}
|
||||||
|
|
||||||
for _, op := range g.ops {
|
for _, op := range g.ops {
|
||||||
db = op(db)
|
db = op(db)
|
||||||
}
|
}
|
||||||
|
6
gorm.go
6
gorm.go
@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
|
DefaultTransactionTimeout time.Duration
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
NamingStrategy schema.Namer
|
NamingStrategy schema.Namer
|
||||||
// FullSaveAssociations full save associations
|
// FullSaveAssociations full save associations
|
||||||
@ -519,7 +521,7 @@ func (db *DB) Use(plugin Plugin) error {
|
|||||||
// .First(&User{})
|
// .First(&User{})
|
||||||
// })
|
// })
|
||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
|
||||||
stmt := tx.Statement
|
stmt := tx.Statement
|
||||||
|
|
||||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -593,15 +594,37 @@ func TestGenericsNestedPreloads(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||||
db.LimitPerRecord(3)
|
|
||||||
return nil
|
return nil
|
||||||
}).Where(user.ID).Take(ctx)
|
}).Where(user.ID).Take(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to nested preload user")
|
t.Errorf("failed to nested preload user")
|
||||||
}
|
}
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
|
||||||
|
t.Fatalf("failed to nested preload")
|
||||||
|
}
|
||||||
|
|
||||||
if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 {
|
if DB.Dialector.Name() == "mysql" {
|
||||||
|
// mysql 5.7 doesn't support row_number()
|
||||||
|
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if DB.Dialector.Name() == "sqlserver" {
|
||||||
|
// sqlserver doesn't support order by in subquery
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.LimitPerRecord(3)
|
||||||
|
return nil
|
||||||
|
}).Where(user.ID).Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to nested preload user")
|
||||||
|
}
|
||||||
|
CheckUser(t, user3, user)
|
||||||
|
|
||||||
|
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
|
||||||
t.Errorf("failed to nested preload with limit per record")
|
t.Errorf("failed to nested preload with limit per record")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -784,3 +807,46 @@ func TestGenericsReuse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
sg.Wait()
|
sg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenericsWithTransaction(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
tx := DB.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatalf("failed to begin transaction: %v", tx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
|
||||||
|
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
|
||||||
|
|
||||||
|
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Count failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("expected 2 records, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Rollback().Error; err != nil {
|
||||||
|
t.Fatalf("failed to rollback transaction: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Count failed: %v", err)
|
||||||
|
}
|
||||||
|
if count2 != 0 {
|
||||||
|
t.Errorf("expected 0 records after rollback, got %d", count2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsToSQL(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
gorm.G[User](tx).Limit(10).Find(ctx)
|
||||||
|
return tx
|
||||||
|
})
|
||||||
|
|
||||||
|
if !regexp.MustCompile("SELECT \\* FROM `users`.* 10").MatchString(sql) {
|
||||||
|
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) {
|
|||||||
return tx2.Scan(&User{}).Error
|
return tx2.Scan(&User{}).Error
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) {
|
|||||||
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
|
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransactionWithDefaultTimeout(t *testing.T) {
|
||||||
|
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect database, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
if err = tx.Find(&User{}).Error; err == nil {
|
||||||
|
t.Errorf("should return error when transaction timeout, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user