fix UPDATE with joins
This commit is contained in:
parent
9954086a91
commit
311be47205
@ -97,11 +97,31 @@ func updateCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(sqls) > 0 {
|
if len(sqls) > 0 {
|
||||||
|
joinSQLFirst := strings.TrimSpace(scope.joinsSQL())
|
||||||
|
var joinSQLSecond string
|
||||||
|
quotedTableName := scope.QuotedTableName()
|
||||||
|
var postgres postgres
|
||||||
|
var mysql mysql
|
||||||
|
if joinSQLFirst != "" {
|
||||||
|
switch scope.Dialect().GetName() {
|
||||||
|
case mysql.GetName():
|
||||||
|
case postgres.GetName():
|
||||||
|
joinSQLSecond = "FROM " + joinSQLFirst
|
||||||
|
joinSQLFirst = ""
|
||||||
|
default:
|
||||||
|
joinSQLSecond = "FROM " + quotedTableName + addExtraSpaceIfExist(joinSQLFirst)
|
||||||
|
joinSQLFirst = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
whereSQL := scope.whereSQL()
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"UPDATE %v SET %v%v%v",
|
"UPDATE %v%v SET %v%v%v%v",
|
||||||
scope.QuotedTableName(),
|
quotedTableName,
|
||||||
|
addExtraSpaceIfExist(joinSQLFirst),
|
||||||
strings.Join(sqls, ", "),
|
strings.Join(sqls, ", "),
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
addExtraSpaceIfExist(joinSQLSecond),
|
||||||
|
addExtraSpaceIfExist(whereSQL+scope.orderSQL()+scope.limitAndOffsetSQL()),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
)).Exec()
|
)).Exec()
|
||||||
}
|
}
|
||||||
|
39
main_test.go
39
main_test.go
@ -1203,6 +1203,45 @@ func TestWhereUpdates(t *testing.T) {
|
|||||||
db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"})
|
db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJoinUpdateColumn(t *testing.T) {
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
if dialect == "" { // sqlite doesn't support 'UPDATE ... JOIN ...' syntax
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
db := DB.New()
|
||||||
|
db.Delete(User{})
|
||||||
|
defer db.Delete(User{})
|
||||||
|
|
||||||
|
DB.Create(&User{Name: "user1"})
|
||||||
|
DB.Create(&User{Name: "user3"})
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if dialect == "mysql" || dialect == "mssql" {
|
||||||
|
err = db.Table("users").Joins("JOIN users AS users2 ON users2.id = users.id").
|
||||||
|
Where("users.name='user1'").UpdateColumn("users.name", "user2").Error
|
||||||
|
} else if dialect == "postgres" {
|
||||||
|
err = db.Table("users").Joins("users AS users2").
|
||||||
|
Where("users2.id = users.id").
|
||||||
|
Where("users.name='user1'").UpdateColumn("name", "user2").Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on update with join")
|
||||||
|
}
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
err = db.Model(User{}).Order("name").Pluck("name", &names).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on pluck")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(names) != 2 || names[0] != "user2" || names[1] != "user3" {
|
||||||
|
t.Error("Unexpected result on pluck after updating")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user