diff --git a/error_translator.go b/error_translator.go new file mode 100644 index 00000000..c5197b81 --- /dev/null +++ b/error_translator.go @@ -0,0 +1,27 @@ +package gorm + +import "gorm.io/gorm/errtranslator" + +func TranslateErr(dialect string, err error) error { + var errTranslator errtranslator.ErrTranslator + + switch dialect { + case "sqlite": + errTranslator = &errtranslator.SqliteErrTranslator{} + case "postgres": + errTranslator = &errtranslator.PostgresErrTranslator{} + case "mysql": + errTranslator = &errtranslator.MysqlErrTranslator{} + case "mssql": + errTranslator = &errtranslator.MssqlErrTranslator{} + } + + if errTranslator != nil { + translatedErr := errTranslator.Translate(err) + if _, ok := translatedErr.(errtranslator.ErrDuplicatedKey); ok { + return ErrDuplicatedKey + } + } + + return err +} diff --git a/errors.go b/errors.go index 49cbfe64..126c9f31 100644 --- a/errors.go +++ b/errors.go @@ -41,4 +41,6 @@ var ( ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") // ErrPreloadNotAllowed preload is not allowed when count is used ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") ) diff --git a/errtranslator/mssql.go b/errtranslator/mssql.go new file mode 100644 index 00000000..7bce4d0c --- /dev/null +++ b/errtranslator/mssql.go @@ -0,0 +1,33 @@ +package errtranslator + +import "encoding/json" + +var mssqlErrCodes = map[string]int{ + "uniqueConstraint": 2627, +} + +type MssqlErrTranslator struct{} + +type MssqlErr struct { + Number int `json:"Number"` + Message string `json:"Message"` +} + +func (m *MssqlErrTranslator) Translate(err error) error { + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var mssqlErr MssqlErr + unmarshalErr := json.Unmarshal(parsedErr, &mssqlErr) + if unmarshalErr != nil { + return err + } + + if mssqlErr.Number == mssqlErrCodes["uniqueConstraint"] { + return ErrDuplicatedKey{Code: mssqlErr.Number, Message: mssqlErr.Message} + } + + return err +} diff --git a/errtranslator/mysql.go b/errtranslator/mysql.go new file mode 100644 index 00000000..9f3800a9 --- /dev/null +++ b/errtranslator/mysql.go @@ -0,0 +1,33 @@ +package errtranslator + +import "encoding/json" + +var mysqlErrCodes = map[string]int{ + "uniqueConstraint": 1062, +} + +type MysqlErrTranslator struct{} + +type MysqlErr struct { + Number int `json:"Number"` + Message string `json:"Message"` +} + +func (m *MysqlErrTranslator) Translate(err error) error { + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var mysqlErr MysqlErr + unmarshalErr := json.Unmarshal(parsedErr, &mysqlErr) + if unmarshalErr != nil { + return err + } + + if mysqlErr.Number == mysqlErrCodes["uniqueConstraint"] { + return ErrDuplicatedKey{Code: mysqlErr.Number, Message: mysqlErr.Message} + } + + return err +} diff --git a/errtranslator/postgres.go b/errtranslator/postgres.go new file mode 100644 index 00000000..22b7d193 --- /dev/null +++ b/errtranslator/postgres.go @@ -0,0 +1,34 @@ +package errtranslator + +import "encoding/json" + +var postgresErrCodes = map[string]string{ + "uniqueConstraint": "23505", +} + +type PostgresErrTranslator struct{} + +type PostgresErr struct { + Code string `json:"Code"` + Severity string `json:"Severity"` + Message string `json:"Message"` +} + +func (p *PostgresErrTranslator) Translate(err error) error { + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var postgresErr PostgresErr + unmarshalErr := json.Unmarshal(parsedErr, &postgresErr) + if unmarshalErr != nil { + return err + } + + if postgresErr.Code == postgresErrCodes["uniqueConstraint"] { + return ErrDuplicatedKey{Code: postgresErr.Code, Message: postgresErr.Message} + } + + return err +} diff --git a/errtranslator/sqlite.go b/errtranslator/sqlite.go new file mode 100644 index 00000000..4bece997 --- /dev/null +++ b/errtranslator/sqlite.go @@ -0,0 +1,34 @@ +package errtranslator + +import "encoding/json" + +var sqliteErrCodes = map[string]int{ + "uniqueConstraint": 2067, +} + +type SqliteErrTranslator struct{} + +type SqliteErr struct { + Code int `json:"Code"` + ExtendedCode int `json:"ExtendedCode"` + SystemErrno int `json:"SystemErrno"` +} + +func (s *SqliteErrTranslator) Translate(err error) error { + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var sqliteErr SqliteErr + unmarshalErr := json.Unmarshal(parsedErr, &sqliteErr) + if unmarshalErr != nil { + return err + } + + if sqliteErr.ExtendedCode == sqliteErrCodes["uniqueConstraint"] { + return ErrDuplicatedKey{Code: sqliteErr.ExtendedCode, Message: ""} + } + + return err +} diff --git a/errtranslator/types.go b/errtranslator/types.go new file mode 100644 index 00000000..9a0eae57 --- /dev/null +++ b/errtranslator/types.go @@ -0,0 +1,16 @@ +package errtranslator + +import "fmt" + +type ErrTranslator interface { + Translate(err error) error +} + +type ErrDuplicatedKey struct { + Code interface{} + Message string +} + +func (e ErrDuplicatedKey) Error() string { + return fmt.Sprintf("duplicated key not allowed, code: %v, message: %s", e.Code, e.Message) +} diff --git a/gorm.go b/gorm.go index 37595ddd..5f60ce77 100644 --- a/gorm.go +++ b/gorm.go @@ -347,10 +347,12 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { + translatedErr := TranslateErr(db.Dialector.Name(), err) + if db.Error == nil { - db.Error = err - } else if err != nil { - db.Error = fmt.Errorf("%v; %w", db.Error, err) + db.Error = translatedErr + } else if translatedErr != nil { + db.Error = fmt.Errorf("%v; %w", db.Error, translatedErr) } return db.Error } diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go new file mode 100644 index 00000000..fcbb4f24 --- /dev/null +++ b/tests/error_translator_test.go @@ -0,0 +1,160 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" +) + +func TestPostgresErrorTranslator(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Product struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&Product{}) + + if err := DB.AutoMigrate(&Product{}); err != nil { + t.Fatalf("Failed to migrate: %v", err) + } + + err := DB.Create(&Product{Name: "Milk"}).Error + if err != nil { + t.Fatalf("errors happened on create: %v", err) + } + + // test errors to be translated + + err = DB.Create(&Product{Name: "Milk"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } + + // test default errors to not be translated + + var product Product + + err = DB.Find(&product, "name = ?", "coffee").Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) + } +} + +func TestMysqlErrorTranslator(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type Product struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&Product{}) + + if err := DB.AutoMigrate(&Product{}); err != nil { + t.Fatalf("Failed to migrate: %v", err) + } + + err := DB.Create(&Product{Name: "Milk"}).Error + if err != nil { + t.Fatalf("errors happened on create: %v", err) + } + + // test errors to be translated + + err = DB.Create(&Product{Name: "Milk"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } + + // test default errors to not be translated + + var product Product + + err = DB.Find(&product, "name = ?", "coffee").Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) + } +} + +func TestMssqlErrorTranslator(t *testing.T) { + if DB.Dialector.Name() != "mssql" { + t.Skip() + } + + type Product struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&Product{}) + + if err := DB.AutoMigrate(&Product{}); err != nil { + t.Fatalf("Failed to migrate: %v", err) + } + + err := DB.Create(&Product{Name: "Milk"}).Error + if err != nil { + t.Fatalf("errors happened on create: %v", err) + } + + // test errors to be translated + + err = DB.Create(&Product{Name: "Milk"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } + + // test default errors to not be translated + + var product Product + + err = DB.Find(&product, "name = ?", "coffee").Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) + } +} + +func TestSqliteErrorTranslator(t *testing.T) { + if DB.Dialector.Name() != "sqlite" { + t.Skip() + } + + type Product struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&Product{}) + + if err := DB.AutoMigrate(&Product{}); err != nil { + t.Fatalf("Failed to migrate: %v", err) + } + + err := DB.Create(&Product{Name: "Milk"}).Error + if err != nil { + t.Fatalf("errors happened on create: %v", err) + } + + // test errors to be translated + + err = DB.Create(&Product{Name: "Milk"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } + + // test default errors to not be translated + + var product Product + + err = DB.Find(&product, "name = ?", "coffee").Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrRecordNotFound, err) + } +} diff --git a/tests/go.mod b/tests/go.mod index 251aabb3..f4b018d5 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,6 +7,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect + github.com/microsoft/go-mssqldb v0.20.0 // indirect golang.org/x/crypto v0.5.0 // indirect gorm.io/driver/mysql v1.4.5 gorm.io/driver/postgres v1.4.6