feat: duplicated key error translator for different drivers

This commit is contained in:
Saeid Saeidee 2023-01-29 17:34:58 +01:00
parent d834dd60b7
commit 5563988ce8
10 changed files with 345 additions and 3 deletions

27
error_translator.go Normal file
View File

@ -0,0 +1,27 @@
package gorm
import "gorm.io/gorm/errtranslator"
func TranslateErr(dialect string, err error) error {
var errTranslator errtranslator.ErrTranslator
switch dialect {
case "sqlite":
errTranslator = &errtranslator.SqliteErrTranslator{}
case "postgres":
errTranslator = &errtranslator.PostgresErrTranslator{}
case "mysql":
errTranslator = &errtranslator.MysqlErrTranslator{}
case "mssql":
errTranslator = &errtranslator.MssqlErrTranslator{}
}
if errTranslator != nil {
translatedErr := errTranslator.Translate(err)
if _, ok := translatedErr.(errtranslator.ErrDuplicatedKey); ok {
return ErrDuplicatedKey
}
}
return err
}

View File

@ -41,4 +41,6 @@ var (
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
)

33
errtranslator/mssql.go Normal file
View File

@ -0,0 +1,33 @@
package errtranslator
import "encoding/json"
var mssqlErrCodes = map[string]int{
"uniqueConstraint": 2627,
}
type MssqlErrTranslator struct{}
type MssqlErr struct {
Number int `json:"Number"`
Message string `json:"Message"`
}
func (m *MssqlErrTranslator) Translate(err error) error {
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var mssqlErr MssqlErr
unmarshalErr := json.Unmarshal(parsedErr, &mssqlErr)
if unmarshalErr != nil {
return err
}
if mssqlErr.Number == mssqlErrCodes["uniqueConstraint"] {
return ErrDuplicatedKey{Code: mssqlErr.Number, Message: mssqlErr.Message}
}
return err
}

33
errtranslator/mysql.go Normal file
View File

@ -0,0 +1,33 @@
package errtranslator
import "encoding/json"
var mysqlErrCodes = map[string]int{
"uniqueConstraint": 1062,
}
type MysqlErrTranslator struct{}
type MysqlErr struct {
Number int `json:"Number"`
Message string `json:"Message"`
}
func (m *MysqlErrTranslator) Translate(err error) error {
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var mysqlErr MysqlErr
unmarshalErr := json.Unmarshal(parsedErr, &mysqlErr)
if unmarshalErr != nil {
return err
}
if mysqlErr.Number == mysqlErrCodes["uniqueConstraint"] {
return ErrDuplicatedKey{Code: mysqlErr.Number, Message: mysqlErr.Message}
}
return err
}

34
errtranslator/postgres.go Normal file
View File

@ -0,0 +1,34 @@
package errtranslator
import "encoding/json"
var postgresErrCodes = map[string]string{
"uniqueConstraint": "23505",
}
type PostgresErrTranslator struct{}
type PostgresErr struct {
Code string `json:"Code"`
Severity string `json:"Severity"`
Message string `json:"Message"`
}
func (p *PostgresErrTranslator) Translate(err error) error {
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var postgresErr PostgresErr
unmarshalErr := json.Unmarshal(parsedErr, &postgresErr)
if unmarshalErr != nil {
return err
}
if postgresErr.Code == postgresErrCodes["uniqueConstraint"] {
return ErrDuplicatedKey{Code: postgresErr.Code, Message: postgresErr.Message}
}
return err
}

34
errtranslator/sqlite.go Normal file
View File

@ -0,0 +1,34 @@
package errtranslator
import "encoding/json"
var sqliteErrCodes = map[string]int{
"uniqueConstraint": 2067,
}
type SqliteErrTranslator struct{}
type SqliteErr struct {
Code int `json:"Code"`
ExtendedCode int `json:"ExtendedCode"`
SystemErrno int `json:"SystemErrno"`
}
func (s *SqliteErrTranslator) Translate(err error) error {
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var sqliteErr SqliteErr
unmarshalErr := json.Unmarshal(parsedErr, &sqliteErr)
if unmarshalErr != nil {
return err
}
if sqliteErr.ExtendedCode == sqliteErrCodes["uniqueConstraint"] {
return ErrDuplicatedKey{Code: sqliteErr.ExtendedCode, Message: ""}
}
return err
}

16
errtranslator/types.go Normal file
View File

@ -0,0 +1,16 @@
package errtranslator
import "fmt"
type ErrTranslator interface {
Translate(err error) error
}
type ErrDuplicatedKey struct {
Code interface{}
Message string
}
func (e ErrDuplicatedKey) Error() string {
return fmt.Sprintf("duplicated key not allowed, code: %v, message: %s", e.Code, e.Message)
}

View File

@ -347,10 +347,12 @@ func (db *DB) Callback() *callbacks {
// AddError add error to db
func (db *DB) AddError(err error) error {
translatedErr := TranslateErr(db.Dialector.Name(), err)
if db.Error == nil {
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
db.Error = translatedErr
} else if translatedErr != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, translatedErr)
}
return db.Error
}

View File

@ -0,0 +1,160 @@
package tests_test
import (
"errors"
"testing"
"gorm.io/gorm"
)
func TestPostgresErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
type Product struct {
gorm.Model
Name string `gorm:"unique"`
}
DB.Migrator().DropTable(&Product{})
if err := DB.AutoMigrate(&Product{}); err != nil {
t.Fatalf("Failed to migrate: %v", err)
}
err := DB.Create(&Product{Name: "Milk"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&Product{Name: "Milk"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
var product Product
err = DB.Find(&product, "name = ?", "coffee").Error
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}
func TestMysqlErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
type Product struct {
gorm.Model
Name string `gorm:"unique"`
}
DB.Migrator().DropTable(&Product{})
if err := DB.AutoMigrate(&Product{}); err != nil {
t.Fatalf("Failed to migrate: %v", err)
}
err := DB.Create(&Product{Name: "Milk"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&Product{Name: "Milk"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
var product Product
err = DB.Find(&product, "name = ?", "coffee").Error
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}
func TestMssqlErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "mssql" {
t.Skip()
}
type Product struct {
gorm.Model
Name string `gorm:"unique"`
}
DB.Migrator().DropTable(&Product{})
if err := DB.AutoMigrate(&Product{}); err != nil {
t.Fatalf("Failed to migrate: %v", err)
}
err := DB.Create(&Product{Name: "Milk"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&Product{Name: "Milk"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
var product Product
err = DB.Find(&product, "name = ?", "coffee").Error
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}
func TestSqliteErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "sqlite" {
t.Skip()
}
type Product struct {
gorm.Model
Name string `gorm:"unique"`
}
DB.Migrator().DropTable(&Product{})
if err := DB.AutoMigrate(&Product{}); err != nil {
t.Fatalf("Failed to migrate: %v", err)
}
err := DB.Create(&Product{Name: "Milk"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&Product{Name: "Milk"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
var product Product
err = DB.Find(&product, "name = ?", "coffee").Error
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}

View File

@ -7,6 +7,7 @@ require (
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.16 // indirect
github.com/microsoft/go-mssqldb v0.20.0 // indirect
golang.org/x/crypto v0.5.0 // indirect
gorm.io/driver/mysql v1.4.5
gorm.io/driver/postgres v1.4.6