feat: support nested join (#6067)
* feat: support nested join * fix: empty rel value
This commit is contained in:
		
							parent
							
								
									654b5f2006
								
							
						
					
					
						commit
						8bf1f269cf
					
				| @ -8,6 +8,8 @@ import ( | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| func Query(db *gorm.DB) { | ||||
| @ -109,13 +111,46 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			specifiedRelationsName := make(map[string]interface{}) | ||||
| 			for _, join := range db.Statement.Joins { | ||||
| 				if db.Statement.Schema == nil { | ||||
| 					fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
| 					}) | ||||
| 				} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { | ||||
| 				if db.Statement.Schema != nil { | ||||
| 					var isRelations bool // is relations or raw sql
 | ||||
| 					var relations []*schema.Relationship | ||||
| 					relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] | ||||
| 					if ok { | ||||
| 						isRelations = true | ||||
| 						relations = append(relations, relation) | ||||
| 					} else { | ||||
| 						// handle nested join like "Manager.Company"
 | ||||
| 						nestedJoinNames := strings.Split(join.Name, ".") | ||||
| 						if len(nestedJoinNames) > 1 { | ||||
| 							isNestedJoin := true | ||||
| 							gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 							currentRelations := db.Statement.Schema.Relationships.Relations | ||||
| 							for _, relname := range nestedJoinNames { | ||||
| 								// incomplete match, only treated as raw sql
 | ||||
| 								if relation, ok = currentRelations[relname]; ok { | ||||
| 									gussNestedRelations = append(gussNestedRelations, relation) | ||||
| 									currentRelations = relation.FieldSchema.Relationships.Relations | ||||
| 								} else { | ||||
| 									isNestedJoin = false | ||||
| 									break | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							if isNestedJoin { | ||||
| 								isRelations = true | ||||
| 								relations = gussNestedRelations | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					if isRelations { | ||||
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							tableAliasName := relation.Name | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 							} | ||||
| 
 | ||||
| 							columnStmt := gorm.Statement{ | ||||
| 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | ||||
| @ -128,7 +163,7 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 									clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 										Table: tableAliasName, | ||||
| 										Name:  s, | ||||
| 								Alias: tableAliasName + "__" + s, | ||||
| 										Alias: utils.NestedRelationName(tableAliasName, s), | ||||
| 									}) | ||||
| 								} | ||||
| 							} | ||||
| @ -137,13 +172,13 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 							for idx, ref := range relation.References { | ||||
| 								if ref.OwnPrimaryKey { | ||||
| 									exprs[idx] = clause.Eq{ | ||||
| 								Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, | ||||
| 										Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, | ||||
| 										Value:  clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, | ||||
| 									} | ||||
| 								} else { | ||||
| 									if ref.PrimaryValue == "" { | ||||
| 										exprs[idx] = clause.Eq{ | ||||
| 									Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, | ||||
| 											Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, | ||||
| 											Value:  clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, | ||||
| 										} | ||||
| 									} else { | ||||
| @ -184,11 +219,28 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 					fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 						Type:  join.JoinType, | ||||
| 							return clause.Join{ | ||||
| 								Type:  joinType, | ||||
| 								Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, | ||||
| 								ON:    clause.Where{Exprs: exprs}, | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						parentTableName := clause.CurrentTable | ||||
| 						for _, rel := range relations { | ||||
| 							// joins table alias like "Manager, Company, Manager__Company"
 | ||||
| 							nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							if _, ok := specifiedRelationsName[nestedAlias]; !ok { | ||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) | ||||
| 								specifiedRelationsName[nestedAlias] = nil | ||||
| 							} | ||||
| 							parentTableName = rel.Name | ||||
| 						} | ||||
| 					} else { | ||||
| 						fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 							Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
| 						}) | ||||
| 					} | ||||
| 				} else { | ||||
| 					fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
| 						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, | ||||
|  | ||||
							
								
								
									
										60
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								scan.go
									
									
									
									
									
								
							| @ -4,10 +4,10 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| // prepareValues prepare values slice
 | ||||
| @ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { | ||||
| func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { | ||||
| 	for idx, field := range fields { | ||||
| 		if field != nil { | ||||
| 			values[idx] = field.NewValuePool.Get() | ||||
| @ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int | ||||
| 
 | ||||
| 	db.RowsAffected++ | ||||
| 	db.AddError(rows.Scan(values...)) | ||||
| 	joinedSchemaMap := make(map[*schema.Field]interface{}) | ||||
| 	joinedNestedSchemaMap := make(map[string]interface{}) | ||||
| 	for idx, field := range fields { | ||||
| 		if field == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if len(joinFields) == 0 || joinFields[idx][0] == nil { | ||||
| 		if len(joinFields) == 0 || len(joinFields[idx]) == 0 { | ||||
| 			db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) | ||||
| 		} else { | ||||
| 			joinSchema := joinFields[idx][0] | ||||
| 			relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) | ||||
| 		} else { // joinFields count is larger than 2 when using join
 | ||||
| 			var isNilPtrValue bool | ||||
| 			var relValue reflect.Value | ||||
| 			// does not contain raw dbname
 | ||||
| 			nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] | ||||
| 			// current reflect value
 | ||||
| 			currentReflectValue := reflectValue | ||||
| 			fullRels := make([]string, 0, len(nestedJoinSchemas)) | ||||
| 			for _, joinSchema := range nestedJoinSchemas { | ||||
| 				fullRels = append(fullRels, joinSchema.Name) | ||||
| 				relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) | ||||
| 				if relValue.Kind() == reflect.Ptr { | ||||
| 				if _, ok := joinedSchemaMap[joinSchema]; !ok { | ||||
| 					fullRelsName := utils.JoinNestedRelationNames(fullRels) | ||||
| 					// same nested structure
 | ||||
| 					if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { | ||||
| 						if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { | ||||
| 						continue | ||||
| 							isNilPtrValue = true | ||||
| 							break | ||||
| 						} | ||||
| 
 | ||||
| 						relValue.Set(reflect.New(relValue.Type().Elem())) | ||||
| 					joinedSchemaMap[joinSchema] = nil | ||||
| 						joinedNestedSchemaMap[fullRelsName] = nil | ||||
| 					} | ||||
| 				} | ||||
| 			db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) | ||||
| 				currentReflectValue = relValue | ||||
| 			} | ||||
| 
 | ||||
| 			if !isNilPtrValue { // ignore if value is nil
 | ||||
| 				f := joinFields[idx][len(joinFields[idx])-1] | ||||
| 				db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// release data to pool
 | ||||
| @ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 	default: | ||||
| 		var ( | ||||
| 			fields       = make([]*schema.Field, len(columns)) | ||||
| 			joinFields   [][2]*schema.Field | ||||
| 			joinFields   [][]*schema.Field | ||||
| 			sch          = db.Statement.Schema | ||||
| 			reflectValue = db.Statement.ReflectValue | ||||
| 		) | ||||
| @ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 						} else { | ||||
| 							matchedFieldCount[column] = 1 | ||||
| 						} | ||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | ||||
| 						if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 							subNameCount := len(names) | ||||
| 							// nested relation fields
 | ||||
| 							relFields := make([]*schema.Field, 0, subNameCount-1) | ||||
| 							relFields = append(relFields, rel.Field) | ||||
| 							for _, name := range names[1 : subNameCount-1] { | ||||
| 								rel = rel.FieldSchema.Relationships.Relations[name] | ||||
| 								relFields = append(relFields, rel.Field) | ||||
| 							} | ||||
| 							// lastest name is raw dbname
 | ||||
| 							dbName := names[subNameCount-1] | ||||
| 							if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { | ||||
| 								fields[idx] = field | ||||
| 
 | ||||
| 								if len(joinFields) == 0 { | ||||
| 									joinFields = make([][2]*schema.Field, len(columns)) | ||||
| 									joinFields = make([][]*schema.Field, len(columns)) | ||||
| 								} | ||||
| 								joinFields[idx] = [2]*schema.Field{rel.Field, field} | ||||
| 								relFields = append(relFields, field) | ||||
| 								joinFields[idx] = relFields | ||||
| 								continue | ||||
| 							} | ||||
| 						} | ||||
|  | ||||
| @ -325,3 +325,66 @@ func TestJoinArgsWithDB(t *testing.T) { | ||||
| 	} | ||||
| 	AssertEqual(t, user4.NamedPet.Name, "") | ||||
| } | ||||
| 
 | ||||
| func TestNestedJoins(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{ | ||||
| 			Name:     "nested-joins-1", | ||||
| 			Manager:  GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}), | ||||
| 			NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name:     "nested-joins-2", | ||||
| 			Manager:  GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}), | ||||
| 			NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	var userIDs []uint | ||||
| 	for _, user := range users { | ||||
| 		userIDs = append(userIDs, user.ID) | ||||
| 	} | ||||
| 
 | ||||
| 	var users2 []User | ||||
| 	if err := DB. | ||||
| 		Joins("Manager"). | ||||
| 		Joins("Manager.Company"). | ||||
| 		Joins("Manager.NamedPet"). | ||||
| 		Joins("NamedPet"). | ||||
| 		Joins("NamedPet.Toy"). | ||||
| 		Find(&users2, "users.id IN ?", userIDs).Error; err != nil { | ||||
| 		t.Fatalf("Failed to load with joins, got error: %v", err) | ||||
| 	} else if len(users2) != len(users) { | ||||
| 		t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Slice(users2, func(i, j int) bool { | ||||
| 		return users2[i].ID > users2[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(users, func(i, j int) bool { | ||||
| 		return users[i].ID > users[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	for idx, user := range users { | ||||
| 		// user
 | ||||
| 		CheckUser(t, user, users2[idx]) | ||||
| 		if users2[idx].Manager == nil { | ||||
| 			t.Fatalf("Failed to load Manager") | ||||
| 		} | ||||
| 		// manager
 | ||||
| 		CheckUser(t, *user.Manager, *users2[idx].Manager) | ||||
| 		// user pet
 | ||||
| 		if users2[idx].NamedPet == nil { | ||||
| 			t.Fatalf("Failed to load NamedPet") | ||||
| 		} | ||||
| 		CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) | ||||
| 		// manager pet
 | ||||
| 		if users2[idx].Manager.NamedPet == nil { | ||||
| 			t.Fatalf("Failed to load NamedPet") | ||||
| 		} | ||||
| 		CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -13,8 +13,14 @@ import ( | ||||
| 
 | ||||
| func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { | ||||
| 	for _, name := range names { | ||||
| 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||
| 		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||
| 		rv := reflect.Indirect(reflect.ValueOf(r)) | ||||
| 		ev := reflect.Indirect(reflect.ValueOf(e)) | ||||
| 		if rv.IsValid() != ev.IsValid() { | ||||
| 			t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) | ||||
| 			return | ||||
| 		} | ||||
| 		got := rv.FieldByName(name).Interface() | ||||
| 		expect := ev.FieldByName(name).Interface() | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			AssertEqual(t, got, expect) | ||||
| 		}) | ||||
|  | ||||
| @ -131,3 +131,20 @@ func ToString(value interface{}) string { | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| const nestedRelationSplit = "__" | ||||
| 
 | ||||
| // NestedRelationName nested relationships like `Manager__Company`
 | ||||
| func NestedRelationName(prefix, name string) string { | ||||
| 	return prefix + nestedRelationSplit + name | ||||
| } | ||||
| 
 | ||||
| // SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
 | ||||
| func SplitNestedRelationName(name string) []string { | ||||
| 	return strings.Split(name, nestedRelationSplit) | ||||
| } | ||||
| 
 | ||||
| // JoinNestedRelationNames nested relationships like `Manager__Company`
 | ||||
| func JoinNestedRelationNames(relationNames []string) string { | ||||
| 	return strings.Join(relationNames, nestedRelationSplit) | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Cr
						Cr