fix UPDATE with joins

This commit is contained in:
Dmitry Zenovich 2019-05-06 17:14:52 +03:00
parent 9954086a91
commit 311be47205
2 changed files with 62 additions and 3 deletions

View File

@ -97,11 +97,31 @@ func updateCallback(scope *Scope) {
}
if len(sqls) > 0 {
joinSQLFirst := strings.TrimSpace(scope.joinsSQL())
var joinSQLSecond string
quotedTableName := scope.QuotedTableName()
var postgres postgres
var mysql mysql
if joinSQLFirst != "" {
switch scope.Dialect().GetName() {
case mysql.GetName():
case postgres.GetName():
joinSQLSecond = "FROM " + joinSQLFirst
joinSQLFirst = ""
default:
joinSQLSecond = "FROM " + quotedTableName + addExtraSpaceIfExist(joinSQLFirst)
joinSQLFirst = ""
}
}
whereSQL := scope.whereSQL()
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v%v%v",
scope.QuotedTableName(),
"UPDATE %v%v SET %v%v%v%v",
quotedTableName,
addExtraSpaceIfExist(joinSQLFirst),
strings.Join(sqls, ", "),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(joinSQLSecond),
addExtraSpaceIfExist(whereSQL+scope.orderSQL()+scope.limitAndOffsetSQL()),
addExtraSpaceIfExist(extraOption),
)).Exec()
}

View File

@ -1203,6 +1203,45 @@ func TestWhereUpdates(t *testing.T) {
db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"})
}
func TestJoinUpdateColumn(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect == "" { // sqlite doesn't support 'UPDATE ... JOIN ...' syntax
t.Skip()
}
db := DB.New()
db.Delete(User{})
defer db.Delete(User{})
DB.Create(&User{Name: "user1"})
DB.Create(&User{Name: "user3"})
var err error
if dialect == "mysql" || dialect == "mssql" {
err = db.Table("users").Joins("JOIN users AS users2 ON users2.id = users.id").
Where("users.name='user1'").UpdateColumn("users.name", "user2").Error
} else if dialect == "postgres" {
err = db.Table("users").Joins("users AS users2").
Where("users2.id = users.id").
Where("users.name='user1'").UpdateColumn("name", "user2").Error
}
if err != nil {
t.Error("Unexpected error on update with join")
}
var names []string
err = db.Model(User{}).Order("name").Pluck("name", &names).Error
if err != nil {
t.Error("Unexpected error on pluck")
}
if len(names) != 2 || names[0] != "user2" || names[1] != "user3" {
t.Error("Unexpected result on pluck after updating")
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {