add join to update clause

This commit is contained in:
Franco Liberali 2023-09-06 15:04:01 +02:00 committed by FrancoLiberali
parent a2cac75218
commit c00cf29ccc
7 changed files with 287 additions and 172 deletions

View File

@ -4,12 +4,9 @@ import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func Query(db *gorm.DB) {
@ -104,157 +101,8 @@ func BuildQuerySQL(db *gorm.DB) {
}
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
fromClause.Joins = append(fromClause.Joins, gorm.GenJoinClauses(db, &clauseSelect)...)
db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}

View File

@ -69,7 +69,9 @@ func Update(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
gorm.CreateUpdateClause(db.Statement)
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")

View File

@ -3,6 +3,7 @@ package clause
type Update struct {
Modifier string
Table Table
Joins []Join
}
// Name update clause name
@ -22,6 +23,11 @@ func (update Update) Build(builder Builder) {
} else {
builder.WriteQuoted(update.Table)
}
for _, join := range update.Joins {
builder.WriteByte(' ')
join.Build(builder)
}
}
// MergeClause merge update clause

180
clauses.go Normal file
View File

@ -0,0 +1,180 @@
package gorm
import (
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func CreateUpdateClause(stmt *Statement) {
updateClause := clause.Update{}
if v, ok := stmt.Clauses["UPDATE"].Expression.(clause.Update); ok {
updateClause = v
}
if len(stmt.Joins) != 0 || len(updateClause.Joins) != 0 {
updateClause.Joins = append(updateClause.Joins, GenJoinClauses(stmt.DB, &clause.Select{})...)
stmt.AddClause(updateClause)
} else {
stmt.AddClauseIfNotExists(clause.Update{})
}
}
func GenJoinClauses(db *DB, clauseSelect *clause.Select) []clause.Join {
joinClauses := []clause.Join{}
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
columnStmt := Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
joinClauses = append(joinClauses, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
joinClauses = append(joinClauses, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
joinClauses = append(joinClauses, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.Joins = nil
return joinClauses
}

View File

@ -141,6 +141,7 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
if _, ok := stmt.Clauses["SET"]; !ok {
curTime := stmt.DB.NowFunc()
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
stmt.SetColumn(sd.Field.DBName, curTime, true)
@ -162,9 +163,12 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
}
}
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.AddClauseIfNotExists(clause.Update{})
CreateUpdateClause(stmt)
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}
}

75
tests/fork_update_test.go Normal file
View File

@ -0,0 +1,75 @@
package tests_test
import (
"strings"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
// only mysql support update join
func TestReasonUpdateJoinUpdatedAtIsAmbiguous(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
return
}
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).InnerJoins("Account", DB.Where("number = ?", 1)).Update("name", "jinzhu").Error; !strings.Contains(err.Error(), "Column 'updated_at' in field list is ambiguous") {
t.Errorf(`Error should be column is ambiguous, but got: "%s"`, err)
}
}
// only mysql support update join
func TestUpdateJoinWorksManuallySettingSetClauses(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
return
}
var (
users = []*User{
GetUser("update-1", Config{Account: true}),
GetUser("update-2", Config{Account: true}),
GetUser("update-3", Config{}),
}
user = users[1]
)
if err := DB.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
} else if user.ID == 0 {
t.Fatalf("user's primary value should not zero, %v", user.ID)
} else if user.UpdatedAt.IsZero() {
t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt)
}
tx := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(user).InnerJoins("Account", DB.Where("number = ?", user.Account.Number))
tx.Statement.AddClause(clause.Set{
{
Column: clause.Column{
Name: "name",
Table: "users",
},
Value: "franco",
},
{
Column: clause.Column{
Name: "updated_at",
Table: "users",
},
Value: time.Now(),
},
})
if rowsAffected := tx.Updates(nil).RowsAffected; rowsAffected != 1 {
t.Errorf("should only update one record, but got %v", rowsAffected)
}
var result User
if err := DB.First(&result, "name = ?", "franco").Error; err != nil {
t.Errorf("user's name should be updated")
} else if result.UpdatedAt.UnixNano() == user.UpdatedAt.UnixNano() {
t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, user.UpdatedAt)
}
}

View File

@ -21,10 +21,10 @@ require (
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/microsoft/go-mssqldb v1.6.0 // indirect
golang.org/x/crypto v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
)
replace gorm.io/gorm => ../