added error types for easier error handling with cases
This commit is contained in:
parent
465f8ea05b
commit
aa4916d4a3
@ -79,7 +79,7 @@ func queryCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if scope.db.RowsAffected == 0 && !isSlice {
|
if scope.db.RowsAffected == 0 && !isSlice {
|
||||||
scope.Err(ErrRecordNotFound)
|
scope.Err(NewErrRecordNotFound())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
43
errors.go
43
errors.go
@ -5,19 +5,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
|
|
||||||
ErrRecordNotFound = errors.New("record not found")
|
|
||||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
|
||||||
ErrInvalidSQL = errors.New("invalid SQL")
|
|
||||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
|
||||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
|
||||||
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
|
||||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
|
||||||
// ErrUnaddressable unaddressable value
|
|
||||||
ErrUnaddressable = errors.New("using unaddressable value")
|
|
||||||
)
|
|
||||||
|
|
||||||
type errorsInterface interface {
|
type errorsInterface interface {
|
||||||
GetErrors() []error
|
GetErrors() []error
|
||||||
}
|
}
|
||||||
@ -32,6 +19,36 @@ func (errs Errors) GetErrors() []error {
|
|||||||
return errs.errors
|
return errs.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ErrRecordNotFound struct{ error }
|
||||||
|
|
||||||
|
type ErrInvalidSQL struct{ error }
|
||||||
|
|
||||||
|
type ErrInvalidTransaction struct{ error }
|
||||||
|
|
||||||
|
type ErrCantStartTransaction struct{ error }
|
||||||
|
|
||||||
|
type ErrUnaddressable struct{ error }
|
||||||
|
|
||||||
|
func NewErrRecordNotFound() error {
|
||||||
|
return ErrRecordNotFound{errors.New("record not found")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrInvalidSQL() error {
|
||||||
|
return ErrInvalidSQL{errors.New("invalid SQL")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrInvalidTransaction() error {
|
||||||
|
return ErrCantStartTransaction{errors.New("no valid transaction")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrCantStartTransaction() error {
|
||||||
|
return ErrCantStartTransaction{errors.New("can't start transaction")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrUnaddressable() error {
|
||||||
|
return ErrUnaddressable{errors.New("using unaddressable value")}
|
||||||
|
}
|
||||||
|
|
||||||
// Add add an error
|
// Add add an error
|
||||||
func (errs *Errors) Add(err error) {
|
func (errs *Errors) Add(err error) {
|
||||||
if errors, ok := err.(errorsInterface); ok {
|
if errors, ok := err.(errorsInterface); ok {
|
||||||
|
2
field.go
2
field.go
@ -21,7 +21,7 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !field.Field.CanAddr() {
|
if !field.Field.CanAddr() {
|
||||||
return ErrUnaddressable
|
return NewErrUnaddressable()
|
||||||
}
|
}
|
||||||
|
|
||||||
reflectValue, ok := value.(reflect.Value)
|
reflectValue, ok := value.(reflect.Value)
|
||||||
|
10
main.go
10
main.go
@ -424,7 +424,7 @@ func (s *DB) Begin() *DB {
|
|||||||
c.db = interface{}(tx).(sqlCommon)
|
c.db = interface{}(tx).(sqlCommon)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(ErrCantStartTransaction)
|
c.AddError(NewErrCantStartTransaction())
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@ -434,7 +434,7 @@ func (s *DB) Commit() *DB {
|
|||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
s.AddError(db.Commit())
|
s.AddError(db.Commit())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(NewErrInvalidTransaction())
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@ -444,7 +444,7 @@ func (s *DB) Rollback() *DB {
|
|||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
s.AddError(db.Rollback())
|
s.AddError(db.Rollback())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(NewErrInvalidTransaction())
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@ -457,7 +457,7 @@ func (s *DB) NewRecord(value interface{}) bool {
|
|||||||
// RecordNotFound check if returning ErrRecordNotFound error
|
// RecordNotFound check if returning ErrRecordNotFound error
|
||||||
func (s *DB) RecordNotFound() bool {
|
func (s *DB) RecordNotFound() bool {
|
||||||
for _, err := range s.GetErrors() {
|
for _, err := range s.GetErrors() {
|
||||||
if err == ErrRecordNotFound {
|
if _, ok := err.(ErrRecordNotFound); ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -633,7 +633,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||||||
// AddError add error to the db
|
// AddError add error to the db
|
||||||
func (s *DB) AddError(err error) error {
|
func (s *DB) AddError(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != ErrRecordNotFound {
|
if _, ok := err.(ErrRecordNotFound); !ok {
|
||||||
if s.logMode == 0 {
|
if s.logMode == 0 {
|
||||||
go s.print(fileWithLineNum(), err)
|
go s.print(fileWithLineNum(), err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -479,7 +479,7 @@ func TestRaw(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
||||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
|
if _, ok := DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error.(gorm.ErrRecordNotFound); !ok {
|
||||||
t.Error("Raw sql to update records")
|
t.Error("Raw sql to update records")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -709,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
|
if _, ok := db.Where("name = ?", "jnfeinstein").First(&user).Error.(gorm.ErrRecordNotFound); ok {
|
||||||
t.Errorf("Should have found existing record")
|
t.Errorf("Should have found existing record")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
|
|||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
if err, ok := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error.(gorm.ErrRecordNotFound); !ok {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -981,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
|
|||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
if err, ok := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error.(gorm.ErrRecordNotFound); !ok {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1038,7 +1038,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
|
|||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
if err, ok := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error.(gorm.ErrRecordNotFound); !ok {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
2
scope.go
2
scope.go
@ -850,7 +850,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
|||||||
err := field.Set(value)
|
err := field.Set(value)
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
if err == ErrUnaddressable {
|
if _, ok := err.(ErrUnaddressable); ok {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
results[field.DBName] = value
|
results[field.DBName] = value
|
||||||
} else {
|
} else {
|
||||||
|
@ -139,7 +139,7 @@ func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
|||||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
str = fmt.Sprintf("%v", value)
|
str = fmt.Sprintf("%v", value)
|
||||||
default:
|
default:
|
||||||
s.db.AddError(ErrInvalidSQL)
|
s.db.AddError(NewErrInvalidSQL())
|
||||||
}
|
}
|
||||||
|
|
||||||
if str == "-1" {
|
if str == "-1" {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user