add join to update clause
This commit is contained in:
parent
a2cac75218
commit
c00cf29ccc
@ -4,12 +4,9 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
@ -104,157 +101,8 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
||||
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
}
|
||||
|
||||
if parentTableName != clause.CurrentTable {
|
||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||
} else {
|
||||
parentTableName = rel.Name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, gorm.GenJoinClauses(db, &clauseSelect)...)
|
||||
db.Statement.AddClause(fromClause)
|
||||
db.Statement.Joins = nil
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
@ -69,7 +69,9 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
|
||||
gorm.CreateUpdateClause(db.Statement)
|
||||
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
|
@ -3,6 +3,7 @@ package clause
|
||||
type Update struct {
|
||||
Modifier string
|
||||
Table Table
|
||||
Joins []Join
|
||||
}
|
||||
|
||||
// Name update clause name
|
||||
@ -22,6 +23,11 @@ func (update Update) Build(builder Builder) {
|
||||
} else {
|
||||
builder.WriteQuoted(update.Table)
|
||||
}
|
||||
|
||||
for _, join := range update.Joins {
|
||||
builder.WriteByte(' ')
|
||||
join.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge update clause
|
||||
|
180
clauses.go
Normal file
180
clauses.go
Normal file
@ -0,0 +1,180 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func CreateUpdateClause(stmt *Statement) {
|
||||
updateClause := clause.Update{}
|
||||
if v, ok := stmt.Clauses["UPDATE"].Expression.(clause.Update); ok {
|
||||
updateClause = v
|
||||
}
|
||||
|
||||
if len(stmt.Joins) != 0 || len(updateClause.Joins) != 0 {
|
||||
updateClause.Joins = append(updateClause.Joins, GenJoinClauses(stmt.DB, &clause.Select{})...)
|
||||
stmt.AddClause(updateClause)
|
||||
} else {
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
}
|
||||
}
|
||||
|
||||
func GenJoinClauses(db *DB, clauseSelect *clause.Select) []clause.Join {
|
||||
joinClauses := []clause.Join{}
|
||||
|
||||
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
|
||||
columnStmt := Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
joinClauses = append(joinClauses, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
}
|
||||
|
||||
if parentTableName != clause.CurrentTable {
|
||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||
} else {
|
||||
parentTableName = rel.Name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
joinClauses = append(joinClauses, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
joinClauses = append(joinClauses, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.Joins = nil
|
||||
|
||||
return joinClauses
|
||||
}
|
@ -141,6 +141,7 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
||||
|
||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
if _, ok := stmt.Clauses["SET"]; !ok {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
|
||||
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
||||
@ -162,9 +163,12 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
|
||||
CreateUpdateClause(stmt)
|
||||
|
||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||
}
|
||||
}
|
||||
|
75
tests/fork_update_test.go
Normal file
75
tests/fork_update_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
// only mysql support update join
|
||||
func TestReasonUpdateJoinUpdatedAtIsAmbiguous(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).InnerJoins("Account", DB.Where("number = ?", 1)).Update("name", "jinzhu").Error; !strings.Contains(err.Error(), "Column 'updated_at' in field list is ambiguous") {
|
||||
t.Errorf(`Error should be column is ambiguous, but got: "%s"`, err)
|
||||
}
|
||||
}
|
||||
|
||||
// only mysql support update join
|
||||
func TestUpdateJoinWorksManuallySettingSetClauses(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
users = []*User{
|
||||
GetUser("update-1", Config{Account: true}),
|
||||
GetUser("update-2", Config{Account: true}),
|
||||
GetUser("update-3", Config{}),
|
||||
}
|
||||
user = users[1]
|
||||
)
|
||||
|
||||
if err := DB.Create(&users).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
} else if user.ID == 0 {
|
||||
t.Fatalf("user's primary value should not zero, %v", user.ID)
|
||||
} else if user.UpdatedAt.IsZero() {
|
||||
t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt)
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(user).InnerJoins("Account", DB.Where("number = ?", user.Account.Number))
|
||||
tx.Statement.AddClause(clause.Set{
|
||||
{
|
||||
Column: clause.Column{
|
||||
Name: "name",
|
||||
Table: "users",
|
||||
},
|
||||
Value: "franco",
|
||||
},
|
||||
{
|
||||
Column: clause.Column{
|
||||
Name: "updated_at",
|
||||
Table: "users",
|
||||
},
|
||||
Value: time.Now(),
|
||||
},
|
||||
})
|
||||
|
||||
if rowsAffected := tx.Updates(nil).RowsAffected; rowsAffected != 1 {
|
||||
t.Errorf("should only update one record, but got %v", rowsAffected)
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.First(&result, "name = ?", "franco").Error; err != nil {
|
||||
t.Errorf("user's name should be updated")
|
||||
} else if result.UpdatedAt.UnixNano() == user.UpdatedAt.UnixNano() {
|
||||
t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, user.UpdatedAt)
|
||||
}
|
||||
}
|
@ -21,10 +21,10 @@ require (
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.6.0 // indirect
|
||||
golang.org/x/crypto v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/crypto v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
Loading…
x
Reference in New Issue
Block a user