refactor: provides error for finisher api updateable vaild
This commit is contained in:
parent
3d7019a7c2
commit
9eb6d44bcf
@ -20,6 +20,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.vaildUpdateableParam(value)
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
return tx.callbacks.Create().Execute(tx)
|
return tx.callbacks.Create().Execute(tx)
|
||||||
}
|
}
|
||||||
@ -32,6 +33,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var rowsAffected int64
|
var rowsAffected int64
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.vaildUpdateableParam(value)
|
||||||
|
|
||||||
callFc := func(tx *DB) error {
|
callFc := func(tx *DB) error {
|
||||||
// the reflection length judgment of the optimized value
|
// the reflection length judgment of the optimized value
|
||||||
@ -62,6 +64,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
tx.RowsAffected = rowsAffected
|
tx.RowsAffected = rowsAffected
|
||||||
default:
|
default:
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.vaildUpdateableParam(value)
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
tx = tx.callbacks.Create().Execute(tx)
|
tx = tx.callbacks.Create().Execute(tx)
|
||||||
}
|
}
|
||||||
@ -71,6 +74,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.vaildUpdateableParam(value)
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
@ -115,6 +119,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
})
|
})
|
||||||
|
tx.vaildUpdateableParam(dest)
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||||
@ -128,6 +133,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1)
|
tx = db.Limit(1)
|
||||||
|
tx.vaildUpdateableParam(dest)
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||||
@ -144,6 +150,8 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
Desc: true,
|
Desc: true,
|
||||||
})
|
})
|
||||||
|
tx.vaildUpdateableParam(dest)
|
||||||
|
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||||
@ -157,6 +165,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
// Find find records that match given conditions
|
// Find find records that match given conditions
|
||||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.vaildUpdateableParam(dest)
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||||
@ -457,6 +466,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||||||
|
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Config = &config
|
tx.Config = &config
|
||||||
|
tx.vaildUpdateableParam(dest)
|
||||||
|
|
||||||
if rows, err := tx.Rows(); err == nil {
|
if rows, err := tx.Rows(); err == nil {
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
@ -665,3 +675,24 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
|||||||
|
|
||||||
return tx.callbacks.Raw().Execute(tx)
|
return tx.callbacks.Raw().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// vaild updateable params
|
||||||
|
// just check Struct/Slice/Array
|
||||||
|
// because other params will not be updated
|
||||||
|
// like map[string]interface{}
|
||||||
|
func (db *DB) vaildUpdateableParam(value interface{}) {
|
||||||
|
vi := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
switch vi.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for i := 0; i < vi.Len(); i++ {
|
||||||
|
if !reflect.Indirect(vi.Index(i)).CanAddr() {
|
||||||
|
db.AddError(ErrInvalidValue)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if !vi.CanAddr() {
|
||||||
|
db.AddError(ErrInvalidValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -366,7 +366,7 @@ func TestCreateInvalidSlice(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) {
|
if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidValue) {
|
||||||
t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
|
t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
42
tests/invalid_test.go
Normal file
42
tests/invalid_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInvalidParamTypeStruct(t *testing.T) {
|
||||||
|
user := User{Name: "TestInvalidParam"}
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
// panic when update values api
|
||||||
|
invalidUser := User{Name: "TestInvalidParam_invalid"}
|
||||||
|
invalidUsers := [1]User{invalidUser}
|
||||||
|
assertInvalidValueError(t, DB.Create(invalidUser))
|
||||||
|
assertInvalidValueError(t, DB.CreateInBatches(invalidUser, 1))
|
||||||
|
assertInvalidValueError(t, DB.CreateInBatches(invalidUsers, 1))
|
||||||
|
assertInvalidValueError(t, DB.Save(invalidUser))
|
||||||
|
|
||||||
|
// panic when found and update values api
|
||||||
|
var invalidQueryUser User
|
||||||
|
invalidQueryUser.ID = user.ID
|
||||||
|
invalidQueryUsers := [1]User{invalidQueryUser}
|
||||||
|
assertInvalidValueError(t, DB.First(invalidQueryUser))
|
||||||
|
assertInvalidValueError(t, DB.Take(invalidQueryUser))
|
||||||
|
assertInvalidValueError(t, DB.Last(invalidQueryUser))
|
||||||
|
assertInvalidValueError(t, DB.Find(invalidQueryUsers))
|
||||||
|
assertInvalidValueError(t, DB.FindInBatches(invalidQueryUsers, 1, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
assertInvalidValueError(t, DB.FirstOrInit(invalidQueryUser))
|
||||||
|
assertInvalidValueError(t, DB.FirstOrCreate(invalidQueryUser))
|
||||||
|
assertInvalidValueError(t, DB.Model(User{}).Scan(invalidQueryUser))
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertInvalidValueError(t *testing.T, tx *gorm.DB) {
|
||||||
|
if !errors.Is(tx.Error, gorm.ErrInvalidValue) {
|
||||||
|
t.Errorf("should returns error invalid")
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user