diff --git a/gorm.go b/gorm.go index 6a6bb032..43c9c3ac 100644 --- a/gorm.go +++ b/gorm.go @@ -298,6 +298,11 @@ func (db *DB) WithContext(ctx context.Context) *DB { return db.Session(&Session{Context: ctx}) } +// Context returns the db.Statement.Context +func (db *DB) Context() context.Context { + return db.Statement.Context +} + // Debug start debug mode func (db *DB) Debug() (tx *DB) { return db.Session(&Session{ diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 4e4b6149..ed592e8f 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -78,6 +78,24 @@ func TestCancelTransaction(t *testing.T) { } } +func TestGetContextTransaction(t *testing.T) { + type ctxValue string + passThrough := ctxValue("passThrough") + ctx := context.Background() + ctx = context.WithValue(ctx, passThrough, "passed") + + user := *GetUser("get_context", Config{}) + DB.Create(&user) + + _ = DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + p := tx.Context().Value(passThrough).(string) + if p != "passed" { + t.Fatalf("Transaction did not contain the passThrough context value from context() function") + } + return nil + }) +} + func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() {