From aa4916d4a33a716af333350dc8df3e13767e20eb Mon Sep 17 00:00:00 2001 From: Chris Belsole Date: Tue, 3 May 2016 11:49:41 -0400 Subject: [PATCH] added error types for easier error handling with cases --- callback_query.go | 2 +- errors.go | 43 ++++++++++++++++++++++++++++++------------- field.go | 2 +- main.go | 10 +++++----- main_test.go | 4 ++-- preload_test.go | 6 +++--- scope.go | 2 +- search.go | 2 +- 8 files changed, 44 insertions(+), 27 deletions(-) diff --git a/callback_query.go b/callback_query.go index 93782b1d..11a31180 100644 --- a/callback_query.go +++ b/callback_query.go @@ -79,7 +79,7 @@ func queryCallback(scope *Scope) { } if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) + scope.Err(NewErrRecordNotFound()) } } } diff --git a/errors.go b/errors.go index ce3a25c0..e4f52f42 100644 --- a/errors.go +++ b/errors.go @@ -5,19 +5,6 @@ import ( "strings" ) -var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - type errorsInterface interface { GetErrors() []error } @@ -32,6 +19,36 @@ func (errs Errors) GetErrors() []error { return errs.errors } +type ErrRecordNotFound struct{ error } + +type ErrInvalidSQL struct{ error } + +type ErrInvalidTransaction struct{ error } + +type ErrCantStartTransaction struct{ error } + +type ErrUnaddressable struct{ error } + +func NewErrRecordNotFound() error { + return ErrRecordNotFound{errors.New("record not found")} +} + +func NewErrInvalidSQL() error { + return ErrInvalidSQL{errors.New("invalid SQL")} +} + +func NewErrInvalidTransaction() error { + return ErrCantStartTransaction{errors.New("no valid transaction")} +} + +func NewErrCantStartTransaction() error { + return ErrCantStartTransaction{errors.New("can't start transaction")} +} + +func NewErrUnaddressable() error { + return ErrUnaddressable{errors.New("using unaddressable value")} +} + // Add add an error func (errs *Errors) Add(err error) { if errors, ok := err.(errorsInterface); ok { diff --git a/field.go b/field.go index 11c410b0..e8c34766 100644 --- a/field.go +++ b/field.go @@ -21,7 +21,7 @@ func (field *Field) Set(value interface{}) (err error) { } if !field.Field.CanAddr() { - return ErrUnaddressable + return NewErrUnaddressable() } reflectValue, ok := value.(reflect.Value) diff --git a/main.go b/main.go index cd445555..ea73a25c 100644 --- a/main.go +++ b/main.go @@ -424,7 +424,7 @@ func (s *DB) Begin() *DB { c.db = interface{}(tx).(sqlCommon) c.AddError(err) } else { - c.AddError(ErrCantStartTransaction) + c.AddError(NewErrCantStartTransaction()) } return c } @@ -434,7 +434,7 @@ func (s *DB) Commit() *DB { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Commit()) } else { - s.AddError(ErrInvalidTransaction) + s.AddError(NewErrInvalidTransaction()) } return s } @@ -444,7 +444,7 @@ func (s *DB) Rollback() *DB { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Rollback()) } else { - s.AddError(ErrInvalidTransaction) + s.AddError(NewErrInvalidTransaction()) } return s } @@ -457,7 +457,7 @@ func (s *DB) NewRecord(value interface{}) bool { // RecordNotFound check if returning ErrRecordNotFound error func (s *DB) RecordNotFound() bool { for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { + if _, ok := err.(ErrRecordNotFound); ok { return true } } @@ -633,7 +633,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join // AddError add error to the db func (s *DB) AddError(err error) error { if err != nil { - if err != ErrRecordNotFound { + if _, ok := err.(ErrRecordNotFound); !ok { if s.logMode == 0 { go s.print(fileWithLineNum(), err) } else { diff --git a/main_test.go b/main_test.go index 8ac015c8..c1f598c5 100644 --- a/main_test.go +++ b/main_test.go @@ -479,7 +479,7 @@ func TestRaw(t *testing.T) { } DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { + if _, ok := DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error.(gorm.ErrRecordNotFound); !ok { t.Error("Raw sql to update records") } } @@ -709,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) { } var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { + if _, ok := db.Where("name = ?", "jnfeinstein").First(&user).Error.(gorm.ErrRecordNotFound); ok { t.Errorf("Should have found existing record") } } diff --git a/preload_test.go b/preload_test.go index 5c49ecc2..733aa777 100644 --- a/preload_test.go +++ b/preload_test.go @@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err, ok := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error.(gorm.ErrRecordNotFound); !ok { t.Error(err) } } @@ -981,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err, ok := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error.(gorm.ErrRecordNotFound); !ok { t.Error(err) } } @@ -1038,7 +1038,7 @@ func TestNestedManyToManyPreload2(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err, ok := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error.(gorm.ErrRecordNotFound); !ok { t.Error(err) } } diff --git a/scope.go b/scope.go index 844df85c..1a614196 100644 --- a/scope.go +++ b/scope.go @@ -850,7 +850,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin err := field.Set(value) if field.IsNormal { hasUpdate = true - if err == ErrUnaddressable { + if _, ok := err.(ErrUnaddressable); ok { fmt.Println(err) results[field.DBName] = value } else { diff --git a/search.go b/search.go index 078bd429..5c7173dd 100644 --- a/search.go +++ b/search.go @@ -139,7 +139,7 @@ func (s *search) getInterfaceAsSQL(value interface{}) (str string) { case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: str = fmt.Sprintf("%v", value) default: - s.db.AddError(ErrInvalidSQL) + s.db.AddError(NewErrInvalidSQL()) } if str == "-1" {