From 15cae34f69c437457e2eb4712d7a13781962e7a2 Mon Sep 17 00:00:00 2001 From: Saeid Saeidee Date: Tue, 21 Feb 2023 17:10:49 +0100 Subject: [PATCH] refactor: added ErrorTransltor interface --- error_translator.go | 27 ------- gorm.go | 10 +-- interfaces.go | 4 ++ tests/error_translator_test.go | 124 ++------------------------------- utils/tests/dummy_dialecter.go | 8 ++- 5 files changed, 24 insertions(+), 149 deletions(-) delete mode 100644 error_translator.go diff --git a/error_translator.go b/error_translator.go deleted file mode 100644 index c5197b81..00000000 --- a/error_translator.go +++ /dev/null @@ -1,27 +0,0 @@ -package gorm - -import "gorm.io/gorm/errtranslator" - -func TranslateErr(dialect string, err error) error { - var errTranslator errtranslator.ErrTranslator - - switch dialect { - case "sqlite": - errTranslator = &errtranslator.SqliteErrTranslator{} - case "postgres": - errTranslator = &errtranslator.PostgresErrTranslator{} - case "mysql": - errTranslator = &errtranslator.MysqlErrTranslator{} - case "mssql": - errTranslator = &errtranslator.MssqlErrTranslator{} - } - - if errTranslator != nil { - translatedErr := errTranslator.Translate(err) - if _, ok := translatedErr.(errtranslator.ErrDuplicatedKey); ok { - return ErrDuplicatedKey - } - } - - return err -} diff --git a/gorm.go b/gorm.go index 5f60ce77..b5d98196 100644 --- a/gorm.go +++ b/gorm.go @@ -347,12 +347,14 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { - translatedErr := TranslateErr(db.Dialector.Name(), err) + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } if db.Error == nil { - db.Error = translatedErr - } else if translatedErr != nil { - db.Error = fmt.Errorf("%v; %w", db.Error, translatedErr) + db.Error = err + } else if err != nil { + db.Error = fmt.Errorf("%v; %w", db.Error, err) } return db.Error } diff --git a/interfaces.go b/interfaces.go index cf9e07b9..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -86,3 +86,7 @@ type Rows interface { Err() error Close() error } + +type ErrorTranslator interface { + Translate(err error) error +} diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index 085f7be0..9aaab9d3 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -5,125 +5,15 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/utils/tests" ) -type City struct { - gorm.Model - Name string `gorm:"unique"` -} +func TestDialectorWithErrorTranslatorSupport(t *testing.T) { + translatedErr := errors.New("translated error") + var db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) -func TestPostgresErrorTranslator(t *testing.T) { - if DB.Dialector.Name() != "postgres" { - t.Skip() - } - - DB.Migrator().DropTable(&City{}) - - if err := DB.AutoMigrate(&City{}); err != nil { - t.Fatalf("Failed to migrate cities table, got error: %v", err) - } - - err := DB.Create(&City{Name: "Amsterdam"}).Error - if err != nil { - t.Fatalf("errors happened on create: %v", err) - } - - // test errors to be translated - err = DB.Create(&City{Name: "Amsterdam"}).Error - if !errors.Is(err, gorm.ErrDuplicatedKey) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) - } - - // test default errors to not be translated - err = DB.Where("name = ?", "Kabul").First(&City{}).Error - if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) - } -} - -func TestMysqlErrorTranslator(t *testing.T) { - if DB.Dialector.Name() != "mysql" { - t.Skip() - } - - DB.Migrator().DropTable(&City{}) - - if err := DB.AutoMigrate(&City{}); err != nil { - t.Fatalf("Failed to migrate cities table, got error: %v", err) - } - - err := DB.Create(&City{Name: "Berlin"}).Error - if err != nil { - t.Fatalf("errors happened on create: %v", err) - } - - // test errors to be translated - err = DB.Create(&City{Name: "Berlin"}).Error - if !errors.Is(err, gorm.ErrDuplicatedKey) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) - } - - // test default errors to not be translated - err = DB.Where("name = ?", "Istanbul").First(&City{}).Error - if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) - } -} - -func TestMssqlErrorTranslator(t *testing.T) { - if DB.Dialector.Name() != "mssql" { - t.Skip() - } - - DB.Migrator().DropTable(&City{}) - - if err := DB.AutoMigrate(&City{}); err != nil { - t.Fatalf("Failed to migrate cities table, got error: %v", err) - } - - err := DB.Create(&City{Name: "Paris"}).Error - if err != nil { - t.Fatalf("errors happened on create: %v", err) - } - - // test errors to be translated - err = DB.Create(&City{Name: "Paris"}).Error - if !errors.Is(err, gorm.ErrDuplicatedKey) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) - } - - // test default errors to not be translated - err = DB.Where("name = ?", "Prague").First(&City{}).Error - if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) - } -} - -func TestSqliteErrorTranslator(t *testing.T) { - if DB.Dialector.Name() != "sqlite" { - t.Skip() - } - - DB.Migrator().DropTable(&City{}) - - if err := DB.AutoMigrate(&City{}); err != nil { - t.Fatalf("Failed to migrate cities table, got error: %v", err) - } - - err := DB.Create(&City{Name: "Madrid"}).Error - if err != nil { - t.Fatalf("errors happened on create: %v", err) - } - - // test errors to be translated - err = DB.Create(&City{Name: "Madrid"}).Error - if !errors.Is(err, gorm.ErrDuplicatedKey) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) - } - - // test default errors to not be translated - err = DB.Where("name = ?", "Rome").First(&City{}).Error - if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) + err := db.AddError(errors.New("some random error")) + if !errors.Is(err, translatedErr) { + t.Fatalf("expected err: %v got err: %v", translatedErr, err) } } diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index c89b944a..a2d9c33d 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -8,7 +8,9 @@ import ( "gorm.io/gorm/schema" ) -type DummyDialector struct{} +type DummyDialector struct { + TranslatedErr error +} func (DummyDialector) Name() string { return "dummy" @@ -92,3 +94,7 @@ func (DummyDialector) Explain(sql string, vars ...interface{}) string { func (DummyDialector) DataTypeOf(*schema.Field) string { return "" } + +func (d DummyDialector) Translate(err error) error { + return d.TranslatedErr +}