Make join logic sharable (so it can be used by delete or update) and use it in delete

This commit is contained in:
mtsoltan 2024-08-12 03:15:34 +03:00
parent 4a50b36f63
commit f9315d3d01
No known key found for this signature in database
GPG Key ID: 2F56E59CE4D1D296
6 changed files with 201 additions and 168 deletions

View File

@ -126,7 +126,19 @@ func Delete(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100) db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
deleteClause := clause.Delete{}
HandleJoins(
db,
func(db *gorm.DB) {
deleteClause.Modifier = db.Statement.Table
},
func(db *gorm.DB, tableAliasName string, join gorm.Join, relation *schema.Relationship) {
},
)
db.Statement.AddClauseIfNotExists(deleteClause)
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)

154
callbacks/join.go Normal file
View File

@ -0,0 +1,154 @@
package callbacks
import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
"strings"
)
func HandleJoins(db *gorm.DB, prejoinCallback func(db *gorm.DB), perFieldNameCallback func(db *gorm.DB, tableAliasName string, join gorm.Join, relation *schema.Relationship)) {
// inline joins
fromClause := clause.From{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
prejoinCallback(db)
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
perFieldNameCallback(db, tableAliasName, join, relation)
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.AddClause(fromClause)
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
}

View File

@ -2,13 +2,11 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
"reflect"
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
@ -96,166 +94,34 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
// inline joins HandleJoins(
fromClause := clause.From{} db,
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { func(db *gorm.DB) {
fromClause = v if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
} clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
} }
}
},
func(db *gorm.DB, tableAliasName string, join gorm.Join, relation *schema.Relationship) {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
if isRelations { selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { for _, s := range relation.FieldSchema.DBNames {
tableAliasName := relation.Name if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
if parentTableName != clause.CurrentTable { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) Table: tableAliasName,
} Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} }
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
} }
} },
)
db.Statement.AddClause(fromClause)
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.AddClauseIfNotExists(clauseSelect)

View File

@ -260,7 +260,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{})
if len(args) == 1 { if len(args) == 1 {
if db, ok := args[0].(*DB); ok { if db, ok := args[0].(*DB); ok {
j := join{ j := Join{
Name: query, Conds: args, Selects: db.Statement.Selects, Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType, Omits: db.Statement.Omits, JoinType: joinType,
} }
@ -272,7 +272,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{})
} }
} }
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) tx.Statement.Joins = append(tx.Statement.Joins, Join{Name: query, Conds: args, JoinType: joinType})
return return
} }
@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
// Unscoped allows queries to include records marked as deleted, // Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior. // overriding the soft deletion behavior.
// Example: // Example:
// var users []User //
// db.Unscoped().Find(&users) // var users []User
// // Retrieves all users, including deleted ones. // db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Unscoped = true tx.Statement.Unscoped = true

View File

@ -13,7 +13,7 @@ func (d Delete) Build(builder Builder) {
if d.Modifier != "" { if d.Modifier != "" {
builder.WriteByte(' ') builder.WriteByte(' ')
builder.WriteString(d.Modifier) builder.WriteQuoted(d.Modifier)
} }
} }

View File

@ -33,7 +33,7 @@ type Statement struct {
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
ColumnMapping map[string]string // map columns ColumnMapping map[string]string // map columns
Joins []join Joins []Join
Preloads map[string][]interface{} Preloads map[string][]interface{}
Settings sync.Map Settings sync.Map
ConnPool ConnPool ConnPool ConnPool
@ -49,7 +49,7 @@ type Statement struct {
scopes []func(*DB) *DB scopes []func(*DB) *DB
} }
type join struct { type Join struct {
Name string Name string
Conds []interface{} Conds []interface{}
On *clause.Where On *clause.Where
@ -538,7 +538,7 @@ func (stmt *Statement) clone() *Statement {
} }
if len(stmt.Joins) > 0 { if len(stmt.Joins) > 0 {
newStmt.Joins = make([]join, len(stmt.Joins)) newStmt.Joins = make([]Join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins) copy(newStmt.Joins, stmt.Joins)
} }