refactor: added ErrorTransltor interface

This commit is contained in:
Saeid Saeidee 2023-02-21 17:10:49 +01:00
parent 5bd20f3121
commit 15cae34f69
5 changed files with 24 additions and 149 deletions

View File

@ -1,27 +0,0 @@
package gorm
import "gorm.io/gorm/errtranslator"
func TranslateErr(dialect string, err error) error {
var errTranslator errtranslator.ErrTranslator
switch dialect {
case "sqlite":
errTranslator = &errtranslator.SqliteErrTranslator{}
case "postgres":
errTranslator = &errtranslator.PostgresErrTranslator{}
case "mysql":
errTranslator = &errtranslator.MysqlErrTranslator{}
case "mssql":
errTranslator = &errtranslator.MssqlErrTranslator{}
}
if errTranslator != nil {
translatedErr := errTranslator.Translate(err)
if _, ok := translatedErr.(errtranslator.ErrDuplicatedKey); ok {
return ErrDuplicatedKey
}
}
return err
}

10
gorm.go
View File

@ -347,12 +347,14 @@ func (db *DB) Callback() *callbacks {
// AddError add error to db
func (db *DB) AddError(err error) error {
translatedErr := TranslateErr(db.Dialector.Name(), err)
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
}
if db.Error == nil {
db.Error = translatedErr
} else if translatedErr != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, translatedErr)
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
return db.Error
}

View File

@ -86,3 +86,7 @@ type Rows interface {
Err() error
Close() error
}
type ErrorTranslator interface {
Translate(err error) error
}

View File

@ -5,125 +5,15 @@ import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/utils/tests"
)
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
translatedErr := errors.New("translated error")
var db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
func TestPostgresErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
DB.Migrator().DropTable(&City{})
if err := DB.AutoMigrate(&City{}); err != nil {
t.Fatalf("Failed to migrate cities table, got error: %v", err)
}
err := DB.Create(&City{Name: "Amsterdam"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&City{Name: "Amsterdam"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
err = DB.Where("name = ?", "Kabul").First(&City{}).Error
if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}
func TestMysqlErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
DB.Migrator().DropTable(&City{})
if err := DB.AutoMigrate(&City{}); err != nil {
t.Fatalf("Failed to migrate cities table, got error: %v", err)
}
err := DB.Create(&City{Name: "Berlin"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&City{Name: "Berlin"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
err = DB.Where("name = ?", "Istanbul").First(&City{}).Error
if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}
func TestMssqlErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "mssql" {
t.Skip()
}
DB.Migrator().DropTable(&City{})
if err := DB.AutoMigrate(&City{}); err != nil {
t.Fatalf("Failed to migrate cities table, got error: %v", err)
}
err := DB.Create(&City{Name: "Paris"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&City{Name: "Paris"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
err = DB.Where("name = ?", "Prague").First(&City{}).Error
if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
}
}
func TestSqliteErrorTranslator(t *testing.T) {
if DB.Dialector.Name() != "sqlite" {
t.Skip()
}
DB.Migrator().DropTable(&City{})
if err := DB.AutoMigrate(&City{}); err != nil {
t.Fatalf("Failed to migrate cities table, got error: %v", err)
}
err := DB.Create(&City{Name: "Madrid"}).Error
if err != nil {
t.Fatalf("errors happened on create: %v", err)
}
// test errors to be translated
err = DB.Create(&City{Name: "Madrid"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
// test default errors to not be translated
err = DB.Where("name = ?", "Rome").First(&City{}).Error
if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err)
err := db.AddError(errors.New("some random error"))
if !errors.Is(err, translatedErr) {
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
}
}

View File

@ -8,7 +8,9 @@ import (
"gorm.io/gorm/schema"
)
type DummyDialector struct{}
type DummyDialector struct {
TranslatedErr error
}
func (DummyDialector) Name() string {
return "dummy"
@ -92,3 +94,7 @@ func (DummyDialector) Explain(sql string, vars ...interface{}) string {
func (DummyDialector) DataTypeOf(*schema.Field) string {
return ""
}
func (d DummyDialector) Translate(err error) error {
return d.TranslatedErr
}