Add returning tests
This commit is contained in:
		
							parent
							
								
									835d7bde59
								
							
						
					
					
						commit
						e953880d19
					
				| @ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) { | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					dest := db.Statement.Dest | ||||
| 					db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					db.Statement.Dest = dest | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
| @ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 	if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | ||||
| 		switch stmt.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var primaryKeyExprs []clause.Expression | ||||
| 			for i := 0; i < stmt.ReflectValue.Len(); i++ { | ||||
| 				var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||
| 				var notZero bool | ||||
| 				for idx, field := range stmt.Schema.PrimaryFields { | ||||
| 					value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) | ||||
| 					exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | ||||
| 					notZero = notZero || !isZero | ||||
| 				} | ||||
| 				if notZero { | ||||
| 					primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) | ||||
| 			if size := stmt.ReflectValue.Len(); size > 0 { | ||||
| 				var primaryKeyExprs []clause.Expression | ||||
| 				for i := 0; i < stmt.ReflectValue.Len(); i++ { | ||||
| 					var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||
| 					var notZero bool | ||||
| 					for idx, field := range stmt.Schema.PrimaryFields { | ||||
| 						value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) | ||||
| 						exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | ||||
| 						notZero = notZero || !isZero | ||||
| 					} | ||||
| 					if notZero { | ||||
| 						primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) | ||||
| 			} | ||||
| 			stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) | ||||
| 		case reflect.Struct: | ||||
| 			for _, field := range stmt.Schema.PrimaryFields { | ||||
| 				if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | ||||
|  | ||||
							
								
								
									
										16
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								scan.go
									
									
									
									
									
								
							| @ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | ||||
| 
 | ||||
| 	switch dest := db.Statement.Dest.(type) { | ||||
| 	case map[string]interface{}, *map[string]interface{}: | ||||
| 		if update && db.Statement.Schema != nil { | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Struct: | ||||
| 				fields := make([]*schema.Field, len(columns)) | ||||
| 				for idx, column := range columns { | ||||
| 					if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 						fields[idx] = field | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if initialized || rows.Next() { | ||||
| 					db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if initialized || rows.Next() { | ||||
| 			columnTypes, _ := rows.ColumnTypes() | ||||
| 			prepareValues(values, db, columnTypes, columns) | ||||
|  | ||||
| @ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { | ||||
| 		} | ||||
| 
 | ||||
| 		stmt.AddClauseIfNotExists(clause.Update{}) | ||||
| 		stmt.Build("UPDATE", "SET", "WHERE") | ||||
| 		stmt.Build(stmt.DB.Callback().Update().Clauses...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // only sqlite, postgres support returning
 | ||||
| func TestSoftDeleteReturning(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	users := []*User{ | ||||
| 		GetUser("delete-returning-1", Config{}), | ||||
| 		GetUser("delete-returning-2", Config{}), | ||||
| 		GetUser("delete-returning-3", Config{}), | ||||
| 	} | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	var results []User | ||||
| 	DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) | ||||
| 	if len(results) != 2 { | ||||
| 		t.Errorf("failed to return delete data, got %v", results) | ||||
| 	} | ||||
| 
 | ||||
| 	var count int64 | ||||
| 	DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) | ||||
| 	if count != 1 { | ||||
| 		t.Errorf("failed to delete data, current count %v", count) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDeleteReturning(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	companies := []Company{ | ||||
| 		{Name: "delete-returning-1"}, | ||||
| 		{Name: "delete-returning-2"}, | ||||
| 		{Name: "delete-returning-3"}, | ||||
| 	} | ||||
| 	DB.Create(&companies) | ||||
| 
 | ||||
| 	var results []Company | ||||
| 	DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) | ||||
| 	if len(results) != 2 { | ||||
| 		t.Errorf("failed to return delete data, got %v", results) | ||||
| 	} | ||||
| 
 | ||||
| 	var count int64 | ||||
| 	DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) | ||||
| 	if count != 1 { | ||||
| 		t.Errorf("failed to delete data, current count %v", count) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -7,8 +7,8 @@ require ( | ||||
| 	github.com/jinzhu/now v1.1.2 | ||||
| 	github.com/lib/pq v1.10.3 | ||||
| 	gorm.io/driver/mysql v1.1.2 | ||||
| 	gorm.io/driver/postgres v1.2.0 | ||||
| 	gorm.io/driver/sqlite v1.2.0 | ||||
| 	gorm.io/driver/postgres v1.2.1 | ||||
| 	gorm.io/driver/sqlite v1.2.2 | ||||
| 	gorm.io/driver/sqlserver v1.1.2 | ||||
| 	gorm.io/gorm v1.22.0 | ||||
| ) | ||||
|  | ||||
| @ -167,16 +167,13 @@ func TestUpdates(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	// update with gorm exprs
 | ||||
| 	if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 	if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||
| 	} | ||||
| 	var user4 User | ||||
| 	DB.First(&user4, user3.ID) | ||||
| 
 | ||||
| 	// sqlite, postgres support returning
 | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 		user3.Age += 100 | ||||
| 	} | ||||
| 	user3.Age += 100 | ||||
| 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | ||||
| } | ||||
| 
 | ||||
| @ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) { | ||||
| 		t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // only sqlite, postgres support returning
 | ||||
| func TestUpdateReturning(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	users := []*User{ | ||||
| 		GetUser("update-returning-1", Config{}), | ||||
| 		GetUser("update-returning-2", Config{}), | ||||
| 		GetUser("update-returning-3", Config{}), | ||||
| 	} | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	var results []User | ||||
| 	DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) | ||||
| 	if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { | ||||
| 		t.Errorf("failed to return updated data, got %v", results) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if results[1].Age-results[0].Age != 100 { | ||||
| 		t.Errorf("failed to return updated age column") | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu