refactor: provides error for finisher api updateable vaild
This commit is contained in:
		
							parent
							
								
									3d7019a7c2
								
							
						
					
					
						commit
						9eb6d44bcf
					
				@ -20,6 +20,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(value)
 | 
				
			||||||
	tx.Statement.Dest = value
 | 
						tx.Statement.Dest = value
 | 
				
			||||||
	return tx.callbacks.Create().Execute(tx)
 | 
						return tx.callbacks.Create().Execute(tx)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -32,6 +33,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
 | 
				
			|||||||
	case reflect.Slice, reflect.Array:
 | 
						case reflect.Slice, reflect.Array:
 | 
				
			||||||
		var rowsAffected int64
 | 
							var rowsAffected int64
 | 
				
			||||||
		tx = db.getInstance()
 | 
							tx = db.getInstance()
 | 
				
			||||||
 | 
							tx.vaildUpdateableParam(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		callFc := func(tx *DB) error {
 | 
							callFc := func(tx *DB) error {
 | 
				
			||||||
			// the reflection length judgment of the optimized value
 | 
								// the reflection length judgment of the optimized value
 | 
				
			||||||
@ -62,6 +64,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
 | 
				
			|||||||
		tx.RowsAffected = rowsAffected
 | 
							tx.RowsAffected = rowsAffected
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		tx = db.getInstance()
 | 
							tx = db.getInstance()
 | 
				
			||||||
 | 
							tx.vaildUpdateableParam(value)
 | 
				
			||||||
		tx.Statement.Dest = value
 | 
							tx.Statement.Dest = value
 | 
				
			||||||
		tx = tx.callbacks.Create().Execute(tx)
 | 
							tx = tx.callbacks.Create().Execute(tx)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -71,6 +74,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
 | 
				
			|||||||
// Save update value in database, if the value doesn't have primary key, will insert it
 | 
					// Save update value in database, if the value doesn't have primary key, will insert it
 | 
				
			||||||
func (db *DB) Save(value interface{}) (tx *DB) {
 | 
					func (db *DB) Save(value interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(value)
 | 
				
			||||||
	tx.Statement.Dest = value
 | 
						tx.Statement.Dest = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	reflectValue := reflect.Indirect(reflect.ValueOf(value))
 | 
						reflectValue := reflect.Indirect(reflect.ValueOf(value))
 | 
				
			||||||
@ -115,6 +119,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			|||||||
	tx = db.Limit(1).Order(clause.OrderByColumn{
 | 
						tx = db.Limit(1).Order(clause.OrderByColumn{
 | 
				
			||||||
		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
 | 
							Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(dest)
 | 
				
			||||||
	if len(conds) > 0 {
 | 
						if len(conds) > 0 {
 | 
				
			||||||
		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
							if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
				
			||||||
			tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
								tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
				
			||||||
@ -128,6 +133,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			|||||||
// Take return a record that match given conditions, the order will depend on the database implementation
 | 
					// Take return a record that match given conditions, the order will depend on the database implementation
 | 
				
			||||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
					func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.Limit(1)
 | 
						tx = db.Limit(1)
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(dest)
 | 
				
			||||||
	if len(conds) > 0 {
 | 
						if len(conds) > 0 {
 | 
				
			||||||
		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
							if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
				
			||||||
			tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
								tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
				
			||||||
@ -144,6 +150,8 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			|||||||
		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
 | 
							Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
 | 
				
			||||||
		Desc:   true,
 | 
							Desc:   true,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(conds) > 0 {
 | 
						if len(conds) > 0 {
 | 
				
			||||||
		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
							if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
				
			||||||
			tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
								tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
				
			||||||
@ -157,6 +165,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			|||||||
// Find find records that match given conditions
 | 
					// Find find records that match given conditions
 | 
				
			||||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
					func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(dest)
 | 
				
			||||||
	if len(conds) > 0 {
 | 
						if len(conds) > 0 {
 | 
				
			||||||
		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
							if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
				
			||||||
			tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
								tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
				
			||||||
@ -457,6 +466,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
	tx.Config = &config
 | 
						tx.Config = &config
 | 
				
			||||||
 | 
						tx.vaildUpdateableParam(dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if rows, err := tx.Rows(); err == nil {
 | 
						if rows, err := tx.Rows(); err == nil {
 | 
				
			||||||
		if rows.Next() {
 | 
							if rows.Next() {
 | 
				
			||||||
@ -665,3 +675,24 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return tx.callbacks.Raw().Execute(tx)
 | 
						return tx.callbacks.Raw().Execute(tx)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// vaild updateable params
 | 
				
			||||||
 | 
					// just check Struct/Slice/Array
 | 
				
			||||||
 | 
					// because other params will not be updated
 | 
				
			||||||
 | 
					// like map[string]interface{}
 | 
				
			||||||
 | 
					func (db *DB) vaildUpdateableParam(value interface{}) {
 | 
				
			||||||
 | 
						vi := reflect.Indirect(reflect.ValueOf(value))
 | 
				
			||||||
 | 
						switch vi.Kind() {
 | 
				
			||||||
 | 
						case reflect.Slice, reflect.Array:
 | 
				
			||||||
 | 
							for i := 0; i < vi.Len(); i++ {
 | 
				
			||||||
 | 
								if !reflect.Indirect(vi.Index(i)).CanAddr() {
 | 
				
			||||||
 | 
									db.AddError(ErrInvalidValue)
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case reflect.Struct:
 | 
				
			||||||
 | 
							if !vi.CanAddr() {
 | 
				
			||||||
 | 
								db.AddError(ErrInvalidValue)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -366,7 +366,7 @@ func TestCreateInvalidSlice(t *testing.T) {
 | 
				
			|||||||
		nil,
 | 
							nil,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) {
 | 
						if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidValue) {
 | 
				
			||||||
		t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
 | 
							t.Errorf("should returns error invalid data when creating from slice that contains invalid data")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										42
									
								
								tests/invalid_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/invalid_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestInvalidParamTypeStruct(t *testing.T) {
 | 
				
			||||||
 | 
						user := User{Name: "TestInvalidParam"}
 | 
				
			||||||
 | 
						DB.Create(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// panic when update values api
 | 
				
			||||||
 | 
						invalidUser := User{Name: "TestInvalidParam_invalid"}
 | 
				
			||||||
 | 
						invalidUsers := [1]User{invalidUser}
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.Create(invalidUser))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.CreateInBatches(invalidUser, 1))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.CreateInBatches(invalidUsers, 1))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.Save(invalidUser))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// panic when found and update values api
 | 
				
			||||||
 | 
						var invalidQueryUser User
 | 
				
			||||||
 | 
						invalidQueryUser.ID = user.ID
 | 
				
			||||||
 | 
						invalidQueryUsers := [1]User{invalidQueryUser}
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.First(invalidQueryUser))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.Take(invalidQueryUser))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.Last(invalidQueryUser))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.Find(invalidQueryUsers))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.FindInBatches(invalidQueryUsers, 1, func(tx *gorm.DB, batch int) error {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.FirstOrInit(invalidQueryUser))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.FirstOrCreate(invalidQueryUser))
 | 
				
			||||||
 | 
						assertInvalidValueError(t, DB.Model(User{}).Scan(invalidQueryUser))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func assertInvalidValueError(t *testing.T, tx *gorm.DB) {
 | 
				
			||||||
 | 
						if !errors.Is(tx.Error, gorm.ErrInvalidValue) {
 | 
				
			||||||
 | 
							t.Errorf("should returns error invalid")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user