add join to update clause
This commit is contained in:
parent
a2cac75218
commit
c00cf29ccc
@ -4,12 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
|
||||||
"gorm.io/gorm/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(db *gorm.DB) {
|
func Query(db *gorm.DB) {
|
||||||
@ -104,157 +101,8 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
||||||
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
fromClause.Joins = append(fromClause.Joins, gorm.GenJoinClauses(db, &clauseSelect)...)
|
||||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
|
||||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
|
||||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
specifiedRelationsName := make(map[string]interface{})
|
|
||||||
for _, join := range db.Statement.Joins {
|
|
||||||
if db.Statement.Schema != nil {
|
|
||||||
var isRelations bool // is relations or raw sql
|
|
||||||
var relations []*schema.Relationship
|
|
||||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
|
||||||
if ok {
|
|
||||||
isRelations = true
|
|
||||||
relations = append(relations, relation)
|
|
||||||
} else {
|
|
||||||
// handle nested join like "Manager.Company"
|
|
||||||
nestedJoinNames := strings.Split(join.Name, ".")
|
|
||||||
if len(nestedJoinNames) > 1 {
|
|
||||||
isNestedJoin := true
|
|
||||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
|
||||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
|
||||||
for _, relname := range nestedJoinNames {
|
|
||||||
// incomplete match, only treated as raw sql
|
|
||||||
if relation, ok = currentRelations[relname]; ok {
|
|
||||||
gussNestedRelations = append(gussNestedRelations, relation)
|
|
||||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
|
||||||
} else {
|
|
||||||
isNestedJoin = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isNestedJoin {
|
|
||||||
isRelations = true
|
|
||||||
relations = gussNestedRelations
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isRelations {
|
|
||||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
|
||||||
tableAliasName := relation.Name
|
|
||||||
if parentTableName != clause.CurrentTable {
|
|
||||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
|
||||||
}
|
|
||||||
|
|
||||||
columnStmt := gorm.Statement{
|
|
||||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
|
||||||
Selects: join.Selects, Omits: join.Omits,
|
|
||||||
}
|
|
||||||
|
|
||||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
|
||||||
for _, s := range relation.FieldSchema.DBNames {
|
|
||||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
|
||||||
Table: tableAliasName,
|
|
||||||
Name: s,
|
|
||||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := make([]clause.Expression, len(relation.References))
|
|
||||||
for idx, ref := range relation.References {
|
|
||||||
if ref.OwnPrimaryKey {
|
|
||||||
exprs[idx] = clause.Eq{
|
|
||||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ref.PrimaryValue == "" {
|
|
||||||
exprs[idx] = clause.Eq{
|
|
||||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
exprs[idx] = clause.Eq{
|
|
||||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
|
||||||
Value: ref.PrimaryValue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
|
||||||
for _, c := range relation.FieldSchema.QueryClauses {
|
|
||||||
onStmt.AddClause(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
if join.On != nil {
|
|
||||||
onStmt.AddClause(join.On)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
|
||||||
if where, ok := cs.Expression.(clause.Where); ok {
|
|
||||||
where.Build(&onStmt)
|
|
||||||
|
|
||||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
|
||||||
vars := onStmt.Vars
|
|
||||||
for idx, v := range vars {
|
|
||||||
bindvar := strings.Builder{}
|
|
||||||
onStmt.Vars = vars[0 : idx+1]
|
|
||||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
|
||||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return clause.Join{
|
|
||||||
Type: joinType,
|
|
||||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
|
||||||
ON: clause.Where{Exprs: exprs},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
parentTableName := clause.CurrentTable
|
|
||||||
for _, rel := range relations {
|
|
||||||
// joins table alias like "Manager, Company, Manager__Company"
|
|
||||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
|
||||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
|
||||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
|
||||||
specifiedRelationsName[nestedAlias] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if parentTableName != clause.CurrentTable {
|
|
||||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
|
||||||
} else {
|
|
||||||
parentTableName = rel.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Statement.AddClause(fromClause)
|
db.Statement.AddClause(fromClause)
|
||||||
db.Statement.Joins = nil
|
|
||||||
} else {
|
} else {
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,9 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
|
|
||||||
if db.Statement.SQL.Len() == 0 {
|
if db.Statement.SQL.Len() == 0 {
|
||||||
db.Statement.SQL.Grow(180)
|
db.Statement.SQL.Grow(180)
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
|
||||||
|
gorm.CreateUpdateClause(db.Statement)
|
||||||
|
|
||||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||||
defer delete(db.Statement.Clauses, "SET")
|
defer delete(db.Statement.Clauses, "SET")
|
||||||
|
@ -3,6 +3,7 @@ package clause
|
|||||||
type Update struct {
|
type Update struct {
|
||||||
Modifier string
|
Modifier string
|
||||||
Table Table
|
Table Table
|
||||||
|
Joins []Join
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name update clause name
|
// Name update clause name
|
||||||
@ -22,6 +23,11 @@ func (update Update) Build(builder Builder) {
|
|||||||
} else {
|
} else {
|
||||||
builder.WriteQuoted(update.Table)
|
builder.WriteQuoted(update.Table)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, join := range update.Joins {
|
||||||
|
builder.WriteByte(' ')
|
||||||
|
join.Build(builder)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeClause merge update clause
|
// MergeClause merge update clause
|
||||||
|
180
clauses.go
Normal file
180
clauses.go
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateUpdateClause(stmt *Statement) {
|
||||||
|
updateClause := clause.Update{}
|
||||||
|
if v, ok := stmt.Clauses["UPDATE"].Expression.(clause.Update); ok {
|
||||||
|
updateClause = v
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stmt.Joins) != 0 || len(updateClause.Joins) != 0 {
|
||||||
|
updateClause.Joins = append(updateClause.Joins, GenJoinClauses(stmt.DB, &clause.Select{})...)
|
||||||
|
stmt.AddClause(updateClause)
|
||||||
|
} else {
|
||||||
|
stmt.AddClauseIfNotExists(clause.Update{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenJoinClauses(db *DB, clauseSelect *clause.Select) []clause.Join {
|
||||||
|
joinClauses := []clause.Join{}
|
||||||
|
|
||||||
|
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||||
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||||
|
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||||
|
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
specifiedRelationsName := make(map[string]interface{})
|
||||||
|
for _, join := range db.Statement.Joins {
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
var isRelations bool // is relations or raw sql
|
||||||
|
var relations []*schema.Relationship
|
||||||
|
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||||
|
if ok {
|
||||||
|
isRelations = true
|
||||||
|
relations = append(relations, relation)
|
||||||
|
} else {
|
||||||
|
// handle nested join like "Manager.Company"
|
||||||
|
nestedJoinNames := strings.Split(join.Name, ".")
|
||||||
|
if len(nestedJoinNames) > 1 {
|
||||||
|
isNestedJoin := true
|
||||||
|
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||||
|
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||||
|
for _, relname := range nestedJoinNames {
|
||||||
|
// incomplete match, only treated as raw sql
|
||||||
|
if relation, ok = currentRelations[relname]; ok {
|
||||||
|
gussNestedRelations = append(gussNestedRelations, relation)
|
||||||
|
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||||
|
} else {
|
||||||
|
isNestedJoin = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNestedJoin {
|
||||||
|
isRelations = true
|
||||||
|
relations = gussNestedRelations
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isRelations {
|
||||||
|
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||||
|
tableAliasName := relation.Name
|
||||||
|
if parentTableName != clause.CurrentTable {
|
||||||
|
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnStmt := Statement{
|
||||||
|
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||||
|
Selects: join.Selects, Omits: join.Omits,
|
||||||
|
}
|
||||||
|
|
||||||
|
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||||
|
for _, s := range relation.FieldSchema.DBNames {
|
||||||
|
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||||
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
|
Table: tableAliasName,
|
||||||
|
Name: s,
|
||||||
|
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := make([]clause.Expression, len(relation.References))
|
||||||
|
for idx, ref := range relation.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||||
|
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ref.PrimaryValue == "" {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: ref.PrimaryValue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
onStmt := Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||||
|
for _, c := range relation.FieldSchema.QueryClauses {
|
||||||
|
onStmt.AddClause(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if join.On != nil {
|
||||||
|
onStmt.AddClause(join.On)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||||
|
if where, ok := cs.Expression.(clause.Where); ok {
|
||||||
|
where.Build(&onStmt)
|
||||||
|
|
||||||
|
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||||
|
vars := onStmt.Vars
|
||||||
|
for idx, v := range vars {
|
||||||
|
bindvar := strings.Builder{}
|
||||||
|
onStmt.Vars = vars[0 : idx+1]
|
||||||
|
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||||
|
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clause.Join{
|
||||||
|
Type: joinType,
|
||||||
|
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||||
|
ON: clause.Where{Exprs: exprs},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parentTableName := clause.CurrentTable
|
||||||
|
for _, rel := range relations {
|
||||||
|
// joins table alias like "Manager, Company, Manager__Company"
|
||||||
|
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||||
|
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||||
|
joinClauses = append(joinClauses, genJoinClause(join.JoinType, parentTableName, rel))
|
||||||
|
specifiedRelationsName[nestedAlias] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if parentTableName != clause.CurrentTable {
|
||||||
|
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||||
|
} else {
|
||||||
|
parentTableName = rel.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
joinClauses = append(joinClauses, clause.Join{
|
||||||
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
joinClauses = append(joinClauses, clause.Join{
|
||||||
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.Joins = nil
|
||||||
|
|
||||||
|
return joinClauses
|
||||||
|
}
|
@ -141,30 +141,34 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
|||||||
|
|
||||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||||
curTime := stmt.DB.NowFunc()
|
if _, ok := stmt.Clauses["SET"]; !ok {
|
||||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
|
curTime := stmt.DB.NowFunc()
|
||||||
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
|
||||||
|
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
||||||
|
|
||||||
if stmt.Schema != nil {
|
if stmt.Schema != nil {
|
||||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||||
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
|
||||||
}
|
|
||||||
|
|
||||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
|
||||||
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
|
||||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||||
|
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||||
|
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
|
if len(values) > 0 {
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||||
stmt.AddClauseIfNotExists(clause.Update{})
|
|
||||||
|
CreateUpdateClause(stmt)
|
||||||
|
|
||||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
75
tests/fork_update_test.go
Normal file
75
tests/fork_update_test.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// only mysql support update join
|
||||||
|
func TestReasonUpdateJoinUpdatedAtIsAmbiguous(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "mysql" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).InnerJoins("Account", DB.Where("number = ?", 1)).Update("name", "jinzhu").Error; !strings.Contains(err.Error(), "Column 'updated_at' in field list is ambiguous") {
|
||||||
|
t.Errorf(`Error should be column is ambiguous, but got: "%s"`, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// only mysql support update join
|
||||||
|
func TestUpdateJoinWorksManuallySettingSetClauses(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "mysql" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
users = []*User{
|
||||||
|
GetUser("update-1", Config{Account: true}),
|
||||||
|
GetUser("update-2", Config{Account: true}),
|
||||||
|
GetUser("update-3", Config{}),
|
||||||
|
}
|
||||||
|
user = users[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := DB.Create(&users).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
|
} else if user.ID == 0 {
|
||||||
|
t.Fatalf("user's primary value should not zero, %v", user.ID)
|
||||||
|
} else if user.UpdatedAt.IsZero() {
|
||||||
|
t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(user).InnerJoins("Account", DB.Where("number = ?", user.Account.Number))
|
||||||
|
tx.Statement.AddClause(clause.Set{
|
||||||
|
{
|
||||||
|
Column: clause.Column{
|
||||||
|
Name: "name",
|
||||||
|
Table: "users",
|
||||||
|
},
|
||||||
|
Value: "franco",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Column: clause.Column{
|
||||||
|
Name: "updated_at",
|
||||||
|
Table: "users",
|
||||||
|
},
|
||||||
|
Value: time.Now(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if rowsAffected := tx.Updates(nil).RowsAffected; rowsAffected != 1 {
|
||||||
|
t.Errorf("should only update one record, but got %v", rowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.First(&result, "name = ?", "franco").Error; err != nil {
|
||||||
|
t.Errorf("user's name should be updated")
|
||||||
|
} else if result.UpdatedAt.UnixNano() == user.UpdatedAt.UnixNano() {
|
||||||
|
t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, user.UpdatedAt)
|
||||||
|
}
|
||||||
|
}
|
@ -21,10 +21,10 @@ require (
|
|||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||||
github.com/jackc/pgx/v5 v5.5.0 // indirect
|
github.com/jackc/pgx/v5 v5.5.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||||
github.com/microsoft/go-mssqldb v1.6.0 // indirect
|
github.com/microsoft/go-mssqldb v1.6.0 // indirect
|
||||||
golang.org/x/crypto v0.15.0 // indirect
|
golang.org/x/crypto v0.13.0 // indirect
|
||||||
golang.org/x/text v0.14.0 // indirect
|
golang.org/x/text v0.13.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
Loading…
x
Reference in New Issue
Block a user