package orm import ( "context" "fmt" sb "github.com/henvic/pgq" "github.com/jackc/pgx/v5" "reflect" "strings" ) type Query struct { model *Model relatedModels map[string]*Model wheres map[string][]any orders []string limit int offset int joins map[*Relationship][3]string engine *Engine ctx context.Context tx pgx.Tx } func (q *Query) totalWheres() int { total := 0 for _, w := range q.wheres { total += len(w) } return total } func (q *Query) Model(val any) *Query { tt := reflect.TypeOf(val) for tt.Kind() == reflect.Ptr { tt = tt.Elem() } q.model = q.engine.modelMap.Map[tt.Name()] return q } func (q *Query) Where(cond any, args ...any) *Query { switch v := cond.(type) { case string: q.wheres[strings.ReplaceAll(v, "$?", fmt.Sprintf("$%d", q.totalWheres()+1))] = args default: rv := reflect.ValueOf(cond) for rv.Kind() == reflect.Ptr { rv = rv.Elem() } rt := rv.Type() for i := range rv.NumField() { field := rt.Field(i) fieldValue := rv.Field(i) if isZero(fieldValue) { continue } mm, ok := q.engine.modelMap.Map[rv.Type().Name()] if !ok { continue } ff, ok := mm.Fields[field.Name] if !ok || ff.ColumnType == "" { continue } whereClause := fmt.Sprintf("%s.%s = ?", mm.TableName, ff.ColumnName /*, q.totalWheres()+1*/) args = append(args, fieldValue.Interface()) q.wheres[whereClause] = args } } return q } func (q *Query) Order(order string) *Query { q.orders = append(q.orders, order) return q } func (q *Query) Limit(limit int) *Query { q.limit = limit return q } func (q *Query) Offset(offset int) *Query { q.offset = offset return q } func (q *Query) buildSelect() sb.SelectBuilder { var fields []string for _, f := range q.model.Fields { if f.ColumnType == "" { continue } tn, a := f.alias() fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) } seenModels := make(map[string]bool) processField := func(f *Field, m *Model, pfk *Relationship) { if f.ColumnType == "" { if rel, ok := m.Relationships[f.Name]; ok { data, ok2 := q.joins[rel] if ok2 && f.ColumnType != "" { tn, a := f.aliasWith(data[0]) fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) } } return } { fk := f.fk if fk == nil { fk = m.Relationships[f.Name] } if fk == nil { fk = pfk } if fk != nil && fk.FieldName == f.Name { data, ok2 := q.joins[fk] if ok2 { var ( tn, a string ) if fk.Type == HasOne { tn, a = fk.RelatedModel.Fields[fk.RelatedModel.IDField].aliasWith(data[0]) } else { tn, a = f.aliasWith(data[0]) } fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) } return } } if f.Name == pfk.FieldName { f.aliasWith(q.joins[pfk][0]) } else { tn, a := f.alias() fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) } } for r := range q.joins { if !seenModels[r.aliasThingy()] { seenModels[r.aliasThingy()] = true for _, f := range r.Model.Fields { processField(f, r.Model, r) } } if !seenModels[r.relatedAlias()] { seenModels[r.relatedAlias()] = true for _, f := range r.RelatedModel.Fields { processField(f, r.RelatedModel, r) } } } return sb.Select(fields...) } func (q *Query) buildSQL() (string, []any) { sqlb := q.buildSelect().From(q.model.TableName) whereargs := make([]any, 0) if len(q.wheres) > 0 { cnt := 0 for w, where := range q.wheres { sqlb = sqlb.Where(w, where...) whereargs = append(whereargs, where...) cnt++ } } for _, j := range q.joins { sqlb = sqlb.LeftJoin(fmt.Sprintf("%s as %s ON %s", j[1], j[0], j[2])) } ool: for _, o := range q.orders { ac, ok := q.model.Fields[o] if !ok { var rel = ac.fk if ac.ColumnType == "" || rel == nil { rel = q.model.Relationships[o] } if rel != nil { if strings.Contains(o, ".") { split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".") cm := rel.Model for i, s := range split { if rel != nil { cm = rel.RelatedModel } else if i == len(split)-1 { break } else { continue ool } rel = cm.Relationships[s] } lf := split[len(split)-1] ac, ok = cm.Fields[lf] if !ok { continue } } } } sqlb = sqlb.OrderBy(ac.ColumnName) } if q.limit > 0 { sqlb = sqlb.Limit(uint64(q.limit)) } return sqlb.MustSQL() }