Fix config and add ut

This commit is contained in:
Krisdiano 2025-05-30 12:04:47 +08:00
parent 40b4998893
commit dac14f0b78
3 changed files with 40 additions and 23 deletions

View File

@ -1,7 +1,6 @@
package callbacks
import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
@ -114,7 +113,7 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
if db.RowsAffected == 0 || db.DisableLastInsertID {
return
}
@ -127,22 +126,6 @@ func Create(config *Config) func(db *gorm.DB) {
insertOk := err == nil && insertID > 0
if !insertOk {
if db.IgnoreLastInsertIDWhenNotSupport {
_, rowsAffectedErr := driver.RowsAffected(0).LastInsertId()
if strings.Compare(err.Error(), rowsAffectedErr.Error()) == 0 {
return
}
_, resultNoRowsErr := driver.ResultNoRows.LastInsertId()
if strings.Compare(err.Error(), resultNoRowsErr.Error()) == 0 {
return
}
if db.IsNotSupportLastInsertIDErr != nil && db.IsNotSupportLastInsertIDErr(err) {
return
}
if db.Logger != nil {
db.Logger.Warn(db.Statement.Context, "Failed to get last insert ID, err: %v", err)
}
}
if !supportReturning {
db.AddError(err)
}

View File

@ -24,11 +24,8 @@ type Config struct {
SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// Not all database support LastInsertId, you can set `IgnoreLastInsertIDWhenNotSupport` to true in those cases
IgnoreLastInsertIDWhenNotSupport bool
// When 'IgnoreLastInsertIDWhenNotSupport' is true, you can set `IsNotSupportLastInsertIDErr` to check if the error is 'NotSupportLastInsertID'
// Gorm only asserts the type of returned sql/driver.Result if 'IsNotSupportLastInsertIDErr' is not set.
IsNotSupportLastInsertIDErr func(error) bool
// Not all database support LastInsertId, you can set `DisableLastInsertID` to true in those cases
DisableLastInsertID bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer

View File

@ -1,6 +1,9 @@
package tests_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
@ -535,6 +538,40 @@ func TestCreateNilPointer(t *testing.T) {
}
}
type ConnPoolLastInsertIDMock struct {
gorm.ConnPool
}
func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
rst, err := m.ConnPool.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
affected, _ := rst.RowsAffected()
return driver.RowsAffected(affected), nil
}
func TestCreateWithDisableLastInsertID(t *testing.T) {
rawPool := DB.ConnPool
DB.ConnPool = ConnPoolLastInsertIDMock{rawPool}
defer func() {
DB.ConnPool = rawPool
}()
user := &User{Name: "TestCreateWithDisableLastInsertID"}
err := DB.Create(&user).Error
if err == nil {
t.Fatalf("it should be error")
}
DB.DisableLastInsertID = true
err = DB.Create(&user).Error
if err != nil {
t.Fatalf("it should be nil")
}
}
func TestFirstOrCreateRowsAffected(t *testing.T) {
user := User{Name: "TestFirstOrCreateRowsAffected"}