diamond-orm/query_tail.go

370 lines
11 KiB
Go

package orm
import (
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"reflect"
"time"
)
// Find - transpiles this query into SQL and places the result in `dest`
func (q *Query) Find(dest any) error {
dstVal := reflect.ValueOf(dest)
if dstVal.Kind() != reflect.Ptr {
return fmt.Errorf("destination must be a pointer, got: %v", dstVal.Kind())
}
maybeSlice := dstVal.Elem()
cols, acols, sqlb, err := q.buildSQL()
if err != nil {
return err
}
qq, qa := sqlb.MustSQL()
q.engine.logQuery("find", qq, qa)
if maybeSlice.Kind() == reflect.Struct {
row := q.engine.conn.QueryRow(q.ctx, qq, qa...)
if err = scanRow(row, cols, acols, maybeSlice, q.model); err != nil {
return err
}
} else if maybeSlice.Kind() == reflect.Slice ||
maybeSlice.Kind() == reflect.Array {
var rows pgx.Rows
rows, err = q.engine.conn.Query(q.ctx, qq, qa...)
if err != nil {
return err
}
defer rows.Close()
etype := maybeSlice.Type().Elem()
for rows.Next() {
nelem := reflect.New(etype).Elem()
if err = scanRow(rows, cols, acols, nelem, q.model); err != nil {
return err
}
maybeSlice.Set(reflect.Append(maybeSlice, nelem))
}
} else {
return fmt.Errorf("unsupported destination type: %s", maybeSlice.Kind())
}
if len(q.populationTree) > 0 {
nslice := maybeSlice
var wasPassedStruct bool
if nslice.Kind() == reflect.Struct {
nslice = reflect.MakeSlice(reflect.SliceOf(maybeSlice.Type()), 0, 0)
wasPassedStruct = true
nslice = reflect.Append(nslice, maybeSlice)
}
err = q.processPopulate(nslice, q.model, q.populationTree)
if err == nil && wasPassedStruct {
maybeSlice.Set(nslice.Index(0))
}
return err
}
return nil
}
// Save - create or update `val` in the database
func (q *Query) Save(val any) error {
return q.saveOrCreate(val, false)
}
// Create - like Save, but hints to the query processor that you want to insert, not update.
// useful if you're importing data and want to keep the IDs intact.
func (q *Query) Create(val any) error {
return q.saveOrCreate(val, true)
}
// UpdateRaw - takes a mapping of struct field names to
// SQL expressions, updating each field's associated column accordingly
func (q *Query) UpdateRaw(values map[string]any) (int64, error) {
var err error
var subQuery sb.SelectBuilder
stmt := sb.Update(q.model.TableName)
_, _, subQuery, err = q.buildSQL()
if err != nil {
return 0, err
}
subQuery = sb.Select(q.model.idField().ColumnName).FromSelect(subQuery, "subQuery")
stmt = stmt.Where(wrapQueryIn(subQuery,
q.model.idField().ColumnName))
for k, v := range values {
asString, isString := v.(string)
if f, ok := q.model.Fields[k]; ok {
if isString {
stmt = stmt.Set(f.ColumnName, sb.Expr(asString))
} else {
stmt = stmt.Set(f.ColumnName, v)
}
}
if _, ok := q.model.FieldsByColumnName[k]; ok {
if isString {
stmt = stmt.Set(k, sb.Expr(asString))
} else {
stmt = stmt.Set(k, v)
}
}
}
sql, args := stmt.MustSQL()
q.engine.logQuery("update/raw", sql, args)
q.tx, err = q.engine.conn.Begin(q.ctx)
if err != nil {
return 0, err
}
defer q.cleanupTx()
ctag, err := q.tx.Exec(q.ctx, sql, args...)
if err != nil {
return 0, err
}
return ctag.RowsAffected(), q.tx.Commit(q.ctx)
}
// Delete - delete one or more entities matching previous conditions specified
// by methods like Where, WhereRaw, or In. will refuse to execute if no
// conditions were specified for safety reasons. to override this, call
// WhereRaw("true") or WhereRaw("1 = 1") before this method.
func (q *Query) Delete() (int64, error) {
var err error
var subQuery sb.SelectBuilder
if len(q.wheres) < 1 {
return 0, ErrNoConditionOnDeleteOrUpdate
}
q.tx, err = q.engine.conn.Begin(q.ctx)
if err != nil {
return 0, err
}
defer q.cleanupTx()
_, _, subQuery, err = q.buildSQL()
if err != nil {
return 0, err
}
sqlb := sb.Delete(q.model.TableName).Where(subQuery)
sql, sqla := sqlb.MustSQL()
q.engine.logQuery("delete", sql, sqla)
cmdTag, err := q.tx.Exec(q.ctx, sql, sqla...)
if err != nil {
return 0, fmt.Errorf("failed to delete: %w", err)
}
return cmdTag.RowsAffected(), nil
}
func (q *Query) saveOrCreate(val any, shouldCreate bool) error {
v := reflect.ValueOf(val)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("Save() must be called with a pointer to a struct")
}
var err error
q.tx, err = q.engine.conn.BeginTx(q.ctx, pgx.TxOptions{
AccessMode: pgx.ReadWrite,
IsoLevel: pgx.ReadUncommitted,
})
if err != nil {
return err
}
defer q.cleanupTx()
if _, err = q.doSave(v.Elem(), q.engine.modelMap.Map[v.Elem().Type().Name()], nil, shouldCreate); err != nil {
return err
}
return q.tx.Commit(q.ctx)
}
func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any, shouldInsert bool) (any, error) {
idField := model.Fields[model.IDField]
var pkField reflect.Value
if val.Kind() == reflect.Pointer {
if !val.Elem().IsValid() || val.Elem().IsZero() {
return nil, nil
}
pkField = val.Elem().FieldByName(model.IDField)
} else {
pkField = val.FieldByName(model.IDField)
}
isNew := pkField.IsZero()
var exists bool
if !pkField.IsZero() {
eb := sb.Select("1").
Prefix("SELECT EXISTS (").
From(model.TableName).
Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()).
Suffix(")")
ebs, eba := eb.MustSQL()
var ex bool
err := q.tx.QueryRow(q.ctx, ebs, eba...).Scan(&ex)
if err != nil {
q.engine.logger.Warn("error while checking existence", "err", err.Error())
}
exists = ex
}
/*{
el, ok := q.seenIds[model]
if !ok {
q.seenIds[model] = make(map[any]bool)
}
if ok && el[pkField.Interface()] {
return pkField.Interface(), nil
}
if !isNew {
q.seenIds[model][pkField.Interface()] = true
}
}*/
doInsert := isNew || !exists
var cols []string
args := make([]any, 0)
seenJoinTables := make(map[string]map[any]bool)
for _, rel := range model.Relationships {
if rel.Type != BelongsTo {
continue
}
parentVal := val.FieldByName(rel.FieldName)
if parentVal.IsValid() {
nid, err := q.doSave(parentVal, rel.RelatedModel, nil, rel.RelatedModel.needsPrimaryKey(parentVal) && isNew)
if err != nil {
return nil, err
}
cols = append(cols, pascalToSnakeCase(rel.joinField()))
args = append(args, nid)
} else if parentVal.IsValid() {
_, nid := rel.RelatedModel.getPrimaryKey(parentVal)
cols = append(cols, pascalToSnakeCase(rel.joinField()))
args = append(args, nid)
}
}
for _, ff := range model.Fields {
var fv reflect.Value
if ff.Index > -1 && !ff.isAnonymous() {
fv = val.Field(ff.Index)
} else if ff.Index > -1 {
for col, ef := range ff.embeddedFields {
fv = val.Field(ff.Index)
cols = append(cols, col)
eif := fv.FieldByName(ef.Name)
if ff.Name == documentField && canConvertTo[Document](ff.Type) {
asTime, ok := eif.Interface().(time.Time)
shouldCreate := ok && (asTime.IsZero() || eif.IsZero())
if doInsert && ef.Name == createdField && shouldCreate {
eif.Set(reflect.ValueOf(time.Now()))
} else if ef.Name == modifiedField || shouldCreate {
eif.Set(reflect.ValueOf(time.Now()))
}
args = append(args, eif.Interface())
continue
}
args = append(args, fv.FieldByName(ef.Name).Interface())
}
continue
}
if ff.Name == model.IDField {
if !isNew && fv.IsValid() {
cols = append(cols, ff.ColumnName)
args = append(args, fv.Interface())
}
continue
}
if fv.IsValid() {
cols = append(cols, ff.ColumnName)
args = append(args, fv.Interface())
}
}
for k, fk := range parentFks {
cols = append(cols, k)
args = append(args, fk)
}
var qq string
var qa []any
if doInsert {
osb := sb.Insert(model.TableName)
if len(cols) == 0 {
qq = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", model.TableName, idField.ColumnName)
} else {
osb = osb.Columns(cols...).Values(args...)
qq, qa = osb.Returning(idField.ColumnName).MustSQL()
}
} else {
osb := sb.Update(model.TableName)
for i := range cols {
osb = osb.Set(cols[i], args[i])
}
osb = osb.Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface())
qq, qa = osb.MustSQL()
}
if doInsert {
var nid any
q.engine.logQuery("insert", qq, qa)
row := q.tx.QueryRow(q.ctx, qq, qa...)
err := row.Scan(&nid)
if err != nil {
return nil, fmt.Errorf("insert failed for model %s: %w", model.Name, err)
}
pkField.Set(reflect.ValueOf(nid))
} else {
q.engine.logQuery("update", qq, qa)
_, err := q.tx.Exec(q.ctx, qq, qa...)
if err != nil {
return nil, fmt.Errorf("update failed for model %s: %w", model.Name, err)
}
}
/*if _, ok := q.seenIds[model]; !ok {
q.seenIds[model] = make(map[any]bool)
}
q.seenIds[model][pkField.Interface()] = true*/
for _, rel := range model.Relationships {
if rel.Idx > -1 && rel.Idx < val.NumField() {
fv := val.FieldByName(rel.FieldName)
cm := rel.RelatedModel
pfks := map[string]any{}
if !model.embeddedIsh && rel.Type == HasMany {
{
rm := cm.Relationships[model.Name]
if rm != nil && rm.Type == ManyToOne {
pfks[pascalToSnakeCase(rm.joinField())] = pkField.Interface()
}
}
for j := range fv.Len() {
child := fv.Index(j).Addr().Elem()
if _, err := q.doSave(child, cm, pfks, cm.needsPrimaryKey(child)); err != nil {
return nil, err
}
}
} else if rel.Type == HasOne && cm.embeddedIsh {
if _, err := q.doSave(fv, cm, pfks, cm.needsPrimaryKey(fv)); err != nil {
return nil, err
}
} else if rel.m2mIsh() || rel.Type == ManyToMany || (model.embeddedIsh && cm.embeddedIsh && rel.Type == HasMany) {
if seenJoinTables[rel.ComputeJoinTable()] == nil {
seenJoinTables[rel.ComputeJoinTable()] = make(map[any]bool)
}
if !seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] {
seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] = true
if err := rel.joinDelete(pkField.Interface(), nil, q); err != nil {
return nil, fmt.Errorf("error deleting existing association: %w", err)
}
}
if fv.Kind() == reflect.Slice || fv.Kind() == reflect.Array {
mField := model.Fields[model.IDField]
mpks := map[string]any{}
if !model.embeddedIsh {
mpks[model.TableName+"_"+mField.ColumnName] = pkField.Interface()
}
for i := range fv.Len() {
cur := fv.Index(i)
if _, err := q.doSave(cur, cm, mpks, cm.needsPrimaryKey(cur) && pkField.IsZero()); err != nil {
return nil, err
}
if rel.m2mIsh() || rel.Type == ManyToMany {
if err := rel.joinInsert(cur, q, pkField.Interface()); err != nil {
return nil, fmt.Errorf("failed to insert association for model %s: %w", model.Name, err)
}
}
}
}
}
}
}
return pkField.Interface(), nil
}