diff --git a/callback_update.go b/callback_update.go index c52162c8..625a99e7 100644 --- a/callback_update.go +++ b/callback_update.go @@ -97,11 +97,31 @@ func updateCallback(scope *Scope) { } if len(sqls) > 0 { + joinSQLFirst := strings.TrimSpace(scope.joinsSQL()) + var joinSQLSecond string + quotedTableName := scope.QuotedTableName() + var postgres postgres + var mysql mysql + if joinSQLFirst != "" { + switch scope.Dialect().GetName() { + case mysql.GetName(): + case postgres.GetName(): + joinSQLSecond = "FROM " + joinSQLFirst + joinSQLFirst = "" + default: + joinSQLSecond = "FROM " + quotedTableName + addExtraSpaceIfExist(joinSQLFirst) + joinSQLFirst = "" + } + } + whereSQL := scope.whereSQL() + scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), + "UPDATE %v%v SET %v%v%v%v", + quotedTableName, + addExtraSpaceIfExist(joinSQLFirst), strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(joinSQLSecond), + addExtraSpaceIfExist(whereSQL+scope.orderSQL()+scope.limitAndOffsetSQL()), addExtraSpaceIfExist(extraOption), )).Exec() } diff --git a/main_test.go b/main_test.go index 14bf34ac..8692f95a 100644 --- a/main_test.go +++ b/main_test.go @@ -1203,6 +1203,45 @@ func TestWhereUpdates(t *testing.T) { db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) } +func TestJoinUpdateColumn(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect == "" { // sqlite doesn't support 'UPDATE ... JOIN ...' syntax + t.Skip() + } + + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(&User{Name: "user1"}) + DB.Create(&User{Name: "user3"}) + + var err error + if dialect == "mysql" || dialect == "mssql" { + err = db.Table("users").Joins("JOIN users AS users2 ON users2.id = users.id"). + Where("users.name='user1'").UpdateColumn("users.name", "user2").Error + } else if dialect == "postgres" { + err = db.Table("users").Joins("users AS users2"). + Where("users2.id = users.id"). + Where("users.name='user1'").UpdateColumn("name", "user2").Error + } + + if err != nil { + t.Error("Unexpected error on update with join") + } + + var names []string + err = db.Model(User{}).Order("name").Pluck("name", &names).Error + + if err != nil { + t.Error("Unexpected error on pluck") + } + + if len(names) != 2 || names[0] != "user2" || names[1] != "user3" { + t.Error("Unexpected result on pluck after updating") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ {