feat: duplicated key error translator for different drivers
This commit is contained in:
parent
d834dd60b7
commit
5563988ce8
27
error_translator.go
Normal file
27
error_translator.go
Normal file
@ -0,0 +1,27 @@
|
||||
package gorm
|
||||
|
||||
import "gorm.io/gorm/errtranslator"
|
||||
|
||||
func TranslateErr(dialect string, err error) error {
|
||||
var errTranslator errtranslator.ErrTranslator
|
||||
|
||||
switch dialect {
|
||||
case "sqlite":
|
||||
errTranslator = &errtranslator.SqliteErrTranslator{}
|
||||
case "postgres":
|
||||
errTranslator = &errtranslator.PostgresErrTranslator{}
|
||||
case "mysql":
|
||||
errTranslator = &errtranslator.MysqlErrTranslator{}
|
||||
case "mssql":
|
||||
errTranslator = &errtranslator.MssqlErrTranslator{}
|
||||
}
|
||||
|
||||
if errTranslator != nil {
|
||||
translatedErr := errTranslator.Translate(err)
|
||||
if _, ok := translatedErr.(errtranslator.ErrDuplicatedKey); ok {
|
||||
return ErrDuplicatedKey
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
@ -41,4 +41,6 @@ var (
|
||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
)
|
||||
|
33
errtranslator/mssql.go
Normal file
33
errtranslator/mssql.go
Normal file
@ -0,0 +1,33 @@
|
||||
package errtranslator
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var mssqlErrCodes = map[string]int{
|
||||
"uniqueConstraint": 2627,
|
||||
}
|
||||
|
||||
type MssqlErrTranslator struct{}
|
||||
|
||||
type MssqlErr struct {
|
||||
Number int `json:"Number"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
func (m *MssqlErrTranslator) Translate(err error) error {
|
||||
parsedErr, marshalErr := json.Marshal(err)
|
||||
if marshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var mssqlErr MssqlErr
|
||||
unmarshalErr := json.Unmarshal(parsedErr, &mssqlErr)
|
||||
if unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mssqlErr.Number == mssqlErrCodes["uniqueConstraint"] {
|
||||
return ErrDuplicatedKey{Code: mssqlErr.Number, Message: mssqlErr.Message}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
33
errtranslator/mysql.go
Normal file
33
errtranslator/mysql.go
Normal file
@ -0,0 +1,33 @@
|
||||
package errtranslator
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var mysqlErrCodes = map[string]int{
|
||||
"uniqueConstraint": 1062,
|
||||
}
|
||||
|
||||
type MysqlErrTranslator struct{}
|
||||
|
||||
type MysqlErr struct {
|
||||
Number int `json:"Number"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
func (m *MysqlErrTranslator) Translate(err error) error {
|
||||
parsedErr, marshalErr := json.Marshal(err)
|
||||
if marshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var mysqlErr MysqlErr
|
||||
unmarshalErr := json.Unmarshal(parsedErr, &mysqlErr)
|
||||
if unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mysqlErr.Number == mysqlErrCodes["uniqueConstraint"] {
|
||||
return ErrDuplicatedKey{Code: mysqlErr.Number, Message: mysqlErr.Message}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
34
errtranslator/postgres.go
Normal file
34
errtranslator/postgres.go
Normal file
@ -0,0 +1,34 @@
|
||||
package errtranslator
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var postgresErrCodes = map[string]string{
|
||||
"uniqueConstraint": "23505",
|
||||
}
|
||||
|
||||
type PostgresErrTranslator struct{}
|
||||
|
||||
type PostgresErr struct {
|
||||
Code string `json:"Code"`
|
||||
Severity string `json:"Severity"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
func (p *PostgresErrTranslator) Translate(err error) error {
|
||||
parsedErr, marshalErr := json.Marshal(err)
|
||||
if marshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var postgresErr PostgresErr
|
||||
unmarshalErr := json.Unmarshal(parsedErr, &postgresErr)
|
||||
if unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if postgresErr.Code == postgresErrCodes["uniqueConstraint"] {
|
||||
return ErrDuplicatedKey{Code: postgresErr.Code, Message: postgresErr.Message}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
34
errtranslator/sqlite.go
Normal file
34
errtranslator/sqlite.go
Normal file
@ -0,0 +1,34 @@
|
||||
package errtranslator
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var sqliteErrCodes = map[string]int{
|
||||
"uniqueConstraint": 2067,
|
||||
}
|
||||
|
||||
type SqliteErrTranslator struct{}
|
||||
|
||||
type SqliteErr struct {
|
||||
Code int `json:"Code"`
|
||||
ExtendedCode int `json:"ExtendedCode"`
|
||||
SystemErrno int `json:"SystemErrno"`
|
||||
}
|
||||
|
||||
func (s *SqliteErrTranslator) Translate(err error) error {
|
||||
parsedErr, marshalErr := json.Marshal(err)
|
||||
if marshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sqliteErr SqliteErr
|
||||
unmarshalErr := json.Unmarshal(parsedErr, &sqliteErr)
|
||||
if unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqliteErr.ExtendedCode == sqliteErrCodes["uniqueConstraint"] {
|
||||
return ErrDuplicatedKey{Code: sqliteErr.ExtendedCode, Message: ""}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
16
errtranslator/types.go
Normal file
16
errtranslator/types.go
Normal file
@ -0,0 +1,16 @@
|
||||
package errtranslator
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ErrTranslator interface {
|
||||
Translate(err error) error
|
||||
}
|
||||
|
||||
type ErrDuplicatedKey struct {
|
||||
Code interface{}
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e ErrDuplicatedKey) Error() string {
|
||||
return fmt.Sprintf("duplicated key not allowed, code: %v, message: %s", e.Code, e.Message)
|
||||
}
|
8
gorm.go
8
gorm.go
@ -347,10 +347,12 @@ func (db *DB) Callback() *callbacks {
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
translatedErr := TranslateErr(db.Dialector.Name(), err)
|
||||
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else if err != nil {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
db.Error = translatedErr
|
||||
} else if translatedErr != nil {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, translatedErr)
|
||||
}
|
||||
return db.Error
|
||||
}
|
||||
|
160
tests/error_translator_test.go
Normal file
160
tests/error_translator_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestPostgresErrorTranslator(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
|
||||
if err := DB.AutoMigrate(&Product{}); err != nil {
|
||||
t.Fatalf("Failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
err := DB.Create(&Product{Name: "Milk"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("errors happened on create: %v", err)
|
||||
}
|
||||
|
||||
// test errors to be translated
|
||||
|
||||
err = DB.Create(&Product{Name: "Milk"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
|
||||
// test default errors to not be translated
|
||||
|
||||
var product Product
|
||||
|
||||
err = DB.Find(&product, "name = ?", "coffee").Error
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlErrorTranslator(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
|
||||
if err := DB.AutoMigrate(&Product{}); err != nil {
|
||||
t.Fatalf("Failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
err := DB.Create(&Product{Name: "Milk"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("errors happened on create: %v", err)
|
||||
}
|
||||
|
||||
// test errors to be translated
|
||||
|
||||
err = DB.Create(&Product{Name: "Milk"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
|
||||
// test default errors to not be translated
|
||||
|
||||
var product Product
|
||||
|
||||
err = DB.Find(&product, "name = ?", "coffee").Error
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMssqlErrorTranslator(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mssql" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
|
||||
if err := DB.AutoMigrate(&Product{}); err != nil {
|
||||
t.Fatalf("Failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
err := DB.Create(&Product{Name: "Milk"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("errors happened on create: %v", err)
|
||||
}
|
||||
|
||||
// test errors to be translated
|
||||
|
||||
err = DB.Create(&Product{Name: "Milk"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
|
||||
// test default errors to not be translated
|
||||
|
||||
var product Product
|
||||
|
||||
err = DB.Find(&product, "name = ?", "coffee").Error
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqliteErrorTranslator(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
|
||||
if err := DB.AutoMigrate(&Product{}); err != nil {
|
||||
t.Fatalf("Failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
err := DB.Create(&Product{Name: "Milk"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("errors happened on create: %v", err)
|
||||
}
|
||||
|
||||
// test errors to be translated
|
||||
|
||||
err = DB.Create(&Product{Name: "Milk"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
|
||||
// test default errors to not be translated
|
||||
|
||||
var product Product
|
||||
|
||||
err = DB.Find(&product, "name = ?", "coffee").Error
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
|
||||
}
|
||||
}
|
@ -7,6 +7,7 @@ require (
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.7
|
||||
github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
||||
github.com/microsoft/go-mssqldb v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.5.0 // indirect
|
||||
gorm.io/driver/mysql v1.4.5
|
||||
gorm.io/driver/postgres v1.4.6
|
||||
|
Loading…
x
Reference in New Issue
Block a user