212 lines
4.4 KiB
Go
212 lines
4.4 KiB
Go
package orm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
sb "github.com/henvic/pgq"
|
|
"github.com/jackc/pgx/v5"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
type Query struct {
|
|
model *Model
|
|
relatedModels map[string]*Model
|
|
wheres map[string][]any
|
|
orders []string
|
|
limit int
|
|
offset int
|
|
joins map[*Relationship][3]string
|
|
engine *Engine
|
|
ctx context.Context
|
|
tx pgx.Tx
|
|
}
|
|
|
|
func (q *Query) totalWheres() int {
|
|
total := 0
|
|
for _, w := range q.wheres {
|
|
total += len(w)
|
|
}
|
|
return total
|
|
}
|
|
|
|
func (q *Query) Model(val any) *Query {
|
|
tt := reflect.TypeOf(val)
|
|
for tt.Kind() == reflect.Ptr {
|
|
tt = tt.Elem()
|
|
}
|
|
q.model = q.engine.modelMap.Map[tt.Name()]
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Where(cond any, args ...any) *Query {
|
|
switch v := cond.(type) {
|
|
case string:
|
|
q.wheres[strings.ReplaceAll(v, "$?", fmt.Sprintf("$%d", q.totalWheres()+1))] = args
|
|
default:
|
|
rv := reflect.ValueOf(cond)
|
|
for rv.Kind() == reflect.Ptr {
|
|
rv = rv.Elem()
|
|
}
|
|
rt := rv.Type()
|
|
for i := range rv.NumField() {
|
|
field := rt.Field(i)
|
|
fieldValue := rv.Field(i)
|
|
if isZero(fieldValue) {
|
|
continue
|
|
}
|
|
mm, ok := q.engine.modelMap.Map[rv.Type().Name()]
|
|
if !ok {
|
|
continue
|
|
}
|
|
ff, ok := mm.Fields[field.Name]
|
|
if !ok || ff.ColumnType == "" {
|
|
continue
|
|
}
|
|
whereClause := fmt.Sprintf("%s.%s = ?", mm.TableName, ff.ColumnName /*, q.totalWheres()+1*/)
|
|
args = append(args, fieldValue.Interface())
|
|
q.wheres[whereClause] = args
|
|
}
|
|
}
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Order(order string) *Query {
|
|
q.orders = append(q.orders, order)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Limit(limit int) *Query {
|
|
q.limit = limit
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Offset(offset int) *Query {
|
|
q.offset = offset
|
|
return q
|
|
}
|
|
|
|
func (q *Query) buildSelect() sb.SelectBuilder {
|
|
var fields []string
|
|
|
|
for _, f := range q.model.Fields {
|
|
if f.ColumnType == "" {
|
|
continue
|
|
}
|
|
tn, a := f.alias()
|
|
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
|
|
}
|
|
seenModels := make(map[string]bool)
|
|
processField := func(f *Field, m *Model, pfk *Relationship) {
|
|
if f.ColumnType == "" {
|
|
if rel, ok := m.Relationships[f.Name]; ok {
|
|
data, ok2 := q.joins[rel]
|
|
if ok2 && f.ColumnType != "" {
|
|
tn, a := f.aliasWith(data[0])
|
|
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
|
|
}
|
|
}
|
|
return
|
|
}
|
|
{
|
|
fk := f.fk
|
|
if fk == nil {
|
|
fk = m.Relationships[f.Name]
|
|
}
|
|
if fk == nil {
|
|
fk = pfk
|
|
}
|
|
if fk != nil && fk.FieldName == f.Name {
|
|
data, ok2 := q.joins[fk]
|
|
if ok2 {
|
|
var (
|
|
tn, a string
|
|
)
|
|
if fk.Type == HasOne {
|
|
tn, a = fk.RelatedModel.Fields[fk.RelatedModel.IDField].aliasWith(data[0])
|
|
} else {
|
|
tn, a = f.aliasWith(data[0])
|
|
}
|
|
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
|
|
}
|
|
return
|
|
}
|
|
}
|
|
if f.Name == pfk.FieldName {
|
|
f.aliasWith(q.joins[pfk][0])
|
|
} else {
|
|
tn, a := f.alias()
|
|
fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
|
|
}
|
|
|
|
}
|
|
for r := range q.joins {
|
|
if !seenModels[r.aliasThingy()] {
|
|
seenModels[r.aliasThingy()] = true
|
|
for _, f := range r.Model.Fields {
|
|
processField(f, r.Model, r)
|
|
}
|
|
}
|
|
if !seenModels[r.relatedAlias()] {
|
|
seenModels[r.relatedAlias()] = true
|
|
for _, f := range r.RelatedModel.Fields {
|
|
processField(f, r.RelatedModel, r)
|
|
}
|
|
}
|
|
}
|
|
return sb.Select(fields...)
|
|
}
|
|
|
|
func (q *Query) buildSQL() (string, []any) {
|
|
sqlb := q.buildSelect().From(q.model.TableName)
|
|
whereargs := make([]any, 0)
|
|
if len(q.wheres) > 0 {
|
|
cnt := 0
|
|
for w, where := range q.wheres {
|
|
sqlb = sqlb.Where(w, where...)
|
|
|
|
whereargs = append(whereargs, where...)
|
|
cnt++
|
|
}
|
|
}
|
|
for _, j := range q.joins {
|
|
sqlb = sqlb.LeftJoin(fmt.Sprintf("%s as %s ON %s", j[1], j[0], j[2]))
|
|
}
|
|
ool:
|
|
for _, o := range q.orders {
|
|
ac, ok := q.model.Fields[o]
|
|
if !ok {
|
|
var rel = ac.fk
|
|
if ac.ColumnType == "" || rel == nil {
|
|
rel = q.model.Relationships[o]
|
|
}
|
|
if rel != nil {
|
|
if strings.Contains(o, ".") {
|
|
split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".")
|
|
cm := rel.Model
|
|
for i, s := range split {
|
|
if rel != nil {
|
|
cm = rel.RelatedModel
|
|
} else if i == len(split)-1 {
|
|
break
|
|
} else {
|
|
continue ool
|
|
}
|
|
rel = cm.Relationships[s]
|
|
}
|
|
lf := split[len(split)-1]
|
|
ac, ok = cm.Fields[lf]
|
|
if !ok {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
sqlb = sqlb.OrderBy(ac.ColumnName)
|
|
}
|
|
if q.limit > 0 {
|
|
sqlb = sqlb.Limit(uint64(q.limit))
|
|
}
|
|
return sqlb.MustSQL()
|
|
}
|