From c00cf29ccc7c6269141873e8bfc1ad5f8c09108f Mon Sep 17 00:00:00 2001 From: Franco Liberali Date: Wed, 6 Sep 2023 15:04:01 +0200 Subject: [PATCH] add join to update clause --- callbacks/query.go | 154 +------------------------------- callbacks/update.go | 4 +- clause/update.go | 6 ++ clauses.go | 180 ++++++++++++++++++++++++++++++++++++++ soft_delete.go | 34 +++---- tests/fork_update_test.go | 75 ++++++++++++++++ tests/go.mod | 6 +- 7 files changed, 287 insertions(+), 172 deletions(-) create mode 100644 clauses.go create mode 100644 tests/fork_update_test.go diff --git a/callbacks/query.go b/callbacks/query.go index e89dd199..b71ff5f5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -4,12 +4,9 @@ import ( "fmt" "reflect" "sort" - "strings" "gorm.io/gorm" "gorm.io/gorm/clause" - "gorm.io/gorm/schema" - "gorm.io/gorm/utils" ) func Query(db *gorm.DB) { @@ -104,157 +101,8 @@ func BuildQuerySQL(db *gorm.DB) { } if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { - if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { - clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) - for idx, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} - } - } - - specifiedRelationsName := make(map[string]interface{}) - for _, join := range db.Statement.Joins { - if db.Statement.Schema != nil { - var isRelations bool // is relations or raw sql - var relations []*schema.Relationship - relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] - if ok { - isRelations = true - relations = append(relations, relation) - } else { - // handle nested join like "Manager.Company" - nestedJoinNames := strings.Split(join.Name, ".") - if len(nestedJoinNames) > 1 { - isNestedJoin := true - gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) - currentRelations := db.Statement.Schema.Relationships.Relations - for _, relname := range nestedJoinNames { - // incomplete match, only treated as raw sql - if relation, ok = currentRelations[relname]; ok { - gussNestedRelations = append(gussNestedRelations, relation) - currentRelations = relation.FieldSchema.Relationships.Relations - } else { - isNestedJoin = false - break - } - } - - if isNestedJoin { - isRelations = true - relations = gussNestedRelations - } - } - } - - if isRelations { - genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) - } - - columnStmt := gorm.Statement{ - Table: tableAliasName, DB: db, Schema: relation.FieldSchema, - Selects: join.Selects, Omits: join.Omits, - } - - selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) - for _, s := range relation.FieldSchema.DBNames { - if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: utils.NestedRelationName(tableAliasName, s), - }) - } - } - - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } - } - } - } - - { - onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} - for _, c := range relation.FieldSchema.QueryClauses { - onStmt.AddClause(c) - } - - if join.On != nil { - onStmt.AddClause(join.On) - } - - if cs, ok := onStmt.Clauses["WHERE"]; ok { - if where, ok := cs.Expression.(clause.Where); ok { - where.Build(&onStmt) - - if onSQL := onStmt.SQL.String(); onSQL != "" { - vars := onStmt.Vars - for idx, v := range vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) - } - - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) - } - } - } - } - - return clause.Join{ - Type: joinType, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - } - } - - parentTableName := clause.CurrentTable - for _, rel := range relations { - // joins table alias like "Manager, Company, Manager__Company" - nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) - if _, ok := specifiedRelationsName[nestedAlias]; !ok { - fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) - specifiedRelationsName[nestedAlias] = nil - } - - if parentTableName != clause.CurrentTable { - parentTableName = utils.NestedRelationName(parentTableName, rel.Name) - } else { - parentTableName = rel.Name - } - } - } else { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) - } - } else { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) - } - } - + fromClause.Joins = append(fromClause.Joins, gorm.GenJoinClauses(db, &clauseSelect)...) db.Statement.AddClause(fromClause) - db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } diff --git a/callbacks/update.go b/callbacks/update.go index ff075dcf..b9ac9dca 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -69,7 +69,9 @@ func Update(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Update{}) + + gorm.CreateUpdateClause(db.Statement) + if _, ok := db.Statement.Clauses["SET"]; !ok { if set := ConvertToAssignments(db.Statement); len(set) != 0 { defer delete(db.Statement.Clauses, "SET") diff --git a/clause/update.go b/clause/update.go index f9d68ac6..ab6917b9 100644 --- a/clause/update.go +++ b/clause/update.go @@ -3,6 +3,7 @@ package clause type Update struct { Modifier string Table Table + Joins []Join } // Name update clause name @@ -22,6 +23,11 @@ func (update Update) Build(builder Builder) { } else { builder.WriteQuoted(update.Table) } + + for _, join := range update.Joins { + builder.WriteByte(' ') + join.Build(builder) + } } // MergeClause merge update clause diff --git a/clauses.go b/clauses.go new file mode 100644 index 00000000..6429eede --- /dev/null +++ b/clauses.go @@ -0,0 +1,180 @@ +package gorm + +import ( + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func CreateUpdateClause(stmt *Statement) { + updateClause := clause.Update{} + if v, ok := stmt.Clauses["UPDATE"].Expression.(clause.Update); ok { + updateClause = v + } + + if len(stmt.Joins) != 0 || len(updateClause.Joins) != 0 { + updateClause.Joins = append(updateClause.Joins, GenJoinClauses(stmt.DB, &clause.Select{})...) + stmt.AddClause(updateClause) + } else { + stmt.AddClauseIfNotExists(clause.Update{}) + } +} + +func GenJoinClauses(db *DB, clauseSelect *clause.Select) []clause.Join { + joinClauses := []clause.Join{} + + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + + specifiedRelationsName := make(map[string]interface{}) + for _, join := range db.Statement.Joins { + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } + } + + if isNestedJoin { + isRelations = true + relations = gussNestedRelations + } + } + } + + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } + + columnStmt := Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) + } + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + { + onStmt := Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + joinClauses = append(joinClauses, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } + } + } else { + joinClauses = append(joinClauses, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } else { + joinClauses = append(joinClauses, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } + + db.Statement.Joins = nil + + return joinClauses +} diff --git a/soft_delete.go b/soft_delete.go index 5673d3b8..c038eb07 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -141,30 +141,34 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - curTime := stmt.DB.NowFunc() - stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) - stmt.SetColumn(sd.Field.DBName, curTime, true) + if _, ok := stmt.Clauses["SET"]; !ok { + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) - if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) - column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) - - if len(values) > 0 { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) - } - - if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) - column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } + + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } } } SoftDeleteQueryClause(sd).ModifyStatement(stmt) - stmt.AddClauseIfNotExists(clause.Update{}) + + CreateUpdateClause(stmt) + stmt.Build(stmt.DB.Callback().Update().Clauses...) } } diff --git a/tests/fork_update_test.go b/tests/fork_update_test.go new file mode 100644 index 00000000..2b1d8ab1 --- /dev/null +++ b/tests/fork_update_test.go @@ -0,0 +1,75 @@ +package tests_test + +import ( + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +// only mysql support update join +func TestReasonUpdateJoinUpdatedAtIsAmbiguous(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).InnerJoins("Account", DB.Where("number = ?", 1)).Update("name", "jinzhu").Error; !strings.Contains(err.Error(), "Column 'updated_at' in field list is ambiguous") { + t.Errorf(`Error should be column is ambiguous, but got: "%s"`, err) + } +} + +// only mysql support update join +func TestUpdateJoinWorksManuallySettingSetClauses(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + var ( + users = []*User{ + GetUser("update-1", Config{Account: true}), + GetUser("update-2", Config{Account: true}), + GetUser("update-3", Config{}), + } + user = users[1] + ) + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Fatalf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) + } + + tx := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(user).InnerJoins("Account", DB.Where("number = ?", user.Account.Number)) + tx.Statement.AddClause(clause.Set{ + { + Column: clause.Column{ + Name: "name", + Table: "users", + }, + Value: "franco", + }, + { + Column: clause.Column{ + Name: "updated_at", + Table: "users", + }, + Value: time.Now(), + }, + }) + + if rowsAffected := tx.Updates(nil).RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + var result User + if err := DB.First(&result, "name = ?", "franco").Error; err != nil { + t.Errorf("user's name should be updated") + } else if result.UpdatedAt.UnixNano() == user.UpdatedAt.UnixNano() { + t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, user.UpdatedAt) + } +} diff --git a/tests/go.mod b/tests/go.mod index 71079050..7a453f87 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -21,10 +21,10 @@ require ( github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.5.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.18 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect - golang.org/x/crypto v0.15.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/crypto v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect ) replace gorm.io/gorm => ../