Fix config and add ut
This commit is contained in:
parent
40b4998893
commit
dac14f0b78
@ -1,7 +1,6 @@
|
|||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -114,7 +113,7 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.RowsAffected == 0 {
|
if db.RowsAffected == 0 || db.DisableLastInsertID {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,22 +126,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
insertOk := err == nil && insertID > 0
|
insertOk := err == nil && insertID > 0
|
||||||
|
|
||||||
if !insertOk {
|
if !insertOk {
|
||||||
if db.IgnoreLastInsertIDWhenNotSupport {
|
|
||||||
_, rowsAffectedErr := driver.RowsAffected(0).LastInsertId()
|
|
||||||
if strings.Compare(err.Error(), rowsAffectedErr.Error()) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, resultNoRowsErr := driver.ResultNoRows.LastInsertId()
|
|
||||||
if strings.Compare(err.Error(), resultNoRowsErr.Error()) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if db.IsNotSupportLastInsertIDErr != nil && db.IsNotSupportLastInsertIDErr(err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if db.Logger != nil {
|
|
||||||
db.Logger.Warn(db.Statement.Context, "Failed to get last insert ID, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !supportReturning {
|
if !supportReturning {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
|
7
gorm.go
7
gorm.go
@ -24,11 +24,8 @@ type Config struct {
|
|||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
DefaultTransactionTimeout time.Duration
|
DefaultTransactionTimeout time.Duration
|
||||||
|
|
||||||
// Not all database support LastInsertId, you can set `IgnoreLastInsertIDWhenNotSupport` to true in those cases
|
// Not all database support LastInsertId, you can set `DisableLastInsertID` to true in those cases
|
||||||
IgnoreLastInsertIDWhenNotSupport bool
|
DisableLastInsertID bool
|
||||||
// When 'IgnoreLastInsertIDWhenNotSupport' is true, you can set `IsNotSupportLastInsertIDErr` to check if the error is 'NotSupportLastInsertID'
|
|
||||||
// Gorm only asserts the type of returned sql/driver.Result if 'IsNotSupportLastInsertIDErr' is not set.
|
|
||||||
IsNotSupportLastInsertIDErr func(error) bool
|
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
NamingStrategy schema.Namer
|
NamingStrategy schema.Namer
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -535,6 +538,40 @@ func TestCreateNilPointer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ConnPoolLastInsertIDMock struct {
|
||||||
|
gorm.ConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
rst, err := m.ConnPool.ExecContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, _ := rst.RowsAffected()
|
||||||
|
return driver.RowsAffected(affected), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateWithDisableLastInsertID(t *testing.T) {
|
||||||
|
rawPool := DB.ConnPool
|
||||||
|
DB.ConnPool = ConnPoolLastInsertIDMock{rawPool}
|
||||||
|
defer func() {
|
||||||
|
DB.ConnPool = rawPool
|
||||||
|
}()
|
||||||
|
|
||||||
|
user := &User{Name: "TestCreateWithDisableLastInsertID"}
|
||||||
|
err := DB.Create(&user).Error
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("it should be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.DisableLastInsertID = true
|
||||||
|
err = DB.Create(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("it should be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFirstOrCreateRowsAffected(t *testing.T) {
|
func TestFirstOrCreateRowsAffected(t *testing.T) {
|
||||||
user := User{Name: "TestFirstOrCreateRowsAffected"}
|
user := User{Name: "TestFirstOrCreateRowsAffected"}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user