diff --git a/callbacks/create.go b/callbacks/create.go index f6bf2242..5cf48ec9 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,7 +1,6 @@ package callbacks import ( - "database/sql/driver" "fmt" "reflect" "strings" @@ -114,7 +113,7 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Result.RowsAffected = db.RowsAffected } - if db.RowsAffected == 0 { + if db.RowsAffected == 0 || db.DisableLastInsertID { return } @@ -127,22 +126,6 @@ func Create(config *Config) func(db *gorm.DB) { insertOk := err == nil && insertID > 0 if !insertOk { - if db.IgnoreLastInsertIDWhenNotSupport { - _, rowsAffectedErr := driver.RowsAffected(0).LastInsertId() - if strings.Compare(err.Error(), rowsAffectedErr.Error()) == 0 { - return - } - _, resultNoRowsErr := driver.ResultNoRows.LastInsertId() - if strings.Compare(err.Error(), resultNoRowsErr.Error()) == 0 { - return - } - if db.IsNotSupportLastInsertIDErr != nil && db.IsNotSupportLastInsertIDErr(err) { - return - } - if db.Logger != nil { - db.Logger.Warn(db.Statement.Context, "Failed to get last insert ID, err: %v", err) - } - } if !supportReturning { db.AddError(err) } diff --git a/gorm.go b/gorm.go index d1e10fbb..eb9d6267 100644 --- a/gorm.go +++ b/gorm.go @@ -24,11 +24,8 @@ type Config struct { SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration - // Not all database support LastInsertId, you can set `IgnoreLastInsertIDWhenNotSupport` to true in those cases - IgnoreLastInsertIDWhenNotSupport bool - // When 'IgnoreLastInsertIDWhenNotSupport' is true, you can set `IsNotSupportLastInsertIDErr` to check if the error is 'NotSupportLastInsertID' - // Gorm only asserts the type of returned sql/driver.Result if 'IsNotSupportLastInsertIDErr' is not set. - IsNotSupportLastInsertIDErr func(error) bool + // Not all database support LastInsertId, you can set `DisableLastInsertID` to true in those cases + DisableLastInsertID bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer diff --git a/tests/create_test.go b/tests/create_test.go index abb82472..5fd086bf 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,6 +1,9 @@ package tests_test import ( + "context" + "database/sql" + "database/sql/driver" "errors" "fmt" "regexp" @@ -535,6 +538,40 @@ func TestCreateNilPointer(t *testing.T) { } } +type ConnPoolLastInsertIDMock struct { + gorm.ConnPool +} + +func (m ConnPoolLastInsertIDMock) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + rst, err := m.ConnPool.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + affected, _ := rst.RowsAffected() + return driver.RowsAffected(affected), nil +} + +func TestCreateWithDisableLastInsertID(t *testing.T) { + rawPool := DB.ConnPool + DB.ConnPool = ConnPoolLastInsertIDMock{rawPool} + defer func() { + DB.ConnPool = rawPool + }() + + user := &User{Name: "TestCreateWithDisableLastInsertID"} + err := DB.Create(&user).Error + if err == nil { + t.Fatalf("it should be error") + } + + DB.DisableLastInsertID = true + err = DB.Create(&user).Error + if err != nil { + t.Fatalf("it should be nil") + } +} + func TestFirstOrCreateRowsAffected(t *testing.T) { user := User{Name: "TestFirstOrCreateRowsAffected"}