package orm import ( "fmt" sb "github.com/henvic/pgq" "github.com/jackc/pgx/v5" "reflect" "time" ) func (q *Query) Find(dest any) error { dstVal := reflect.ValueOf(dest) if dstVal.Kind() != reflect.Ptr { return fmt.Errorf("destination must be a pointer, got: %v", dstVal.Kind()) } maybeSlice := dstVal.Elem() cols, acols, sqlb, err := q.buildSQL() if err != nil { return err } qq, qa := sqlb.MustSQL() q.engine.logQuery("find", qq, qa) if maybeSlice.Kind() == reflect.Struct { row := q.engine.conn.QueryRow(q.ctx, qq, qa...) if err = scanRow(row, cols, acols, maybeSlice, q.model); err != nil { return err } } else if maybeSlice.Kind() == reflect.Slice || maybeSlice.Kind() == reflect.Array { var rows pgx.Rows rows, err = q.engine.conn.Query(q.ctx, qq, qa...) if err != nil { return err } defer rows.Close() etype := maybeSlice.Type().Elem() for rows.Next() { nelem := reflect.New(etype).Elem() if err = scanRow(rows, cols, acols, nelem, q.model); err != nil { return err } maybeSlice.Set(reflect.Append(maybeSlice, nelem)) } } else { return fmt.Errorf("unsupported destination type: %s", maybeSlice.Kind()) } if len(q.populationTree) > 0 { nslice := maybeSlice var wasPassedStruct bool if nslice.Kind() == reflect.Struct { nslice = reflect.MakeSlice(reflect.SliceOf(maybeSlice.Type()), 0, 0) wasPassedStruct = true nslice = reflect.Append(nslice, maybeSlice) } err = q.processPopulate(nslice, q.model, q.populationTree) if err == nil && wasPassedStruct { maybeSlice.Set(nslice.Index(0)) } return err } return nil } func (q *Query) Save(val any) error { return q.saveOrCreate(val, false) } func (q *Query) Create(val any) error { return q.saveOrCreate(val, true) } // UpdateRaw - takes a mapping of struct field names to // SQL expressions, updating each field's associated column accordingly func (q *Query) UpdateRaw(values map[string]any) (int64, error) { var err error var subQuery sb.SelectBuilder stmt := sb.Update(q.model.TableName) _, _, subQuery, err = q.buildSQL() if err != nil { return 0, err } subQuery = sb.Select(q.model.idField().ColumnName).FromSelect(subQuery, "subQuery") stmt = stmt.Where(wrapQueryIn(subQuery, q.model.idField().ColumnName)) for k, v := range values { asString, isString := v.(string) if f, ok := q.model.Fields[k]; ok { if isString { stmt = stmt.Set(f.ColumnName, sb.Expr(asString)) } else { stmt = stmt.Set(f.ColumnName, v) } } if _, ok := q.model.FieldsByColumnName[k]; ok { if isString { stmt = stmt.Set(k, sb.Expr(asString)) } else { stmt = stmt.Set(k, v) } } } sql, args := stmt.MustSQL() q.engine.logQuery("update/raw", sql, args) q.tx, err = q.engine.conn.Begin(q.ctx) if err != nil { return 0, err } defer q.cleanupTx() ctag, err := q.tx.Exec(q.ctx, sql, args...) if err != nil { return 0, err } return ctag.RowsAffected(), q.tx.Commit(q.ctx) } func (q *Query) Delete() (int64, error) { var err error var subQuery sb.SelectBuilder if len(q.wheres) < 1 { return 0, ErrNoConditionOnDeleteOrUpdate } q.tx, err = q.engine.conn.Begin(q.ctx) if err != nil { return 0, err } defer q.cleanupTx() _, _, subQuery, err = q.buildSQL() if err != nil { return 0, err } sqlb := sb.Delete(q.model.TableName).Where(subQuery) sql, sqla := sqlb.MustSQL() q.engine.logQuery("delete", sql, sqla) cmdTag, err := q.tx.Exec(q.ctx, sql, sqla...) if err != nil { return 0, fmt.Errorf("failed to delete: %w", err) } return cmdTag.RowsAffected(), nil } func (q *Query) saveOrCreate(val any, shouldCreate bool) error { v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { return fmt.Errorf("Save() must be called with a pointer to a struct") } var err error q.tx, err = q.engine.conn.BeginTx(q.ctx, pgx.TxOptions{ AccessMode: pgx.ReadWrite, IsoLevel: pgx.ReadUncommitted, }) if err != nil { return err } defer q.cleanupTx() if _, err = q.doSave(v.Elem(), q.engine.modelMap.Map[v.Elem().Type().Name()], nil, shouldCreate); err != nil { return err } return q.tx.Commit(q.ctx) } func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any, shouldInsert bool) (any, error) { idField := model.Fields[model.IDField] var pkField reflect.Value if val.Kind() == reflect.Pointer { if !val.Elem().IsValid() || val.Elem().IsZero() { return nil, nil } pkField = val.Elem().FieldByName(model.IDField) } else { pkField = val.FieldByName(model.IDField) } isNew := pkField.IsZero() var exists bool if !pkField.IsZero() { eb := sb.Select("1"). Prefix("SELECT EXISTS ("). From(model.TableName). Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()). Suffix(")") ebs, eba := eb.MustSQL() var ex bool err := q.tx.QueryRow(q.ctx, ebs, eba...).Scan(&ex) if err != nil { q.engine.logger.Warn("error while checking existence", "err", err.Error()) } exists = ex } /*{ el, ok := q.seenIds[model] if !ok { q.seenIds[model] = make(map[any]bool) } if ok && el[pkField.Interface()] { return pkField.Interface(), nil } if !isNew { q.seenIds[model][pkField.Interface()] = true } }*/ doInsert := isNew || !exists var cols []string args := make([]any, 0) seenJoinTables := make(map[string]map[any]bool) for _, rel := range model.Relationships { if rel.Type != BelongsTo { continue } parentVal := val.FieldByName(rel.FieldName) if parentVal.IsValid() { nid, err := q.doSave(parentVal, rel.RelatedModel, nil, rel.RelatedModel.needsPrimaryKey(parentVal) && isNew) if err != nil { return nil, err } cols = append(cols, pascalToSnakeCase(rel.joinField())) args = append(args, nid) } else if parentVal.IsValid() { _, nid := rel.RelatedModel.getPrimaryKey(parentVal) cols = append(cols, pascalToSnakeCase(rel.joinField())) args = append(args, nid) } } for _, ff := range model.Fields { var fv reflect.Value if ff.Index > -1 && !ff.isAnonymous() { fv = val.Field(ff.Index) } else if ff.Index > -1 { for col, ef := range ff.embeddedFields { fv = val.Field(ff.Index) cols = append(cols, col) eif := fv.FieldByName(ef.Name) if ff.Name == documentField && canConvertTo[Document](ff.Type) { asTime, ok := eif.Interface().(time.Time) shouldCreate := ok && (asTime.IsZero() || eif.IsZero()) if doInsert && ef.Name == createdField && shouldCreate { eif.Set(reflect.ValueOf(time.Now())) } else if ef.Name == modifiedField || shouldCreate { eif.Set(reflect.ValueOf(time.Now())) } args = append(args, eif.Interface()) continue } args = append(args, fv.FieldByName(ef.Name).Interface()) } continue } if ff.Name == model.IDField { if !isNew && fv.IsValid() { cols = append(cols, ff.ColumnName) args = append(args, fv.Interface()) } continue } if fv.IsValid() { cols = append(cols, ff.ColumnName) args = append(args, fv.Interface()) } } for k, fk := range parentFks { cols = append(cols, k) args = append(args, fk) } var qq string var qa []any if doInsert { osb := sb.Insert(model.TableName) if len(cols) == 0 { qq = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", model.TableName, idField.ColumnName) } else { osb = osb.Columns(cols...).Values(args...) qq, qa = osb.Returning(idField.ColumnName).MustSQL() } } else { osb := sb.Update(model.TableName) for i := range cols { osb = osb.Set(cols[i], args[i]) } osb = osb.Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()) qq, qa = osb.MustSQL() } if doInsert { var nid any q.engine.logQuery("insert", qq, qa) row := q.tx.QueryRow(q.ctx, qq, qa...) err := row.Scan(&nid) if err != nil { return nil, fmt.Errorf("insert failed for model %s: %w", model.Name, err) } pkField.Set(reflect.ValueOf(nid)) } else { q.engine.logQuery("update", qq, qa) _, err := q.tx.Exec(q.ctx, qq, qa...) if err != nil { return nil, fmt.Errorf("update failed for model %s: %w", model.Name, err) } } /*if _, ok := q.seenIds[model]; !ok { q.seenIds[model] = make(map[any]bool) } q.seenIds[model][pkField.Interface()] = true*/ for _, rel := range model.Relationships { if rel.Idx > -1 && rel.Idx < val.NumField() { fv := val.FieldByName(rel.FieldName) cm := rel.RelatedModel pfks := map[string]any{} if !model.embeddedIsh && rel.Type == HasMany { { rm := cm.Relationships[model.Name] if rm != nil && rm.Type == ManyToOne { pfks[pascalToSnakeCase(rm.joinField())] = pkField.Interface() } } for j := range fv.Len() { child := fv.Index(j).Addr().Elem() if _, err := q.doSave(child, cm, pfks, cm.needsPrimaryKey(child)); err != nil { return nil, err } } } else if rel.Type == HasOne && cm.embeddedIsh { if _, err := q.doSave(fv, cm, pfks, cm.needsPrimaryKey(fv)); err != nil { return nil, err } } else if rel.m2mIsh() || rel.Type == ManyToMany || (model.embeddedIsh && cm.embeddedIsh && rel.Type == HasMany) { if seenJoinTables[rel.ComputeJoinTable()] == nil { seenJoinTables[rel.ComputeJoinTable()] = make(map[any]bool) } if !seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] { seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] = true if err := rel.joinDelete(pkField.Interface(), nil, q); err != nil { return nil, fmt.Errorf("error deleting existing association: %w", err) } } if fv.Kind() == reflect.Slice || fv.Kind() == reflect.Array { mField := model.Fields[model.IDField] mpks := map[string]any{} if !model.embeddedIsh { mpks[model.TableName+"_"+mField.ColumnName] = pkField.Interface() } for i := range fv.Len() { cur := fv.Index(i) if _, err := q.doSave(cur, cm, mpks, cm.needsPrimaryKey(cur) && pkField.IsZero()); err != nil { return nil, err } if rel.m2mIsh() || rel.Type == ManyToMany { if err := rel.joinInsert(cur, q, pkField.Interface()); err != nil { return nil, fmt.Errorf("failed to insert association for model %s: %w", model.Name, err) } } } } } } } return pkField.Interface(), nil }