diamond-orm/query.go

323 lines
8.6 KiB
Go

package orm
import (
"context"
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"reflect"
"strings"
)
// Query - contains the state and other details
// pertaining to the current FUCK operation (Find, Update/Create, Kill [Delete])
type Query struct {
engine *Engine // the Engine instance that created this Query
model *Model // the primary Model this Query pertains to
tx pgx.Tx // the transaction for insert, update and delete operations
ctx context.Context // does nothing, but is needed by some pgx functions
populationTree map[string]any // a tree-like map representing the dot-separated paths of fields to populate
wheres map[string][]any // a mapping of where clauses to a list of their arguments
joins []string // slice of tables to join on before executing Find. useful for hwne you have Where clauses referencing fields/columns in other structs/tables
orders []string // slice of `ORDER BY` clauses
limit int // argument to a LIMIT clause, if non-zero
offset int // unused (for now)
}
func (q *Query) setModel(val any) *Query {
tt := reflect.TypeOf(val)
for tt.Kind() == reflect.Ptr {
tt = tt.Elem()
}
q.model = q.engine.modelMap.Map[tt.Name()]
return q
}
func (q *Query) cleanupTx() {
q.tx.Rollback(q.ctx)
q.tx = nil
}
// Order - add an `ORDER BY` clause to the current Query.
// Only applicable for Find queries
func (q *Query) Order(order string) *Query {
q.orders = append(q.orders, order)
return q
}
// Limit - limit resultset to at most `limit` results.
// does nothing if `limit` <= 0 or the final operation isn't Find
func (q *Query) Limit(limit int) *Query {
if limit > -1 {
q.limit = limit
}
return q
}
// Offset - skip to the nth result, where n = `offset`.
// does nothing if `offset` <= 0 or the final operation isn't Find
func (q *Query) Offset(offset int) *Query {
if offset > -1 {
q.offset = offset
}
return q
}
// Where - add a `WHERE` clause to this query.
// struct field names can be passed to this method,
// and they will be automatically converted
func (q *Query) Where(cond string, args ...any) *Query {
q.processWheres(cond, "eq", args...)
return q
}
// WhereRaw - add a `WHERE` clause to this query, except `cond` is passed as-is.
func (q *Query) WhereRaw(cond string, args ...any) *Query {
q.wheres[cond] = args
return q
}
// In - add a `WHERE ... IN(...)` clause to this query
func (q *Query) In(cond string, args ...any) *Query {
q.processWheres(cond, "in", args...)
return q
}
// Join - join the current model's table with the table
// representing the type of struct field named `field`.
// Must be called before Where if referencing other
// structs/types to avoid errors
func (q *Query) Join(field string) *Query {
var clauses []string
parts := strings.Split(field, ".")
cur := q.model
found := false
aliasMap := q.getNestedAliases(field)
for _, part := range parts {
rel, ok := cur.Relationships[part]
if !ok {
found = false
break
}
if rel.FieldName != part {
found = false
break
}
found = true
aliases := aliasMap[rel]
curAlias := aliases[0]
nalias := aliases[1]
if rel.m2mIsh() || rel.Type == ManyToMany {
joinAlias := aliases[2]
jc1 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s_id",
rel.ComputeJoinTable(), joinAlias,
curAlias, cur.idField().ColumnName,
joinAlias, rel.Model.TableName,
)
jc2 := fmt.Sprintf("%s AS %s ON %s.%s_id = %s.%s",
rel.RelatedModel.TableName, nalias,
joinAlias, rel.RelatedModel.TableName,
nalias, rel.relatedID().ColumnName,
)
clauses = append(clauses, jc1, jc2)
}
if rel.Type == HasMany || rel.Type == HasOne {
fkr := rel.RelatedModel.Relationships[cur.Name]
if fkr != nil {
jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
rel.RelatedModel.TableName, nalias,
curAlias, cur.idField().ColumnName,
nalias, pascalToSnakeCase(fkr.joinField()),
)
clauses = append(clauses, jc)
}
}
if rel.Type == BelongsTo {
jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
rel.RelatedModel.TableName, nalias,
curAlias, pascalToSnakeCase(rel.joinField()),
nalias, rel.RelatedModel.idField().ColumnName,
)
clauses = append(clauses, jc)
}
curAlias = nalias
cur = rel.RelatedModel
}
if found {
q.joins = append(q.joins, clauses...)
}
return q
}
func (q *Query) getNestedAliases(field string) (amap map[*Relationship][]string) {
amap = make(map[*Relationship][]string)
parts := strings.Split(field, ".")
cur := q.model
curAlias := q.model.TableName
first := curAlias
found := false
for _, part := range parts {
rel, ok := cur.Relationships[part]
if !ok {
found = false
break
}
if rel.FieldName != part {
found = false
break
}
found = true
amap[rel] = make([]string, 0)
nalias := pascalToSnakeCase(part)
if rel.m2mIsh() || rel.Type == ManyToMany {
joinAlias := rel.ComputeJoinTable() + "_joined"
amap[rel] = append(amap[rel], curAlias, nalias, joinAlias)
} else if rel.Type == HasMany || rel.Type == HasOne || rel.Type == BelongsTo {
amap[rel] = append(amap[rel], curAlias, nalias)
}
curAlias = nalias
cur = rel.RelatedModel
}
if !found {
return
}
amap[nil] = []string{first}
return
}
func (q *Query) processWheres(cond string, exprKind string, args ...any) {
parts := strings.SplitN(cond, " ", 2)
var translatedColumn string
fieldPath := parts[0]
ncond := ""
if len(parts) > 1 {
ncond = " " + parts[1]
}
pathParts := strings.Split(fieldPath, ".")
if len(pathParts) > 1 {
relPath := pathParts[:len(pathParts)-1]
fieldName := pathParts[len(pathParts)-1]
relPathStr := strings.Join(relPath, ".")
aliasMap := q.getNestedAliases(relPathStr)
for r, a := range aliasMap {
if r == nil {
continue
}
f, ok := r.RelatedModel.Fields[fieldName]
if ok {
translatedColumn = fmt.Sprintf("%s.%s", a[1], f.ColumnName)
}
}
} else if pf := q.model.Fields[pathParts[0]]; pf != nil {
translatedColumn = fmt.Sprintf("%s.%s", q.model.TableName, pf.ColumnName)
}
var tq string
switch strings.ToLower(exprKind) {
case "in":
tq = fmt.Sprintf("%s IN (%s)", translatedColumn, MakePlaceholders(len(args)))
default:
tq = fmt.Sprintf("%s%s", translatedColumn, ncond)
}
q.wheres[tq] = args
}
// buildSQL - aggregates the information in this Query into a pgq.SelectBuilder.
// it returns a slice of column names as well to avoid issues with scanning
func (q *Query) buildSQL() (cols []string, anonymousCols map[string][]string, finalSb sb.SelectBuilder, err error) {
var inParents []any
anonymousCols = make(map[string][]string)
for _, field := range q.model.Fields {
if field.isAnonymous() {
for _, ef := range field.embeddedFields {
anonymousCols[field.ColumnName] = append(anonymousCols[field.ColumnName], ef.ColumnName)
}
continue
}
cols = append(cols, field.ColumnName)
}
finalSb = sb.Select(cols...)
for _, cc := range anonymousCols {
finalSb = finalSb.Columns(cc...)
}
finalSb = finalSb.From(q.model.TableName)
if len(q.joins) > 0 {
idq := sb.Select(fmt.Sprintf("%s.%s", q.model.TableName, q.model.idField().ColumnName)).
Distinct().
From(q.model.TableName)
for w, arg := range q.wheres {
idq = idq.Where(w, arg...)
}
for _, j := range q.joins {
idq = idq.Join(j)
}
qq, qa := idq.MustSQL()
var rows pgx.Rows
rows, err = q.engine.conn.Query(q.ctx, qq, qa...)
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
var id any
if err = rows.Scan(&id); err != nil {
return
}
inParents = append(inParents, id)
}
if len(inParents) == 0 {
return
}
}
if len(inParents) > 0 {
finalSb = finalSb.Where(
fmt.Sprintf("%s IN (%s)",
q.model.idField().ColumnName,
MakePlaceholders(len(inParents))), inParents...)
} else if len(q.wheres) > 0 {
for k, vv := range q.wheres {
finalSb = finalSb.Where(k, vv...)
}
}
ool:
for _, o := range q.orders {
ac, ok := q.model.Fields[o]
if !ok {
var rel = q.model.Relationships[o]
if rel != nil {
if strings.Contains(o, ".") {
split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".")
cm := rel.Model
for i, s := range split {
if rel != nil {
cm = rel.RelatedModel
} else if i == len(split)-1 {
break
} else {
continue ool
}
rel = cm.Relationships[s]
}
lf := split[len(split)-1]
ac, ok = cm.Fields[lf]
if !ok {
continue
}
}
}
}
finalSb = finalSb.OrderBy(ac.ColumnName)
}
if q.limit > 0 {
finalSb = finalSb.Limit(uint64(q.limit))
}
if q.offset > 0 {
finalSb = finalSb.Offset(uint64(q.offset))
}
return
}