From 636f90fbcd534a5bc62a43522b569d9eeeba9029 Mon Sep 17 00:00:00 2001 From: Krisdiano Date: Fri, 30 May 2025 15:54:46 +0800 Subject: [PATCH] Fix ut --- callbacks/create.go | 11 +++++++++++ tests/create_test.go | 33 +++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 5cf48ec9..1e5447ee 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -2,7 +2,9 @@ package callbacks import ( "fmt" + "os" "reflect" + "strconv" "strings" "gorm.io/gorm" @@ -36,12 +38,21 @@ func BeforeCreate(db *gorm.DB) { // Create create hook func Create(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.CreateClauses, "RETURNING") + rawSupportReturning := supportReturning return func(db *gorm.DB) { if db.Error != nil { return } + mock := os.Getenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING") + mockSupportReturning, err := strconv.ParseBool(mock) + if err == nil { + supportReturning = mockSupportReturning + } else { + supportReturning = rawSupportReturning + } + if db.Statement.Schema != nil { if !db.Statement.Unscoped { for _, c := range db.Statement.Schema.CreateClauses { diff --git a/tests/create_test.go b/tests/create_test.go index 5fd086bf..2a636843 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "errors" "fmt" + "os" "regexp" "testing" "time" @@ -553,19 +554,39 @@ func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string, } func TestCreateWithDisableLastInsertID(t *testing.T) { - rawPool := DB.ConnPool - DB.ConnPool = ConnPoolLastInsertIDMock{rawPool} - defer func() { - DB.ConnPool = rawPool - }() + mockCreateSupportReturning := func() func() { + revertCreateSupportReturning := func() { + os.Setenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING", "") + } + + os.Setenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING", "false") + return revertCreateSupportReturning + } + + mockConnPoolExec := func() func() { + rawPool := DB.ConnPool + DB.ConnPool = ConnPoolLastInsertIDMock{rawPool} + rawStatementPool := DB.Statement.ConnPool + DB.Statement.ConnPool = ConnPoolLastInsertIDMock{rawStatementPool} + return func() { + DB.ConnPool = rawPool + DB.Statement.ConnPool = rawStatementPool + } + } + + defer mockCreateSupportReturning()() + defer mockConnPoolExec()() user := &User{Name: "TestCreateWithDisableLastInsertID"} err := DB.Create(&user).Error - if err == nil { + if DB.RowsAffected > 0 && err == nil { t.Fatalf("it should be error") } DB.DisableLastInsertID = true + defer func() { + DB.DisableLastInsertID = false + }() err = DB.Create(&user).Error if err != nil { t.Fatalf("it should be nil")