diff --git a/finisher_api.go b/finisher_api.go index b4d29b71..66943773 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -20,6 +20,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { } tx = db.getInstance() + tx.vaildUpdateableParam(value) tx.Statement.Dest = value return tx.callbacks.Create().Execute(tx) } @@ -32,6 +33,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { case reflect.Slice, reflect.Array: var rowsAffected int64 tx = db.getInstance() + tx.vaildUpdateableParam(value) callFc := func(tx *DB) error { // the reflection length judgment of the optimized value @@ -62,6 +64,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { tx.RowsAffected = rowsAffected default: tx = db.getInstance() + tx.vaildUpdateableParam(value) tx.Statement.Dest = value tx = tx.callbacks.Create().Execute(tx) } @@ -71,6 +74,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { // Save update value in database, if the value doesn't have primary key, will insert it func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() + tx.vaildUpdateableParam(value) tx.Statement.Dest = value reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -115,6 +119,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) + tx.vaildUpdateableParam(dest) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) @@ -128,6 +133,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) + tx.vaildUpdateableParam(dest) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) @@ -144,6 +150,8 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) + tx.vaildUpdateableParam(dest) + if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) @@ -157,6 +165,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { // Find find records that match given conditions func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + tx.vaildUpdateableParam(dest) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) @@ -457,6 +466,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() tx.Config = &config + tx.vaildUpdateableParam(dest) if rows, err := tx.Rows(); err == nil { if rows.Next() { @@ -665,3 +675,24 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { return tx.callbacks.Raw().Execute(tx) } + +// vaild updateable params +// just check Struct/Slice/Array +// because other params will not be updated +// like map[string]interface{} +func (db *DB) vaildUpdateableParam(value interface{}) { + vi := reflect.Indirect(reflect.ValueOf(value)) + switch vi.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < vi.Len(); i++ { + if !reflect.Indirect(vi.Index(i)).CanAddr() { + db.AddError(ErrInvalidValue) + break + } + } + case reflect.Struct: + if !vi.CanAddr() { + db.AddError(ErrInvalidValue) + } + } +} diff --git a/tests/create_test.go b/tests/create_test.go index 2b23d440..b60f507d 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -366,7 +366,7 @@ func TestCreateInvalidSlice(t *testing.T) { nil, } - if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { + if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidValue) { t.Errorf("should returns error invalid data when creating from slice that contains invalid data") } } diff --git a/tests/invalid_test.go b/tests/invalid_test.go new file mode 100644 index 00000000..429cf3f6 --- /dev/null +++ b/tests/invalid_test.go @@ -0,0 +1,42 @@ +package tests_test + +import ( + "errors" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" + "testing" +) + +func TestInvalidParamTypeStruct(t *testing.T) { + user := User{Name: "TestInvalidParam"} + DB.Create(&user) + + // panic when update values api + invalidUser := User{Name: "TestInvalidParam_invalid"} + invalidUsers := [1]User{invalidUser} + assertInvalidValueError(t, DB.Create(invalidUser)) + assertInvalidValueError(t, DB.CreateInBatches(invalidUser, 1)) + assertInvalidValueError(t, DB.CreateInBatches(invalidUsers, 1)) + assertInvalidValueError(t, DB.Save(invalidUser)) + + // panic when found and update values api + var invalidQueryUser User + invalidQueryUser.ID = user.ID + invalidQueryUsers := [1]User{invalidQueryUser} + assertInvalidValueError(t, DB.First(invalidQueryUser)) + assertInvalidValueError(t, DB.Take(invalidQueryUser)) + assertInvalidValueError(t, DB.Last(invalidQueryUser)) + assertInvalidValueError(t, DB.Find(invalidQueryUsers)) + assertInvalidValueError(t, DB.FindInBatches(invalidQueryUsers, 1, func(tx *gorm.DB, batch int) error { + return nil + })) + assertInvalidValueError(t, DB.FirstOrInit(invalidQueryUser)) + assertInvalidValueError(t, DB.FirstOrCreate(invalidQueryUser)) + assertInvalidValueError(t, DB.Model(User{}).Scan(invalidQueryUser)) +} + +func assertInvalidValueError(t *testing.T, tx *gorm.DB) { + if !errors.Is(tx.Error, gorm.ErrInvalidValue) { + t.Errorf("should returns error invalid") + } +}