diamond-orm/query.go

212 lines
4.4 KiB
Go

package orm
import (
"context"
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"reflect"
"strings"
)
type Query struct {
model *Model
relatedModels map[string]*Model
wheres map[string][]any
orders []string
limit int
offset int
joins map[*Relationship][3]string
engine *Engine
ctx context.Context
tx pgx.Tx
}
func (q *Query) totalWheres() int {
total := 0
for _, w := range q.wheres {
total += len(w)
}
return total
}
func (q *Query) Model(val any) *Query {
tt := reflect.TypeOf(val)
for tt.Kind() == reflect.Ptr {
tt = tt.Elem()
}
q.model = q.engine.modelMap.Map[tt.Name()]
return q
}
func (q *Query) Where(cond any, args ...any) *Query {
switch v := cond.(type) {
case string:
q.wheres[strings.ReplaceAll(v, "$?", fmt.Sprintf("$%d", q.totalWheres()+1))] = args
default:
rv := reflect.ValueOf(cond)
for rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
rt := rv.Type()
for i := range rv.NumField() {
field := rt.Field(i)
fieldValue := rv.Field(i)
if isZero(fieldValue) {
continue
}
mm, ok := q.engine.modelMap.Map[rv.Type().Name()]
if !ok {
continue
}
ff, ok := mm.Fields[field.Name]
if !ok || ff.ColumnType == "" {
continue
}
whereClause := fmt.Sprintf("%s.%s = ?", mm.TableName, ff.ColumnName /*, q.totalWheres()+1*/)
args = append(args, fieldValue.Interface())
q.wheres[whereClause] = args
}
}
return q
}
func (q *Query) Order(order string) *Query {
q.orders = append(q.orders, order)
return q
}
func (q *Query) Limit(limit int) *Query {
q.limit = limit
return q
}
func (q *Query) Offset(offset int) *Query {
q.offset = offset
return q
}
func (q *Query) buildSelect() sb.SelectBuilder {
var fields []string
for _, f := range q.model.Fields {
if f.ColumnType == "" {
continue
}
tn, a := f.alias()
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
}
seenModels := make(map[string]bool)
processField := func(f *Field, m *Model, pfk *Relationship) {
if f.ColumnType == "" {
if rel, ok := m.Relationships[f.Name]; ok {
data, ok2 := q.joins[rel]
if ok2 && f.ColumnType != "" {
tn, a := f.aliasWith(data[0])
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
}
}
return
}
{
fk := f.fk
if fk == nil {
fk = m.Relationships[f.Name]
}
if fk == nil {
fk = pfk
}
if fk != nil && fk.FieldName == f.Name {
data, ok2 := q.joins[fk]
if ok2 {
var (
tn, a string
)
if fk.Type == HasOne {
tn, a = fk.RelatedModel.Fields[fk.RelatedModel.IDField].aliasWith(data[0])
} else {
tn, a = f.aliasWith(data[0])
}
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
}
return
}
}
if f.Name == pfk.FieldName {
f.aliasWith(q.joins[pfk][0])
} else {
tn, a := f.alias()
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
}
}
for r := range q.joins {
if !seenModels[r.aliasThingy()] {
seenModels[r.aliasThingy()] = true
for _, f := range r.Model.Fields {
processField(f, r.Model, r)
}
}
if !seenModels[r.relatedAlias()] {
seenModels[r.relatedAlias()] = true
for _, f := range r.RelatedModel.Fields {
processField(f, r.RelatedModel, r)
}
}
}
return sb.Select(fields...)
}
func (q *Query) buildSQL() (string, []any) {
sqlb := q.buildSelect().From(q.model.TableName)
whereargs := make([]any, 0)
if len(q.wheres) > 0 {
cnt := 0
for w, where := range q.wheres {
sqlb = sqlb.Where(w, where...)
whereargs = append(whereargs, where...)
cnt++
}
}
for _, j := range q.joins {
sqlb = sqlb.LeftJoin(fmt.Sprintf("%s as %s ON %s", j[1], j[0], j[2]))
}
ool:
for _, o := range q.orders {
ac, ok := q.model.Fields[o]
if !ok {
var rel = ac.fk
if ac.ColumnType == "" || rel == nil {
rel = q.model.Relationships[o]
}
if rel != nil {
if strings.Contains(o, ".") {
split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".")
cm := rel.Model
for i, s := range split {
if rel != nil {
cm = rel.RelatedModel
} else if i == len(split)-1 {
break
} else {
continue ool
}
rel = cm.Relationships[s]
}
lf := split[len(split)-1]
ac, ok = cm.Fields[lf]
if !ok {
continue
}
}
}
}
sqlb = sqlb.OrderBy(ac.ColumnName)
}
if q.limit > 0 {
sqlb = sqlb.Limit(uint64(q.limit))
}
return sqlb.MustSQL()
}