370 lines
11 KiB
Go
370 lines
11 KiB
Go
package orm
|
|
|
|
import (
|
|
"fmt"
|
|
sb "github.com/henvic/pgq"
|
|
"github.com/jackc/pgx/v5"
|
|
"reflect"
|
|
"time"
|
|
)
|
|
|
|
// Find - transpiles this query into SQL and places the result in `dest`
|
|
func (q *Query) Find(dest any) error {
|
|
dstVal := reflect.ValueOf(dest)
|
|
if dstVal.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("destination must be a pointer, got: %v", dstVal.Kind())
|
|
}
|
|
maybeSlice := dstVal.Elem()
|
|
cols, acols, sqlb, err := q.buildSQL()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
qq, qa := sqlb.MustSQL()
|
|
q.engine.logQuery("find", qq, qa)
|
|
|
|
if maybeSlice.Kind() == reflect.Struct {
|
|
row := q.engine.conn.QueryRow(q.ctx, qq, qa...)
|
|
if err = scanRow(row, cols, acols, maybeSlice, q.model); err != nil {
|
|
return err
|
|
}
|
|
} else if maybeSlice.Kind() == reflect.Slice ||
|
|
maybeSlice.Kind() == reflect.Array {
|
|
var rows pgx.Rows
|
|
rows, err = q.engine.conn.Query(q.ctx, qq, qa...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
etype := maybeSlice.Type().Elem()
|
|
for rows.Next() {
|
|
nelem := reflect.New(etype).Elem()
|
|
if err = scanRow(rows, cols, acols, nelem, q.model); err != nil {
|
|
return err
|
|
}
|
|
maybeSlice.Set(reflect.Append(maybeSlice, nelem))
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unsupported destination type: %s", maybeSlice.Kind())
|
|
}
|
|
if len(q.populationTree) > 0 {
|
|
nslice := maybeSlice
|
|
var wasPassedStruct bool
|
|
if nslice.Kind() == reflect.Struct {
|
|
nslice = reflect.MakeSlice(reflect.SliceOf(maybeSlice.Type()), 0, 0)
|
|
wasPassedStruct = true
|
|
nslice = reflect.Append(nslice, maybeSlice)
|
|
}
|
|
err = q.processPopulate(nslice, q.model, q.populationTree)
|
|
if err == nil && wasPassedStruct {
|
|
maybeSlice.Set(nslice.Index(0))
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Save - create or update `val` in the database
|
|
func (q *Query) Save(val any) error {
|
|
return q.saveOrCreate(val, false)
|
|
}
|
|
|
|
// Create - like Save, but hints to the query processor that you want to insert, not update.
|
|
// useful if you're importing data and want to keep the IDs intact.
|
|
func (q *Query) Create(val any) error {
|
|
return q.saveOrCreate(val, true)
|
|
}
|
|
|
|
// UpdateRaw - takes a mapping of struct field names to
|
|
// SQL expressions, updating each field's associated column accordingly
|
|
func (q *Query) UpdateRaw(values map[string]any) (int64, error) {
|
|
var err error
|
|
var subQuery sb.SelectBuilder
|
|
stmt := sb.Update(q.model.TableName)
|
|
_, _, subQuery, err = q.buildSQL()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
subQuery = sb.Select(q.model.idField().ColumnName).FromSelect(subQuery, "subQuery")
|
|
stmt = stmt.Where(wrapQueryIn(subQuery,
|
|
q.model.idField().ColumnName))
|
|
for k, v := range values {
|
|
asString, isString := v.(string)
|
|
if f, ok := q.model.Fields[k]; ok {
|
|
if isString {
|
|
stmt = stmt.Set(f.ColumnName, sb.Expr(asString))
|
|
} else {
|
|
stmt = stmt.Set(f.ColumnName, v)
|
|
}
|
|
}
|
|
if _, ok := q.model.FieldsByColumnName[k]; ok {
|
|
if isString {
|
|
stmt = stmt.Set(k, sb.Expr(asString))
|
|
} else {
|
|
stmt = stmt.Set(k, v)
|
|
}
|
|
}
|
|
}
|
|
sql, args := stmt.MustSQL()
|
|
q.engine.logQuery("update/raw", sql, args)
|
|
q.tx, err = q.engine.conn.Begin(q.ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer q.cleanupTx()
|
|
|
|
ctag, err := q.tx.Exec(q.ctx, sql, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return ctag.RowsAffected(), q.tx.Commit(q.ctx)
|
|
}
|
|
|
|
// Delete - delete one or more entities matching previous conditions specified
|
|
// by methods like Where, WhereRaw, or In. will refuse to execute if no
|
|
// conditions were specified for safety reasons. to override this, call
|
|
// WhereRaw("true") or WhereRaw("1 = 1") before this method.
|
|
func (q *Query) Delete() (int64, error) {
|
|
var err error
|
|
var subQuery sb.SelectBuilder
|
|
if len(q.wheres) < 1 {
|
|
return 0, ErrNoConditionOnDeleteOrUpdate
|
|
}
|
|
q.tx, err = q.engine.conn.Begin(q.ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer q.cleanupTx()
|
|
_, _, subQuery, err = q.buildSQL()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
sqlb := sb.Delete(q.model.TableName).Where(subQuery)
|
|
sql, sqla := sqlb.MustSQL()
|
|
q.engine.logQuery("delete", sql, sqla)
|
|
cmdTag, err := q.tx.Exec(q.ctx, sql, sqla...)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to delete: %w", err)
|
|
}
|
|
return cmdTag.RowsAffected(), nil
|
|
}
|
|
|
|
func (q *Query) saveOrCreate(val any, shouldCreate bool) error {
|
|
v := reflect.ValueOf(val)
|
|
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("Save() must be called with a pointer to a struct")
|
|
}
|
|
var err error
|
|
q.tx, err = q.engine.conn.BeginTx(q.ctx, pgx.TxOptions{
|
|
AccessMode: pgx.ReadWrite,
|
|
IsoLevel: pgx.ReadUncommitted,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer q.cleanupTx()
|
|
if _, err = q.doSave(v.Elem(), q.engine.modelMap.Map[v.Elem().Type().Name()], nil, shouldCreate); err != nil {
|
|
return err
|
|
}
|
|
return q.tx.Commit(q.ctx)
|
|
}
|
|
|
|
func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any, shouldInsert bool) (any, error) {
|
|
idField := model.Fields[model.IDField]
|
|
var pkField reflect.Value
|
|
if val.Kind() == reflect.Pointer {
|
|
if !val.Elem().IsValid() || val.Elem().IsZero() {
|
|
return nil, nil
|
|
}
|
|
pkField = val.Elem().FieldByName(model.IDField)
|
|
} else {
|
|
pkField = val.FieldByName(model.IDField)
|
|
}
|
|
isNew := pkField.IsZero()
|
|
var exists bool
|
|
if !pkField.IsZero() {
|
|
eb := sb.Select("1").
|
|
Prefix("SELECT EXISTS (").
|
|
From(model.TableName).
|
|
Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()).
|
|
Suffix(")")
|
|
ebs, eba := eb.MustSQL()
|
|
var ex bool
|
|
err := q.tx.QueryRow(q.ctx, ebs, eba...).Scan(&ex)
|
|
if err != nil {
|
|
q.engine.logger.Warn("error while checking existence", "err", err.Error())
|
|
}
|
|
exists = ex
|
|
}
|
|
/*{
|
|
el, ok := q.seenIds[model]
|
|
if !ok {
|
|
q.seenIds[model] = make(map[any]bool)
|
|
}
|
|
if ok && el[pkField.Interface()] {
|
|
return pkField.Interface(), nil
|
|
}
|
|
if !isNew {
|
|
q.seenIds[model][pkField.Interface()] = true
|
|
}
|
|
}*/
|
|
doInsert := isNew || !exists
|
|
var cols []string
|
|
args := make([]any, 0)
|
|
seenJoinTables := make(map[string]map[any]bool)
|
|
for _, rel := range model.Relationships {
|
|
if rel.Type != BelongsTo {
|
|
continue
|
|
}
|
|
parentVal := val.FieldByName(rel.FieldName)
|
|
if parentVal.IsValid() {
|
|
nid, err := q.doSave(parentVal, rel.RelatedModel, nil, rel.RelatedModel.needsPrimaryKey(parentVal) && isNew)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cols = append(cols, pascalToSnakeCase(rel.joinField()))
|
|
args = append(args, nid)
|
|
} else if parentVal.IsValid() {
|
|
_, nid := rel.RelatedModel.getPrimaryKey(parentVal)
|
|
cols = append(cols, pascalToSnakeCase(rel.joinField()))
|
|
args = append(args, nid)
|
|
}
|
|
}
|
|
for _, ff := range model.Fields {
|
|
var fv reflect.Value
|
|
if ff.Index > -1 && !ff.isAnonymous() {
|
|
fv = val.Field(ff.Index)
|
|
} else if ff.Index > -1 {
|
|
for col, ef := range ff.embeddedFields {
|
|
fv = val.Field(ff.Index)
|
|
cols = append(cols, col)
|
|
eif := fv.FieldByName(ef.Name)
|
|
if ff.Name == documentField && canConvertTo[Document](ff.Type) {
|
|
asTime, ok := eif.Interface().(time.Time)
|
|
shouldCreate := ok && (asTime.IsZero() || eif.IsZero())
|
|
if doInsert && ef.Name == createdField && shouldCreate {
|
|
eif.Set(reflect.ValueOf(time.Now()))
|
|
} else if ef.Name == modifiedField || shouldCreate {
|
|
eif.Set(reflect.ValueOf(time.Now()))
|
|
}
|
|
args = append(args, eif.Interface())
|
|
continue
|
|
}
|
|
args = append(args, fv.FieldByName(ef.Name).Interface())
|
|
}
|
|
continue
|
|
}
|
|
if ff.Name == model.IDField {
|
|
if !isNew && fv.IsValid() {
|
|
cols = append(cols, ff.ColumnName)
|
|
args = append(args, fv.Interface())
|
|
}
|
|
continue
|
|
}
|
|
|
|
if fv.IsValid() {
|
|
cols = append(cols, ff.ColumnName)
|
|
args = append(args, fv.Interface())
|
|
}
|
|
}
|
|
for k, fk := range parentFks {
|
|
cols = append(cols, k)
|
|
args = append(args, fk)
|
|
}
|
|
var qq string
|
|
var qa []any
|
|
if doInsert {
|
|
osb := sb.Insert(model.TableName)
|
|
if len(cols) == 0 {
|
|
qq = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", model.TableName, idField.ColumnName)
|
|
} else {
|
|
osb = osb.Columns(cols...).Values(args...)
|
|
qq, qa = osb.Returning(idField.ColumnName).MustSQL()
|
|
}
|
|
} else {
|
|
osb := sb.Update(model.TableName)
|
|
for i := range cols {
|
|
osb = osb.Set(cols[i], args[i])
|
|
}
|
|
osb = osb.Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface())
|
|
qq, qa = osb.MustSQL()
|
|
}
|
|
if doInsert {
|
|
var nid any
|
|
q.engine.logQuery("insert", qq, qa)
|
|
row := q.tx.QueryRow(q.ctx, qq, qa...)
|
|
err := row.Scan(&nid)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("insert failed for model %s: %w", model.Name, err)
|
|
}
|
|
pkField.Set(reflect.ValueOf(nid))
|
|
} else {
|
|
q.engine.logQuery("update", qq, qa)
|
|
_, err := q.tx.Exec(q.ctx, qq, qa...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("update failed for model %s: %w", model.Name, err)
|
|
}
|
|
}
|
|
/*if _, ok := q.seenIds[model]; !ok {
|
|
q.seenIds[model] = make(map[any]bool)
|
|
}
|
|
q.seenIds[model][pkField.Interface()] = true*/
|
|
for _, rel := range model.Relationships {
|
|
|
|
if rel.Idx > -1 && rel.Idx < val.NumField() {
|
|
fv := val.FieldByName(rel.FieldName)
|
|
cm := rel.RelatedModel
|
|
pfks := map[string]any{}
|
|
if !model.embeddedIsh && rel.Type == HasMany {
|
|
{
|
|
rm := cm.Relationships[model.Name]
|
|
if rm != nil && rm.Type == ManyToOne {
|
|
pfks[pascalToSnakeCase(rm.joinField())] = pkField.Interface()
|
|
}
|
|
}
|
|
for j := range fv.Len() {
|
|
child := fv.Index(j).Addr().Elem()
|
|
if _, err := q.doSave(child, cm, pfks, cm.needsPrimaryKey(child)); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
} else if rel.Type == HasOne && cm.embeddedIsh {
|
|
if _, err := q.doSave(fv, cm, pfks, cm.needsPrimaryKey(fv)); err != nil {
|
|
return nil, err
|
|
}
|
|
} else if rel.m2mIsh() || rel.Type == ManyToMany || (model.embeddedIsh && cm.embeddedIsh && rel.Type == HasMany) {
|
|
if seenJoinTables[rel.ComputeJoinTable()] == nil {
|
|
seenJoinTables[rel.ComputeJoinTable()] = make(map[any]bool)
|
|
}
|
|
if !seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] {
|
|
seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] = true
|
|
if err := rel.joinDelete(pkField.Interface(), nil, q); err != nil {
|
|
return nil, fmt.Errorf("error deleting existing association: %w", err)
|
|
}
|
|
}
|
|
if fv.Kind() == reflect.Slice || fv.Kind() == reflect.Array {
|
|
mField := model.Fields[model.IDField]
|
|
mpks := map[string]any{}
|
|
if !model.embeddedIsh {
|
|
mpks[model.TableName+"_"+mField.ColumnName] = pkField.Interface()
|
|
}
|
|
for i := range fv.Len() {
|
|
cur := fv.Index(i)
|
|
if _, err := q.doSave(cur, cm, mpks, cm.needsPrimaryKey(cur) && pkField.IsZero()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if rel.m2mIsh() || rel.Type == ManyToMany {
|
|
if err := rel.joinInsert(cur, q, pkField.Interface()); err != nil {
|
|
return nil, fmt.Errorf("failed to insert association for model %s: %w", model.Name, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return pkField.Interface(), nil
|
|
}
|