package orm import ( "context" "fmt" sb "github.com/henvic/pgq" "github.com/jackc/pgx/v5" "reflect" "strings" ) type Query struct { engine *Engine model *Model tx pgx.Tx ctx context.Context populationTree map[string]any wheres map[string][]any joins []string orders []string limit int offset int } func (q *Query) setModel(val any) *Query { tt := reflect.TypeOf(val) for tt.Kind() == reflect.Ptr { tt = tt.Elem() } q.model = q.engine.modelMap.Map[tt.Name()] return q } func (q *Query) cleanupTx() { q.tx.Rollback(q.ctx) q.tx = nil } func (q *Query) Order(order string) *Query { q.orders = append(q.orders, order) return q } func (q *Query) Limit(limit int) *Query { if limit > -1 { q.limit = limit } return q } func (q *Query) Offset(offset int) *Query { if offset > -1 { q.offset = offset } return q } func (q *Query) Where(cond string, args ...any) *Query { q.processWheres(cond, "eq", args...) return q } func (q *Query) WhereRaw(cond string, args ...any) *Query { q.wheres[cond] = args return q } func (q *Query) In(cond string, args ...any) *Query { q.processWheres(cond, "in", args...) return q } func (q *Query) Join(field string) *Query { var clauses []string parts := strings.Split(field, ".") cur := q.model found := false aliasMap := q.getNestedAliases(field) for _, part := range parts { rel, ok := cur.Relationships[part] if !ok { found = false break } if rel.FieldName != part { found = false break } found = true aliases := aliasMap[rel] curAlias := aliases[0] nalias := aliases[1] if rel.m2mIsh() || rel.Type == ManyToMany { joinAlias := aliases[2] jc1 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s_id", rel.ComputeJoinTable(), joinAlias, curAlias, cur.idField().ColumnName, joinAlias, rel.Model.TableName, ) jc2 := fmt.Sprintf("%s AS %s ON %s.%s_id = %s.%s", rel.RelatedModel.TableName, nalias, joinAlias, rel.RelatedModel.TableName, nalias, rel.relatedID().ColumnName, ) clauses = append(clauses, jc1, jc2) } if rel.Type == HasMany || rel.Type == HasOne { fkr := rel.RelatedModel.Relationships[cur.Name] if fkr != nil { jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", rel.RelatedModel.TableName, nalias, curAlias, cur.idField().ColumnName, nalias, pascalToSnakeCase(fkr.joinField()), ) clauses = append(clauses, jc) } } if rel.Type == BelongsTo { jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", rel.RelatedModel.TableName, nalias, curAlias, pascalToSnakeCase(rel.joinField()), nalias, rel.RelatedModel.idField().ColumnName, ) clauses = append(clauses, jc) } curAlias = nalias cur = rel.RelatedModel } if found { q.joins = append(q.joins, clauses...) } return q } func (q *Query) getNestedAliases(field string) (amap map[*Relationship][]string) { amap = make(map[*Relationship][]string) parts := strings.Split(field, ".") cur := q.model curAlias := q.model.TableName first := curAlias found := false for _, part := range parts { rel, ok := cur.Relationships[part] if !ok { found = false break } if rel.FieldName != part { found = false break } found = true amap[rel] = make([]string, 0) nalias := pascalToSnakeCase(part) if rel.m2mIsh() || rel.Type == ManyToMany { joinAlias := rel.ComputeJoinTable() + "_joined" amap[rel] = append(amap[rel], curAlias, nalias, joinAlias) } else if rel.Type == HasMany || rel.Type == HasOne || rel.Type == BelongsTo { amap[rel] = append(amap[rel], curAlias, nalias) } curAlias = nalias cur = rel.RelatedModel } if !found { return } amap[nil] = []string{first} return } func (q *Query) processWheres(cond string, exprKind string, args ...any) { parts := strings.SplitN(cond, " ", 2) var translatedColumn string fieldPath := parts[0] ncond := "" if len(parts) > 1 { ncond = " " + parts[1] } pathParts := strings.Split(fieldPath, ".") if len(pathParts) > 1 { relPath := pathParts[:len(pathParts)-1] fieldName := pathParts[len(pathParts)-1] relPathStr := strings.Join(relPath, ".") aliasMap := q.getNestedAliases(relPathStr) for r, a := range aliasMap { if r == nil { continue } f, ok := r.RelatedModel.Fields[fieldName] if ok { translatedColumn = fmt.Sprintf("%s.%s", a[1], f.ColumnName) } } } else if pf := q.model.Fields[pathParts[0]]; pf != nil { translatedColumn = fmt.Sprintf("%s.%s", q.model.TableName, pf.ColumnName) } var tq string switch strings.ToLower(exprKind) { case "in": tq = fmt.Sprintf("%s IN (%s)", translatedColumn, MakePlaceholders(len(args))) default: tq = fmt.Sprintf("%s%s", translatedColumn, ncond) } q.wheres[tq] = args } func (q *Query) buildSQL() (cols []string, anonymousCols map[string][]string, finalSb sb.SelectBuilder, err error) { var inParents []any anonymousCols = make(map[string][]string) for _, field := range q.model.Fields { if field.isAnonymous() { for _, ef := range field.embeddedFields { anonymousCols[field.ColumnName] = append(anonymousCols[field.ColumnName], ef.ColumnName) } continue } cols = append(cols, field.ColumnName) } finalSb = sb.Select(cols...) for _, cc := range anonymousCols { finalSb = finalSb.Columns(cc...) } finalSb = finalSb.From(q.model.TableName) if len(q.joins) > 0 { idq := sb.Select(fmt.Sprintf("%s.%s", q.model.TableName, q.model.idField().ColumnName)). Distinct(). From(q.model.TableName) for w, arg := range q.wheres { idq = idq.Where(w, arg...) } for _, j := range q.joins { idq = idq.Join(j) } qq, qa := idq.MustSQL() var rows pgx.Rows rows, err = q.engine.conn.Query(q.ctx, qq, qa...) if err != nil { return } defer rows.Close() for rows.Next() { var id any if err = rows.Scan(&id); err != nil { return } inParents = append(inParents, id) } if len(inParents) == 0 { return } } if len(inParents) > 0 { finalSb = finalSb.Where( fmt.Sprintf("%s IN (%s)", q.model.idField().ColumnName, MakePlaceholders(len(inParents))), inParents...) } else if len(q.wheres) > 0 { for k, vv := range q.wheres { finalSb = finalSb.Where(k, vv...) } } ool: for _, o := range q.orders { ac, ok := q.model.Fields[o] if !ok { var rel = q.model.Relationships[o] if rel != nil { if strings.Contains(o, ".") { split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".") cm := rel.Model for i, s := range split { if rel != nil { cm = rel.RelatedModel } else if i == len(split)-1 { break } else { continue ool } rel = cm.Relationships[s] } lf := split[len(split)-1] ac, ok = cm.Fields[lf] if !ok { continue } } } } finalSb = finalSb.OrderBy(ac.ColumnName) } if q.limit > 0 { finalSb = finalSb.Limit(uint64(q.limit)) } if q.offset > 0 { finalSb = finalSb.Offset(uint64(q.offset)) } return }