323 lines
8.6 KiB
Go
323 lines
8.6 KiB
Go
package orm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
sb "github.com/henvic/pgq"
|
|
"github.com/jackc/pgx/v5"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
// Query - contains the state and other details
|
|
// pertaining to the current FUCK operation (Find, Update/Create, Kill [Delete])
|
|
type Query struct {
|
|
engine *Engine // the Engine instance that created this Query
|
|
model *Model // the primary Model this Query pertains to
|
|
tx pgx.Tx // the transaction for insert, update and delete operations
|
|
ctx context.Context // does nothing, but is needed by some pgx functions
|
|
populationTree map[string]any // a tree-like map representing the dot-separated paths of fields to populate
|
|
wheres map[string][]any // a mapping of where clauses to a list of their arguments
|
|
joins []string // slice of tables to join on before executing Find. useful for hwne you have Where clauses referencing fields/columns in other structs/tables
|
|
orders []string // slice of `ORDER BY` clauses
|
|
limit int // argument to a LIMIT clause, if non-zero
|
|
offset int // unused (for now)
|
|
}
|
|
|
|
func (q *Query) setModel(val any) *Query {
|
|
tt := reflect.TypeOf(val)
|
|
for tt.Kind() == reflect.Ptr {
|
|
tt = tt.Elem()
|
|
}
|
|
q.model = q.engine.modelMap.Map[tt.Name()]
|
|
return q
|
|
}
|
|
|
|
func (q *Query) cleanupTx() {
|
|
q.tx.Rollback(q.ctx)
|
|
q.tx = nil
|
|
}
|
|
|
|
// Order - add an `ORDER BY` clause to the current Query.
|
|
// Only applicable for Find queries
|
|
func (q *Query) Order(order string) *Query {
|
|
q.orders = append(q.orders, order)
|
|
return q
|
|
}
|
|
|
|
// Limit - limit resultset to at most `limit` results.
|
|
// does nothing if `limit` <= 0 or the final operation isn't Find
|
|
func (q *Query) Limit(limit int) *Query {
|
|
if limit > -1 {
|
|
q.limit = limit
|
|
}
|
|
return q
|
|
}
|
|
|
|
// Offset - skip to the nth result, where n = `offset`.
|
|
// does nothing if `offset` <= 0 or the final operation isn't Find
|
|
func (q *Query) Offset(offset int) *Query {
|
|
if offset > -1 {
|
|
q.offset = offset
|
|
}
|
|
return q
|
|
}
|
|
|
|
// Where - add a `WHERE` clause to this query.
|
|
// struct field names can be passed to this method,
|
|
// and they will be automatically converted
|
|
func (q *Query) Where(cond string, args ...any) *Query {
|
|
q.processWheres(cond, "eq", args...)
|
|
return q
|
|
}
|
|
|
|
// WhereRaw - add a `WHERE` clause to this query, except `cond` is passed as-is.
|
|
func (q *Query) WhereRaw(cond string, args ...any) *Query {
|
|
q.wheres[cond] = args
|
|
return q
|
|
}
|
|
|
|
// In - add a `WHERE ... IN(...)` clause to this query
|
|
func (q *Query) In(cond string, args ...any) *Query {
|
|
q.processWheres(cond, "in", args...)
|
|
return q
|
|
}
|
|
|
|
// Join - join the current model's table with the table
|
|
// representing the type of struct field named `field`.
|
|
// Must be called before Where if referencing other
|
|
// structs/types to avoid errors
|
|
func (q *Query) Join(field string) *Query {
|
|
var clauses []string
|
|
parts := strings.Split(field, ".")
|
|
cur := q.model
|
|
found := false
|
|
aliasMap := q.getNestedAliases(field)
|
|
|
|
for _, part := range parts {
|
|
rel, ok := cur.Relationships[part]
|
|
if !ok {
|
|
found = false
|
|
break
|
|
}
|
|
if rel.FieldName != part {
|
|
found = false
|
|
break
|
|
}
|
|
found = true
|
|
aliases := aliasMap[rel]
|
|
curAlias := aliases[0]
|
|
nalias := aliases[1]
|
|
if rel.m2mIsh() || rel.Type == ManyToMany {
|
|
joinAlias := aliases[2]
|
|
jc1 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s_id",
|
|
rel.ComputeJoinTable(), joinAlias,
|
|
curAlias, cur.idField().ColumnName,
|
|
joinAlias, rel.Model.TableName,
|
|
)
|
|
jc2 := fmt.Sprintf("%s AS %s ON %s.%s_id = %s.%s",
|
|
rel.RelatedModel.TableName, nalias,
|
|
joinAlias, rel.RelatedModel.TableName,
|
|
nalias, rel.relatedID().ColumnName,
|
|
)
|
|
clauses = append(clauses, jc1, jc2)
|
|
}
|
|
if rel.Type == HasMany || rel.Type == HasOne {
|
|
fkr := rel.RelatedModel.Relationships[cur.Name]
|
|
if fkr != nil {
|
|
jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
|
|
rel.RelatedModel.TableName, nalias,
|
|
curAlias, cur.idField().ColumnName,
|
|
nalias, pascalToSnakeCase(fkr.joinField()),
|
|
)
|
|
clauses = append(clauses, jc)
|
|
}
|
|
}
|
|
if rel.Type == BelongsTo {
|
|
jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
|
|
rel.RelatedModel.TableName, nalias,
|
|
curAlias, pascalToSnakeCase(rel.joinField()),
|
|
nalias, rel.RelatedModel.idField().ColumnName,
|
|
)
|
|
clauses = append(clauses, jc)
|
|
}
|
|
curAlias = nalias
|
|
cur = rel.RelatedModel
|
|
}
|
|
if found {
|
|
q.joins = append(q.joins, clauses...)
|
|
}
|
|
return q
|
|
}
|
|
|
|
func (q *Query) getNestedAliases(field string) (amap map[*Relationship][]string) {
|
|
amap = make(map[*Relationship][]string)
|
|
parts := strings.Split(field, ".")
|
|
cur := q.model
|
|
curAlias := q.model.TableName
|
|
first := curAlias
|
|
found := false
|
|
|
|
for _, part := range parts {
|
|
rel, ok := cur.Relationships[part]
|
|
if !ok {
|
|
found = false
|
|
break
|
|
}
|
|
if rel.FieldName != part {
|
|
found = false
|
|
break
|
|
}
|
|
found = true
|
|
amap[rel] = make([]string, 0)
|
|
|
|
nalias := pascalToSnakeCase(part)
|
|
if rel.m2mIsh() || rel.Type == ManyToMany {
|
|
joinAlias := rel.ComputeJoinTable() + "_joined"
|
|
amap[rel] = append(amap[rel], curAlias, nalias, joinAlias)
|
|
} else if rel.Type == HasMany || rel.Type == HasOne || rel.Type == BelongsTo {
|
|
amap[rel] = append(amap[rel], curAlias, nalias)
|
|
}
|
|
|
|
curAlias = nalias
|
|
cur = rel.RelatedModel
|
|
}
|
|
if !found {
|
|
return
|
|
}
|
|
amap[nil] = []string{first}
|
|
return
|
|
}
|
|
|
|
func (q *Query) processWheres(cond string, exprKind string, args ...any) {
|
|
parts := strings.SplitN(cond, " ", 2)
|
|
var translatedColumn string
|
|
fieldPath := parts[0]
|
|
ncond := ""
|
|
if len(parts) > 1 {
|
|
ncond = " " + parts[1]
|
|
}
|
|
pathParts := strings.Split(fieldPath, ".")
|
|
if len(pathParts) > 1 {
|
|
relPath := pathParts[:len(pathParts)-1]
|
|
fieldName := pathParts[len(pathParts)-1]
|
|
relPathStr := strings.Join(relPath, ".")
|
|
aliasMap := q.getNestedAliases(relPathStr)
|
|
for r, a := range aliasMap {
|
|
if r == nil {
|
|
continue
|
|
}
|
|
f, ok := r.RelatedModel.Fields[fieldName]
|
|
if ok {
|
|
translatedColumn = fmt.Sprintf("%s.%s", a[1], f.ColumnName)
|
|
}
|
|
}
|
|
} else if pf := q.model.Fields[pathParts[0]]; pf != nil {
|
|
translatedColumn = fmt.Sprintf("%s.%s", q.model.TableName, pf.ColumnName)
|
|
}
|
|
var tq string
|
|
switch strings.ToLower(exprKind) {
|
|
case "in":
|
|
tq = fmt.Sprintf("%s IN (%s)", translatedColumn, MakePlaceholders(len(args)))
|
|
default:
|
|
tq = fmt.Sprintf("%s%s", translatedColumn, ncond)
|
|
}
|
|
q.wheres[tq] = args
|
|
}
|
|
|
|
// buildSQL - aggregates the information in this Query into a pgq.SelectBuilder.
|
|
// it returns a slice of column names as well to avoid issues with scanning
|
|
func (q *Query) buildSQL() (cols []string, anonymousCols map[string][]string, finalSb sb.SelectBuilder, err error) {
|
|
var inParents []any
|
|
anonymousCols = make(map[string][]string)
|
|
for _, field := range q.model.Fields {
|
|
if field.isAnonymous() {
|
|
for _, ef := range field.embeddedFields {
|
|
anonymousCols[field.ColumnName] = append(anonymousCols[field.ColumnName], ef.ColumnName)
|
|
}
|
|
continue
|
|
}
|
|
cols = append(cols, field.ColumnName)
|
|
}
|
|
finalSb = sb.Select(cols...)
|
|
for _, cc := range anonymousCols {
|
|
finalSb = finalSb.Columns(cc...)
|
|
}
|
|
finalSb = finalSb.From(q.model.TableName)
|
|
if len(q.joins) > 0 {
|
|
idq := sb.Select(fmt.Sprintf("%s.%s", q.model.TableName, q.model.idField().ColumnName)).
|
|
Distinct().
|
|
From(q.model.TableName)
|
|
for w, arg := range q.wheres {
|
|
idq = idq.Where(w, arg...)
|
|
}
|
|
for _, j := range q.joins {
|
|
idq = idq.Join(j)
|
|
}
|
|
qq, qa := idq.MustSQL()
|
|
var rows pgx.Rows
|
|
rows, err = q.engine.conn.Query(q.ctx, qq, qa...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var id any
|
|
if err = rows.Scan(&id); err != nil {
|
|
return
|
|
}
|
|
inParents = append(inParents, id)
|
|
}
|
|
if len(inParents) == 0 {
|
|
return
|
|
}
|
|
}
|
|
if len(inParents) > 0 {
|
|
finalSb = finalSb.Where(
|
|
fmt.Sprintf("%s IN (%s)",
|
|
q.model.idField().ColumnName,
|
|
MakePlaceholders(len(inParents))), inParents...)
|
|
} else if len(q.wheres) > 0 {
|
|
for k, vv := range q.wheres {
|
|
finalSb = finalSb.Where(k, vv...)
|
|
}
|
|
}
|
|
ool:
|
|
for _, o := range q.orders {
|
|
ac, ok := q.model.Fields[o]
|
|
if !ok {
|
|
var rel = q.model.Relationships[o]
|
|
|
|
if rel != nil {
|
|
if strings.Contains(o, ".") {
|
|
split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".")
|
|
cm := rel.Model
|
|
for i, s := range split {
|
|
if rel != nil {
|
|
cm = rel.RelatedModel
|
|
} else if i == len(split)-1 {
|
|
break
|
|
} else {
|
|
continue ool
|
|
}
|
|
rel = cm.Relationships[s]
|
|
}
|
|
lf := split[len(split)-1]
|
|
ac, ok = cm.Fields[lf]
|
|
if !ok {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
finalSb = finalSb.OrderBy(ac.ColumnName)
|
|
}
|
|
if q.limit > 0 {
|
|
finalSb = finalSb.Limit(uint64(q.limit))
|
|
}
|
|
if q.offset > 0 {
|
|
finalSb = finalSb.Offset(uint64(q.offset))
|
|
}
|
|
return
|
|
}
|