Fix config and add ut
This commit is contained in:
parent
40b4998893
commit
dac14f0b78
@ -1,7 +1,6 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
@ -114,7 +113,7 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 {
|
||||
if db.RowsAffected == 0 || db.DisableLastInsertID {
|
||||
return
|
||||
}
|
||||
|
||||
@ -127,22 +126,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
if db.IgnoreLastInsertIDWhenNotSupport {
|
||||
_, rowsAffectedErr := driver.RowsAffected(0).LastInsertId()
|
||||
if strings.Compare(err.Error(), rowsAffectedErr.Error()) == 0 {
|
||||
return
|
||||
}
|
||||
_, resultNoRowsErr := driver.ResultNoRows.LastInsertId()
|
||||
if strings.Compare(err.Error(), resultNoRowsErr.Error()) == 0 {
|
||||
return
|
||||
}
|
||||
if db.IsNotSupportLastInsertIDErr != nil && db.IsNotSupportLastInsertIDErr(err) {
|
||||
return
|
||||
}
|
||||
if db.Logger != nil {
|
||||
db.Logger.Warn(db.Statement.Context, "Failed to get last insert ID, err: %v", err)
|
||||
}
|
||||
}
|
||||
if !supportReturning {
|
||||
db.AddError(err)
|
||||
}
|
||||
|
7
gorm.go
7
gorm.go
@ -24,11 +24,8 @@ type Config struct {
|
||||
SkipDefaultTransaction bool
|
||||
DefaultTransactionTimeout time.Duration
|
||||
|
||||
// Not all database support LastInsertId, you can set `IgnoreLastInsertIDWhenNotSupport` to true in those cases
|
||||
IgnoreLastInsertIDWhenNotSupport bool
|
||||
// When 'IgnoreLastInsertIDWhenNotSupport' is true, you can set `IsNotSupportLastInsertIDErr` to check if the error is 'NotSupportLastInsertID'
|
||||
// Gorm only asserts the type of returned sql/driver.Result if 'IsNotSupportLastInsertIDErr' is not set.
|
||||
IsNotSupportLastInsertIDErr func(error) bool
|
||||
// Not all database support LastInsertId, you can set `DisableLastInsertID` to true in those cases
|
||||
DisableLastInsertID bool
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
|
@ -1,6 +1,9 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
@ -535,6 +538,40 @@ func TestCreateNilPointer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type ConnPoolLastInsertIDMock struct {
|
||||
gorm.ConnPool
|
||||
}
|
||||
|
||||
func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
rst, err := m.ConnPool.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
affected, _ := rst.RowsAffected()
|
||||
return driver.RowsAffected(affected), nil
|
||||
}
|
||||
|
||||
func TestCreateWithDisableLastInsertID(t *testing.T) {
|
||||
rawPool := DB.ConnPool
|
||||
DB.ConnPool = ConnPoolLastInsertIDMock{rawPool}
|
||||
defer func() {
|
||||
DB.ConnPool = rawPool
|
||||
}()
|
||||
|
||||
user := &User{Name: "TestCreateWithDisableLastInsertID"}
|
||||
err := DB.Create(&user).Error
|
||||
if err == nil {
|
||||
t.Fatalf("it should be error")
|
||||
}
|
||||
|
||||
DB.DisableLastInsertID = true
|
||||
err = DB.Create(&user).Error
|
||||
if err != nil {
|
||||
t.Fatalf("it should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstOrCreateRowsAffected(t *testing.T) {
|
||||
user := User{Name: "TestFirstOrCreateRowsAffected"}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user