Make join logic sharable (so it can be used by delete or update) and use it in delete
This commit is contained in:
		
							parent
							
								
									4a50b36f63
								
							
						
					
					
						commit
						f9315d3d01
					
				| @ -126,7 +126,19 @@ func Delete(config *Config) func(db *gorm.DB) { | ||||
| 
 | ||||
| 		if db.Statement.SQL.Len() == 0 { | ||||
| 			db.Statement.SQL.Grow(100) | ||||
| 			db.Statement.AddClauseIfNotExists(clause.Delete{}) | ||||
| 
 | ||||
| 			deleteClause := clause.Delete{} | ||||
| 
 | ||||
| 			HandleJoins( | ||||
| 				db, | ||||
| 				func(db *gorm.DB) { | ||||
| 					deleteClause.Modifier = db.Statement.Table | ||||
| 				}, | ||||
| 				func(db *gorm.DB, tableAliasName string, join gorm.Join, relation *schema.Relationship) { | ||||
| 				}, | ||||
| 			) | ||||
| 
 | ||||
| 			db.Statement.AddClauseIfNotExists(deleteClause) | ||||
| 
 | ||||
| 			if db.Statement.Schema != nil { | ||||
| 				_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) | ||||
|  | ||||
							
								
								
									
										154
									
								
								callbacks/join.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								callbacks/join.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,154 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| func HandleJoins(db *gorm.DB, prejoinCallback func(db *gorm.DB), perFieldNameCallback func(db *gorm.DB, tableAliasName string, join gorm.Join, relation *schema.Relationship)) { | ||||
| 	// inline joins
 | ||||
| 	fromClause := clause.From{} | ||||
| 	if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { | ||||
| 		fromClause = v | ||||
| 	} | ||||
| 
 | ||||
| 	if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { | ||||
| 		prejoinCallback(db) | ||||
| 
 | ||||
| 		specifiedRelationsName := make(map[string]interface{}) | ||||
| 		for _, join := range db.Statement.Joins { | ||||
| 			if db.Statement.Schema != nil { | ||||
| 				var isRelations bool // is relations or raw sql
 | ||||
| 				var relations []*schema.Relationship | ||||
| 				relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] | ||||
| 				if ok { | ||||
| 					isRelations = true | ||||
| 					relations = append(relations, relation) | ||||
| 				} else { | ||||
| 					// handle nested join like "Manager.Company"
 | ||||
| 					nestedJoinNames := strings.Split(join.Name, ".") | ||||
| 					if len(nestedJoinNames) > 1 { | ||||
| 						isNestedJoin := true | ||||
| 						gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 						currentRelations := db.Statement.Schema.Relationships.Relations | ||||
| 						for _, relname := range nestedJoinNames { | ||||
| 							// incomplete match, only treated as raw sql
 | ||||
| 							if relation, ok = currentRelations[relname]; ok { | ||||
| 								gussNestedRelations = append(gussNestedRelations, relation) | ||||
| 								currentRelations = relation.FieldSchema.Relationships.Relations | ||||
| 							} else { | ||||
| 								isNestedJoin = false | ||||
| 								break | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						if isNestedJoin { | ||||
| 							isRelations = true | ||||
| 							relations = gussNestedRelations | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if isRelations { | ||||
| 					genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 						tableAliasName := relation.Name | ||||
| 						if parentTableName != clause.CurrentTable { | ||||
| 							tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 						} | ||||
| 
 | ||||
| 						perFieldNameCallback(db, tableAliasName, join, relation) | ||||
| 
 | ||||
| 						exprs := make([]clause.Expression, len(relation.References)) | ||||
| 						for idx, ref := range relation.References { | ||||
| 							if ref.OwnPrimaryKey { | ||||
| 								exprs[idx] = clause.Eq{ | ||||
| 									Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, | ||||
| 									Value:  clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, | ||||
| 								} | ||||
| 							} else { | ||||
| 								if ref.PrimaryValue == "" { | ||||
| 									exprs[idx] = clause.Eq{ | ||||
| 										Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, | ||||
| 										Value:  clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, | ||||
| 									} | ||||
| 								} else { | ||||
| 									exprs[idx] = clause.Eq{ | ||||
| 										Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, | ||||
| 										Value:  ref.PrimaryValue, | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						{ | ||||
| 							onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} | ||||
| 							for _, c := range relation.FieldSchema.QueryClauses { | ||||
| 								onStmt.AddClause(c) | ||||
| 							} | ||||
| 
 | ||||
| 							if join.On != nil { | ||||
| 								onStmt.AddClause(join.On) | ||||
| 							} | ||||
| 
 | ||||
| 							if cs, ok := onStmt.Clauses["WHERE"]; ok { | ||||
| 								if where, ok := cs.Expression.(clause.Where); ok { | ||||
| 									where.Build(&onStmt) | ||||
| 
 | ||||
| 									if onSQL := onStmt.SQL.String(); onSQL != "" { | ||||
| 										vars := onStmt.Vars | ||||
| 										for idx, v := range vars { | ||||
| 											bindvar := strings.Builder{} | ||||
| 											onStmt.Vars = vars[0 : idx+1] | ||||
| 											db.Dialector.BindVarTo(&bindvar, &onStmt, v) | ||||
| 											onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) | ||||
| 										} | ||||
| 
 | ||||
| 										exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						return clause.Join{ | ||||
| 							Type:  joinType, | ||||
| 							Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, | ||||
| 							ON:    clause.Where{Exprs: exprs}, | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					parentTableName := clause.CurrentTable | ||||
| 					for _, rel := range relations { | ||||
| 						// joins table alias like "Manager, Company, Manager__Company"
 | ||||
| 						nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 						if _, ok := specifiedRelationsName[nestedAlias]; !ok { | ||||
| 							fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) | ||||
| 							specifiedRelationsName[nestedAlias] = nil | ||||
| 						} | ||||
| 
 | ||||
| 						if parentTableName != clause.CurrentTable { | ||||
| 							parentTableName = utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 						} else { | ||||
| 							parentTableName = rel.Name | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
| 					}) | ||||
| 				} | ||||
| 			} else { | ||||
| 				fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 					Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.AddClause(fromClause) | ||||
| 	} else { | ||||
| 		db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| @ -2,13 +2,11 @@ package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| func Query(db *gorm.DB) { | ||||
| @ -96,166 +94,34 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// inline joins
 | ||||
| 		fromClause := clause.From{} | ||||
| 		if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { | ||||
| 			fromClause = v | ||||
| 		} | ||||
| 
 | ||||
| 		if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { | ||||
| 			if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { | ||||
| 				clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) | ||||
| 				for idx, dbName := range db.Statement.Schema.DBNames { | ||||
| 					clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			specifiedRelationsName := make(map[string]interface{}) | ||||
| 			for _, join := range db.Statement.Joins { | ||||
| 				if db.Statement.Schema != nil { | ||||
| 					var isRelations bool // is relations or raw sql
 | ||||
| 					var relations []*schema.Relationship | ||||
| 					relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] | ||||
| 					if ok { | ||||
| 						isRelations = true | ||||
| 						relations = append(relations, relation) | ||||
| 					} else { | ||||
| 						// handle nested join like "Manager.Company"
 | ||||
| 						nestedJoinNames := strings.Split(join.Name, ".") | ||||
| 						if len(nestedJoinNames) > 1 { | ||||
| 							isNestedJoin := true | ||||
| 							gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 							currentRelations := db.Statement.Schema.Relationships.Relations | ||||
| 							for _, relname := range nestedJoinNames { | ||||
| 								// incomplete match, only treated as raw sql
 | ||||
| 								if relation, ok = currentRelations[relname]; ok { | ||||
| 									gussNestedRelations = append(gussNestedRelations, relation) | ||||
| 									currentRelations = relation.FieldSchema.Relationships.Relations | ||||
| 								} else { | ||||
| 									isNestedJoin = false | ||||
| 									break | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							if isNestedJoin { | ||||
| 								isRelations = true | ||||
| 								relations = gussNestedRelations | ||||
| 							} | ||||
| 						} | ||||
| 		HandleJoins( | ||||
| 			db, | ||||
| 			func(db *gorm.DB) { | ||||
| 				if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { | ||||
| 					clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) | ||||
| 					for idx, dbName := range db.Statement.Schema.DBNames { | ||||
| 						clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} | ||||
| 					} | ||||
| 				} | ||||
| 			}, | ||||
| 			func(db *gorm.DB, tableAliasName string, join gorm.Join, relation *schema.Relationship) { | ||||
| 				columnStmt := gorm.Statement{ | ||||
| 					Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | ||||
| 					Selects: join.Selects, Omits: join.Omits, | ||||
| 				} | ||||
| 
 | ||||
| 					if isRelations { | ||||
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							tableAliasName := relation.Name | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 							} | ||||
| 
 | ||||
| 							columnStmt := gorm.Statement{ | ||||
| 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | ||||
| 								Selects: join.Selects, Omits: join.Omits, | ||||
| 							} | ||||
| 
 | ||||
| 							selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) | ||||
| 							for _, s := range relation.FieldSchema.DBNames { | ||||
| 								if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { | ||||
| 									clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 										Table: tableAliasName, | ||||
| 										Name:  s, | ||||
| 										Alias: utils.NestedRelationName(tableAliasName, s), | ||||
| 									}) | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							exprs := make([]clause.Expression, len(relation.References)) | ||||
| 							for idx, ref := range relation.References { | ||||
| 								if ref.OwnPrimaryKey { | ||||
| 									exprs[idx] = clause.Eq{ | ||||
| 										Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, | ||||
| 										Value:  clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, | ||||
| 									} | ||||
| 								} else { | ||||
| 									if ref.PrimaryValue == "" { | ||||
| 										exprs[idx] = clause.Eq{ | ||||
| 											Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, | ||||
| 											Value:  clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, | ||||
| 										} | ||||
| 									} else { | ||||
| 										exprs[idx] = clause.Eq{ | ||||
| 											Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, | ||||
| 											Value:  ref.PrimaryValue, | ||||
| 										} | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							{ | ||||
| 								onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} | ||||
| 								for _, c := range relation.FieldSchema.QueryClauses { | ||||
| 									onStmt.AddClause(c) | ||||
| 								} | ||||
| 
 | ||||
| 								if join.On != nil { | ||||
| 									onStmt.AddClause(join.On) | ||||
| 								} | ||||
| 
 | ||||
| 								if cs, ok := onStmt.Clauses["WHERE"]; ok { | ||||
| 									if where, ok := cs.Expression.(clause.Where); ok { | ||||
| 										where.Build(&onStmt) | ||||
| 
 | ||||
| 										if onSQL := onStmt.SQL.String(); onSQL != "" { | ||||
| 											vars := onStmt.Vars | ||||
| 											for idx, v := range vars { | ||||
| 												bindvar := strings.Builder{} | ||||
| 												onStmt.Vars = vars[0 : idx+1] | ||||
| 												db.Dialector.BindVarTo(&bindvar, &onStmt, v) | ||||
| 												onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) | ||||
| 											} | ||||
| 
 | ||||
| 											exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) | ||||
| 										} | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							return clause.Join{ | ||||
| 								Type:  joinType, | ||||
| 								Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, | ||||
| 								ON:    clause.Where{Exprs: exprs}, | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						parentTableName := clause.CurrentTable | ||||
| 						for _, rel := range relations { | ||||
| 							// joins table alias like "Manager, Company, Manager__Company"
 | ||||
| 							nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							if _, ok := specifiedRelationsName[nestedAlias]; !ok { | ||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) | ||||
| 								specifiedRelationsName[nestedAlias] = nil | ||||
| 							} | ||||
| 
 | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								parentTableName = utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							} else { | ||||
| 								parentTableName = rel.Name | ||||
| 							} | ||||
| 						} | ||||
| 					} else { | ||||
| 						fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 							Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
| 				selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) | ||||
| 				for _, s := range relation.FieldSchema.DBNames { | ||||
| 					if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { | ||||
| 						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 							Table: tableAliasName, | ||||
| 							Name:  s, | ||||
| 							Alias: utils.NestedRelationName(tableAliasName, s), | ||||
| 						}) | ||||
| 					} | ||||
| 				} else { | ||||
| 					fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			db.Statement.AddClause(fromClause) | ||||
| 		} else { | ||||
| 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 		} | ||||
| 			}, | ||||
| 		) | ||||
| 
 | ||||
| 		db.Statement.AddClauseIfNotExists(clauseSelect) | ||||
| 
 | ||||
|  | ||||
| @ -260,7 +260,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) | ||||
| 
 | ||||
| 	if len(args) == 1 { | ||||
| 		if db, ok := args[0].(*DB); ok { | ||||
| 			j := join{ | ||||
| 			j := Join{ | ||||
| 				Name: query, Conds: args, Selects: db.Statement.Selects, | ||||
| 				Omits: db.Statement.Omits, JoinType: joinType, | ||||
| 			} | ||||
| @ -272,7 +272,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) | ||||
| 	tx.Statement.Joins = append(tx.Statement.Joins, Join{Name: query, Conds: args, JoinType: joinType}) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { | ||||
| // Unscoped allows queries to include records marked as deleted,
 | ||||
| // overriding the soft deletion behavior.
 | ||||
| // Example:
 | ||||
| //    var users []User
 | ||||
| //    db.Unscoped().Find(&users)
 | ||||
| //    // Retrieves all users, including deleted ones.
 | ||||
| //
 | ||||
| //	var users []User
 | ||||
| //	db.Unscoped().Find(&users)
 | ||||
| //	// Retrieves all users, including deleted ones.
 | ||||
| func (db *DB) Unscoped() (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.Unscoped = true | ||||
|  | ||||
| @ -13,7 +13,7 @@ func (d Delete) Build(builder Builder) { | ||||
| 
 | ||||
| 	if d.Modifier != "" { | ||||
| 		builder.WriteByte(' ') | ||||
| 		builder.WriteString(d.Modifier) | ||||
| 		builder.WriteQuoted(d.Modifier) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -33,7 +33,7 @@ type Statement struct { | ||||
| 	Selects              []string          // selected columns
 | ||||
| 	Omits                []string          // omit columns
 | ||||
| 	ColumnMapping        map[string]string // map columns
 | ||||
| 	Joins                []join | ||||
| 	Joins                []Join | ||||
| 	Preloads             map[string][]interface{} | ||||
| 	Settings             sync.Map | ||||
| 	ConnPool             ConnPool | ||||
| @ -49,7 +49,7 @@ type Statement struct { | ||||
| 	scopes               []func(*DB) *DB | ||||
| } | ||||
| 
 | ||||
| type join struct { | ||||
| type Join struct { | ||||
| 	Name     string | ||||
| 	Conds    []interface{} | ||||
| 	On       *clause.Where | ||||
| @ -538,7 +538,7 @@ func (stmt *Statement) clone() *Statement { | ||||
| 	} | ||||
| 
 | ||||
| 	if len(stmt.Joins) > 0 { | ||||
| 		newStmt.Joins = make([]join, len(stmt.Joins)) | ||||
| 		newStmt.Joins = make([]Join, len(stmt.Joins)) | ||||
| 		copy(newStmt.Joins, stmt.Joins) | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 mtsoltan
						mtsoltan