diff --git a/callbacks/create.go b/callbacks/create.go index 1e5447ee..5cf48ec9 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -2,9 +2,7 @@ package callbacks import ( "fmt" - "os" "reflect" - "strconv" "strings" "gorm.io/gorm" @@ -38,21 +36,12 @@ func BeforeCreate(db *gorm.DB) { // Create create hook func Create(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.CreateClauses, "RETURNING") - rawSupportReturning := supportReturning return func(db *gorm.DB) { if db.Error != nil { return } - mock := os.Getenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING") - mockSupportReturning, err := strconv.ParseBool(mock) - if err == nil { - supportReturning = mockSupportReturning - } else { - supportReturning = rawSupportReturning - } - if db.Statement.Schema != nil { if !db.Statement.Unscoped { for _, c := range db.Statement.Schema.CreateClauses { diff --git a/tests/create_test.go b/tests/create_test.go index 393721a9..a9071fb8 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -6,7 +6,6 @@ import ( "database/sql/driver" "errors" "fmt" - "os" "regexp" "testing" "time" @@ -557,28 +556,6 @@ func TestCreateWithDisableLastInsertID(t *testing.T) { if isSQLServer() { t.Skip("SQLServer driver doesn't use default create hook in gorm") } - mockCreateSupportReturning := func() func() { - revertCreateSupportReturning := func() { - os.Setenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING", "") - } - - os.Setenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING", "false") - return revertCreateSupportReturning - } - - mockConnPoolExec := func() func() { - rawPool := DB.ConnPool - DB.ConnPool = ConnPoolLastInsertIDMock{rawPool} - rawStatementPool := DB.Statement.ConnPool - DB.Statement.ConnPool = ConnPoolLastInsertIDMock{rawStatementPool} - return func() { - DB.ConnPool = rawPool - DB.Statement.ConnPool = rawStatementPool - } - } - - defer mockCreateSupportReturning()() - defer mockConnPoolExec()() user := GetUser("TestCreateWithDisableLastInsertID0", Config{}) err := DB.Create(user).Error @@ -586,12 +563,27 @@ func TestCreateWithDisableLastInsertID(t *testing.T) { t.Fatalf("it should be error") } - DB.DisableLastInsertID = true - defer func() { - DB.DisableLastInsertID = false - }() + // create a new connection with new config + db, err := OpenTestConnection(&gorm.Config{DisableLastInsertID: true}) + if err != nil { + t.Fatal("failed to connect database") + } - err = DB.Create(user).Error + // mock driver result + mockConnPoolExec := func() func() { + rawPool := db.ConnPool + db.ConnPool = ConnPoolLastInsertIDMock{rawPool} + rawStatementPool := db.Statement.ConnPool + db.Statement.ConnPool = ConnPoolLastInsertIDMock{rawStatementPool} + return func() { + db.ConnPool = rawPool + db.Statement.ConnPool = rawStatementPool + } + } + defer mockConnPoolExec()() + + user = GetUser("TestCreateWithDisableLastInsertID1", Config{}) + err = db.Create(user).Error if err != nil { t.Fatalf("it should be nil, got %v", err) }