diamond-orm/query_populate.go

401 lines
12 KiB
Go

package orm
import (
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"reflect"
"strings"
)
const PopulateAll = "~~~ALL~~~"
// Populate - allows you to pre-load embedded structs/slices within the current model.
// use dots between field names to specify nested paths. use the PopulateAll constant to populate all
// relationships non-recursively
func (q *Query) Populate(fields ...string) *Query {
if q.populationTree == nil {
q.populationTree = make(map[string]any)
}
for _, field := range fields {
if field == PopulateAll {
for k := range q.model.Relationships {
if _, ok := q.populationTree[k]; !ok {
q.populationTree[k] = make(map[string]any)
}
}
continue
}
cur := q.populationTree
parts := strings.Split(field, ".")
for _, part := range parts {
if _, ok := cur[part]; !ok {
cur[part] = make(map[string]any)
}
cur = cur[part].(map[string]any)
}
}
return q
}
func (q *Query) processPopulate(parent reflect.Value, model *Model, populationTree map[string]any) error {
if parent.Len() == 0 {
return nil
}
pids := make([]any, 0)
var err error
idField := model.IDField
for i := range parent.Len() {
pval := parent.Index(i)
if pval.Kind() == reflect.Pointer {
pval = pval.Elem()
}
pids = append(pids, pval.FieldByName(idField).Interface())
}
toClose := make([]pgx.Rows, 0)
defer func() {
for _, c := range toClose {
c.Close()
}
}()
for p, nested := range populationTree {
var rel *Relationship
for _, r := range model.Relationships {
if r.FieldName == p {
rel = r
break
}
}
if rel == nil {
return fmt.Errorf("field '%s' not found in model '%s'", p, model.Name)
}
childSlice := reflect.Value{}
if (rel.Type == HasMany || rel.Type == HasOne) && !rel.m2mIsh() {
childSlice, err = q.populateHas(rel, parent, pids)
} else if rel.Type == BelongsTo {
childSlice, err = q.populateBelongsTo(rel, parent, pids)
} else if rel.Type == ManyToMany || rel.m2mIsh() {
childSlice, err = q.populateManyToMany(rel, parent, pids)
}
if err != nil {
return fmt.Errorf("failed to populate field at '%s': %w", p, err)
}
ntree, ok := nested.(map[string]any)
if ok && len(ntree) > 0 && childSlice.IsValid() && childSlice.Len() > 0 {
if err = q.processPopulate(childSlice, rel.RelatedModel, ntree); err != nil {
return err
}
}
}
return nil
}
func (q *Query) populateHas(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) {
fkf := rel.primaryID()
var fk string
if fkf != nil && fkf.ColumnType != "" {
fk = fkf.ColumnName
} else if rel.relatedID() != nil {
fk = pascalToSnakeCase(rel.RelatedModel.Name + rel.relatedID().Name)
}
if rel.RelatedModel.embeddedIsh && !rel.Model.embeddedIsh && rel.Type == HasMany {
arel := rel.RelatedModel.Relationships[rel.Model.Name]
fk = pascalToSnakeCase(arel.joinField())
}
ccols := make([]string, 0)
anonymousCols := make(map[string]map[string]*Field)
for _, f := range rel.RelatedModel.Fields {
if !f.isAnonymous() {
ccols = append(ccols, f.ColumnName)
}
}
for _, f := range rel.RelatedModel.Fields {
if f.isAnonymous() {
ccols = append(ccols, f.anonymousColumnNames()...)
anonymousCols[f.Name] = f.embeddedFields
}
}
for _, r := range rel.RelatedModel.Relationships {
if r.Type != ManyToOne {
continue
}
ccols = append(ccols, pascalToSnakeCase(r.joinField()))
}
/*var tableName string
if rel.Type == HasOne {
tableName = rel.Model.TableName
}
if rel.Type == HasMany {
tableName = rel.RelatedModel.TableName
}*/
aq, aa := sb.Select(ccols...).
From(rel.RelatedModel.TableName).
Where(fmt.Sprintf("%s IN (%s)", fk, MakePlaceholders(len(parentIds))), parentIds...).MustSQL()
q.engine.logQuery("populate", aq, aa)
rows, err := q.engine.conn.Query(q.ctx, aq, aa...)
if err != nil {
return reflect.Value{}, err
}
defer rows.Close()
idFieldName := rel.Model.IDField
idField := rel.Model.Fields[idFieldName]
if rel.Type == HasMany {
childMap := reflect.MakeMap(reflect.MapOf(
idField.Type,
reflect.SliceOf(rel.RelatedModel.Type),
))
for rows.Next() {
child := reflect.New(rel.RelatedModel.Type).Elem()
var fkValue any
scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &fkValue)
if err = rows.Scan(scanDest...); err != nil {
return reflect.Value{}, err
}
fkVal := reflect.ValueOf(fkValue)
childrenOfParent := childMap.MapIndex(fkVal)
if !childrenOfParent.IsValid() {
childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0)
}
childrenOfParent = reflect.Append(childrenOfParent, child)
childMap.SetMapIndex(fkVal, childrenOfParent)
}
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Pointer {
ps = ps.Elem()
}
pid := ps.FieldByName(idFieldName)
c := childMap.MapIndex(pid)
if c.IsValid() {
ps.FieldByName(rel.FieldName).Set(c)
}
}
} else {
childMap := reflect.MakeMap(reflect.MapOf(idField.Type, rel.RelatedModel.Type))
for rows.Next() {
child := reflect.New(rel.RelatedModel.Type).Elem()
var fkValue any
scanDest, _ := buildScanDest(child, rel.Model, rel, ccols, anonymousCols, &fkValue)
if err = rows.Scan(scanDest...); err != nil {
return reflect.Value{}, err
}
fkVal := reflect.ValueOf(fkValue)
childMap.SetMapIndex(fkVal, child)
}
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Pointer {
ps = ps.Elem()
}
parentID := ps.FieldByName(idFieldName)
if child := childMap.MapIndex(parentID); child.IsValid() {
ps.FieldByName(rel.FieldName).Set(child)
}
}
}
childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0)
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Ptr {
ps = ps.Elem()
}
childField := ps.FieldByName(rel.FieldName)
if !childField.IsValid() {
continue
}
if rel.Type == HasMany {
for j := range childField.Len() {
childSlice = reflect.Append(childSlice, childField.Index(j).Addr())
}
} else {
if !childField.IsZero() {
childSlice = reflect.Append(childSlice, childField.Addr())
}
}
}
return childSlice, nil
}
func (q *Query) populateManyToMany(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) {
inPlaceholders := MakePlaceholders(len(parentIds))
ccols := make([]string, 0)
anonymousCols := make(map[string]map[string]*Field)
for _, f := range rel.RelatedModel.Fields {
if !f.isAnonymous() {
ccols = append(ccols, "m."+f.ColumnName)
}
}
for _, f := range rel.RelatedModel.Fields {
if f.isAnonymous() {
for ecol := range f.embeddedFields {
ccols = append(ccols, "m."+ecol)
}
anonymousCols[f.Name] = f.embeddedFields
}
}
ccols = append(ccols, fmt.Sprintf("jt.%s_id", rel.Model.TableName))
mq, ma := sb.Select(ccols...).
From(fmt.Sprintf("%s AS m", rel.RelatedModel.TableName)).
Join(
fmt.Sprintf("%s AS jt ON m.%s = jt.%s_id",
rel.ComputeJoinTable(),
rel.relatedID().ColumnName, rel.RelatedModel.TableName)).
Where(fmt.Sprintf("jt.%s_id IN (%s)",
rel.Model.TableName, inPlaceholders), parentIds...).MustSQL()
q.engine.logQuery("populate/join", mq, ma)
rows, err := q.engine.conn.Query(q.ctx, mq, ma...)
if err != nil {
return reflect.Value{}, err
}
defer rows.Close()
idFieldName := rel.Model.IDField
idField := rel.Model.Fields[idFieldName]
childMap := reflect.MakeMap(reflect.MapOf(
idField.Type,
reflect.SliceOf(rel.RelatedModel.Type)))
for rows.Next() {
child := reflect.New(rel.RelatedModel.Type).Elem()
var foreignKeyValue any
scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &foreignKeyValue)
if err = rows.Scan(scanDest...); err != nil {
return reflect.Value{}, err
}
fkVal := reflect.ValueOf(foreignKeyValue)
childrenOfParent := childMap.MapIndex(fkVal)
if !childrenOfParent.IsValid() {
childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0)
}
childrenOfParent = reflect.Append(childrenOfParent, child)
childMap.SetMapIndex(fkVal, childrenOfParent)
}
for i := range parent.Len() {
p := parent.Index(i)
if p.Kind() == reflect.Ptr {
p = p.Elem()
}
parentID := p.FieldByName(rel.primaryID().Name)
if children := childMap.MapIndex(parentID); children.IsValid() {
p.FieldByName(rel.FieldName).Set(children)
}
}
childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0)
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Ptr {
ps = ps.Elem()
}
childField := ps.FieldByName(rel.FieldName)
if childField.IsValid() {
for j := range childField.Len() {
childSlice = reflect.Append(childSlice, childField.Index(j).Addr())
}
}
}
return childSlice, nil
}
func (q *Query) populateBelongsTo(rel *Relationship, childrenSlice reflect.Value, childIDs []any) (reflect.Value, error) {
childIdField := rel.Model.Fields[rel.Model.IDField]
parentIdField := rel.RelatedModel.Fields[rel.RelatedModel.IDField]
fk := pascalToSnakeCase(rel.joinField())
qs, qa := sb.Select(childIdField.ColumnName, fk).
From(rel.Model.TableName).
Where(fmt.Sprintf("%s IN (%s)",
childIdField.ColumnName, MakePlaceholders(len(childIDs)),
), childIDs...).MustSQL()
q.engine.logQuery("populate/belongs-to", qs, qa)
rows, err := q.engine.conn.Query(q.ctx, qs, qa...)
if err != nil {
return reflect.Value{}, err
}
childParentKeyMap := make(map[any]any)
parentKeyValues := make([]any, 0)
parentKeySet := make(map[any]bool)
for rows.Next() {
var cid, pfk any
err = rows.Scan(&cid, &pfk)
if err != nil {
rows.Close()
return reflect.Value{}, err
}
childParentKeyMap[cid] = pfk
if !parentKeySet[pfk] {
parentKeySet[pfk] = true
parentKeyValues = append(parentKeyValues, pfk)
}
}
rows.Close()
if len(parentKeyValues) == 0 {
return reflect.Value{}, nil
}
pcols := make([]string, 0)
anonymousCols := make(map[string]map[string]*Field)
for _, f := range rel.RelatedModel.Fields {
if !f.isAnonymous() {
pcols = append(pcols, f.ColumnName)
}
}
for _, f := range rel.RelatedModel.Fields {
if f.isAnonymous() {
pcols = append(pcols, f.anonymousColumnNames()...)
anonymousCols[f.Name] = f.embeddedFields
}
}
pquery, pqargs := sb.Select(pcols...).
From(rel.RelatedModel.TableName).
Where(fmt.Sprintf("%s IN (%s)",
parentIdField.ColumnName,
MakePlaceholders(len(parentKeyValues))), parentKeyValues...).
MustSQL()
q.engine.logQuery("populate/belongs-to->parent", pquery, pqargs)
parentRows, err := q.engine.conn.Query(q.ctx, pquery, pqargs...)
if err != nil {
return reflect.Value{}, err
}
defer parentRows.Close()
parentMap := reflect.MakeMap(reflect.MapOf(
parentIdField.Type,
rel.RelatedModel.Type,
))
for parentRows.Next() {
parent := reflect.New(rel.RelatedModel.Type).Elem()
scanDst, _ := buildScanDest(parent, rel.RelatedModel, rel, pcols, anonymousCols, nil)
if err = parentRows.Scan(scanDst...); err != nil {
return reflect.Value{}, err
}
parentId := parent.FieldByName(rel.RelatedModel.IDField)
parentMap.SetMapIndex(parentId, parent)
}
for i := range childrenSlice.Len() {
child := childrenSlice.Index(i)
childID := child.FieldByName(rel.Model.IDField)
if parentKey, ok := childParentKeyMap[childID.Interface()]; ok && parentKey != nil {
if parent := parentMap.MapIndex(reflect.ValueOf(parentKey)); parent.IsValid() {
child.FieldByName(rel.FieldName).Set(parent)
}
}
}
ntype := rel.RelatedModel.Type
if rel.Kind == reflect.Pointer {
ntype = reflect.PointerTo(rel.RelatedModel.Type)
}
parentSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(ntype)), 0, 0)
for i := range childrenSlice.Len() {
ps := childrenSlice.Index(i)
if ps.Kind() == reflect.Ptr {
ps = ps.Elem()
}
childField := ps.FieldByName(rel.FieldName)
if childField.IsValid() {
parentSlice = reflect.Append(parentSlice, childField.Addr())
}
}
return parentSlice, nil
}