This commit is contained in:
Krisdiano 2025-05-30 15:54:46 +08:00
parent dac14f0b78
commit 636f90fbcd
2 changed files with 38 additions and 6 deletions

View File

@ -2,7 +2,9 @@ package callbacks
import (
"fmt"
"os"
"reflect"
"strconv"
"strings"
"gorm.io/gorm"
@ -36,12 +38,21 @@ func BeforeCreate(db *gorm.DB) {
// Create create hook
func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
rawSupportReturning := supportReturning
return func(db *gorm.DB) {
if db.Error != nil {
return
}
mock := os.Getenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING")
mockSupportReturning, err := strconv.ParseBool(mock)
if err == nil {
supportReturning = mockSupportReturning
} else {
supportReturning = rawSupportReturning
}
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"os"
"regexp"
"testing"
"time"
@ -553,19 +554,39 @@ func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string,
}
func TestCreateWithDisableLastInsertID(t *testing.T) {
rawPool := DB.ConnPool
DB.ConnPool = ConnPoolLastInsertIDMock{rawPool}
defer func() {
DB.ConnPool = rawPool
}()
mockCreateSupportReturning := func() func() {
revertCreateSupportReturning := func() {
os.Setenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING", "")
}
os.Setenv("GORM_E2E_TEST_MOCK_CREATE_RETURNING", "false")
return revertCreateSupportReturning
}
mockConnPoolExec := func() func() {
rawPool := DB.ConnPool
DB.ConnPool = ConnPoolLastInsertIDMock{rawPool}
rawStatementPool := DB.Statement.ConnPool
DB.Statement.ConnPool = ConnPoolLastInsertIDMock{rawStatementPool}
return func() {
DB.ConnPool = rawPool
DB.Statement.ConnPool = rawStatementPool
}
}
defer mockCreateSupportReturning()()
defer mockConnPoolExec()()
user := &User{Name: "TestCreateWithDisableLastInsertID"}
err := DB.Create(&user).Error
if err == nil {
if DB.RowsAffected > 0 && err == nil {
t.Fatalf("it should be error")
}
DB.DisableLastInsertID = true
defer func() {
DB.DisableLastInsertID = false
}()
err = DB.Create(&user).Error
if err != nil {
t.Fatalf("it should be nil")