package orm import ( "fmt" sb "github.com/henvic/pgq" "github.com/jackc/pgx/v5" "reflect" "strings" ) const PopulateAll = "~~~ALL~~~" // Populate - allows you to pre-load embedded structs/slices within the current model. // use dots between field names to specify nested paths. use the PopulateAll constant to populate all // relationships non-recursively func (q *Query) Populate(fields ...string) *Query { if q.populationTree == nil { q.populationTree = make(map[string]any) } for _, field := range fields { if field == PopulateAll { for k := range q.model.Relationships { if _, ok := q.populationTree[k]; !ok { q.populationTree[k] = make(map[string]any) } } continue } cur := q.populationTree parts := strings.Split(field, ".") for _, part := range parts { if _, ok := cur[part]; !ok { cur[part] = make(map[string]any) } cur = cur[part].(map[string]any) } } return q } func (q *Query) processPopulate(parent reflect.Value, model *Model, populationTree map[string]any) error { if parent.Len() == 0 { return nil } pids := make([]any, 0) var err error idField := model.IDField for i := range parent.Len() { pval := parent.Index(i) if pval.Kind() == reflect.Pointer { pval = pval.Elem() } pids = append(pids, pval.FieldByName(idField).Interface()) } toClose := make([]pgx.Rows, 0) defer func() { for _, c := range toClose { c.Close() } }() for p, nested := range populationTree { var rel *Relationship for _, r := range model.Relationships { if r.FieldName == p { rel = r break } } if rel == nil { return fmt.Errorf("field '%s' not found in model '%s'", p, model.Name) } childSlice := reflect.Value{} if (rel.Type == HasMany || rel.Type == HasOne) && !rel.m2mIsh() { childSlice, err = q.populateHas(rel, parent, pids) } else if rel.Type == BelongsTo { childSlice, err = q.populateBelongsTo(rel, parent, pids) } else if rel.Type == ManyToMany || rel.m2mIsh() { childSlice, err = q.populateManyToMany(rel, parent, pids) } if err != nil { return fmt.Errorf("failed to populate field at '%s': %w", p, err) } ntree, ok := nested.(map[string]any) if ok && len(ntree) > 0 && childSlice.IsValid() && childSlice.Len() > 0 { if err = q.processPopulate(childSlice, rel.RelatedModel, ntree); err != nil { return err } } } return nil } func (q *Query) populateHas(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) { fkf := rel.primaryID() var fk string if fkf != nil && fkf.ColumnType != "" { fk = fkf.ColumnName } else if rel.relatedID() != nil { fk = pascalToSnakeCase(rel.RelatedModel.Name + rel.relatedID().Name) } if rel.RelatedModel.embeddedIsh && !rel.Model.embeddedIsh && rel.Type == HasMany { arel := rel.RelatedModel.Relationships[rel.Model.Name] fk = pascalToSnakeCase(arel.joinField()) } ccols := make([]string, 0) anonymousCols := make(map[string]map[string]*Field) for _, f := range rel.RelatedModel.Fields { if !f.isAnonymous() { ccols = append(ccols, f.ColumnName) } } for _, f := range rel.RelatedModel.Fields { if f.isAnonymous() { ccols = append(ccols, f.anonymousColumnNames()...) anonymousCols[f.Name] = f.embeddedFields } } for _, r := range rel.RelatedModel.Relationships { if r.Type != ManyToOne { continue } ccols = append(ccols, pascalToSnakeCase(r.joinField())) } /*var tableName string if rel.Type == HasOne { tableName = rel.Model.TableName } if rel.Type == HasMany { tableName = rel.RelatedModel.TableName }*/ aq, aa := sb.Select(ccols...). From(rel.RelatedModel.TableName). Where(fmt.Sprintf("%s IN (%s)", fk, MakePlaceholders(len(parentIds))), parentIds...).MustSQL() q.engine.logQuery("populate", aq, aa) rows, err := q.engine.conn.Query(q.ctx, aq, aa...) if err != nil { return reflect.Value{}, err } defer rows.Close() idFieldName := rel.Model.IDField idField := rel.Model.Fields[idFieldName] if rel.Type == HasMany { childMap := reflect.MakeMap(reflect.MapOf( idField.Type, reflect.SliceOf(rel.RelatedModel.Type), )) for rows.Next() { child := reflect.New(rel.RelatedModel.Type).Elem() var fkValue any scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &fkValue) if err = rows.Scan(scanDest...); err != nil { return reflect.Value{}, err } fkVal := reflect.ValueOf(fkValue) childrenOfParent := childMap.MapIndex(fkVal) if !childrenOfParent.IsValid() { childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0) } childrenOfParent = reflect.Append(childrenOfParent, child) childMap.SetMapIndex(fkVal, childrenOfParent) } for i := range parent.Len() { ps := parent.Index(i) if ps.Kind() == reflect.Pointer { ps = ps.Elem() } pid := ps.FieldByName(idFieldName) c := childMap.MapIndex(pid) if c.IsValid() { ps.FieldByName(rel.FieldName).Set(c) } } } else { childMap := reflect.MakeMap(reflect.MapOf(idField.Type, rel.RelatedModel.Type)) for rows.Next() { child := reflect.New(rel.RelatedModel.Type).Elem() var fkValue any scanDest, _ := buildScanDest(child, rel.Model, rel, ccols, anonymousCols, &fkValue) if err = rows.Scan(scanDest...); err != nil { return reflect.Value{}, err } fkVal := reflect.ValueOf(fkValue) childMap.SetMapIndex(fkVal, child) } for i := range parent.Len() { ps := parent.Index(i) if ps.Kind() == reflect.Pointer { ps = ps.Elem() } parentID := ps.FieldByName(idFieldName) if child := childMap.MapIndex(parentID); child.IsValid() { ps.FieldByName(rel.FieldName).Set(child) } } } childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0) for i := range parent.Len() { ps := parent.Index(i) if ps.Kind() == reflect.Ptr { ps = ps.Elem() } childField := ps.FieldByName(rel.FieldName) if !childField.IsValid() { continue } if rel.Type == HasMany { for j := range childField.Len() { childSlice = reflect.Append(childSlice, childField.Index(j).Addr()) } } else { if !childField.IsZero() { childSlice = reflect.Append(childSlice, childField.Addr()) } } } return childSlice, nil } func (q *Query) populateManyToMany(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) { inPlaceholders := MakePlaceholders(len(parentIds)) ccols := make([]string, 0) anonymousCols := make(map[string]map[string]*Field) for _, f := range rel.RelatedModel.Fields { if !f.isAnonymous() { ccols = append(ccols, "m."+f.ColumnName) } } for _, f := range rel.RelatedModel.Fields { if f.isAnonymous() { for ecol := range f.embeddedFields { ccols = append(ccols, "m."+ecol) } anonymousCols[f.Name] = f.embeddedFields } } ccols = append(ccols, fmt.Sprintf("jt.%s_id", rel.Model.TableName)) mq, ma := sb.Select(ccols...). From(fmt.Sprintf("%s AS m", rel.RelatedModel.TableName)). Join( fmt.Sprintf("%s AS jt ON m.%s = jt.%s_id", rel.ComputeJoinTable(), rel.relatedID().ColumnName, rel.RelatedModel.TableName)). Where(fmt.Sprintf("jt.%s_id IN (%s)", rel.Model.TableName, inPlaceholders), parentIds...).MustSQL() q.engine.logQuery("populate/join", mq, ma) rows, err := q.engine.conn.Query(q.ctx, mq, ma...) if err != nil { return reflect.Value{}, err } defer rows.Close() idFieldName := rel.Model.IDField idField := rel.Model.Fields[idFieldName] childMap := reflect.MakeMap(reflect.MapOf( idField.Type, reflect.SliceOf(rel.RelatedModel.Type))) for rows.Next() { child := reflect.New(rel.RelatedModel.Type).Elem() var foreignKeyValue any scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &foreignKeyValue) if err = rows.Scan(scanDest...); err != nil { return reflect.Value{}, err } fkVal := reflect.ValueOf(foreignKeyValue) childrenOfParent := childMap.MapIndex(fkVal) if !childrenOfParent.IsValid() { childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0) } childrenOfParent = reflect.Append(childrenOfParent, child) childMap.SetMapIndex(fkVal, childrenOfParent) } for i := range parent.Len() { p := parent.Index(i) if p.Kind() == reflect.Ptr { p = p.Elem() } parentID := p.FieldByName(rel.primaryID().Name) if children := childMap.MapIndex(parentID); children.IsValid() { p.FieldByName(rel.FieldName).Set(children) } } childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0) for i := range parent.Len() { ps := parent.Index(i) if ps.Kind() == reflect.Ptr { ps = ps.Elem() } childField := ps.FieldByName(rel.FieldName) if childField.IsValid() { for j := range childField.Len() { childSlice = reflect.Append(childSlice, childField.Index(j).Addr()) } } } return childSlice, nil } func (q *Query) populateBelongsTo(rel *Relationship, childrenSlice reflect.Value, childIDs []any) (reflect.Value, error) { childIdField := rel.Model.Fields[rel.Model.IDField] parentIdField := rel.RelatedModel.Fields[rel.RelatedModel.IDField] fk := pascalToSnakeCase(rel.joinField()) qs, qa := sb.Select(childIdField.ColumnName, fk). From(rel.Model.TableName). Where(fmt.Sprintf("%s IN (%s)", childIdField.ColumnName, MakePlaceholders(len(childIDs)), ), childIDs...).MustSQL() q.engine.logQuery("populate/belongs-to", qs, qa) rows, err := q.engine.conn.Query(q.ctx, qs, qa...) if err != nil { return reflect.Value{}, err } childParentKeyMap := make(map[any]any) parentKeyValues := make([]any, 0) parentKeySet := make(map[any]bool) for rows.Next() { var cid, pfk any err = rows.Scan(&cid, &pfk) if err != nil { rows.Close() return reflect.Value{}, err } childParentKeyMap[cid] = pfk if !parentKeySet[pfk] { parentKeySet[pfk] = true parentKeyValues = append(parentKeyValues, pfk) } } rows.Close() if len(parentKeyValues) == 0 { return reflect.Value{}, nil } pcols := make([]string, 0) anonymousCols := make(map[string]map[string]*Field) for _, f := range rel.RelatedModel.Fields { if !f.isAnonymous() { pcols = append(pcols, f.ColumnName) } } for _, f := range rel.RelatedModel.Fields { if f.isAnonymous() { pcols = append(pcols, f.anonymousColumnNames()...) anonymousCols[f.Name] = f.embeddedFields } } pquery, pqargs := sb.Select(pcols...). From(rel.RelatedModel.TableName). Where(fmt.Sprintf("%s IN (%s)", parentIdField.ColumnName, MakePlaceholders(len(parentKeyValues))), parentKeyValues...). MustSQL() q.engine.logQuery("populate/belongs-to->parent", pquery, pqargs) parentRows, err := q.engine.conn.Query(q.ctx, pquery, pqargs...) if err != nil { return reflect.Value{}, err } defer parentRows.Close() parentMap := reflect.MakeMap(reflect.MapOf( parentIdField.Type, rel.RelatedModel.Type, )) for parentRows.Next() { parent := reflect.New(rel.RelatedModel.Type).Elem() scanDst, _ := buildScanDest(parent, rel.RelatedModel, rel, pcols, anonymousCols, nil) if err = parentRows.Scan(scanDst...); err != nil { return reflect.Value{}, err } parentId := parent.FieldByName(rel.RelatedModel.IDField) parentMap.SetMapIndex(parentId, parent) } for i := range childrenSlice.Len() { child := childrenSlice.Index(i) childID := child.FieldByName(rel.Model.IDField) if parentKey, ok := childParentKeyMap[childID.Interface()]; ok && parentKey != nil { if parent := parentMap.MapIndex(reflect.ValueOf(parentKey)); parent.IsValid() { child.FieldByName(rel.FieldName).Set(parent) } } } ntype := rel.RelatedModel.Type if rel.Kind == reflect.Pointer { ntype = reflect.PointerTo(rel.RelatedModel.Type) } parentSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(ntype)), 0, 0) for i := range childrenSlice.Len() { ps := childrenSlice.Index(i) if ps.Kind() == reflect.Ptr { ps = ps.Elem() } childField := ps.FieldByName(rel.FieldName) if childField.IsValid() { parentSlice = reflect.Append(parentSlice, childField.Addr()) } } return parentSlice, nil }