Compare commits

..

No commits in common. "main" and "backup/main" have entirely different histories.

41 changed files with 4836 additions and 3661 deletions

View File

@ -1,3 +0,0 @@
((nil . ((elcord-project-name . "the diamond ORM")
(elcord-project-description . "a mongodb ORM for golang that rocks!")
)))

3
.gitignore vendored
View File

@ -2,5 +2,4 @@
go.work.sum
go.work
muck/
/build/
/test-logs/
/build/

16
.idea/.gitignore generated vendored
View File

@ -1,8 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

15
.idea/dataSources.xml generated
View File

@ -1,15 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="testbed_i_think@localhost" uuid="ba7bd11e-a526-49f4-ab27-28ba42665ce5">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://localhost:5432/testbed_i_think</jdbc-url>
<jdbc-additional-properties>
<property name="JdbcLog.Enabled" value="true" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

View File

@ -1,49 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="CssDeprecatedValue" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidAtRule" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidCharsetRule" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssInvalidCustomPropertyAtRuleDeclaration" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidCustomPropertyAtRuleName" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidFunction" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidHtmlTagReference" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssInvalidImport" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssInvalidMediaFeature" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidPropertyValue" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssInvalidPseudoSelector" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssMissingComma" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssNegativeValue" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssNoGenericFontName" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssNonIntegerLengthInPixels" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="CssOverwrittenProperties" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssRedundantUnit" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="CssReplaceWithShorthandSafely" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="CssReplaceWithShorthandUnsafely" enabled="false" level="INFORMATION" enabled_by_default="false" />
<inspection_tool class="CssUnknownProperty" enabled="false" level="WARNING" enabled_by_default="false">
<option name="myCustomPropertiesEnabled" value="false" />
<option name="myIgnoreVendorSpecificProperties" value="false" />
<option name="myCustomPropertiesList">
<value>
<list size="0" />
</value>
</option>
</inspection_tool>
<inspection_tool class="CssUnknownTarget" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssUnknownUnit" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssUnresolvedClassInComposesRule" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssUnresolvedCustomProperty" enabled="false" level="ERROR" enabled_by_default="false" />
<inspection_tool class="CssUnusedSymbol" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="GoErrorStringFormat" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="GoMixedReceiverTypes" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="GoSwitchMissingCasesForIotaConsts" enabled="true" level="WEAK WARNING" enabled_by_default="true" editorAttributes="INFO_ATTRIBUTES" />
<inspection_tool class="GoUnhandledErrorResult" enabled="true" level="WEAK WARNING" enabled_by_default="true" editorAttributes="INFO_ATTRIBUTES" />
<inspection_tool class="GrazieInspection" enabled="false" level="GRAMMAR_ERROR" enabled_by_default="false" />
<inspection_tool class="LanguageDetectionInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
<option name="processCode" value="true" />
<option name="processLiterals" value="true" />
<option name="processComments" value="true" />
</inspection_tool>
</profile>
</component>

379
README.md Normal file
View File

@ -0,0 +1,379 @@
# diamond
### a golang ORM for mongodb that rocks 🎸~♬
# usage
## installation
run the following command in your terminal...
```shell
go get rockfic.com/orm
```
...and import the package at the top of your file(s) like so:
```go
package tutorial
import "rockfic.com/orm"
```
## Convert your database to a replica set (if you haven't already)
this is **very important!!** otherwise, functionality like
auto-incrementing IDs will not be available.
add these lines to `mongod.conf`:
```
replication:
replSetName: "rs0"
```
of course, you can replace `rs0` with whatever name you want.
## Connect to the database
```go
package tutorial
import "rockfic.com/orm"
func main() {
orm.Connect("mongodb://127.0.0.1", "your_database")
}
```
this will create a connection and store it in the `DB` global variable. This global variable is used internally to
interact with the underlying database.
if you need to <sub><sub>~~why?~~</sub></sub>, you may also access the mongoDB client directly via the `orm.Client`
variable.
## Create a Model
to create a new model, you need to define a struct like this:
```go
package tutorial
import "rockfic.com/orm"
type User struct {
orm.Document `bson:",inline" coll:"collection"`
ID int64 `bson:"_id"`
Username string `bson:"username"`
Email string `bson:"email"`
Friends []User `bson:"friends"`
}
```
this on its own is useless. to actually do anything useful with it, you need to *register* this struct as a model:
```go
package tutorial
import "rockfic.com/orm"
func main() {
/* ~ snip ~ */
orm.ModelRegistry.Model(User{})
}
```
you can also pass multiple arguments to `orm.Model`, so long as they embed the `Document` struct.
you can access the newly created model like this:
```go
package tutorial
import "rockfic.com/orm"
func main() {
/* ~ snip ~ */
userModel := orm.ModelRegistry.Get("User")
}
```
## Documents
any struct can be used as a document, so long as it embeds the `Document` struct. the `Document` struct is special, in
that it turns any structs which embed it into `IDocument` implementors.
`Document` should be embedded with a `bson:",inline" tag, otherwise you will end up with something like this in the
database:
```bson
{
"document": {
"createdAt": ISODate('2001-09-11T05:37:18.742Z'),
"updatedAt": ISODate('2001-09-11T05:37:18.742Z')
},
_id: 1,
username: "foobar",
email: "test@testi.ng",
friends: []
}
```
a `coll` or `collection` tag is also required, to assign the model to a mongodb collection. this tag is only valid on
the embedded `Document` field, and each document can only have one collection associated with it.
<sub>obviously.</sub>
<sub><sub>i mean seriously, who'd want to store one thing in two places?</sub></sub>
the recommended way to create a new document instance is via the `orm.Create` method:
```go
package tutorial
import "rockfic.com/orm"
func main() {
/* ~ snip ~ */
user := orm.Create(User{}).(*User)
}
```
similarly, to create a slice of documents, call `orm.CreateSlice`:
```go
package tutorial
import "rockfic.com/orm"
func main() {
/* ~ snip ~ */
users := orm.CreateSlice[User](User{})
}
```
lastly, let's implement the `HasID` interface on our document:
```go
package tutorial
import "rockfic.com/orm"
func (u *User) GetId() any {
return u.ID
}
func (u *User) SetId(id any) {
u.ID = id.(int64)
}
```
#### but why do i need to implement this? :(
other ORMs rudely assume that you want your `_id`s to be mongodb ObjectIDs, which are **fucking ugly**, but also, more
importantly, *may not match your existing schema.*
by doing things this way, you can set your _id to be any of the following types:
- int
- int32
- int64
- uint
- uint32
- uint64
- string
- ObjectId
<sub>(if you hate yourself that much)</sub>
## finding/querying
`Find`, `FindOne`, `FindByID`, and `FindPaged` all return an `*orm.Query` object.
to get the underlying value, simply call `.Exec()`, passing to it a pointer to the variable in which you wish to store
the results.
### `Find` example
```go
package tutorial
import (
"go.mongodb.org/mongo-driver/bson"
"rockfic.com/orm"
)
func main() {
/* ~ snip ~ */
userModel := orm.ModelRegistry.Get("User")
if userModel != nil {
res := orm.CreateSlice(User{})
jq, _ := userModel.Find(bson.M{"username": "foobar"})
q.Exec(&res)
}
}
```
## saving/updating
to save a document, simply call the `Save()` method. if the document doesn't exist, the ORM will create it, otherwise it
replaces the existing one.
## referencing other documents
it's possible to store references to other documents/structs, or slices to other documents. given the `User` struct we
defined above, which contains a slice (`Friends`) referencing other `User`s, we could change the `Friends` field in the
`User` type to be populateable, like so:
```diff
type User struct {
orm.Document `bson:",inline" coll:"collection"`
ID int64 `bson:"_id"`
Username string `bson:"username"`
Email string `bson:"email"`
- Friends []User `bson:"friends"`
+ Friends []User `bson:"friends" ref:"User"`
}
```
assuming we've filled a user's `Friends` with a few friends...
```go
package tutorial
import "rockfic.com/orm"
func main() {
/* ~ snip ~ */
user := orm.Create(User{}).(*User)
for i := 0; i < 10; i++ {
friend := orm.Create(User{ID: int64(i + 2)}).(*User)
user.Friends = append(user.Friends, friend)
}
}
```
...in the database, `friends` will be stored like this:
```json5
{
// ~ snip ~ //
friends: [
2,
3,
4,
5,
6,
7,
8,
9,
10,
11
]
}
```
after retrieving a `User` from the database, you can load their `Friends` and their info by using the `Populate`function
on the returned `Query` pointer, like in the following example:
```go
package tutorial
import (
"fmt"
"go.mongodb.org/mongo-driver/bson"
"rockfic.com/orm"
)
func main() {
/* ~ snip ~ */
query, err := userModel.FindOne(bson.M{"_id": 1})
if err == nil {
result := orm.Create(User{}).(*User)
query.Populate("Friends").Exec(result)
fmt.Printf("%+v", result)
}
}
```
### Models vs Documents: what's the difference?
a **`Model`** is a static type that contains information about a given Document-like type. a **`Document`**, on the
other hand, is a full-fledged instances with its own set of methods.
if you come from an object-oriented programming background, you can envision the relationship between a `Model` and
`Document` something like this (using Java as an example):
```java
public class Document {
public void Pull(field String, elements ...Document) {
// ...
}
public void Append(field String, elements ...Document) {
// ...
}
public void Save() {
//...
}
public void Delete() {
//...
}
// model methods //
public static Document FindOne(HashMap<String, Object> query, MongoOptions options) {
// ...
}
public static ArrayList<Document> FindAll(HashMap<String, Object> query, MongoOptions options) {
// ...
}
}
```
## Struct Tags
the following tags are recognized by the ORM:
### `ref`
the syntax of the `ref` tag is:
```
`ref:"struct_name"`
```
where `struct_name` is the name of the struct you're referencing, as shown
in [Referencing other documents](#referencing-other-documents).
### `gridfs`
you can load files stored in [gridfs](https://www.mongodb.com/docs/manual/core/gridfs/) directly into your structs using
this tag, which has the following syntax:
```
`gridfs:"bucket_name,file_fmt"
```
where:
- `bucket_name` is the name of the gridfs bucket
- `file_fmt` is a valid [go template string](https://pkg.go.dev/text/template) that resolves to the unique file name.
all exported methods and fields in the surrounding struct can be referenced in this template.
currently, the supported field types for the `gridfs` tag are:
- `string`
- `[]byte`
### `idx`/`index`
indexes can be defined using this tag on the field you want to create the index for. the syntax is as follows:
`idx:{field || field1,field2,...},keyword1,keyword2,...`
supported keywords are `unique` and `sparse`.
`background` is technically allowed, but it's deprecated, and so will essentially be a no-op on mongodb 4.2+.
## acknowledgements
- [goonode/mogo](https://github.com/goonode/mogo), the project which largely inspired this one
# Further reading
see the [godocs](https://git.tablet.sh/tablet/diamond-orm/src/branch/main/godoc.html) for more details :)

View File

@ -1,67 +0,0 @@
package orm
import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)
func testCreate1(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Save(&u)
assert.Nil(t, err)
assert.NotZero(t, u.Favs.ID)
}
func testCreate2(t assert.TestingT, e *Engine) {
u := friend(t)
err := e.Model(&user{}).Save(&u)
assert.Nil(t, err)
assert.NotZero(t, u.Favs.ID)
}
func testCreate3(t assert.TestingT, e *Engine) {
insertBands(t, e)
}
func testCreate4(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Save(&u)
if err != nil {
fmt.Println(err.Error())
}
assert.Nil(t, err)
insertBands(t, e)
storyBase(e, t, u)
}
func TestCreate1(t *testing.T) {
e := initTest(t)
testCreate1(t, e)
e.Disconnect()
}
func TestCreate2(t *testing.T) {
e := initTest(t)
testCreate2(t, e)
e.Disconnect()
}
func TestCreate3(t *testing.T) {
e := initTest(t)
testCreate3(t, e)
e.Disconnect()
}
func TestCreate4(t *testing.T) {
e := initTest(t)
testCreate4(t, e)
e.Disconnect()
}
func BenchmarkCreate(b *testing.B) {
b.Run("Create-1", bench(testCreate1))
b.Run("Create-2", bench(testCreate2))
b.Run("Create-3", bench(testCreate3))
b.Run("Create-4", bench(testCreate4))
}

View File

@ -1 +0,0 @@
package orm

View File

@ -1,201 +0,0 @@
package orm
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"io"
"log/slog"
"os"
"rockfic.com/orm/internal/logging"
"time"
)
// LevelQuery enables logging of SQL queries if passed to Config.LogLevel
const LevelQuery = slog.Level(-6)
const defaultKey = "default"
type Config struct {
DryRun bool // when true, queries will not run on the underlying database
LogLevel slog.Level // controls the level of information logged; defaults to slog.LevelInfo if not set
LogTo io.Writer // where to write log output to; defaults to os.Stdout
}
type Engine struct {
modelMap *internalModelMap
conn *pgxpool.Pool
m2mSeen map[string]bool
dryRun bool
pgCfg *pgxpool.Config
ctx context.Context
logger *slog.Logger
cfg *Config
levelVar *slog.LevelVar
connStr string
}
// Models - parse and register one or more types as persistable models
func (e *Engine) Models(v ...any) {
emm := makeModelMap(v...)
for k := range emm.Map {
if _, ok := e.modelMap.Map[k]; !ok {
e.modelMap.Mux.Lock()
e.modelMap.Map[k] = emm.Map[k]
e.modelMap.Mux.Unlock()
}
}
}
// Model - createes a Query and sets its model to
// the one corresponding to the type of `val`
func (e *Engine) Model(val any) *Query {
qq := &Query{
engine: e,
ctx: context.Background(),
wheres: make(map[string][]any),
orders: make([]string, 0),
populationTree: make(map[string]any),
joins: make([]string, 0),
}
return qq.setModel(val)
}
// QueryRaw - wrapper for the Query method of pgxpool.Pool
func (e *Engine) QueryRaw(sql string, args ...any) (pgx.Rows, error) {
return e.conn.Query(e.ctx, sql, args...)
}
// Migrate - non-destructive; run migrations to update the underlying schema, WITHOUT dropping tables beforehand
func (e *Engine) Migrate() error {
failedMigrations := make(map[string]*Model)
var err error
for mk, m := range e.modelMap.Map {
err = m.migrate(e)
if err != nil {
failedMigrations[mk] = m
}
}
for len(failedMigrations) > 0 {
e.m2mSeen = make(map[string]bool)
for mk, m := range failedMigrations {
err = m.migrate(e)
if err == nil {
delete(failedMigrations, mk)
}
}
}
return err
}
// MigrateDropping - destructive migration; DROP the necessary tables if they exist,
// then recreate them to match your models' schema
func (e *Engine) MigrateDropping() error {
for _, m := range e.modelMap.Map {
sql := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", m.TableName)
if _, err := e.conn.Exec(e.ctx, sql); err != nil {
return err
}
for _, r := range m.Relationships {
if r.m2mIsh() || r.Type == ManyToMany {
jsql := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", r.ComputeJoinTable())
if _, err := e.conn.Exec(e.ctx, jsql); err != nil {
return err
}
}
}
}
return e.Migrate()
}
func (e *Engine) logQuery(msg, sql string, args []any) {
e.logger.Log(e.ctx, LevelQuery, msg, "sql", sql, "args", logTrunc(200, args))
}
func (e *Engine) logSql(msg, sql string) {
e.logger.Log(e.ctx, LevelQuery, msg, "sql", sql)
}
// Disconnect - closes and disposes of this Engine's connection pool.
func (e *Engine) Disconnect() {
e.conn.Close()
if asFile, ok := e.cfg.LogTo.(*os.File); ok {
_ = asFile.Close()
}
}
// Open - creates a new connection according to `connString`
// and returns a brand new Engine to run FUCK operations on.
func Open(connString string, cfg *Config) (*Engine, error) {
if cfg == nil {
cfg = &Config{
LogLevel: slog.LevelInfo,
LogTo: os.Stdout,
DryRun: connString == "",
}
} else {
if cfg.LogTo == nil {
cfg.LogTo = os.Stdout
}
}
e := &Engine{
modelMap: &internalModelMap{
Map: make(map[string]*Model),
},
m2mSeen: make(map[string]bool),
dryRun: connString == "",
ctx: context.Background(),
levelVar: new(slog.LevelVar),
cfg: cfg,
}
e.levelVar.Set(cfg.LogLevel)
replacer := func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.LevelKey {
level := a.Value.Any().(slog.Level)
switch level {
case LevelQuery:
a.Value = slog.StringValue("query")
}
}
return a
}
e.logger = slog.New(logging.NewFormattedHandler(cfg.LogTo, logging.Options{
Level: e.levelVar,
ReplaceAttr: replacer,
Format: "{{.Time}} [{{.Level}}] {{.Message}} | {{ rest }}",
}))
slog.SetDefault(e.logger)
if connString != "" {
engines.Mux.Lock()
if len(engines.Engines) == 0 || engines.Engines[defaultKey] == nil {
engines.Engines[defaultKey] = e
} else {
engines.Engines[connString] = e
}
e.connStr = ""
engines.Mux.Unlock()
var err error
e.pgCfg, err = pgxpool.ParseConfig(connString)
e.pgCfg.MinConns = 5
e.pgCfg.MaxConns = 10
e.pgCfg.MaxConnIdleTime = time.Minute * 2
e.pgCfg.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
oldHandler := conn.Config().OnPgError
conn.Config().OnPgError = func(conn *pgconn.PgConn, pgError *pgconn.PgError) bool {
e.logger.Error("ERROR ->", "err", pgError.Error())
return oldHandler(conn, pgError)
}
return nil
}
if err != nil {
return nil, err
}
e.conn, err = pgxpool.NewWithConfig(e.ctx, e.pgCfg)
if err != nil {
return nil, err
}
}
return e, nil
}

View File

@ -1,15 +1,341 @@
package orm
import "time"
import (
"encoding/json"
"fmt"
"go.mongodb.org/mongo-driver/v2/bson"
"log"
"reflect"
"strings"
"time"
)
// Document - embed this into your structs anonymously to specify
// model parameters like table name via struct tags
type Document struct {
Created time.Time `json:"createdAt" tstype:"Date"`
Modified time.Time `json:"modifiedAt" tstype:"Date"`
// Created time. updated/added automatically.
Created time.Time `bson:"createdAt" json:"createdAt" tstype:"Date"`
// Modified time. updated/added automatically.
Modified time.Time `bson:"updatedAt" json:"updatedAt" tstype:"Date"`
model *Model `bson:"-"`
exists bool `bson:"-"`
self any `bson:"-"`
raw any `bson:"-"`
populatedFields map[string]bool `bson:"-"`
}
func (d *Document) UnmarshalJSON(bytes []byte) error {
var fiv interface{}
if err := json.Unmarshal(bytes, &fiv); err != nil {
return err
}
var err error
switch fiv.(type) {
case []interface{}:
tmpV := make(bson.A, 0)
err = json.Unmarshal(bytes, &tmpV)
typ := reflect.SliceOf(d.model.Type)
//d.SetSelf()
var arr []interface{}
reified := rerere(tmpV, typ, true)
if ta, ok := reified.(bson.A); ok {
arr = []interface{}(ta)
} else {
arr = reified.([]interface{})
}
fmt.Println(len(arr))
break
case map[string]interface{}:
tmpV := make(bson.M)
err = json.Unmarshal(bytes, &tmpV)
typ := reflect.PointerTo(d.model.Type)
self := reflect.ValueOf(d.self)
nself := reflect.NewAt(typ.Elem(), self.UnsafePointer())
reified := rerere(tmpV, typ, true)
nself.Elem().Set(reflect.ValueOf(reified).Elem())
d.self = nself.Interface()
break
}
return err
}
func (d *Document) MarshalJSON() ([]byte, error) {
v := serializeIDs((d).self, true, d.populatedFields, "")
return json.Marshal(v)
}
type IDocument interface {
Append(field string, a ...interface{}) error
Pull(field string, a ...any) error
Swap(field string, i, j int) error
Delete() error
Remove() error
Save() error
SaveWith(opts *SaveOptions) error
Populate(fields ...string)
SetSelf(arg interface{})
getExists() bool
setExists(n bool)
setModified(Modified time.Time)
setCreated(Modified time.Time)
getModified() time.Time
getCreated() time.Time
serializeToStore() any
getModel() *Model
setModel(m Model)
markPopulated(field string)
markDepopulated(field string)
newPopulationMap()
getPopulated() map[string]bool
getRaw() any
setRaw(raw any)
}
// SaveOptions - unused (for now)
type SaveOptions struct {
SetTimestamps bool
}
func (d *Document) getCreated() time.Time {
return d.Created
}
func (d *Document) setCreated(Created time.Time) {
d.Created = Created
}
func (d *Document) getModified() time.Time {
return d.Modified
}
func (d *Document) setModified(Modified time.Time) {
d.Modified = Modified
}
// SetSelf - don't call this lol
func (d *Document) SetSelf(arg interface{}) {
d.self = arg
}
func (d *Document) getModel() *Model {
return d.model
}
func (d *Document) setModel(m Model) {
d.model = &m
}
func (d *Document) getExists() bool {
return d.exists
}
func (d *Document) setExists(n bool) {
d.exists = n
}
// Delete - deletes a model instance from the database
func (d *Document) Delete() error {
var err error
val := valueOf(d.self)
if val.Kind() == reflect.Slice {
for i := 0; i < val.Len(); i++ {
cur := val.Index(i)
if err = doDelete(d, cur.Interface()); err != nil {
return err
}
}
return nil
} else {
return doDelete(d, d.self)
}
}
// Remove - alias for Delete
func (d *Document) Remove() error {
return d.Delete()
}
// SaveWith - updates this Model in the database,
// or inserts it if it doesn't exist, using the provided
// SaveOptions
func (d *Document) SaveWith(opts *SaveOptions) error {
val := valueOf(d.self)
if val.Kind() == reflect.Slice {
for i := range val.Len() {
cur := val.Index(i)
if err := doSave(d.model.Collection(), !d.exists, opts, cur.Interface()); err != nil {
return err
}
}
return nil
} else {
return doSave(d.model.Collection(), !d.exists, opts, d.self)
}
}
// Save - updates this Model in the database,
// or inserts it if it doesn't exist, using
// default SaveOptions
func (d *Document) Save() error {
return d.SaveWith(&SaveOptions{
SetTimestamps: true,
})
}
func (d *Document) serializeToStore() any {
return serializeIDs((d).self, false, d.populatedFields, "")
}
// Append appends one or more items to `field`.
// will error if this Model contains a reference
// to multiple documents, or if `field` is not a
// slice.
func (d *Document) Append(field string, a ...interface{}) error {
var d0 IDocument = d
d0.getCreated()
rv := reflect.ValueOf(d.self)
selfRef := rv
rt := reflect.TypeOf(d.self)
if selfRef.Kind() == reflect.Pointer {
selfRef = selfRef.Elem()
rt = rt.Elem()
}
if err := checkStruct(selfRef); err != nil {
return err
}
_, origV, err := getNested(field, selfRef)
if err != nil {
return err
}
origRef := makeSettable(*origV, (*origV).Interface())
fv := origRef
if fv.Kind() == reflect.Pointer {
fv = fv.Elem()
}
if fv.Kind() != reflect.Slice {
return ErrNotASlice
}
for _, b := range a {
val := reflect.ValueOf(incrementTagged(b))
fv.Set(reflect.Append(fv, val))
}
return nil
}
// Pull - removes elements from the subdocument slice stored in `field`.
func (d *Document) Pull(field string, a ...any) error {
rv := reflect.ValueOf(d.self)
selfRef := rv
rt := reflect.TypeOf(d.self)
if selfRef.Kind() == reflect.Pointer {
selfRef = selfRef.Elem()
rt = rt.Elem()
}
if err := checkStruct(selfRef); err != nil {
return err
}
_, origV, err := getNested(field, selfRef)
if err != nil {
return err
}
origRef := makeSettable(*origV, (*origV).Interface())
fv := origRef
if fv.Kind() == reflect.Pointer {
fv = fv.Elem()
}
if fv.Kind() != reflect.Slice {
return ErrNotASlice
}
for _, b := range a {
inner:
for i := 0; i < fv.Len(); i++ {
if reflect.DeepEqual(b, fv.Index(i).Interface()) {
fv.Set(pull(fv, i, fv.Index(i).Type()))
break inner
}
}
}
return nil
}
// Swap - swaps the elements at indexes `i` and `j` in the
// slice stored at `field`
func (d *Document) Swap(field string, i, j int) error {
rv := reflect.ValueOf(d.self)
selfRef := rv
rt := reflect.TypeOf(d.self)
if selfRef.Kind() == reflect.Pointer {
selfRef = selfRef.Elem()
rt = rt.Elem()
}
if err := checkStruct(selfRef); err != nil {
return err
}
_, origV, err := getNested(field, selfRef)
if err != nil {
return err
}
origRef := makeSettable(*origV, (*origV).Interface())
fv := origRef
if fv.Kind() == reflect.Pointer {
fv = fv.Elem()
}
if err = checkSlice(fv); err != nil {
return err
}
if i >= fv.Len() || j >= fv.Len() {
return ErrOutOfBounds
}
oi := fv.Index(i).Interface()
oj := fv.Index(j).Interface()
fv.Index(i).Set(reflect.ValueOf(oj))
fv.Index(j).Set(reflect.ValueOf(oi))
return nil
}
func (d *Document) Populate(fields ...string) {
_, cm, _ := ModelRegistry.HasByName(d.model.typeName)
if cm != nil {
rawDoc := d.raw
for _, field := range fields {
// 0 = fieldname, 1 = typename, 2 = bson name
r, refOk := cm.references[field]
if refOk {
// get self
// get ptr
// find
// unmarshal...
htt := r.HydratedType
if htt.Kind() == reflect.Pointer || htt.Kind() == reflect.Slice {
htt = htt.Elem()
}
if strings.HasSuffix(field, ".") || strings.HasPrefix(field, ".") {
log.Printf("WARN: invalid field name passed to Populate(). skipping...\n")
continue
}
tto := r.HydratedType
if tto.Kind() == reflect.Pointer || tto.Kind() == reflect.Slice {
tto = tto.Elem()
}
_, refColl, _ := ModelRegistry.HasByName(tto.Name())
var tmp1 interface{}
asIDocument, docOk := d.self.(IDocument)
if docOk {
asIDocument.markPopulated(field)
}
v := reflect.ValueOf(d.self)
tt := v.Elem().Type()
tmp1 = populate(r, d.populatedFields, refColl.collection, rawDoc, field, d.self)
nv := reflect.NewAt(tt, v.UnsafePointer())
nv.Elem().Set(reflect.ValueOf(tmp1).Elem())
}
}
}
}

340
document_internals.go Normal file
View File

@ -0,0 +1,340 @@
package orm
import (
"context"
"errors"
"fmt"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"reflect"
"time"
)
func serializeIDs(input interface{}, isJson bool, populated map[string]bool, parent string) interface{} {
var key string
if isJson {
key = "json"
} else {
key = "bson"
}
vp := reflect.ValueOf(input)
mt := reflect.TypeOf(input)
var ret interface{}
if vp.Kind() != reflect.Ptr {
if vp.CanAddr() {
vp = vp.Addr()
} else {
vp = makeSettable(vp, input)
}
}
if mt.Kind() == reflect.Pointer {
mt = mt.Elem()
}
getID := func(bbb interface{}) interface{} {
mptr := reflect.ValueOf(bbb)
if mptr.Kind() != reflect.Pointer {
mptr = makeSettable(mptr, bbb)
}
ifc, ok := mptr.Interface().(HasID)
if ok {
return ifc.Id()
} else {
return nil
}
}
switch vp.Elem().Kind() {
case reflect.Struct:
ret0 := bson.M{}
for i := 0; i < vp.Elem().NumField(); i++ {
fv := vp.Elem().Field(i)
ft := mt.Field(i)
var descent string
if parent != "" {
descent = parent + "." + ft.Name
} else {
descent = ft.Name
}
tag, err := structtag.Parse(string(ft.Tag))
panik(err)
bbson, err := tag.Get(key)
if err != nil || bbson.Name == "-" {
continue
}
if bbson.Name == "" {
marsh, _ := bson.Marshal(fv.Interface())
unmarsh := bson.M{}
bson.Unmarshal(marsh, &unmarsh)
for k, v := range unmarsh {
ret0[k] = v
}
} else {
_, terr := tag.Get("ref")
if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer {
vp1 := reflect.New(fv.Type())
vp1.Elem().Set(reflect.ValueOf(fv.Interface()))
fv.Set(vp1.Elem())
}
var ip bool
for k1, v1 := range populated {
if k1 == descent {
ip = v1
break
}
}
if iidoc, ok := input.(IDocument); ok && !ip {
for k1, v1 := range iidoc.getPopulated() {
if k1 == descent || k1 == ft.Name {
ip = v1
break
}
}
}
if terr == nil {
ifc, ok := fv.Interface().(HasID)
if fv.Kind() == reflect.Slice {
rarr := bson.A{}
for j := 0; j < fv.Len(); j++ {
if !isJson {
rarr = append(rarr, getID(fv.Index(j).Interface()))
} else {
if ip {
rarr = append(rarr, serializeIDs(fv.Index(j).Interface(), isJson, populated, descent))
} else {
rarr = append(rarr, getID(fv.Index(j).Interface()))
}
}
}
ret0[bbson.Name] = rarr
} else if !ok {
panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name))
} else {
if reflect.ValueOf(ifc).IsNil() || reflect.ValueOf(ifc.Id()).IsZero() {
ret0[bbson.Name] = nil
} else {
if !isJson {
ret0[bbson.Name] = ifc.Id()
} else {
if ip && bbson.Name != "-" {
ret0[bbson.Name] = serializeIDs(fv.Interface(), isJson, populated, descent)
} else if bbson.Name != "-" {
ret0[bbson.Name] = ifc.Id()
}
}
}
}
} else {
if fv.Type() == reflect.TypeFor[time.Time]() {
ret0[bbson.Name] = fv.Interface()
} else {
ret0[bbson.Name] = serializeIDs(fv.Interface(), isJson, populated, descent)
}
}
}
ret = ret0
}
case reflect.Slice:
ret0 := bson.A{}
for i := 0; i < vp.Elem().Len(); i++ {
ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface(), isJson, populated, parent))
}
ret = ret0
default:
ret = vp.Elem().Interface()
}
return ret
}
func doSave(c *mongo.Collection, isNew bool, opts *SaveOptions, arg interface{}) error {
var err error
d, ok := arg.(IDocument)
if !ok {
return fmt.Errorf(errFmtNotAModel, nameOf(arg))
}
d.SetSelf(d)
now := time.Now()
selfo := reflect.ValueOf(d)
vp := selfo
if vp.Kind() != reflect.Ptr {
vp = reflect.New(selfo.Type())
vp.Elem().Set(selfo)
}
var asHasId = vp.Interface().(HasID)
var asModel = vp.Interface().(IDocument)
if (isNew && reflect.ValueOf(asHasId.Id()).IsZero()) && opts.SetTimestamps {
d.setCreated(now)
}
if opts.SetTimestamps {
d.setModified(now)
}
idxs := d.getModel().getIdxs()
for _, i := range idxs {
_, err = c.Indexes().CreateOne(context.TODO(), *i)
if err != nil {
return err
}
}
if isNew {
nid := getLastInColl(c.Name(), asHasId.Id())
pnid := incrementInterface(nid)
if reflect.ValueOf(asHasId.Id()).IsZero() {
(asHasId).SetId(pnid)
}
incrementAll(asHasId)
_, im, _ := ModelRegistry.HasByName(asModel.getModel().getTypeName())
_ = gridFsSave(asHasId, *im)
_, err = c.InsertOne(context.TODO(), d.serializeToStore())
if err == nil {
d.setExists(true)
} else {
_, err = c.UpdateOne(context.TODO(), bson.D{{Key: "_id", Value: d.(HasID).Id()}}, bson.M{
"$set": d.serializeToStore(),
})
//_, err = c.ReplaceOne(context.TO_DO(), bson.D{{Key: "_id", Value: d.(HasID).Id()}}, d.serializeToStore())
}
} else {
//_, err = c.ReplaceOne(context.TO_DO(), bson.D{{Key: "_id", Value: d.(HasID).Id()}}, d.serializeToStore())
_, im, _ := ModelRegistry.HasByName(asModel.getModel().getTypeName())
_ = gridFsSave(asHasId, *im)
_, err = c.UpdateOne(context.TODO(), bson.D{{Key: "_id", Value: d.(HasID).Id()}}, bson.M{
"$set": d.serializeToStore(),
})
if errors.Is(err, mongo.ErrNoDocuments) {
_, err = c.InsertOne(context.TODO(), d.serializeToStore())
if err == nil {
d.setExists(true)
}
}
}
return err
}
func doDelete(d *Document, arg interface{}) error {
self, ok := arg.(HasID)
if !ok {
return fmt.Errorf(errFmtNotHasID, nameOf(arg))
}
c := d.model.Collection()
_, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()})
if err == nil {
d.exists = false
err = gridFsDel(arg, *d.model)
}
return err
}
func incrementTagged(item interface{}) interface{} {
rv := reflect.ValueOf(item)
rt := reflect.TypeOf(item)
if rv.Kind() != reflect.Pointer {
rv = makeSettable(rv, item)
}
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
if rt.Kind() != reflect.Struct {
if rt.Kind() == reflect.Slice {
for i := 0; i < rv.Elem().Len(); i++ {
incrementTagged(rv.Elem().Index(i).Addr().Interface())
}
} else {
return item
}
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
cur := rv.Elem().Field(i)
tags, err := structtag.Parse(string(structField.Tag))
if err != nil {
continue
}
incTag, err := tags.Get("autoinc")
if err != nil {
continue
}
nid := getLastInColl(incTag.Name, cur.Interface())
if cur.IsZero() {
coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur)
if coerced != nil {
cur.Set(reflect.ValueOf(coerced))
} else {
cur.Set(reflect.ValueOf(incrementInterface(nid)))
}
counterColl := DB.Collection(COUNTER_COL)
counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.UpdateOne().SetUpsert(true))
}
}
return rv.Elem().Interface()
}
func incrementAll(item interface{}) {
if item == nil {
return
}
vp := reflect.ValueOf(item)
el := vp
if vp.Kind() == reflect.Pointer {
el = vp.Elem()
}
if vp.Kind() == reflect.Pointer && vp.IsNil() {
return
}
vt := el.Type()
switch el.Kind() {
case reflect.Struct:
incrementTagged(item)
for i := 0; i < el.NumField(); i++ {
fv := el.Field(i)
fst := vt.Field(i)
if !fst.IsExported() {
continue
}
incrementAll(fv.Interface())
}
case reflect.Slice:
for i := 0; i < el.Len(); i++ {
incd := incrementTagged(el.Index(i).Addr().Interface())
if reflect.ValueOf(incd).Kind() == reflect.Pointer {
el.Index(i).Set(reflect.ValueOf(incd).Elem())
} else {
el.Index(i).Set(reflect.ValueOf(incd))
}
}
default:
}
}
func (d *Document) markPopulated(field string) {
d.newPopulationMap()
d.populatedFields[field] = true
}
func (d *Document) markDepopulated(field string) {
d.newPopulationMap()
d.populatedFields[field] = false
}
func (d *Document) newPopulationMap() {
if d.populatedFields == nil {
d.populatedFields = make(map[string]bool)
}
}
func (d *Document) getPopulated() map[string]bool {
return d.populatedFields
}
func (d *Document) setRaw(raw any) {
d.raw = raw
}
func (d *Document) getRaw() any {
return d.raw
}

47
document_slice.go Normal file
View File

@ -0,0 +1,47 @@
package orm
type IDocumentSlice interface {
Delete() error
Remove() error
Save() error
setExists(n bool)
getModel() *Model
}
type DocumentSlice []IDocument
func (d *DocumentSlice) Delete() error {
var err error
for _, doc := range *d {
err = doc.Delete()
if err != nil {
return err
}
}
return err
}
func (d *DocumentSlice) Remove() error {
return d.Delete()
}
func (d *DocumentSlice) Save() error {
var err error
for _, doc := range *d {
err = doc.Save()
if err != nil {
return err
}
}
return err
}
func (d *DocumentSlice) setExists(b bool) {
for _, s := range *d {
s.setExists(b)
}
}
func (d *DocumentSlice) getModel() *Model {
return (*d)[0].getModel()
}

View File

@ -1,7 +1,21 @@
package orm
import "fmt"
import (
"errors"
)
var ErrNoConditionOnDeleteOrUpdate = fmt.Errorf("refusing to delete/update with no conditions specified.\n"+
" (hint: call `.WhereRaw(%s)` or `.WhereRaw(%s)` to do so anyways)",
`"true"`, `"1 = 1"`)
var (
ErrNotASlice = errors.New("Current object or field is not a slice!")
ErrNotAStruct = errors.New("Current object or field is not a struct!")
ErrOutOfBounds = errors.New("Index(es) out of bounds!")
ErrAppendMultipleDocuments = errors.New("Cannot append to multiple documents!")
ErrNotSliceOrStruct = errors.New("Current object or field is not a slice nor a struct!")
ErrUnsupportedID = errors.New("Unknown or unsupported id type provided")
)
const (
errFmtMalformedField = "Malformed field name passed: '%s'"
errFmtNotAModel = "Type '%s' is not a model"
errFmtNotHasID = "Type '%s' does not implement HasID"
errFmtModelNotRegistered = "Model not registered for type: '%s'"
)

159
field.go
View File

@ -1,159 +0,0 @@
package orm
import (
"net"
"reflect"
"time"
)
// Field - represents a field with a valid SQL type in a Model
type Field struct {
Name string // the name of this field as it appears in its Model's type definition
ColumnName string // this field's snake_cased column name as it appears in database
ColumnType string // the SQL type of this field's column (bigint, bigserial, text, ...)
Type reflect.Type // the reflect.Type of the struct field this Field represents
Original reflect.StructField // the raw struct field, as obtained by using reflect.Type.Field or reflect.Type.FieldByName
Model *Model // the Model this field belongs to
Index int // the index at which Original appears in its struct
AutoIncrement bool // whether this field's column is an auto-incrementing column
PrimaryKey bool // true if this field's column is a primary key
Nullable bool // true if this field's column can be NULL
embeddedFields map[string]*Field // mapping of column names to Field pointers that correspond to the surrounding struct's fields
}
func (f *Field) isAnonymous() bool {
return f.Original.Anonymous
}
func (f *Field) anonymousColumnNames() []string {
cols := make([]string, 0)
if !f.isAnonymous() {
return cols
}
for _, ef := range f.embeddedFields {
cols = append(cols, ef.ColumnName)
}
return cols
}
func defaultColumnValue(ty reflect.Type) any {
switch ty.Kind() {
case reflect.Int32, reflect.Uint32, reflect.Int, reflect.Uint, reflect.Int64, reflect.Uint64:
return 0
case reflect.Bool:
return false
case reflect.String:
return "''"
case reflect.Float32, reflect.Float64:
return 0.0
case reflect.Struct:
if canConvertTo[time.Time](ty) {
return "now()"
}
if canConvertTo[net.IP](ty) {
return "'0.0.0.0'::INET"
}
if canConvertTo[net.IPNet](ty) {
return "'0.0.0.0/0'::CIDR"
}
case reflect.Slice:
return "'{}'"
}
return "NULL"
}
func columnType(ty reflect.Type, isPk, isAutoInc bool) string {
it := ty
switch it.Kind() {
case reflect.Ptr:
for it.Kind() == reflect.Ptr {
it = it.Elem()
}
case reflect.Int32, reflect.Uint32, reflect.Int, reflect.Uint:
if isPk || isAutoInc {
return "serial"
} else {
return "int"
}
case reflect.Int64, reflect.Uint64:
if isPk || isAutoInc {
return "bigserial"
} else {
return "bigint"
}
case reflect.String:
return "text"
case reflect.Float32:
return "float4"
case reflect.Float64:
return "double precision"
case reflect.Bool:
return "boolean"
case reflect.Struct:
if canConvertTo[time.Time](ty) {
return "timestamptz"
}
if canConvertTo[net.IP](ty) {
return "inet"
}
if canConvertTo[net.IPNet](ty) {
return "cidr"
}
default:
return ""
}
return ""
}
func parseField(f reflect.StructField, minfo *Model, modelMap map[string]*Model, i int) *Field {
field := &Field{
Name: f.Name,
Original: f,
Index: i,
}
tags := parseTags(f.Tag.Get("d"))
if tags["-"] != "" {
return nil
}
field.PrimaryKey = tags["pk"] != "" || tags["primarykey"] != "" || field.Name == "ID"
field.AutoIncrement = tags["autoinc"] != ""
field.Nullable = tags["nullable"] != ""
field.ColumnType = tags["type"]
if field.ColumnType == "" {
field.ColumnType = columnType(f.Type, field.PrimaryKey, field.AutoIncrement)
}
field.ColumnName = tags["column"]
if field.ColumnName == "" {
field.ColumnName = pascalToSnakeCase(field.Name)
}
if field.PrimaryKey {
minfo.IDField = field.Name
}
elem := f.Type
for elem.Kind() == reflect.Ptr {
if !field.Nullable {
field.Nullable = true
}
elem = elem.Elem()
}
field.Type = elem
switch elem.Kind() {
case reflect.Array, reflect.Slice:
elem = elem.Elem()
fallthrough
case reflect.Struct:
if canConvertTo[Document](elem) && f.Anonymous {
minfo.TableName = tags["table"]
field.embeddedFields = make(map[string]*Field)
for j := range elem.NumField() {
efield := elem.Field(j)
field.embeddedFields[pascalToSnakeCase(efield.Name)] = parseField(efield, minfo, modelMap, j)
}
} else if field.ColumnType == "" {
minfo.Relationships[field.Name] = parseRelationship(f, modelMap, minfo.Type, i, tags)
}
}
return field
}

View File

@ -1,171 +0,0 @@
package orm
import (
"github.com/stretchr/testify/assert"
"testing"
)
func storyBaseLoop(t assert.TestingT, e *Engine, u user, count int, additionalPopulateFields ...string) {
if !isTestBench(t) {
for range count {
storyBase(e, t, u, additionalPopulateFields...)
}
} else {
storyBase(e, t, u, additionalPopulateFields...)
}
}
func storyBase(e *Engine, t assert.TestingT, u user, additionalPopulateFields ...string) *story {
s := iti_multi(u)
err := e.Model(&story{}).Save(s)
assert.Nil(t, err)
assert.NotZero(t, s.ID)
checkChapters(t, s)
var ns story
fields := []string{
PopulateAll,
}
fields = append(fields, additionalPopulateFields...)
err = e.Model(&story{}).Where("ID = ?", s.ID).Populate(fields...).Find(&ns)
assert.Nil(t, err)
assert.NotZero(t, ns.ID)
assert.NotZero(t, ns.Author.ID)
assert.NotZero(t, ns.Author.Username)
assert.Equal(t, len(s.Chapters), len(ns.Chapters))
checkChapters(t, &ns)
return &ns
}
func testFind1(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Save(&u)
assert.Nil(t, err)
var nu user
err = e.Model(&user{}).Where("ID = ?", u.ID).Find(&nu)
assert.Nil(t, err)
assert.NotZero(t, nu.ID)
assert.NotZero(t, nu.Username)
}
func testJoin1(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Save(&u)
assert.Nil(t, err)
insertBands(t, e)
storyBaseLoop(t, e, u, 7)
var withBodom []story
err = e.Model(&story{}).
Join("Chapters.Bands").
In("Chapters.Bands.ID", bodom.ID).
Find(&withBodom)
assert.Nil(t, err)
assert.NotEmpty(t, withBodom)
}
func testIn(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Save(&u)
assert.Nil(t, err)
insertBands(t, e)
storyBaseLoop(t, e, u, 10)
var threes []story
err = e.Model(&story{}).
In("ID", 1, 2, 3).
Find(&threes)
assert.Nil(t, err)
assert.GreaterOrEqual(t, len(threes), 1)
}
func testPopulateAll(t assert.TestingT, e *Engine) {
insertBands(t, e)
u := author(t)
err := e.Model(&user{}).Create(&u)
assert.Nil(t, err)
storyBase(e, t, u)
}
func testPopulateNested1(t assert.TestingT, e *Engine) {
u := author(t)
f := friend(t)
err := e.Model(&user{}).Create(&f)
assert.Nil(t, err)
assert.NotZero(t, f.ID)
err = e.Model(&user{}).Create(&u)
assert.NotZero(t, u.Favs.ID)
assert.Nil(t, err)
u.Favs.Authors = append(u.Favs.Authors, f)
err = e.Model(&user{}).Save(&u)
assert.Nil(t, err)
var nu user
err = e.Model(&user{}).
Populate(PopulateAll, "Favs.Authors").
Where("ID = ?", u.ID).
Find(&nu)
assert.Nil(t, err)
assert.NotZero(t, nu.Favs.ID)
assert.NotEmpty(t, nu.Favs.Authors)
}
func testPopulateNested2(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Create(&u)
assert.Nil(t, err)
insertBands(t, e)
ns := storyBase(e, t, u, "Chapters.Bands")
for _, c := range ns.Chapters {
assert.NotEmpty(t, c.Bands)
for _, b := range c.Bands {
assert.NotZero(t, b.Name)
assert.NotEmpty(t, b.Characters)
}
}
}
func TestFind(t *testing.T) {
e := initTest(t)
testFind1(t, e)
}
func TestJoin1(t *testing.T) {
e := initTest(t)
testJoin1(t, e)
}
func TestIn(t *testing.T) {
e := initTest(t)
testIn(t, e)
}
func TestPopulateAll(t *testing.T) {
e := initTest(t)
testPopulateAll(t, e)
}
func TestPopulateNested1(t *testing.T) {
e := initTest(t)
testPopulateNested1(t, e)
}
func TestPopulateNested2(t *testing.T) {
e := initTest(t)
testPopulateNested2(t, e)
}
func BenchmarkFind(b *testing.B) {
b.Run("Plain-1", bench(testFind1))
b.Run("Join-1", bench(testJoin1))
b.Run("In-1", bench(testIn))
}
func BenchmarkPopulate(b *testing.B) {
b.Run("Populate", func(b *testing.B) {
b.Run("Simple-1", bench(testPopulateAll))
b.Run("Populate-Nested-1", bench(testPopulateNested1))
b.Run("Populate-Nested-2", bench(testPopulateNested2))
})
}

35
go.mod
View File

@ -3,30 +3,29 @@ module rockfic.com/orm
go 1.24.1
require (
github.com/go-loremipsum/loremipsum v1.1.4
github.com/huandu/go-sqlbuilder v1.35.1
github.com/jackc/pgx/v5 v5.7.5
github.com/jinzhu/now v1.1.5
github.com/stretchr/testify v1.10.0
github.com/fatih/structtag v1.2.0
golang.org/x/net v0.37.0
)
require (
github.com/Jeffail/gabs v1.4.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fatih/structtag v1.2.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/henvic/pgq v0.0.4 // indirect
github.com/huandu/xstrings v1.4.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/text v0.24.0 // indirect
go.mongodb.org/mongo-driver/v2 v2.1.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (
github.com/go-loremipsum/loremipsum v1.1.3
github.com/golang/snappy v0.0.4 // indirect
github.com/klauspost/compress v1.16.7 // indirect
github.com/stretchr/testify v1.10.0
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/text v0.23.0 // indirect
)
replace rockfic.com/orm => C:/rockfic/orm

95
go.sum
View File

@ -1,55 +1,62 @@
github.com/Jeffail/gabs v1.4.0 h1://5fYRRTq1edjfIrQGvdkcd22pkYUrHZ5YC/H2GJVAo=
github.com/Jeffail/gabs v1.4.0/go.mod h1:6xMvQMK4k33lb7GUUpaAPh6nKMmemQeg5d4gn7/bOXc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/go-loremipsum/loremipsum v1.1.4 h1:RJaJlJwX4y9A2+CMgKIyPcjuFHFKTmaNMhxbL+sI6Vg=
github.com/go-loremipsum/loremipsum v1.1.4/go.mod h1:whNWskGoefTakPnCu2CO23v5Y7RwiG4LMOEtTDaBeOY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/henvic/pgq v0.0.4 h1:BgLnxofZJSWWs+9VOf19Gr9uBkSVbHWGiu8wix1nsIY=
github.com/henvic/pgq v0.0.4/go.mod h1:k0FMvOgmQ45MQ3TgCLe8I3+sDKy9lPAiC2m9gg37pVA=
github.com/huandu/go-assert v1.1.6 h1:oaAfYxq9KNDi9qswn/6aE0EydfxSa+tWZC1KabNitYs=
github.com/huandu/go-assert v1.1.6/go.mod h1:JuIfbmYG9ykwvuxoJ3V8TB5QP+3+ajIA54Y44TmkMxs=
github.com/huandu/go-sqlbuilder v1.35.1 h1:znTuAksxq3T1rYfr3nsD4P0brWDY8qNzdZnI6+vtia4=
github.com/huandu/go-sqlbuilder v1.35.1/go.mod h1:mS0GAtrtW+XL6nM2/gXHRJax2RwSW1TraavWDFAc1JA=
github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU=
github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/go-loremipsum/loremipsum v1.1.3 h1:ZRhA0ZmJ49lGe5HhWeMONr+iGftWDsHfrYBl5ktDXso=
github.com/go-loremipsum/loremipsum v1.1.3/go.mod h1:OJQjXdvwlG9hsyhmMQoT4HOm4DG4l62CYywebw0XBoo=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I=
github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mongodb.org/mongo-driver/v2 v2.1.0 h1:/ELnVNjmfUKDsoBisXxuJL0noR9CfeUIrP7Yt3R+egg=
go.mongodb.org/mongo-driver/v2 v2.1.0/go.mod h1:AWiLRShSrk5RHQS3AEn3RL19rqOzVq49MCpWQ3x/huI=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

1229
godoc.html Normal file

File diff suppressed because one or more lines are too long

247
gridfs.go Normal file
View File

@ -0,0 +1,247 @@
package orm
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"html/template"
"io"
"reflect"
"strings"
)
type GridFSFile struct {
ID bson.ObjectID `bson:"_id"`
Name string `bson:"filename"`
Length int `bson:"length"`
}
func parseFmt(format string, value any) (string, error) {
tmpl, err := template.New("filename").Parse(format)
panik(err)
w := new(strings.Builder)
err = tmpl.Execute(w, value)
return w.String(), err
}
func bucket(gfsRef gridFSReference) *mongo.GridFSBucket {
b := DB.GridFSBucket(options.GridFSBucket().SetName(gfsRef.BucketName))
return b
}
func hasTag(rtype reflect.Type, imodel Model) bool {
if rtype.Kind() == reflect.Pointer {
rtype = rtype.Elem()
}
if rtype.Kind() == reflect.Slice {
return hasTag(rtype.Elem(), imodel)
} else if rtype.Kind() == reflect.Struct {
for i := 0; i < rtype.NumField(); i++ {
f := rtype.Field(i)
tags, err := structtag.Parse(string(f.Tag))
if err != nil {
return false
}
if _, err = tags.Get("gridfs"); err == nil {
return true
}
for kk := range imodel.gridFSReferences {
if strings.HasPrefix(kk, f.Name) {
return true
}
}
}
}
return false
}
func gridFsLoad(val any, g gridFSReference, field string) any {
doc := reflect.ValueOf(val)
rdoc := reflect.ValueOf(val)
if doc.Kind() != reflect.Pointer {
doc = reflect.New(reflect.TypeOf(val))
doc.Elem().Set(reflect.ValueOf(val))
}
var next string
if len(strings.Split(field, ".")) > 1 {
next = strings.Join(strings.Split(field, ".")[1:], ".")
field = strings.Split(field, ".")[0]
} else {
next = field
}
_, rfield, ferr := getNested(field, rdoc)
if ferr != nil {
return nil
}
switch rfield.Kind() {
case reflect.Slice:
for i := 0; i < rfield.Len(); i++ {
cur := rfield.Index(i)
if cur.Kind() != reflect.Pointer {
tmp := reflect.New(cur.Type())
tmp.Elem().Set(cur)
cur = tmp
}
intermediate := gridFsLoad(cur.Interface(), g, next)
if intermediate == nil {
continue
}
ival := reflect.ValueOf(intermediate)
if ival.Kind() == reflect.Pointer {
ival = ival.Elem()
}
rfield.Index(i).Set(ival)
}
case reflect.Struct:
intermediate := gridFsLoad(rfield.Interface(), g, next)
if intermediate != nil {
rfield.Set(reflect.ValueOf(intermediate))
}
default:
b := bucket(g)
var found GridFSFile
fname, err := parseFmt(g.FilenameFmt, val)
if err != nil {
return nil
}
cursor, err := b.Find(context.TODO(), bson.M{"filename": fname})
if err != nil {
return nil
}
cursor.Next(context.TODO())
_ = cursor.Decode(&found)
bb := bytes.NewBuffer(nil)
_, err = b.DownloadToStream(context.TODO(), found.ID, bb)
if err != nil {
return nil
}
if rfield.Type().AssignableTo(reflect.TypeFor[[]byte]()) {
rfield.Set(reflect.ValueOf(bb.Bytes()))
} else if rfield.Type().AssignableTo(reflect.TypeFor[string]()) {
rfield.Set(reflect.ValueOf(bb.String()))
}
}
if rdoc.Kind() != reflect.Pointer {
return doc.Elem().Interface()
}
return doc.Interface()
}
func gridFsGen(val any, imodel Model, isSaving bool) error {
var rerr error
v := reflect.ValueOf(val)
el := v
if v.Kind() == reflect.Pointer {
el = el.Elem()
}
switch el.Kind() {
case reflect.Struct:
for i := 0; i < el.NumField(); i++ {
ft := el.Type().Field(i)
fv := el.Field(i)
if !ft.IsExported() {
continue
}
_, err := structtag.Parse(string(ft.Tag))
panik(err)
var gfsRef *gridFSReference
for kk, vv := range imodel.gridFSReferences {
if strings.HasPrefix(kk, ft.Name) {
gfsRef = &vv
break
}
}
var inner = func(b *mongo.GridFSBucket, it reflect.Value) error {
filename, err2 := parseFmt(gfsRef.FilenameFmt, it.Interface())
if err2 != nil {
return err2
}
contents := GridFSFile{}
curs, err2 := b.Find(context.TODO(), bson.M{"filename": filename})
if !errors.Is(err2, mongo.ErrNoDocuments) {
for {
if !curs.Next(context.TODO()) {
break
}
_ = curs.Decode(&contents)
if !reflect.ValueOf(contents).IsZero() {
b.Delete(context.TODO(), contents.ID)
}
}
}
if isSaving {
c := it.Field(gfsRef.Idx)
var rdr io.Reader
if c.Type().AssignableTo(reflect.TypeOf([]byte{})) {
rdr = bytes.NewReader(c.Interface().([]byte))
} else if c.Type().AssignableTo(reflect.TypeOf("")) {
rdr = strings.NewReader(c.Interface().(string))
} else {
return fmt.Errorf("gridfs loader type '%s' not supported", c.Type().String())
}
_, err = b.UploadFromStream(context.TODO(), filename, rdr)
}
return err
}
if gfsRef != nil {
b := bucket(*gfsRef)
if fv.Kind() == reflect.Slice {
for j := 0; j < fv.Len(); j++ {
lerr := inner(b, fv.Index(j))
if lerr != nil {
continue
}
}
} else if fv.Kind() == reflect.Struct {
lerr := inner(b, fv)
if lerr != nil {
continue
}
} else {
lerr := inner(b, el)
if lerr != nil {
continue
}
}
}
if hasTag(ft.Type, imodel) {
err = gridFsGen(fv.Interface(), imodel, isSaving)
if err != nil {
return err
}
}
}
case reflect.Slice:
for i := 0; i < el.Len(); i++ {
rerr = gridFsGen(el.Index(i).Interface(), imodel, isSaving)
if rerr != nil {
return rerr
}
}
default:
break
}
return rerr
}
func gridFsDel(val any, imodel Model) error {
return gridFsGen(val, imodel, false)
}
func gridFsSave(val any, imodel Model) error {
t := reflect.TypeOf(val)
if hasTag(t, imodel) {
return gridFsGen(val, imodel, true)
}
return nil
}

43
idcounter.go Normal file
View File

@ -0,0 +1,43 @@
package orm
import (
"context"
"reflect"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
const COUNTER_COL = "@counters"
type Counter struct {
Current any `bson:"current"`
Collection string `bson:"collection"`
}
func getLastInColl(cname string, id interface{}) interface{} {
var opts = options.FindOne()
switch id.(type) {
case int, int64, int32, uint, uint32, uint64, bson.ObjectID:
opts.SetSort(bson.M{"_id": -1})
case string:
opts.SetSort(bson.M{"createdAt": -1})
default:
panic(ErrUnsupportedID)
}
var cnt Counter
if !reflect.ValueOf(id).IsZero() {
return id
}
if err := DB.Collection(COUNTER_COL).FindOne(context.TODO(), bson.M{
"collection": cname,
}, opts).Decode(&cnt); err != nil {
cnt = Counter{
Current: id,
}
cnt.Collection = cname
}
return cnt.Current
}

148
indexes.go Normal file
View File

@ -0,0 +1,148 @@
package orm
import (
"go/scanner"
"go/token"
"strings"
"unicode"
"go.mongodb.org/mongo-driver/v2/mongo"
)
var optionKeywords = [...]string{"unique", "sparse", "background", "dropdups"}
// prolly won't need to use indexes, but just in case...
type InternalIndex struct {
Fields []string
Options []string
lastFieldIdx int
sticky bool
}
func (in *InternalIndex) appendOption(o string) {
o = strings.ToLower(o)
o = strings.Trim(o, " ")
for i := range optionKeywords {
if optionKeywords[i] == o {
in.Options = append(in.Options, o)
}
}
}
func (in *InternalIndex) appendField(f string) {
in.Fields = append(in.Fields, f)
}
func (in *InternalIndex) appendDotField(f string) {
in.Fields[in.lastFieldIdx] = in.Fields[in.lastFieldIdx] + "." + f
}
func (in *InternalIndex) updateLastField() {
if len(in.Fields) > 0 {
in.lastFieldIdx = len(in.Fields) - 1
return
}
in.lastFieldIdx = 0
}
func (in *InternalIndex) setSticky(s bool) {
in.sticky = s
}
func (in *InternalIndex) getSticky() bool {
return in.sticky
}
func scanIndex(src string) []InternalIndex {
var s scanner.Scanner
var parsed []InternalIndex
src = func(ss string) string {
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return -1
}
return r
}, ss)
}(src)
fset := token.NewFileSet()
file := fset.AddFile("", fset.Base(), len(src))
s.Init(file, []byte(src), nil, scanner.ScanComments)
lb := false
p := &InternalIndex{}
for {
_, tok, lit := s.Scan()
switch tok {
case token.LBRACE:
if lb {
goto _panik
}
lb = true
case token.RBRACE:
if !lb || len(p.Fields) == 0 {
goto _panik
}
lb = false
case token.IDENT:
if lb {
if p.getSticky() {
p.appendDotField(lit)
p.setSticky(false)
break
}
p.appendField(lit)
break
}
p.appendOption(lit)
case token.PERIOD:
if p.getSticky() {
goto _panik
}
p.setSticky(true)
p.updateLastField()
case token.COMMA:
case token.COLON:
if lb {
goto _panik
}
case token.SEMICOLON:
if lb {
goto _panik
}
parsed = append(parsed, *p)
p = &InternalIndex{}
case token.EOF:
if lb {
goto _panik
}
return parsed
default:
goto _panik
}
}
_panik:
panic("parsing error in index expression!")
}
func buildIndex(i InternalIndex) *mongo.IndexModel {
idx := &mongo.IndexModel{
Keys: i.Fields,
}
for _, o := range i.Options {
switch o {
case "unique":
idx.Options.SetUnique(true)
//case "background":
// idx.Options.SetBackground(true)
case "sparse":
idx.Options.SetSparse(true)
}
}
return idx
}

View File

@ -1,207 +0,0 @@
package logging
import (
"bytes"
"context"
"io"
"log/slog"
"runtime"
"slices"
"strings"
"sync"
"text/template"
"time"
)
type FormattedHandler struct {
mu *sync.Mutex
out io.Writer
opts Options
attrs map[string]slog.Value
groups []string
groupLvl int
}
type Options struct {
Level slog.Leveler
Format string
ReplaceAttr func(groups []string, attr slog.Attr) slog.Attr
}
type locData struct {
FileName string
Function string
Line int
}
func NewFormattedHandler(out io.Writer, options Options) *FormattedHandler {
h := &FormattedHandler{
opts: options,
out: out,
mu: &sync.Mutex{},
groups: make([]string, 0),
}
if h.opts.Format == "" {
h.opts.Format = "{{.Time}} [{{.Level}}]"
}
if h.opts.Level == nil {
h.opts.Level = slog.LevelInfo
}
return h
}
func (f *FormattedHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= f.opts.Level.Level()
}
func (f *FormattedHandler) Handle(ctx context.Context, r slog.Record) error {
bufp := allocBuf()
buf := *bufp
defer func() {
*bufp = buf
freeBuf(bufp)
}()
rep := f.opts.ReplaceAttr
key := slog.LevelKey
val := r.Level
if rep == nil {
r.AddAttrs(slog.String(key, val.String()))
} else {
nattr := slog.Any(key, val)
nattr.Value = rep(f.groups, nattr).Value
r.AddAttrs(nattr)
}
f.mu.Lock()
defer f.mu.Unlock()
tctx, tmpl := f.newFmtCtx(r)
wr := bytes.NewBuffer(buf)
parsed, err := tmpl.Parse(f.opts.Format)
if err != nil {
return err
}
err = parsed.Execute(wr, tctx)
if err != nil {
return err
}
wr.WriteByte('\n')
_, err = f.out.Write(wr.Bytes())
return err
}
func (f *FormattedHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return f
}
nf := f.clone()
bufp := allocBuf()
buf := *bufp
defer func() {
*bufp = buf
freeBuf(bufp)
}()
s := f.newState(bytes.NewBuffer(buf))
defer s.free()
pos := s.buf.Len()
s.startGroups()
if !s.appendAttrs(attrs) {
s.buf.Truncate(pos)
} else {
nf.groupLvl = len(nf.groups)
}
return nf
}
func (f *FormattedHandler) WithGroup(name string) slog.Handler {
if name == "" {
return f
}
f2 := f.clone()
f2.groups = append(f2.groups, name)
return f2
}
func (f *FormattedHandler) clone() *FormattedHandler {
return &FormattedHandler{
opts: f.opts,
groups: slices.Clip(f.groups),
out: f.out,
mu: f.mu,
groupLvl: f.groupLvl,
}
}
type tmplData struct {
Level string
Message string
RawTime time.Time
Time string
PC uintptr
Location locData
Record slog.Record
}
func hasBuiltInKey(a slog.Attr) bool {
return a.Key == slog.MessageKey ||
a.Key == slog.TimeKey ||
a.Key == slog.SourceKey
}
func (f *FormattedHandler) newFmtCtx(r slog.Record) (ctx *tmplData, tmpl *template.Template) {
tmpl = template.New("log")
ctx = &tmplData{
Message: r.Message,
RawTime: r.Time,
PC: r.PC,
Location: locData{},
}
if !r.Time.IsZero() {
ctx.Time = r.Time.Format(time.RFC3339Nano)
}
r.Attrs(func(a slog.Attr) bool {
if a.Key == slog.LevelKey {
str := strings.ToUpper(a.Value.String())
if rep := f.opts.ReplaceAttr; rep != nil {
str = strings.ToUpper(a.Value.String())
}
ctx.Level = str
}
return true
})
if r.PC != 0 {
frames := runtime.CallersFrames([]uintptr{r.PC})
frame, _ := frames.Next()
ctx.Location.FileName = frame.File
ctx.Location.Function = frame.Function
ctx.Location.Line = frame.Line
}
fm := make(map[string]any)
fm["rest"] = func() string {
bb := new(bytes.Buffer)
s := f.newState(bb)
defer s.free()
s.begin(r)
return s.buf.String()
}
tmpl = tmpl.Funcs(fm)
return
}
var bufPool = sync.Pool{
New: func() any {
b := make([]byte, 0, 4096)
return &b
},
}
func allocBuf() *[]byte {
return bufPool.Get().(*[]byte)
}
func freeBuf(b *[]byte) {
const maxBufferSize = 16 << 10
if cap(*b) <= maxBufferSize {
*b = (*b)[:0]
bufPool.Put(b)
}
}

View File

@ -1,145 +0,0 @@
package logging
import (
"bytes"
"fmt"
"log/slog"
"strings"
"sync"
"time"
)
func (f *FormattedHandler) newState(sb *bytes.Buffer) state {
s := state{
fh: f,
buf: sb,
}
if f.opts.ReplaceAttr != nil {
s.groups = groupPool.Get().(*[]string)
*s.groups = append(*s.groups, f.groups[:f.groupLvl]...)
}
return s
}
type state struct {
buf *bytes.Buffer
fh *FormattedHandler
groups *[]string
}
func (s *state) startGroups() {
for _, n := range s.fh.groups[s.fh.groupLvl:] {
s.startGroup(n)
}
}
func (s *state) startGroup(name string) {
s.buf.WriteByte('\n')
if s.groups != nil {
*s.groups = append(*s.groups, name)
}
}
func (s *state) endGroup() {
if s.groups != nil {
*s.groups = (*s.groups)[:len(*s.groups)-1]
}
}
func (s *state) appendAttr(a slog.Attr) bool {
a.Value = a.Value.Resolve()
if rep := s.fh.opts.ReplaceAttr; rep != nil && a.Value.Kind() != slog.KindGroup {
var gs []string
if s.groups != nil {
gs = *s.groups
}
a = rep(gs, a)
a.Value = a.Value.Resolve()
}
if a.Equal(slog.Attr{}) ||
hasBuiltInKey(a) ||
a.Key == slog.LevelKey {
return false
}
if a.Value.Kind() == slog.KindGroup {
pos := s.buf.Len()
attrs := a.Value.Group()
if len(attrs) > 0 {
if a.Key != "" {
s.startGroup(a.Key)
}
if !s.appendAttrs(attrs) {
s.buf.Truncate(pos)
return false
}
if a.Key != "" {
s.endGroup()
}
}
} else {
s.writeAttr(a)
}
return true
}
func (s *state) appendAttrs(as []slog.Attr) bool {
nonEmpty := false
for _, a := range as {
if s.appendAttr(a) {
nonEmpty = true
}
}
return nonEmpty
}
func (s *state) writeAttr(a slog.Attr) {
if s.buf.Len() > 0 {
s.buf.WriteString(";")
}
if len(*s.groups) > 0 {
s.buf.WriteString(fmt.Sprintf("%*s", len(*s.groups)*2, ""))
s.buf.WriteString(strings.Join(*s.groups, "."))
s.buf.WriteString(".")
}
s.buf.WriteString(a.Key)
s.buf.WriteString("=")
switch a.Value.Kind() {
case slog.KindDuration:
s.buf.WriteString(a.Value.Duration().String())
case slog.KindTime:
s.buf.WriteString(a.Value.Time().Format(time.RFC3339Nano))
default:
s.buf.WriteString(fmt.Sprintf("%+v", a.Value.Any()))
}
}
func (s *state) begin(r slog.Record) {
if r.NumAttrs() > 0 {
pos := s.buf.Len()
s.startGroups()
empty := true
r.Attrs(func(a slog.Attr) bool {
isBuiltIn := hasBuiltInKey(a) || a.Key == slog.LevelKey
if !isBuiltIn && s.appendAttr(a) {
empty = false
}
return true
})
if empty {
s.buf.Truncate(pos)
}
}
}
func (s *state) free() {
if gs := s.groups; gs != nil {
*gs = (*gs)[:0]
groupPool.Put(gs)
}
s.buf.Reset()
}
var groupPool = sync.Pool{New: func() any {
s := make([]string, 0, 10)
return &s
}}

View File

@ -1,34 +0,0 @@
package logging
import (
"context"
"log/slog"
"os"
"testing"
)
const LevelQ = slog.Level(-6)
func TestDoAFlip(t *testing.T) {
t.Name()
replacer := func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.LevelKey {
level := a.Value.Any().(slog.Level)
switch level {
case LevelQ:
a.Value = slog.StringValue("q")
}
}
return a
}
h := NewFormattedHandler(os.Stderr, Options{
Format: "{{.Time}} [{{.Level}}] {{.Message}} | {{ rest }}",
Level: LevelQ,
ReplaceAttr: replacer,
})
logger := slog.New(h)
slog.SetDefault(logger)
logger.Debug("hello", "btfash", true)
logger.Log(context.TODO(), LevelQ, "hi")
}

319
json.go
View File

@ -1,319 +0,0 @@
package orm
import (
"encoding/json"
"fmt"
"github.com/fatih/structtag"
"reflect"
"time"
)
func defaultEngine() *Engine {
return engines.Engines[defaultKey]
}
func anyToModel(input any) *Model {
rv := reflect.TypeOf(input)
for rv.Kind() == reflect.Ptr ||
rv.Kind() == reflect.Interface ||
rv.Kind() == reflect.Slice || rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
maybeEngine := defaultEngine()
if maybeEngine == nil {
return nil
}
return maybeEngine.modelMap.Map[rv.Name()]
}
func JSONSerialize(input any, pretty bool) ([]byte, error) {
vp := reflect.ValueOf(input)
vt := reflect.TypeOf(input)
if vt.Kind() != reflect.Pointer {
return nil, fmt.Errorf("Argument must be a pointer or pointer to a slice; got: %v", vt.Kind())
}
ser, err := innerSerialize(vp)
if err != nil {
return nil, err
}
if pretty {
return json.MarshalIndent(ser, "", "\t")
}
return json.Marshal(ser)
}
func JSONDeserialize(val any, ser []byte) error {
var fiv any
if err := json.Unmarshal(ser, &fiv); err != nil {
return err
}
vp := reflect.ValueOf(val)
if vp.Kind() != reflect.Pointer {
return fmt.Errorf("Argument must be a pointer or pointer to a slice; got: %v", vp.Kind())
}
m := anyToModel(val)
if m == nil {
return fmt.Errorf("No model found for type '%s'", vp.Type().Name())
}
maybeEngine := defaultEngine()
if maybeEngine == nil {
return fmt.Errorf("No engines have been created!?")
}
fv, err := innerDeserialize(fiv, m, maybeEngine)
if err != nil {
return err
}
vp.Elem().Set(fv)
return nil
}
func innerSerialize(v reflect.Value) (ret any, err error) {
switch v.Kind() {
case reflect.Interface:
v = v.Elem()
fallthrough
case reflect.Pointer:
if v.IsNil() {
return ret, nil
}
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.IsZero() {
return ret, nil
}
fallthrough
case reflect.Struct:
m := anyToModel(v.Interface())
if m == nil {
if canConvertTo[time.Time](v.Type()) {
ret = v.Interface().(time.Time).Format(time.RFC3339)
} else {
var bytes []byte
bytes, err = json.Marshal(v.Interface())
if err != nil {
return nil, err
}
ser := make(map[string]any)
err = json.Unmarshal(bytes, &ser)
if err != nil {
return nil, err
}
ret = ser
}
} else {
depopulated, depopulatedId := isDepopulated(v, m.IDField)
if depopulated {
ret = depopulatedId
} else {
rmap := make(map[string]any)
for i := range v.NumField() {
fv := v.Field(i)
ft := v.Type().Field(i)
var tag *structtag.Tags
tag, err = structtag.Parse(string(ft.Tag))
if err != nil {
return nil, err
}
var jsonTag *structtag.Tag
jsonTag, err = tag.Get("json")
if err != nil || jsonTag.Name == "-" {
continue
}
if jsonTag.Name == "" {
// we are dealing with an inlined/anonymous struct
var maybeMap any
maybeMap, err = innerSerialize(fv)
if amap, ok := maybeMap.(map[string]any); ok {
for k, vv := range amap {
rmap[k] = vv
}
}
} else {
rmap[jsonTag.Name], err = innerSerialize(fv)
if err != nil {
return nil, err
}
}
}
ret = rmap
}
}
case reflect.Slice, reflect.Array:
ret0 := make([]any, 0)
for i := range v.Len() {
var ser any
ser, err = innerSerialize(v.Index(i))
if err != nil {
return nil, err
}
ret0 = append(ret0, ser)
}
ret = ret0
default:
ret = v.Interface()
}
return ret, nil
}
func innerDeserialize(input any, m *Model, e *Engine) (nv reflect.Value, err error) {
t := m.Type
irv := reflect.ValueOf(input)
if irv.Kind() == reflect.Slice || irv.Kind() == reflect.Array {
nv = reflect.MakeSlice(reflect.SliceOf(t), 0, 0)
for i := range irv.Len() {
var snv reflect.Value
cur := irv.Index(i)
snv, err = innerDeserialize(cur.Interface(), m, e)
if err != nil {
return snv, err
}
nv = reflect.Append(nv, snv)
}
} else { // it's a map or primitive value
nv = reflect.New(t).Elem()
if asMap, ok := input.(map[string]any); ok {
for _, f := range m.Fields {
if f.Index < 0 {
continue
}
ft := f.Original
fv := nv.Field(f.Index)
var tags *structtag.Tags
var btag *structtag.Tag
tags, err = structtag.Parse(string(ft.Tag))
if err != nil {
return
}
btag, err = tags.Get("json")
if err != nil || btag.Name == "-" {
continue
}
interm := asMap[btag.Name]
var tmp any
if str, sok := interm.(string); sok {
if ttmp, terr := time.Parse(time.RFC3339, str); terr == nil {
tmp = ttmp
} else {
tmp = interm
}
} else {
tmp = interm
}
switch fv.Kind() {
case reflect.Int64, reflect.Int32, reflect.Int, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if tmp != nil {
fv.Set(reflect.ValueOf(tmp).Convert(ft.Type))
}
case reflect.Array, reflect.Slice:
if interm != nil {
slic := reflect.ValueOf(interm)
fv.Set(handleSliceMaybe(slic, fv.Type().Elem()))
}
default:
if ft.Anonymous {
var nfv reflect.Value
nfv, err = handleAnon(input, ft.Type)
if err != nil {
return
}
fv.Set(nfv)
} else {
fv.Set(reflect.ValueOf(tmp))
}
}
}
for _, r := range m.Relationships {
if r.Type == BelongsTo || r.Type == ManyToOne || r.Idx < 0 {
continue
}
ft := r.OriginalField
fv := nv.Field(r.Idx)
var tags *structtag.Tags
var btag *structtag.Tag
tags, err = structtag.Parse(string(ft.Tag))
if err != nil {
return
}
btag, err = tags.Get("json")
if err != nil || btag.Name == "-" {
continue
}
var rv reflect.Value
interm := asMap[btag.Name]
rv, err = innerDeserialize(interm, r.RelatedModel, e)
if err != nil {
return reflect.Value{}, err
}
fv.Set(rv)
}
} else {
iface := nv.Addr().Interface()
err = e.Model(iface).Where(fmt.Sprintf("%s = ?", m.IDField), input).Find(iface)
if err != nil {
return reflect.Value{}, err
}
}
}
return
}
func handleAnon(raw any, rtype reflect.Type) (nv reflect.Value, err error) {
nv = reflect.New(rtype).Elem()
amap, ok := raw.(map[string]any)
if ok {
for i := range rtype.NumField() {
ft := rtype.Field(i)
fv := nv.Field(i)
var tags *structtag.Tags
var btag *structtag.Tag
tags, err = structtag.Parse(string(ft.Tag))
if err != nil {
return
}
btag, terr := tags.Get("json")
if terr != nil || btag.Name == "-" || !ft.IsExported() {
continue
}
fval := amap[btag.Name]
if reflect.TypeOf(fval) == reflect.TypeFor[string]() && ft.Type == reflect.TypeFor[time.Time]() {
tt, _ := time.Parse(time.RFC3339, fval.(string))
fv.Set(reflect.ValueOf(tt))
} else if fval != nil {
fv.Set(reflect.ValueOf(fval))
}
}
}
return
}
func handleSliceMaybe(iv reflect.Value, dstType reflect.Type) reflect.Value {
if iv.Kind() != reflect.Slice && iv.Kind() != reflect.Pointer {
return iv
}
dst := reflect.MakeSlice(reflect.SliceOf(dstType), 0, 0)
for i := range iv.Len() {
//dst.Set(reflect.Append(fv, handleSliceMaybe(iv.Index(i).Elem())))
maybeIface := iv.Index(i)
if maybeIface.Kind() == reflect.Interface {
maybeIface = maybeIface.Elem()
}
maybeNdest := dstType
if maybeNdest.Kind() == reflect.Slice || maybeNdest.Kind() == reflect.Array {
maybeNdest = maybeNdest.Elem()
}
dst = reflect.Append(dst, handleSliceMaybe(maybeIface, maybeNdest))
}
return dst
}
func isDepopulated(v reflect.Value, idField string) (bool, any) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
syn := reflect.New(v.Type()).Elem()
syn.FieldByName(idField).Set(v.FieldByName(idField))
finalId := v.FieldByName(idField).Interface()
return reflect.DeepEqual(v.Interface(), syn.Interface()), finalId
}

View File

@ -1,74 +0,0 @@
package orm
import (
"encoding/json"
"fmt"
"github.com/Jeffail/gabs"
"github.com/stretchr/testify/assert"
"testing"
)
func TestJsonSerialize(t *testing.T) {
e := initTest(t)
defer e.Disconnect()
u := author(t)
err := e.Model(&user{}).Save(&u)
if err != nil {
fmt.Println(err.Error())
}
assert.Nil(t, err)
insertBands(t, e)
ns := storyBase(e, t, u, "Chapters.Bands")
bytes, err := JSONSerialize(ns, true)
assert.Nil(t, err)
fmt.Println(string(bytes))
}
func TestJSONDeserialize(t *testing.T) {
e := initTest(t)
defer e.Disconnect()
u := author(t)
err := e.Model(&user{}).Save(&u)
if err != nil {
fmt.Println(err.Error())
}
assert.Nil(t, err)
insertBands(t, e)
ns := storyBase(e, t, u, "Chapters.Bands")
bytes, err := JSONSerialize(ns, true)
assert.Nil(t, err)
fmt.Println(string(bytes))
msi := make(map[string]any)
err = json.Unmarshal(bytes, &msi)
assert.Nil(t, err)
obj, err := gabs.Consume(msi)
assert.Nil(t, err)
children, err := obj.S("chapters").Children()
assert.Nil(t, err)
for _, child := range children {
bands := child.S("bands")
var bcontainer []*gabs.Container
bcontainer, err = bands.Children()
assert.Nil(t, err)
if err != nil {
break
}
for j := range bcontainer {
id := bcontainer[j].S("_id").Data()
//obj.S("chapters").Index(i).S("bands").Index
_, err = bands.SetIndex(id, j)
assert.Nil(t, err)
if err != nil {
break
}
}
}
nbytes := obj.Bytes()
var des story
err = JSONDeserialize(&des, nbytes)
assert.Nil(t, err)
for _, c := range des.Chapters {
assert.NotNil(t, c.Bands)
assert.GreaterOrEqual(t, len(c.Bands), 1)
}
}

294
model.go
View File

@ -1,66 +1,270 @@
package orm
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"reflect"
)
// Model - an intermediate representation of a Go struct
// Model - type which contains "static" methods like
// Find, FindOne, etc.
type Model struct {
Name string // the name, almost always the name of the underlying Type
Type reflect.Type // the Go type this model represents
Relationships map[string]*Relationship // a mapping of struct field names to Relationship pointers
IDField string // the name of the field containing this model's ID
Fields map[string]*Field // mapping of struct field names to Field pointers
FieldsByColumnName map[string]*Field // mapping of database column names to Field pointers
TableName string // the name of the table where this model's data is stored. defaults to a snake_cased version of the type/struct name if not provided explicitly via tag
embeddedIsh bool // INTERNAL - whether this model is: 1) contained in another type and 2) wasn't explicitly passed to the `Models` method (i.e., it can't "exist" on its own)
Indexes map[string][]InternalIndex
Type reflect.Type
collection string
gridFSReferences map[string]gridFSReference
idx int
references map[string]Reference
typeName string `bson:"-"`
}
func (m *Model) addField(field *Field) {
field.Model = m
m.Fields[field.Name] = field
m.FieldsByColumnName[field.ColumnName] = field
// HasID is a simple interface that you must implement
// in your models, using a pointer receiver.
// This allows for more flexibility in cases where
// your ID isn't an ObjectID (e.g., int, uint, string...).
//
// and yes, those darn ugly ObjectIDs are supported :)
type HasID interface {
Id() any
SetId(id any)
}
const (
documentField = "Document"
createdField = "Created"
modifiedField = "Modified"
)
type HasIDSlice []HasID
func (m *Model) docField() *Field {
return m.Fields[documentField]
type IModel interface {
FindRaw(query interface{}, opts *options.FindOptionsBuilder) (*mongo.Cursor, error)
Find(query interface{}, opts *options.FindOptionsBuilder) (*Query, error)
FindByID(id interface{}) (*Query, error)
FindOne(query interface{}, options *options.FindOneOptionsBuilder) (*Query, error)
FindPaged(query interface{}, page int64, perPage int64, options *options.FindOptionsBuilder) (*Query, error)
Exists(query interface{}) (bool, error)
ExistsID(id interface{}) (bool, error)
Collection() *mongo.Collection
getIdxs() []*mongo.IndexModel
getParsedIdxs() map[string][]InternalIndex
getTypeName() string
setTypeName(str string)
}
func (m *Model) idField() *Field {
return m.Fields[m.IDField]
func (m *Model) getTypeName() string {
return m.typeName
}
func (m *Model) getPrimaryKey(val reflect.Value) (string, any) {
colField := m.Fields[m.IDField]
if colField == nil {
return "", nil
}
colName := colField.ColumnName
wasPtr := false
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return "", nil
func (m *Model) setTypeName(str string) {
m.typeName = str
}
// Collection - returns the collection associated with this Model
func (m *Model) Collection() *mongo.Collection {
return DB.Collection(m.collection)
}
func (m *Model) getIdxs() []*mongo.IndexModel {
mi := make([]*mongo.IndexModel, 0)
if mpi := m.getParsedIdxs(); mpi != nil {
for _, v := range mpi {
for _, i := range v {
mi = append(mi, buildIndex(i))
}
}
val = val.Elem()
wasPtr = true
return mi
}
if val.IsZero() && wasPtr {
return "", nil
}
idField := val.FieldByName(m.IDField)
if idField.IsValid() {
return colName, idField.Interface()
}
return "", nil
return nil
}
func (m *Model) needsPrimaryKey(val reflect.Value) bool {
_, pk := m.getPrimaryKey(val)
return pk == nil || reflect.ValueOf(pk).IsZero()
func (m *Model) getParsedIdxs() map[string][]InternalIndex {
_, ri, ok := ModelRegistry.HasByName(m.typeName)
if !ok {
panic(fmt.Sprintf(errFmtModelNotRegistered, m.typeName))
}
return ri.Indexes
}
// FindRaw - find documents satisfying `query` and return a plain mongo cursor.
func (m *Model) FindRaw(query interface{}, opts *options.FindOptionsBuilder) (*mongo.Cursor, error) {
coll := m.Collection()
if opts == nil {
opts = options.Find()
}
var fo options.FindOptions
for _, setter := range opts.Opts {
_ = setter(&fo)
}
cursor, err := coll.Find(context.TODO(), query, opts)
return cursor, err
}
// Find - find all documents satisfying `query`.
// returns a pointer to a Query for further chaining.
func (m *Model) Find(query interface{}, opts *options.FindOptionsBuilder) (*Query, error) {
qqn := ModelRegistry.new_(m.typeName)
qqt := reflect.SliceOf(reflect.TypeOf(qqn))
qqv := reflect.New(qqt)
qqv.Elem().Set(reflect.MakeSlice(qqt, 0, 0))
qq := &Query{
model: m,
collection: m.Collection(),
doc: qqv.Interface(),
op: OP_FIND_ALL,
}
q, err := m.FindRaw(query, opts)
//idoc := (*DocumentSlice)(qqv.Elem().UnsafePointer())
idoc := make(DocumentSlice, 0)
if err == nil {
rawRes := bson.A{}
err = q.All(context.TODO(), &rawRes)
if err == nil {
idoc.setExists(true)
}
qq.rawDoc = rawRes
err = q.All(context.TODO(), &qq.doc)
if err != nil {
qq.reOrganize()
err = nil
}
for i := 0; i < qqv.Elem().Len(); i++ {
idoc = append(idoc, qqv.Elem().Index(i).Interface().(IDocument))
}
for i, doc := range idoc {
doc.setModel(*m)
doc.SetSelf(doc)
doc.setRaw(rawRes[i])
}
}
return qq, err
}
// FindPaged - Wrapper around FindAll with the Skip and Limit options populated.
// returns a pointer to a Query for further chaining.
func (m *Model) FindPaged(query interface{}, page int64, perPage int64, opts *options.FindOptionsBuilder) (*Query, error) {
skipAmt := perPage * (page - 1)
if skipAmt < 0 {
skipAmt = 0
}
opts.SetSkip(skipAmt).SetLimit(perPage)
q, err := m.Find(query, opts)
q.op = OP_FIND_PAGED
return q, err
}
// FindByID - find a single document by its _id field.
// Wrapper around FindOne with an ID query as its first argument
func (m *Model) FindByID(id interface{}) (*Query, error) {
return m.FindOne(bson.D{{"_id", id}}, nil)
}
// FindOne - find a single document satisfying `query`.
// returns a pointer to a Query for further chaining.
func (m *Model) FindOne(query interface{}, options *options.FindOneOptionsBuilder) (*Query, error) {
coll := m.Collection()
rip := coll.FindOne(context.TODO(), query, options)
raw := bson.M{}
err := rip.Decode(&raw)
if err != nil {
return nil, err
}
qqn := ModelRegistry.new_(m.typeName)
idoc, ok := qqn.(IDocument)
idoc.setRaw(raw)
qq := &Query{
collection: m.Collection(),
rawDoc: raw,
doc: idoc,
op: OP_FIND_ONE,
model: m,
}
qq.rawDoc = raw
err = rip.Decode(qq.doc)
if err != nil {
qq.reOrganize()
err = nil
}
if ok {
idoc.setExists(true)
idoc.setModel(*m)
idoc.setRaw(raw)
}
idoc.SetSelf(idoc)
return qq, err
}
func (m *Model) Count(query interface{}, options *options.CountOptionsBuilder) (int64, error) {
coll := m.Collection()
return coll.CountDocuments(context.TODO(), query, options)
}
func (m *Model) Exists(query interface{}) (bool, error) {
cnt, err := m.Count(query, options.Count())
return cnt > 0, err
}
func (m *Model) ExistsID(id interface{}) (bool, error) {
cnt, err := m.Count(bson.M{"_id": id}, options.Count())
return cnt > 0, err
}
func createBase(d any) (reflect.Value, int, string) {
var n string
var ri *Model
var ok bool
n, ri, ok = ModelRegistry.HasByName(nameOf(d))
if !ok {
ModelRegistry.Model(d)
n, ri, _ = ModelRegistry.Has(d)
}
t := ri.Type
v := valueOf(d)
i := ModelRegistry.Index(n)
r := reflect.New(reflect.PointerTo(t)).Elem()
r.Addr().Elem().Set(reflect.New(t))
r.Addr().Elem().Elem().Set(v)
if reflect.ValueOf(d).Kind() == reflect.Pointer {
r.Addr().Elem().Elem().Set(reflect.ValueOf(d).Elem())
} else {
r.Addr().Elem().Elem().Set(reflect.ValueOf(d))
}
ri.setTypeName(n)
r.Interface().(IDocument).setModel(*ri)
r.Interface().(IDocument).newPopulationMap()
return r, i, n
}
// Create creates a new instance of a given Document
// type and returns a pointer to it.
func Create(d any) any {
r, _, n := createBase(d)
//df := r.Elem().Field(i)
dm := r.Interface().(IDocument)
dm.getModel().setTypeName(n)
what := r.Interface()
dm.SetSelf(what)
//df.Set(reflect.ValueOf(dm))
return what
}
// CreateSlice - convenience method which creates a new slice
// of type *T (where T is a type which embeds Document) and
// returns it
func CreateSlice[T any](d T) []*T {
r, _, _ := createBase(d)
rtype := r.Type()
rslice := reflect.SliceOf(rtype)
newItem := reflect.New(rslice)
newItem.Elem().Set(reflect.MakeSlice(rslice, 0, 0))
return newItem.Elem().Interface().([]*T)
}

View File

@ -1,105 +0,0 @@
package orm
import (
"reflect"
"strings"
)
func parseModel(model any) *Model {
t := reflect.TypeOf(model)
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
minfo := &Model{
Name: t.Name(),
Relationships: make(map[string]*Relationship),
Fields: make(map[string]*Field),
FieldsByColumnName: make(map[string]*Field),
Type: t,
}
for i := range t.NumField() {
f := t.Field(i)
if !f.IsExported() {
continue
}
}
if minfo.TableName == "" {
minfo.TableName = pascalToSnakeCase(t.Name())
}
return minfo
}
func parseModelFields(model *Model, modelMap map[string]*Model) {
t := model.Type
for i := range t.NumField() {
f := t.Field(i)
fi := parseField(f, model, modelMap, i)
if fi != nil && (fi.ColumnType != "" || fi.isAnonymous()) {
model.addField(fi)
}
}
}
func makeModelMap(models ...any) *internalModelMap {
modelMap := &internalModelMap{
Map: make(map[string]*Model),
}
//internalModelMap := make(map[string]*Model)
for _, model := range models {
minfo := parseModel(model)
modelMap.Mux.Lock()
modelMap.Map[minfo.Name] = minfo
modelMap.Mux.Unlock()
}
for _, model := range modelMap.Map {
modelMap.Mux.Lock()
parseModelFields(model, modelMap.Map)
modelMap.Mux.Unlock()
}
tagManyToMany(modelMap)
for _, model := range modelMap.Map {
for _, ref := range model.Relationships {
if ref.Type != ManyToMany && ref.Idx != -1 {
modelMap.Mux.Lock()
addForeignKeyFields(ref)
modelMap.Mux.Unlock()
}
}
}
return modelMap
}
func tagManyToMany(models *internalModelMap) {
hasManys := make(map[string]*Relationship)
for _, model := range models.Map {
for relName := range model.Relationships {
hasManys[model.Name+"."+relName] = model.Relationships[relName]
}
}
for _, model := range models.Map {
models.Mux.Lock()
for relName := range model.Relationships {
mb := model.Relationships[relName].RelatedModel
var name string
for n, reltmp := range hasManys {
if !strings.HasPrefix(n, mb.Name) || reltmp.Type != HasMany {
continue
}
if reltmp.RelatedType == model.Type {
name = reltmp.FieldName
break
}
}
if rel2, ok := mb.Relationships[name]; ok {
if name < relName &&
rel2.Type == HasMany && model.Relationships[relName].Type == HasMany {
mb.Relationships[name].Type = ManyToMany
mb.Relationships[name].m2mInverse = model.Relationships[relName]
model.Relationships[relName].Type = ManyToMany
model.Relationships[relName].m2mInverse = mb.Relationships[name]
}
}
}
models.Mux.Unlock()
}
}

View File

@ -1,20 +0,0 @@
package orm
import (
"sync"
)
type internalModelMap struct {
Map map[string]*Model
Mux sync.RWMutex
}
type engineHolder struct {
Engines map[string]*Engine
Mux sync.RWMutex
}
var engines = &engineHolder{
Engines: make(map[string]*Engine),
Mux: sync.RWMutex{},
}

View File

@ -1,187 +0,0 @@
package orm
import (
"fmt"
"reflect"
"strings"
)
type schemaSnapshotColumn struct {
Document `d:"table:__schemas"`
ID int64 `d:"pk"`
ModelName string
FieldName string
FieldType string
FieldIndex int
IsRelationship bool
IsSynthetic bool
}
func (m *Model) toSnapshotColumns() (ssc []*schemaSnapshotColumn) {
for _, field := range m.Fields {
ssc = append(ssc, &schemaSnapshotColumn{
ModelName: m.Name,
FieldName: field.Name,
FieldType: field.Type.String(),
FieldIndex: field.Index,
})
}
for _, rel := range m.Relationships {
rt := rel.RelatedType
if rel.Kind == reflect.Slice {
rt = reflect.SliceOf(rel.RelatedType)
}
ssc = append(ssc, &schemaSnapshotColumn{
ModelName: m.Name,
FieldName: rel.FieldName,
FieldType: rt.String(),
FieldIndex: rel.Idx,
IsRelationship: true,
IsSynthetic: rel.Idx < 0,
})
}
return
}
func (m *Model) createTableSql() string {
var fields []string
for _, field := range m.Fields {
if !field.isAnonymous() {
isStructOrSliceOfStructs := field.Type.Kind() == reflect.Struct ||
((field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) &&
field.Type.Elem().Kind() == reflect.Struct)
if field.PrimaryKey {
fields = append(fields, fmt.Sprintf("%s %s PRIMARY KEY", field.ColumnName, field.ColumnType))
} else if !isStructOrSliceOfStructs || field.ColumnType != "" {
lalala := fmt.Sprintf("%s %s", field.ColumnName, field.ColumnType)
if !field.Nullable {
lalala += " NOT NULL"
}
lalala += fmt.Sprintf(" DEFAULT %v", defaultColumnValue(field.Type))
fields = append(fields, lalala)
}
} else {
ft := field.Type
for ft.Kind() == reflect.Pointer {
ft = ft.Elem()
}
for i := range ft.NumField() {
efield := field.Type.Field(i)
ctype := columnType(efield.Type, false, false)
if ctype != "" {
def := fmt.Sprintf("%s %s NOT NULL DEFAULT %v", pascalToSnakeCase(efield.Name), ctype, defaultColumnValue(efield.Type))
fields = append(fields, def)
}
}
}
}
inter := strings.Join(fields, ", ")
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);",
m.TableName, inter)
}
func (m *Model) createJoinTableSql(relName string) string {
ref, ok := m.Relationships[relName]
if !ok {
return ""
}
aTable := m.TableName
joinTableName := ref.ComputeJoinTable()
fct := serialToRegular(ref.primaryID().ColumnType)
rct := serialToRegular(ref.relatedID().ColumnType)
pkSection := fmt.Sprintf(",\nPRIMARY KEY (%s_id, %s_id)",
aTable,
ref.RelatedModel.TableName,
)
if ref.m2mIsh() {
pkSection = ""
}
return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
%s_id %s REFERENCES %s(%s) ON DELETE CASCADE,
%s_id %s REFERENCES %s(%s) ON DELETE CASCADE %s
);`,
joinTableName,
ref.Model.TableName,
fct,
ref.Model.TableName, ref.Model.Fields[ref.Model.IDField].ColumnName,
ref.RelatedModel.TableName,
rct,
ref.RelatedModel.TableName, ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnName,
pkSection,
)
}
func (m *Model) generateConstraints(engine *Engine) error {
for _, rel := range m.Relationships {
field := rel.relatedID()
if rel.Type != ManyToMany && rel.Type != HasMany && !rel.m2mIsh() {
colType := serialToRegular(field.ColumnType)
if !field.Nullable && !rel.Nullable {
colType += " NOT NULL"
}
/*constraint := fmt.Sprintf("%s %s REFERENCES %s(%s)",
pascalToSnakeCase(rel.joinField()), colType,
rel.RelatedModel.TableName,
field.ColumnName)
if rel.Type != ManyToOne && rel.Type != BelongsTo {
constraint += " ON DELETE CASCADE ON UPDATE CASCADE"
}*/
fk := fmt.Sprintf("fk_%s", pascalToSnakeCase(capitalizeFirst(rel.Model.Name)+rel.FieldName+rel.relatedID().Name))
q := fmt.Sprintf(`ALTER TABLE %s
ADD COLUMN IF NOT EXISTS %s %s,
ADD CONSTRAINT %s
FOREIGN KEY (%s) REFERENCES %s(%s)
ON DELETE CASCADE
ON UPDATE CASCADE;`,
rel.Model.TableName,
pascalToSnakeCase(rel.joinField()), colType,
fk,
pascalToSnakeCase(rel.joinField()),
rel.RelatedModel.TableName, field.ColumnName,
)
dq := fmt.Sprintf(`ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s;`, m.TableName, fk)
engine.logSql("drop constraint", dq)
engine.logSql("alter table", q)
if _, err := engine.conn.Exec(engine.ctx, dq); err != nil {
return err
}
if _, err := engine.conn.Exec(engine.ctx, q); err != nil {
return err
}
}
}
return nil
}
func (m *Model) migrate(engine *Engine) error {
sql := m.createTableSql()
engine.logSql("create table", sql)
if !engine.dryRun {
_, err := engine.conn.Exec(engine.ctx, sql)
if err != nil {
return err
}
}
for relName, rel := range m.Relationships {
relkey := rel.ComputeJoinTable()
if (rel.Type == ManyToMany && !engine.m2mSeen[relkey]) ||
(rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh && rel.Type == HasMany) {
if rel.Type == ManyToMany {
engine.m2mSeen[relkey] = true
engine.m2mSeen[rel.Model.Name] = true
engine.m2mSeen[rel.RelatedModel.Name] = true
}
jtsql := m.createJoinTableSql(relName)
engine.logSql("crate join table", jtsql)
if !engine.dryRun {
_, err := engine.conn.Exec(engine.ctx, jtsql)
if err != nil {
return err
}
}
}
}
return m.generateConstraints(engine)
}

281
model_test.go Normal file
View File

@ -0,0 +1,281 @@
package orm
import (
"encoding/json"
"fmt"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNew(t *testing.T) {
initTest()
is := iti_single()
doc := Create(is).(*story)
assert.Equal(t, is.Title, doc.Title)
assert.Equal(t, is.Chapters[0].Summary, doc.Chapters[0].Summary)
}
func TestSave(t *testing.T) {
initTest()
storyDoc := Create(iti_multi()).(*story)
lauthor := Create(author).(*user)
storyDoc.Author = lauthor
assert.Equal(t, storyDoc.Id(), int64(0))
assert.Equal(t, lauthor.ID, storyDoc.Author.ID)
aerr := lauthor.Save()
assert.Equal(t, nil, aerr)
serr := storyDoc.Save()
assert.Equal(t, nil, serr)
assert.Less(t, int64(0), storyDoc.ID)
assert.Less(t, int64(0), lauthor.ID)
for _, c := range storyDoc.Chapters {
assert.NotZero(t, c.ChapterID)
}
}
func TestPopulate(t *testing.T) {
initTest()
bandDoc := Create(iti_single().Chapters[0].Bands[0]).(*band)
storyDoc := Create(iti_single()).(*story)
mauthor := Create(author).(*user)
saveDoc(t, mauthor)
saveDoc(t, bandDoc)
storyDoc.Author = mauthor
saveDoc(t, storyDoc)
assert.Greater(t, storyDoc.ID, int64(0))
smodel := ModelRegistry["story"]
q, err := smodel.FindByID(storyDoc.ID)
assert.Equal(t, nil, err)
assert.NotPanics(t, func() {
foundDoc := &story{}
q.Populate("Author", "Chapters.Bands").Exec(foundDoc)
j, _ := q.JSON()
fmt.Printf("%s\n", j)
})
for _, c := range storyDoc.Chapters {
assert.NotZero(t, c.Bands[0].Name)
}
}
func TestUpdate(t *testing.T) {
initTest()
nb := Create(metallica).(*band)
saveDoc(t, nb)
nb.Locked = true
saveDoc(t, nb)
foundM := ModelRegistry["band"]
q, err := foundM.FindByID(int64(1))
assert.Equal(t, nil, err)
found := &band{}
q.Exec(found)
assert.Equal(t, int64(1), found.ID)
assert.Equal(t, nil, err)
assert.Equal(t, true, found.Locked)
}
func TestModel_FindAll(t *testing.T) {
initTest()
im := iti_multi()
createAndSave(t, &im)
smodel := ModelRegistry["story"]
query, err := smodel.Find(bson.M{}, options.Find())
assert.Equal(t, nil, err)
final := CreateSlice(story{})
query.Exec(&final)
assert.Greater(t, len(final), 0)
}
func TestModel_PopulateMulti(t *testing.T) {
initTest()
bandDoc := Create(iti_single().Chapters[0].Bands[0]).(*band)
saveDoc(t, bandDoc)
mauthor := Create(author).(*user)
saveDoc(t, mauthor)
im := iti_multi()
im.Author = mauthor
createAndSave(t, &im)
smodel := ModelRegistry["story"]
query, err := smodel.Find(bson.M{}, options.Find())
assert.Equal(t, nil, err)
final := CreateSlice(story{})
query.Populate("Author", "Chapters.Bands").Exec(&final)
assert.Greater(t, len(final), 0)
for _, s := range final {
assert.NotZero(t, s.Chapters[0].Bands[0].Name)
}
bytes, _ := json.MarshalIndent(final, "", "\t")
fmt.Println(string(bytes))
}
func TestModel_PopulateChained_Multi(t *testing.T) {
initTest()
im := iti_multi()
bandDoc := Create(iti_single().Chapters[0].Bands[0]).(*band)
saveDoc(t, bandDoc)
mauthor := Create(author).(*user)
saveDoc(t, mauthor)
im.Author = mauthor
createAndSave(t, &im)
smodel := ModelRegistry["story"]
query, err := smodel.Find(bson.M{}, options.Find())
assert.Equal(t, nil, err)
final := CreateSlice(story{})
query.Populate("Author").Populate("Chapters.Bands").Exec(&final)
assert.Greater(t, len(final), 0)
for _, s := range final {
assert.NotZero(t, s.Chapters[0].Bands[0].Name)
}
}
func TestPopulate_Chained(t *testing.T) {
initTest()
bandDoc := Create(iti_single().Chapters[0].Bands[0]).(*band)
storyDoc := Create(iti_single()).(*story)
mauthor := Create(author).(*user)
saveDoc(t, mauthor)
saveDoc(t, bandDoc)
storyDoc.Author = mauthor
saveDoc(t, storyDoc)
assert.Greater(t, storyDoc.ID, int64(0))
smodel := ModelRegistry["story"]
q, err := smodel.FindByID(storyDoc.ID)
assert.Equal(t, nil, err)
assert.NotPanics(t, func() {
foundDoc := &story{}
q.Populate("Author").Populate("Chapters.Bands").Exec(foundDoc)
j, _ := q.JSON()
fmt.Printf("%s\n", j)
})
for _, c := range storyDoc.Chapters {
assert.NotZero(t, c.Bands[0].Name)
}
}
func TestModel_Append(t *testing.T) {
initTest()
bandDoc := Create(metallica).(*band)
saveDoc(t, bandDoc)
bmodel := ModelRegistry["band"]
query, err := bmodel.FindByID(int64(1))
assert.Equal(t, nil, err)
fin := &band{}
query.Exec(fin)
assert.Greater(t, fin.ID, int64(0))
err = fin.Append("Characters", "Robert Trujillo")
assert.Equal(t, nil, err)
saveDoc(t, fin)
fin = &band{}
query, _ = bmodel.FindByID(int64(1))
query.Exec(fin)
assert.Greater(t, len(fin.Characters), 4)
}
func TestModel_Delete(t *testing.T) {
initTest()
bandDoc := Create(metallica).(*band)
saveDoc(t, bandDoc)
err := bandDoc.Delete()
assert.Nil(t, err)
}
func TestModel_Pull(t *testing.T) {
initTest()
storyDoc := Create(iti_multi()).(*story)
smodel := ModelRegistry["story"]
saveDoc(t, storyDoc)
err := storyDoc.Pull("Chapters", storyDoc.Chapters[4])
assert.Nil(t, err)
assert.NotZero(t, storyDoc.ID)
saveDoc(t, storyDoc)
fin := &story{}
query, err := smodel.FindByID(storyDoc.ID)
assert.Nil(t, err)
query.Exec(fin)
assert.Equal(t, 4, len(fin.Chapters))
}
func TestModel_Swap(t *testing.T) {
initTest()
is := iti_single()
is.Author = &author
storyDoc := Create(iti_single()).(*story)
saveDoc(t, storyDoc)
storyDoc.Chapters[0].Bands = append(storyDoc.Chapters[0].Bands, bodom)
assert.Equal(t, 2, len(storyDoc.Chapters[0].Bands))
err := storyDoc.Swap("Chapters[0].Bands", 0, 1)
assert.Nil(t, err)
c := storyDoc.Chapters[0].Bands
assert.Equal(t, bodom.ID, c[0].ID)
assert.Equal(t, diamondHead.ID, c[1].ID)
saveDoc(t, storyDoc)
}
func TestModel_GridFSLoad(t *testing.T) {
initTest()
ModelRegistry.Model(somethingWithNestedChapters{})
model := ModelRegistry["somethingWithNestedChapters"]
thingDoc := Create(doSomethingWithNested()).(*somethingWithNestedChapters)
found := &somethingWithNestedChapters{}
saveDoc(t, thingDoc)
assert.NotZero(t, thingDoc.ID)
fq, err := model.FindByID(thingDoc.ID)
assert.Nil(t, err)
fq.LoadFile("NestedText", "Chapters.Text").Exec(found)
assert.NotZero(t, found.NestedText)
assert.NotZero(t, len(found.Chapters))
for _, c := range found.Chapters {
assert.NotZero(t, c.Text)
}
}
func TestModel_GridFSLoad_Chained(t *testing.T) {
initTest()
ModelRegistry.Model(somethingWithNestedChapters{})
model := ModelRegistry["somethingWithNestedChapters"]
thingDoc := Create(doSomethingWithNested()).(*somethingWithNestedChapters)
found := &somethingWithNestedChapters{}
saveDoc(t, thingDoc)
assert.NotZero(t, thingDoc.ID)
fq, err := model.FindByID(thingDoc.ID)
assert.Nil(t, err)
fq.LoadFile("NestedText").LoadFile("Chapters.Text").Exec(found)
assert.NotZero(t, found.NestedText)
assert.NotZero(t, len(found.Chapters))
for _, c := range found.Chapters {
assert.NotZero(t, c.Text)
}
}
func TestModel_GridFSLoad_Complex(t *testing.T) {
initTest()
model := ModelRegistry["story"]
bandDoc := Create(iti_single().Chapters[0].Bands[0]).(*band)
thingDoc := Create(iti_multi()).(*story)
mauthor := Create(author).(*user)
found := &story{}
saveDoc(t, bandDoc)
saveDoc(t, mauthor)
thingDoc.Author = mauthor
saveDoc(t, thingDoc)
assert.NotZero(t, thingDoc.ID)
fq, err := model.FindByID(thingDoc.ID)
assert.Nil(t, err)
fq.Populate("Author", "Chapters.Bands").LoadFile("Chapters.Text").Exec(found)
assert.NotZero(t, len(found.Chapters))
for _, c := range found.Chapters {
assert.NotZero(t, c.Text)
assert.NotZero(t, c.Bands[0].Name)
}
j, _ := fq.JSON()
fmt.Printf("%s\n", j)
}

888
query.go
View File

@ -2,321 +2,637 @@ package orm
import (
"context"
"encoding/json"
"errors"
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"reflect"
"strings"
"time"
)
// Query - contains the state and other details
// pertaining to the current FUCK operation (Find, Update/Create, Kill [Delete])
type Query struct {
engine *Engine // the Engine instance that created this Query
model *Model // the primary Model this Query pertains to
tx pgx.Tx // the transaction for insert, update and delete operations
ctx context.Context // does nothing, but is needed by some pgx functions
populationTree map[string]any // a tree-like map representing the dot-separated paths of fields to populate
wheres map[string][]any // a mapping of where clauses to a list of their arguments
joins []string // slice of tables to join on before executing Find. useful for hwne you have Where clauses referencing fields/columns in other structs/tables
orders []string // slice of `ORDER BY` clauses
limit int // argument to a LIMIT clause, if non-zero
offset int // unused (for now)
collection *mongo.Collection
op string
model *Model
done bool
rawDoc any
doc any
}
func (q *Query) setModel(val any) *Query {
tt := reflect.TypeOf(val)
for tt.Kind() == reflect.Ptr {
tt = tt.Elem()
const (
OP_FIND_ONE = "findOne"
OP_FIND_PAGED = "findPaged"
OP_FIND_ALL = "findAll"
OP_FIND = "find"
)
func populate(r Reference,
alreadyPopulated map[string]bool,
rcoll string, rawDoc interface{},
curDescent string, src interface{}) any {
rt := reflect.TypeOf(src)
rv := reflect.ValueOf(src)
srt := rt
if srt.Kind() == reflect.Pointer {
srt = rt.Elem()
}
q.model = q.engine.modelMap.Map[tt.Name()]
return q
}
func (q *Query) cleanupTx() {
q.tx.Rollback(q.ctx)
q.tx = nil
}
// Order - add an `ORDER BY` clause to the current Query.
// Only applicable for Find queries
func (q *Query) Order(order string) *Query {
q.orders = append(q.orders, order)
return q
}
// Limit - limit resultset to at most `limit` results.
// does nothing if `limit` <= 0 or the final operation isn't Find
func (q *Query) Limit(limit int) *Query {
if limit > -1 {
q.limit = limit
if rv.Kind() == reflect.Pointer && reflect.ValueOf(src).IsNil() {
return src
}
return q
}
// Offset - skip to the nth result, where n = `offset`.
// does nothing if `offset` <= 0 or the final operation isn't Find
func (q *Query) Offset(offset int) *Query {
if offset > -1 {
q.offset = offset
if rv.Kind() != reflect.Pointer {
rv = reflect.New(rt)
rv.Elem().Set(reflect.ValueOf(src))
}
return q
}
if srt.Kind() == reflect.Struct && !isObject(rawDoc) {
nrd := ModelRegistry.Get(srt.Name())
if nrd != nil && nrd.collection != rcoll {
q, err := nrd.FindByID(rawDoc)
if err == nil {
rawDoc = q.rawDoc
toPopulate := []string{curDescent}
if asIDoc, ok := rv.Interface().(IDocument); ok {
for k, v := range asIDoc.getPopulated() {
if k != curDescent && v {
toPopulate = append(toPopulate, k)
}
}
}
q.Populate(toPopulate...)
// Where - add a `WHERE` clause to this query.
// struct field names can be passed to this method,
// and they will be automatically converted
func (q *Query) Where(cond string, args ...any) *Query {
q.processWheres(cond, "eq", args...)
return q
}
// WhereRaw - add a `WHERE` clause to this query, except `cond` is passed as-is.
func (q *Query) WhereRaw(cond string, args ...any) *Query {
q.wheres[cond] = args
return q
}
// In - add a `WHERE ... IN(...)` clause to this query
func (q *Query) In(cond string, args ...any) *Query {
q.processWheres(cond, "in", args...)
return q
}
// Join - join the current model's table with the table
// representing the type of struct field named `field`.
// Must be called before Where if referencing other
// structs/types to avoid errors
func (q *Query) Join(field string) *Query {
var clauses []string
parts := strings.Split(field, ".")
cur := q.model
found := false
aliasMap := q.getNestedAliases(field)
for _, part := range parts {
rel, ok := cur.Relationships[part]
if !ok {
found = false
break
}
if rel.FieldName != part {
found = false
break
}
found = true
aliases := aliasMap[rel]
curAlias := aliases[0]
nalias := aliases[1]
if rel.m2mIsh() || rel.Type == ManyToMany {
joinAlias := aliases[2]
jc1 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s_id",
rel.ComputeJoinTable(), joinAlias,
curAlias, cur.idField().ColumnName,
joinAlias, rel.Model.TableName,
)
jc2 := fmt.Sprintf("%s AS %s ON %s.%s_id = %s.%s",
rel.RelatedModel.TableName, nalias,
joinAlias, rel.RelatedModel.TableName,
nalias, rel.relatedID().ColumnName,
)
clauses = append(clauses, jc1, jc2)
}
if rel.Type == HasMany || rel.Type == HasOne {
fkr := rel.RelatedModel.Relationships[cur.Name]
if fkr != nil {
jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
rel.RelatedModel.TableName, nalias,
curAlias, cur.idField().ColumnName,
nalias, pascalToSnakeCase(fkr.joinField()),
)
clauses = append(clauses, jc)
q.Exec(rv.Interface())
src = rv.Interface()
}
}
if rel.Type == BelongsTo {
jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
rel.RelatedModel.TableName, nalias,
curAlias, pascalToSnakeCase(rel.joinField()),
nalias, rel.RelatedModel.idField().ColumnName,
)
clauses = append(clauses, jc)
}
curAlias = nalias
cur = rel.RelatedModel
}
if found {
q.joins = append(q.joins, clauses...)
}
return q
}
func (q *Query) getNestedAliases(field string) (amap map[*Relationship][]string) {
amap = make(map[*Relationship][]string)
parts := strings.Split(field, ".")
cur := q.model
curAlias := q.model.TableName
first := curAlias
found := false
for _, part := range parts {
rel, ok := cur.Relationships[part]
if !ok {
found = false
break
}
if rel.FieldName != part {
found = false
break
}
found = true
amap[rel] = make([]string, 0)
nalias := pascalToSnakeCase(part)
if rel.m2mIsh() || rel.Type == ManyToMany {
joinAlias := rel.ComputeJoinTable() + "_joined"
amap[rel] = append(amap[rel], curAlias, nalias, joinAlias)
} else if rel.Type == HasMany || rel.Type == HasOne || rel.Type == BelongsTo {
amap[rel] = append(amap[rel], curAlias, nalias)
}
curAlias = nalias
cur = rel.RelatedModel
var fieldsMap [3]string
type bsonWhat struct {
What string
}
if !found {
return
}
amap[nil] = []string{first}
return
}
func (q *Query) processWheres(cond string, exprKind string, args ...any) {
parts := strings.SplitN(cond, " ", 2)
var translatedColumn string
fieldPath := parts[0]
ncond := ""
if len(parts) > 1 {
ncond = " " + parts[1]
}
pathParts := strings.Split(fieldPath, ".")
if len(pathParts) > 1 {
relPath := pathParts[:len(pathParts)-1]
fieldName := pathParts[len(pathParts)-1]
relPathStr := strings.Join(relPath, ".")
aliasMap := q.getNestedAliases(relPathStr)
for r, a := range aliasMap {
if r == nil {
continue
}
f, ok := r.RelatedModel.Fields[fieldName]
if ok {
translatedColumn = fmt.Sprintf("%s.%s", a[1], f.ColumnName)
}
}
} else if pf := q.model.Fields[pathParts[0]]; pf != nil {
translatedColumn = fmt.Sprintf("%s.%s", q.model.TableName, pf.ColumnName)
}
var tq string
switch strings.ToLower(exprKind) {
case "in":
tq = fmt.Sprintf("%s IN (%s)", translatedColumn, MakePlaceholders(len(args)))
var isNotStructOrSlice bool
var w bsonWhat
switch rawDoc.(type) {
case bson.A:
w.What = "A"
case bson.D:
w.What = "D"
case bson.M:
w.What = "M"
default:
tq = fmt.Sprintf("%s%s", translatedColumn, ncond)
w.What = "-"
}
q.wheres[tq] = args
var fld string
var next string
if len(strings.Split(curDescent, ".")) > 1 {
next = strings.Join(strings.Split(curDescent, ".")[1:], ".")
fld = strings.Split(curDescent, ".")[0]
} else {
fld = curDescent
next = curDescent
}
var toReturn interface{}
switch w.What {
case "A":
rvs := reflect.MakeSlice(rt, 0, 0)
var rahh reflect.Value
if rv.Kind() == reflect.Ptr {
rahh = rv.Elem()
} else {
rahh = rv
}
if len(rawDoc.(bson.A)) > 0 {
if !isObject(rawDoc.(bson.A)[0]) {
next = curDescent
}
}
for i, el := range rawDoc.(bson.A) {
it := rahh.Index(i)
rel := el
popped := populate(r, alreadyPopulated, rcoll, rel, next, it.Interface())
if pidoc, pok := popped.(IDocument); pok {
pidoc.setRaw(rel)
}
poppedVal := reflect.ValueOf(popped)
if poppedVal.Kind() == reflect.Pointer {
rvs = reflect.Append(rvs, poppedVal.Elem())
} else {
rvs = reflect.Append(rvs, poppedVal)
}
}
if rv.CanSet() {
rv.Set(rvs)
} else if rv.Kind() == reflect.Pointer {
rv.Elem().Set(rvs)
} else {
src = rvs.Interface()
toReturn = src
}
case "D":
loc := rawDoc.(bson.D)
nrd := bson.M{}
for _, el := range loc {
nrd[el.Key] = el.Value
}
rawDoc = nrd
fallthrough
case "M":
dd := rawDoc.(bson.M)
var sf reflect.Value
var rsf reflect.Value
if rv.Kind() == reflect.Pointer {
sf = rv.Elem().FieldByName(fld)
} else {
sf = rv.FieldByName(fld)
}
if rv.Kind() == reflect.Pointer {
rsf = rv.Elem().FieldByName(fld)
} else {
rsf = rv.FieldByName(fld)
}
var ff reflect.StructField
var ok bool
if rt.Kind() == reflect.Pointer {
ff, ok = rt.Elem().FieldByName(fld)
} else {
ff, ok = rt.FieldByName(fld)
}
if ok {
tag, err := structtag.Parse(string(ff.Tag))
if err == nil {
val, err2 := tag.Get("bson")
if err2 == nil && val.Name != "-" {
fttt := ff.Type
if fttt.Kind() == reflect.Pointer || fttt.Kind() == reflect.Slice {
fttt = fttt.Elem()
}
fieldsMap = [3]string{fld, fttt.Name(), val.Name}
}
}
} else {
fmt.Println("todo")
}
intermediate := populate(r, alreadyPopulated, rcoll, dd[fieldsMap[2]], next, sf.Interface())
/*if iidoc, idocOk := intermediate.(IDocument); idocOk {
if (reflect.ValueOf(intermediate).CanAddr() && !reflect.ValueOf(intermediate).IsNil()) || !reflect.ValueOf(intermediate).IsZero() {
iiid, iok := intermediate.(HasID)
if intermediate != nil && iok && !reflect.ValueOf(iiid.Id()).IsZero() {
iidoc.setRaw(dd[fieldsMap[2]])
}
}
}*/
if rsf.CanSet() {
ival := reflect.ValueOf(intermediate)
if ival.Kind() != reflect.Pointer && rsf.Kind() == reflect.Pointer {
rsf.Set(ival.Elem())
} else if ival.Kind() == reflect.Pointer && rsf.Kind() != reflect.Pointer {
rsf.Set(ival.Elem())
} else {
rsf.Set(ival)
}
} else {
src = intermediate
}
default:
isNotStructOrSlice = true
if r.exists {
tto := r.HydratedType
if tto.Kind() == reflect.Pointer || tto.Kind() == reflect.Slice {
tto = tto.Elem()
}
rawt := ModelRegistry.new_(tto.Name())
t := rawt.(IDocument)
q := bson.M{"_id": rawDoc}
reso := DB.Collection(rcoll).FindOne(context.TODO(), q)
if !errors.Is(reso.Err(), mongo.ErrNoDocuments) {
var anotherMap = make(bson.M)
reso.Decode(&anotherMap)
reflect.ValueOf(t).Elem().Set(reflect.ValueOf(rerere(anotherMap, tto, false)).Elem())
t.setRaw(anotherMap)
}
hatred := rv
if hatred.Kind() == reflect.Pointer {
hatred = hatred.Elem()
}
if hatred.CanSet() {
if reflect.ValueOf(t).Kind() == reflect.Pointer {
if hatred.Kind() == reflect.Pointer {
hatred.Set(reflect.ValueOf(t))
} else {
hatred.Set(reflect.ValueOf(t).Elem())
}
} else {
hatred.Set(reflect.ValueOf(t))
}
} else {
src = t
toReturn = src
}
t.SetSelf(t)
t.setExists(true)
}
}
if toReturn == nil {
sidoc, sok := rv.Interface().(IDocument)
if sok {
sidoc.SetSelf(rv.Interface())
if !isNotStructOrSlice {
sidoc.setRaw(rawDoc)
}
} else if rv.Kind() == reflect.Pointer && rt.Kind() != reflect.Pointer {
rv = rv.Elem()
}
return rv.Interface()
}
sidoc, sok := src.(IDocument)
if sok {
sidoc.SetSelf(src)
if !isNotStructOrSlice {
sidoc.setRaw(rawDoc)
}
}
return src
}
// buildSQL - aggregates the information in this Query into a pgq.SelectBuilder.
// it returns a slice of column names as well to avoid issues with scanning
func (q *Query) buildSQL() (cols []string, anonymousCols map[string][]string, finalSb sb.SelectBuilder, err error) {
var inParents []any
anonymousCols = make(map[string][]string)
for _, field := range q.model.Fields {
if field.isAnonymous() {
for _, ef := range field.embeddedFields {
anonymousCols[field.ColumnName] = append(anonymousCols[field.ColumnName], ef.ColumnName)
// LoadFile - loads the contents of one or more files
// stored in gridFS into the fields named by `fields`.
//
// gridFS fields can be either a `string` or `[]byte`, and are
// tagged with `gridfs:"BUCKET,FILE_FORMAT`
// where:
// - `BUCKET` is the name of the bucket where the files are stored
// - `FILE_FORMAT` is any valid go template string that resolves to
// the unique file name.
// all exported values and methods present in the surrounding
// struct can be used in this template.
func (q *Query) LoadFile(fields ...string) *Query {
_, cm, _ := ModelRegistry.HasByName(q.model.typeName)
if cm != nil {
for _, field := range fields {
var r gridFSReference
hasAnnotated := false
for k2, v := range cm.gridFSReferences {
if strings.HasPrefix(k2, field) {
r = v
hasAnnotated = true
break
}
}
if hasAnnotated {
q.doc = gridFsLoad(q.doc, r, field)
}
}
}
return q
}
func readFields(field string, m *Model) map[string]Reference {
r := make(map[string]Reference)
if m == nil {
return r
}
for k, v := range m.references {
if strings.HasPrefix(field, k) {
r[k] = v
}
}
if vv, ok := m.references[field]; ok {
r[field] = vv
}
return r
}
// Populate populates document references via reflection
func (q *Query) Populate(fields ...string) *Query {
vvv := reflect.ValueOf(q.doc)
if vvv.Kind() == reflect.Pointer {
vvv = vvv.Elem()
}
if vvv.Kind() == reflect.Slice {
typ := reflect.PointerTo(q.model.Type)
slic := reflect.New(reflect.SliceOf(typ))
for i := 0; i < vvv.Len(); i++ {
val2 := vvv.Index(i).Interface()
aid, docOk := val2.(IDocument)
if docOk {
rdoc := q.rawDoc.(bson.A)
aid.setRaw(rdoc[i])
aid.Populate(fields...)
if reflect.ValueOf(aid).Kind() == reflect.Pointer {
slic.Elem().Set(reflect.Append(slic.Elem(), reflect.ValueOf(aid)))
} else {
slic.Elem().Set(reflect.Append(slic, reflect.ValueOf(aid)))
}
}
}
q.doc = slic.Interface()
} else if asDoc, ok2 := q.doc.(IDocument); ok2 {
asDoc.setRaw(q.rawDoc)
asDoc.Populate(fields...)
}
return q
}
func (q *Query) reOrganize() {
var trvo reflect.Value
if arr, ok := q.rawDoc.(bson.A); ok {
typ := ModelRegistry[q.model.typeName].Type
if typ.Kind() != reflect.Pointer {
typ = reflect.PointerTo(typ)
}
slic := reflect.New(reflect.SliceOf(typ))
for _, v2 := range arr {
inter := reflect.ValueOf(rerere(v2, typ, false))
slic.Elem().Set(reflect.Append(slic.Elem(), inter))
}
trvo = slic.Elem()
} else {
trvo = reflect.ValueOf(rerere(q.rawDoc, reflect.TypeOf(q.doc), false))
}
resV := reflect.ValueOf(q.doc)
for {
if resV.Kind() == reflect.Pointer {
if resV.Elem().Kind() == reflect.Slice {
resV = resV.Elem()
} else {
break
}
} else {
break
}
}
if resV.CanSet() {
resV.Set(trvo)
} else {
resV.Elem().Set(trvo.Elem())
}
}
func rerere(input interface{}, resType reflect.Type, isJson bool) interface{} {
t := reflect.TypeOf(input)
v := reflect.ValueOf(input)
var key string
if isJson {
key = "json"
} else {
key = "bson"
}
if input == nil {
return nil
}
if v.Type().Kind() == reflect.Pointer {
v = v.Elem()
}
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if resType.Kind() == reflect.Pointer {
resType = resType.Elem()
}
var resV reflect.Value
var newInstance interface{}
if _, _, has := ModelRegistry.Has(reflect.New(resType).Elem().Interface()); has {
newInstance = Create(reflect.New(resType).Elem().Interface())
resV = reflect.ValueOf(newInstance)
} else {
newInstance = ModelRegistry.newForType(resType)
if newInstance == nil && resType.Kind() == reflect.Pointer {
newInstance = ModelRegistry.newForType(resType.Elem())
if newInstance == nil {
resV = reflect.New(resType)
} else {
resV = reflect.ValueOf(newInstance)
}
} else if newInstance != nil {
resV = reflect.ValueOf(newInstance)
} else {
resV = reflect.New(resType)
}
}
var rve = resV
if rve.Kind() == reflect.Pointer {
rve = resV.Elem()
}
if d, isD := v.Interface().(bson.D); isD {
m := bson.M{}
for _, el := range d {
m[el.Key] = el.Value
}
input = m
v = reflect.ValueOf(input)
}
switch resType.Kind() {
case reflect.Struct:
shouldBreak := false
mipmap, ok := v.Interface().(bson.M)
if !ok {
var omap map[string]interface{}
if omap, ok = v.Interface().(map[string]interface{}); ok {
mipmap = bson.M(omap)
}
}
if ok {
for i := 0; i < resType.NumField(); i++ {
ft := resType.Field(i)
fv := rve.Field(i)
if ft.Anonymous {
fv.Set(handleAnon(input, ft.Type, fv))
continue
}
tags, err := structtag.Parse(string(ft.Tag))
panik(err)
btag, err := tags.Get(key)
if err != nil {
continue
}
if btag.Name == "-" {
continue
}
intermediate := mipmap[btag.Name]
_, err = tags.Get("ref")
if err != nil {
var tmp interface{}
if ttmp, tok := intermediate.(bson.DateTime); tok {
tmp = ttmp.Time()
} else {
tmp = rerere(intermediate, ft.Type, isJson)
}
fuck := reflect.ValueOf(tmp)
if tmp != nil {
if fuck.Type().Kind() == reflect.Pointer {
fuck = fuck.Elem()
}
fv.Set(fuck)
}
shouldBreak = true
} else {
tt := ft.Type
if tt.Kind() == reflect.Pointer {
tt = tt.Elem()
}
tmp := rerere(intermediate, ft.Type, isJson)
if tmp != nil {
if reflect.ValueOf(tmp).Kind() == reflect.Pointer && fv.Kind() != reflect.Pointer {
fv.Set(reflect.ValueOf(tmp).Elem())
} else {
fv.Set(reflect.ValueOf(tmp))
}
}
}
}
if shouldBreak {
// break
}
} else {
nunu := ModelRegistry.new_(resType.Name())
ider, ok := nunu.(HasID)
if ok {
toConvert := reflect.ValueOf(ider.Id()).Type()
if v.Type() != toConvert {
if v.CanConvert(toConvert) {
ider.SetId(v.Convert(toConvert).Interface())
}
} else {
ider.SetId(v.Interface())
}
if reflect.ValueOf(ider).Kind() == reflect.Pointer {
nunu = reflect.ValueOf(ider).Elem().Interface()
}
rve.Set(reflect.ValueOf(nunu))
}
}
case reflect.Slice:
_, aOk := v.Interface().(bson.A)
var arr []interface{}
if !aOk {
arr = v.Interface().([]interface{})
} else {
arr = []interface{}(v.Interface().(bson.A))
}
for _, it := range arr {
if it != nil {
tmp := reflect.ValueOf(rerere(it, rve.Type().Elem(), isJson))
if tmp.Kind() == reflect.Pointer {
tmp = tmp.Elem()
}
rve.Set(reflect.Append(rve, tmp))
}
}
default:
if resType.AssignableTo(v.Type()) {
rve.Set(reflect.ValueOf(input))
} else {
switch rve.Interface().(type) {
case int, int32, int64, uint, uint32, uint64:
rve.Set(reflect.ValueOf(coerceInt(v, rve)))
}
}
}
return resV.Interface()
}
func handleAnon(raw interface{}, rtype reflect.Type, rval reflect.Value) reflect.Value {
f := rval
g := rtype
if rtype.Kind() == reflect.Pointer {
g = rtype.Elem()
}
if !f.CanSet() {
f = reflect.New(g)
f.Elem().Set(rval)
}
if rtype.Kind() != reflect.Struct {
return rval
}
for i := 0; i < rtype.NumField(); i++ {
typeField := rtype.Field(i)
valueField := f.Field(i)
tags, err := structtag.Parse(string(typeField.Tag))
if !typeField.IsExported() {
continue
}
cols = append(cols, field.ColumnName)
}
finalSb = sb.Select(cols...)
for _, cc := range anonymousCols {
finalSb = finalSb.Columns(cc...)
}
finalSb = finalSb.From(q.model.TableName)
if len(q.joins) > 0 {
idq := sb.Select(fmt.Sprintf("%s.%s", q.model.TableName, q.model.idField().ColumnName)).
Distinct().
From(q.model.TableName)
for w, arg := range q.wheres {
idq = idq.Where(w, arg...)
}
for _, j := range q.joins {
idq = idq.Join(j)
}
qq, qa := idq.MustSQL()
var rows pgx.Rows
rows, err = q.engine.conn.Query(q.ctx, qq, qa...)
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
var id any
if err = rows.Scan(&id); err != nil {
return
if err == nil {
var btag *structtag.Tag
btag, err = tags.Get("bson")
if err != nil {
continue
}
inParents = append(inParents, id)
}
if len(inParents) == 0 {
return
}
}
if len(inParents) > 0 {
finalSb = finalSb.Where(
fmt.Sprintf("%s IN (%s)",
q.model.idField().ColumnName,
MakePlaceholders(len(inParents))), inParents...)
} else if len(q.wheres) > 0 {
for k, vv := range q.wheres {
finalSb = finalSb.Where(k, vv...)
}
}
ool:
for _, o := range q.orders {
ac, ok := q.model.Fields[o]
if !ok {
var rel = q.model.Relationships[o]
if btag.Name == "-" {
continue
}
amap, ok := raw.(bson.M)
if ok {
fval := amap[btag.Name]
if reflect.TypeOf(fval) == reflect.TypeFor[bson.DateTime]() {
fval = fval.(bson.DateTime).Time()
}
if valueField.Kind() == reflect.Pointer {
valueField.Elem().Set(reflect.ValueOf(fval))
} else {
if reflect.TypeOf(fval) == reflect.TypeFor[string]() && typeField.Type == reflect.TypeFor[time.Time]() {
tt, _ := time.Parse(time.RFC3339, fval.(string))
valueField.Set(reflect.ValueOf(tt))
} else if fval != nil {
valueField.Set(reflect.ValueOf(fval))
if rel != nil {
if strings.Contains(o, ".") {
split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".")
cm := rel.Model
for i, s := range split {
if rel != nil {
cm = rel.RelatedModel
} else if i == len(split)-1 {
break
} else {
continue ool
}
rel = cm.Relationships[s]
}
lf := split[len(split)-1]
ac, ok = cm.Fields[lf]
if !ok {
continue
}
}
}
}
finalSb = finalSb.OrderBy(ac.ColumnName)
}
if q.limit > 0 {
finalSb = finalSb.Limit(uint64(q.limit))
}
if q.offset > 0 {
finalSb = finalSb.Offset(uint64(q.offset))
}
return
return f
}
// JSON - marshals this Query's results into json format
func (q *Query) JSON() (string, error) {
res, err := json.MarshalIndent(q.doc, "", "\t")
if err != nil {
return "", err
}
return string(res[:]), nil
}
// Exec - executes the query and puts its results into the
// provided argument.
//
// Will panic if called more than once on the same Query instance.
func (q *Query) Exec(result interface{}) {
if q.done {
panic("Exec() has already been called!")
}
doc := reflect.ValueOf(q.doc)
if doc.Elem().Kind() == reflect.Slice {
for i := 0; i < doc.Elem().Len(); i++ {
cur := doc.Elem().Index(i)
imodel, ok := cur.Interface().(IDocument)
if ok {
imodel.setExists(true)
imodel.SetSelf(imodel)
doc.Elem().Index(i).Set(reflect.ValueOf(imodel))
}
}
}
if idoc, ok := q.doc.(IDocument); ok {
idoc.SetSelf(result)
}
if rdoc, ok2 := result.(IDocument); ok2 {
rdoc.SetSelf(result)
}
reflect.ValueOf(result).Elem().Set(reflect.ValueOf(q.doc).Elem())
q.done = true
}

View File

@ -1,400 +0,0 @@
package orm
import (
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"reflect"
"strings"
)
const PopulateAll = "~~~ALL~~~"
// Populate - allows you to pre-load embedded structs/slices within the current model.
// use dots between field names to specify nested paths. use the PopulateAll constant to populate all
// relationships non-recursively
func (q *Query) Populate(fields ...string) *Query {
if q.populationTree == nil {
q.populationTree = make(map[string]any)
}
for _, field := range fields {
if field == PopulateAll {
for k := range q.model.Relationships {
if _, ok := q.populationTree[k]; !ok {
q.populationTree[k] = make(map[string]any)
}
}
continue
}
cur := q.populationTree
parts := strings.Split(field, ".")
for _, part := range parts {
if _, ok := cur[part]; !ok {
cur[part] = make(map[string]any)
}
cur = cur[part].(map[string]any)
}
}
return q
}
func (q *Query) processPopulate(parent reflect.Value, model *Model, populationTree map[string]any) error {
if parent.Len() == 0 {
return nil
}
pids := make([]any, 0)
var err error
idField := model.IDField
for i := range parent.Len() {
pval := parent.Index(i)
if pval.Kind() == reflect.Pointer {
pval = pval.Elem()
}
pids = append(pids, pval.FieldByName(idField).Interface())
}
toClose := make([]pgx.Rows, 0)
defer func() {
for _, c := range toClose {
c.Close()
}
}()
for p, nested := range populationTree {
var rel *Relationship
for _, r := range model.Relationships {
if r.FieldName == p {
rel = r
break
}
}
if rel == nil {
return fmt.Errorf("field '%s' not found in model '%s'", p, model.Name)
}
childSlice := reflect.Value{}
if (rel.Type == HasMany || rel.Type == HasOne) && !rel.m2mIsh() {
childSlice, err = q.populateHas(rel, parent, pids)
} else if rel.Type == BelongsTo {
childSlice, err = q.populateBelongsTo(rel, parent, pids)
} else if rel.Type == ManyToMany || rel.m2mIsh() {
childSlice, err = q.populateManyToMany(rel, parent, pids)
}
if err != nil {
return fmt.Errorf("failed to populate field at '%s': %w", p, err)
}
ntree, ok := nested.(map[string]any)
if ok && len(ntree) > 0 && childSlice.IsValid() && childSlice.Len() > 0 {
if err = q.processPopulate(childSlice, rel.RelatedModel, ntree); err != nil {
return err
}
}
}
return nil
}
func (q *Query) populateHas(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) {
fkf := rel.primaryID()
var fk string
if fkf != nil && fkf.ColumnType != "" {
fk = fkf.ColumnName
} else if rel.relatedID() != nil {
fk = pascalToSnakeCase(rel.RelatedModel.Name + rel.relatedID().Name)
}
if rel.RelatedModel.embeddedIsh && !rel.Model.embeddedIsh && rel.Type == HasMany {
arel := rel.RelatedModel.Relationships[rel.Model.Name]
fk = pascalToSnakeCase(arel.joinField())
}
ccols := make([]string, 0)
anonymousCols := make(map[string]map[string]*Field)
for _, f := range rel.RelatedModel.Fields {
if !f.isAnonymous() {
ccols = append(ccols, f.ColumnName)
}
}
for _, f := range rel.RelatedModel.Fields {
if f.isAnonymous() {
ccols = append(ccols, f.anonymousColumnNames()...)
anonymousCols[f.Name] = f.embeddedFields
}
}
for _, r := range rel.RelatedModel.Relationships {
if r.Type != ManyToOne {
continue
}
ccols = append(ccols, pascalToSnakeCase(r.joinField()))
}
/*var tableName string
if rel.Type == HasOne {
tableName = rel.Model.TableName
}
if rel.Type == HasMany {
tableName = rel.RelatedModel.TableName
}*/
aq, aa := sb.Select(ccols...).
From(rel.RelatedModel.TableName).
Where(fmt.Sprintf("%s IN (%s)", fk, MakePlaceholders(len(parentIds))), parentIds...).MustSQL()
q.engine.logQuery("populate", aq, aa)
rows, err := q.engine.conn.Query(q.ctx, aq, aa...)
if err != nil {
return reflect.Value{}, err
}
defer rows.Close()
idFieldName := rel.Model.IDField
idField := rel.Model.Fields[idFieldName]
if rel.Type == HasMany {
childMap := reflect.MakeMap(reflect.MapOf(
idField.Type,
reflect.SliceOf(rel.RelatedModel.Type),
))
for rows.Next() {
child := reflect.New(rel.RelatedModel.Type).Elem()
var fkValue any
scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &fkValue)
if err = rows.Scan(scanDest...); err != nil {
return reflect.Value{}, err
}
fkVal := reflect.ValueOf(fkValue)
childrenOfParent := childMap.MapIndex(fkVal)
if !childrenOfParent.IsValid() {
childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0)
}
childrenOfParent = reflect.Append(childrenOfParent, child)
childMap.SetMapIndex(fkVal, childrenOfParent)
}
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Pointer {
ps = ps.Elem()
}
pid := ps.FieldByName(idFieldName)
c := childMap.MapIndex(pid)
if c.IsValid() {
ps.FieldByName(rel.FieldName).Set(c)
}
}
} else {
childMap := reflect.MakeMap(reflect.MapOf(idField.Type, rel.RelatedModel.Type))
for rows.Next() {
child := reflect.New(rel.RelatedModel.Type).Elem()
var fkValue any
scanDest, _ := buildScanDest(child, rel.Model, rel, ccols, anonymousCols, &fkValue)
if err = rows.Scan(scanDest...); err != nil {
return reflect.Value{}, err
}
fkVal := reflect.ValueOf(fkValue)
childMap.SetMapIndex(fkVal, child)
}
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Pointer {
ps = ps.Elem()
}
parentID := ps.FieldByName(idFieldName)
if child := childMap.MapIndex(parentID); child.IsValid() {
ps.FieldByName(rel.FieldName).Set(child)
}
}
}
childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0)
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Ptr {
ps = ps.Elem()
}
childField := ps.FieldByName(rel.FieldName)
if !childField.IsValid() {
continue
}
if rel.Type == HasMany {
for j := range childField.Len() {
childSlice = reflect.Append(childSlice, childField.Index(j).Addr())
}
} else {
if !childField.IsZero() {
childSlice = reflect.Append(childSlice, childField.Addr())
}
}
}
return childSlice, nil
}
func (q *Query) populateManyToMany(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) {
inPlaceholders := MakePlaceholders(len(parentIds))
ccols := make([]string, 0)
anonymousCols := make(map[string]map[string]*Field)
for _, f := range rel.RelatedModel.Fields {
if !f.isAnonymous() {
ccols = append(ccols, "m."+f.ColumnName)
}
}
for _, f := range rel.RelatedModel.Fields {
if f.isAnonymous() {
for ecol := range f.embeddedFields {
ccols = append(ccols, "m."+ecol)
}
anonymousCols[f.Name] = f.embeddedFields
}
}
ccols = append(ccols, fmt.Sprintf("jt.%s_id", rel.Model.TableName))
mq, ma := sb.Select(ccols...).
From(fmt.Sprintf("%s AS m", rel.RelatedModel.TableName)).
Join(
fmt.Sprintf("%s AS jt ON m.%s = jt.%s_id",
rel.ComputeJoinTable(),
rel.relatedID().ColumnName, rel.RelatedModel.TableName)).
Where(fmt.Sprintf("jt.%s_id IN (%s)",
rel.Model.TableName, inPlaceholders), parentIds...).MustSQL()
q.engine.logQuery("populate/join", mq, ma)
rows, err := q.engine.conn.Query(q.ctx, mq, ma...)
if err != nil {
return reflect.Value{}, err
}
defer rows.Close()
idFieldName := rel.Model.IDField
idField := rel.Model.Fields[idFieldName]
childMap := reflect.MakeMap(reflect.MapOf(
idField.Type,
reflect.SliceOf(rel.RelatedModel.Type)))
for rows.Next() {
child := reflect.New(rel.RelatedModel.Type).Elem()
var foreignKeyValue any
scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &foreignKeyValue)
if err = rows.Scan(scanDest...); err != nil {
return reflect.Value{}, err
}
fkVal := reflect.ValueOf(foreignKeyValue)
childrenOfParent := childMap.MapIndex(fkVal)
if !childrenOfParent.IsValid() {
childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0)
}
childrenOfParent = reflect.Append(childrenOfParent, child)
childMap.SetMapIndex(fkVal, childrenOfParent)
}
for i := range parent.Len() {
p := parent.Index(i)
if p.Kind() == reflect.Ptr {
p = p.Elem()
}
parentID := p.FieldByName(rel.primaryID().Name)
if children := childMap.MapIndex(parentID); children.IsValid() {
p.FieldByName(rel.FieldName).Set(children)
}
}
childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0)
for i := range parent.Len() {
ps := parent.Index(i)
if ps.Kind() == reflect.Ptr {
ps = ps.Elem()
}
childField := ps.FieldByName(rel.FieldName)
if childField.IsValid() {
for j := range childField.Len() {
childSlice = reflect.Append(childSlice, childField.Index(j).Addr())
}
}
}
return childSlice, nil
}
func (q *Query) populateBelongsTo(rel *Relationship, childrenSlice reflect.Value, childIDs []any) (reflect.Value, error) {
childIdField := rel.Model.Fields[rel.Model.IDField]
parentIdField := rel.RelatedModel.Fields[rel.RelatedModel.IDField]
fk := pascalToSnakeCase(rel.joinField())
qs, qa := sb.Select(childIdField.ColumnName, fk).
From(rel.Model.TableName).
Where(fmt.Sprintf("%s IN (%s)",
childIdField.ColumnName, MakePlaceholders(len(childIDs)),
), childIDs...).MustSQL()
q.engine.logQuery("populate/belongs-to", qs, qa)
rows, err := q.engine.conn.Query(q.ctx, qs, qa...)
if err != nil {
return reflect.Value{}, err
}
childParentKeyMap := make(map[any]any)
parentKeyValues := make([]any, 0)
parentKeySet := make(map[any]bool)
for rows.Next() {
var cid, pfk any
err = rows.Scan(&cid, &pfk)
if err != nil {
rows.Close()
return reflect.Value{}, err
}
childParentKeyMap[cid] = pfk
if !parentKeySet[pfk] {
parentKeySet[pfk] = true
parentKeyValues = append(parentKeyValues, pfk)
}
}
rows.Close()
if len(parentKeyValues) == 0 {
return reflect.Value{}, nil
}
pcols := make([]string, 0)
anonymousCols := make(map[string]map[string]*Field)
for _, f := range rel.RelatedModel.Fields {
if !f.isAnonymous() {
pcols = append(pcols, f.ColumnName)
}
}
for _, f := range rel.RelatedModel.Fields {
if f.isAnonymous() {
pcols = append(pcols, f.anonymousColumnNames()...)
anonymousCols[f.Name] = f.embeddedFields
}
}
pquery, pqargs := sb.Select(pcols...).
From(rel.RelatedModel.TableName).
Where(fmt.Sprintf("%s IN (%s)",
parentIdField.ColumnName,
MakePlaceholders(len(parentKeyValues))), parentKeyValues...).
MustSQL()
q.engine.logQuery("populate/belongs-to->parent", pquery, pqargs)
parentRows, err := q.engine.conn.Query(q.ctx, pquery, pqargs...)
if err != nil {
return reflect.Value{}, err
}
defer parentRows.Close()
parentMap := reflect.MakeMap(reflect.MapOf(
parentIdField.Type,
rel.RelatedModel.Type,
))
for parentRows.Next() {
parent := reflect.New(rel.RelatedModel.Type).Elem()
scanDst, _ := buildScanDest(parent, rel.RelatedModel, rel, pcols, anonymousCols, nil)
if err = parentRows.Scan(scanDst...); err != nil {
return reflect.Value{}, err
}
parentId := parent.FieldByName(rel.RelatedModel.IDField)
parentMap.SetMapIndex(parentId, parent)
}
for i := range childrenSlice.Len() {
child := childrenSlice.Index(i)
childID := child.FieldByName(rel.Model.IDField)
if parentKey, ok := childParentKeyMap[childID.Interface()]; ok && parentKey != nil {
if parent := parentMap.MapIndex(reflect.ValueOf(parentKey)); parent.IsValid() {
child.FieldByName(rel.FieldName).Set(parent)
}
}
}
ntype := rel.RelatedModel.Type
if rel.Kind == reflect.Pointer {
ntype = reflect.PointerTo(rel.RelatedModel.Type)
}
parentSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(ntype)), 0, 0)
for i := range childrenSlice.Len() {
ps := childrenSlice.Index(i)
if ps.Kind() == reflect.Ptr {
ps = ps.Elem()
}
childField := ps.FieldByName(rel.FieldName)
if childField.IsValid() {
parentSlice = reflect.Append(parentSlice, childField.Addr())
}
}
return parentSlice, nil
}

View File

@ -1,369 +0,0 @@
package orm
import (
"fmt"
sb "github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"reflect"
"time"
)
// Find - transpiles this query into SQL and places the result in `dest`
func (q *Query) Find(dest any) error {
dstVal := reflect.ValueOf(dest)
if dstVal.Kind() != reflect.Ptr {
return fmt.Errorf("destination must be a pointer, got: %v", dstVal.Kind())
}
maybeSlice := dstVal.Elem()
cols, acols, sqlb, err := q.buildSQL()
if err != nil {
return err
}
qq, qa := sqlb.MustSQL()
q.engine.logQuery("find", qq, qa)
if maybeSlice.Kind() == reflect.Struct {
row := q.engine.conn.QueryRow(q.ctx, qq, qa...)
if err = scanRow(row, cols, acols, maybeSlice, q.model); err != nil {
return err
}
} else if maybeSlice.Kind() == reflect.Slice ||
maybeSlice.Kind() == reflect.Array {
var rows pgx.Rows
rows, err = q.engine.conn.Query(q.ctx, qq, qa...)
if err != nil {
return err
}
defer rows.Close()
etype := maybeSlice.Type().Elem()
for rows.Next() {
nelem := reflect.New(etype).Elem()
if err = scanRow(rows, cols, acols, nelem, q.model); err != nil {
return err
}
maybeSlice.Set(reflect.Append(maybeSlice, nelem))
}
} else {
return fmt.Errorf("unsupported destination type: %s", maybeSlice.Kind())
}
if len(q.populationTree) > 0 {
nslice := maybeSlice
var wasPassedStruct bool
if nslice.Kind() == reflect.Struct {
nslice = reflect.MakeSlice(reflect.SliceOf(maybeSlice.Type()), 0, 0)
wasPassedStruct = true
nslice = reflect.Append(nslice, maybeSlice)
}
err = q.processPopulate(nslice, q.model, q.populationTree)
if err == nil && wasPassedStruct {
maybeSlice.Set(nslice.Index(0))
}
return err
}
return nil
}
// Save - create or update `val` in the database
func (q *Query) Save(val any) error {
return q.saveOrCreate(val, false)
}
// Create - like Save, but hints to the query processor that you want to insert, not update.
// useful if you're importing data and want to keep the IDs intact.
func (q *Query) Create(val any) error {
return q.saveOrCreate(val, true)
}
// UpdateRaw - takes a mapping of struct field names to
// SQL expressions, updating each field's associated column accordingly
func (q *Query) UpdateRaw(values map[string]any) (int64, error) {
var err error
var subQuery sb.SelectBuilder
stmt := sb.Update(q.model.TableName)
_, _, subQuery, err = q.buildSQL()
if err != nil {
return 0, err
}
subQuery = sb.Select(q.model.idField().ColumnName).FromSelect(subQuery, "subQuery")
stmt = stmt.Where(wrapQueryIn(subQuery,
q.model.idField().ColumnName))
for k, v := range values {
asString, isString := v.(string)
if f, ok := q.model.Fields[k]; ok {
if isString {
stmt = stmt.Set(f.ColumnName, sb.Expr(asString))
} else {
stmt = stmt.Set(f.ColumnName, v)
}
}
if _, ok := q.model.FieldsByColumnName[k]; ok {
if isString {
stmt = stmt.Set(k, sb.Expr(asString))
} else {
stmt = stmt.Set(k, v)
}
}
}
sql, args := stmt.MustSQL()
q.engine.logQuery("update/raw", sql, args)
q.tx, err = q.engine.conn.Begin(q.ctx)
if err != nil {
return 0, err
}
defer q.cleanupTx()
ctag, err := q.tx.Exec(q.ctx, sql, args...)
if err != nil {
return 0, err
}
return ctag.RowsAffected(), q.tx.Commit(q.ctx)
}
// Delete - delete one or more entities matching previous conditions specified
// by methods like Where, WhereRaw, or In. will refuse to execute if no
// conditions were specified for safety reasons. to override this, call
// WhereRaw("true") or WhereRaw("1 = 1") before this method.
func (q *Query) Delete() (int64, error) {
var err error
var subQuery sb.SelectBuilder
if len(q.wheres) < 1 {
return 0, ErrNoConditionOnDeleteOrUpdate
}
q.tx, err = q.engine.conn.Begin(q.ctx)
if err != nil {
return 0, err
}
defer q.cleanupTx()
_, _, subQuery, err = q.buildSQL()
if err != nil {
return 0, err
}
sqlb := sb.Delete(q.model.TableName).Where(subQuery)
sql, sqla := sqlb.MustSQL()
q.engine.logQuery("delete", sql, sqla)
cmdTag, err := q.tx.Exec(q.ctx, sql, sqla...)
if err != nil {
return 0, fmt.Errorf("failed to delete: %w", err)
}
return cmdTag.RowsAffected(), nil
}
func (q *Query) saveOrCreate(val any, shouldCreate bool) error {
v := reflect.ValueOf(val)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("Save() must be called with a pointer to a struct")
}
var err error
q.tx, err = q.engine.conn.BeginTx(q.ctx, pgx.TxOptions{
AccessMode: pgx.ReadWrite,
IsoLevel: pgx.ReadUncommitted,
})
if err != nil {
return err
}
defer q.cleanupTx()
if _, err = q.doSave(v.Elem(), q.engine.modelMap.Map[v.Elem().Type().Name()], nil, shouldCreate); err != nil {
return err
}
return q.tx.Commit(q.ctx)
}
func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any, shouldInsert bool) (any, error) {
idField := model.Fields[model.IDField]
var pkField reflect.Value
if val.Kind() == reflect.Pointer {
if !val.Elem().IsValid() || val.Elem().IsZero() {
return nil, nil
}
pkField = val.Elem().FieldByName(model.IDField)
} else {
pkField = val.FieldByName(model.IDField)
}
isNew := pkField.IsZero()
var exists bool
if !pkField.IsZero() {
eb := sb.Select("1").
Prefix("SELECT EXISTS (").
From(model.TableName).
Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()).
Suffix(")")
ebs, eba := eb.MustSQL()
var ex bool
err := q.tx.QueryRow(q.ctx, ebs, eba...).Scan(&ex)
if err != nil {
q.engine.logger.Warn("error while checking existence", "err", err.Error())
}
exists = ex
}
/*{
el, ok := q.seenIds[model]
if !ok {
q.seenIds[model] = make(map[any]bool)
}
if ok && el[pkField.Interface()] {
return pkField.Interface(), nil
}
if !isNew {
q.seenIds[model][pkField.Interface()] = true
}
}*/
doInsert := isNew || !exists
var cols []string
args := make([]any, 0)
seenJoinTables := make(map[string]map[any]bool)
for _, rel := range model.Relationships {
if rel.Type != BelongsTo {
continue
}
parentVal := val.FieldByName(rel.FieldName)
if parentVal.IsValid() {
nid, err := q.doSave(parentVal, rel.RelatedModel, nil, rel.RelatedModel.needsPrimaryKey(parentVal) && isNew)
if err != nil {
return nil, err
}
cols = append(cols, pascalToSnakeCase(rel.joinField()))
args = append(args, nid)
} else if parentVal.IsValid() {
_, nid := rel.RelatedModel.getPrimaryKey(parentVal)
cols = append(cols, pascalToSnakeCase(rel.joinField()))
args = append(args, nid)
}
}
for _, ff := range model.Fields {
var fv reflect.Value
if ff.Index > -1 && !ff.isAnonymous() {
fv = val.Field(ff.Index)
} else if ff.Index > -1 {
for col, ef := range ff.embeddedFields {
fv = val.Field(ff.Index)
cols = append(cols, col)
eif := fv.FieldByName(ef.Name)
if ff.Name == documentField && canConvertTo[Document](ff.Type) {
asTime, ok := eif.Interface().(time.Time)
shouldCreate := ok && (asTime.IsZero() || eif.IsZero())
if doInsert && ef.Name == createdField && shouldCreate {
eif.Set(reflect.ValueOf(time.Now()))
} else if ef.Name == modifiedField || shouldCreate {
eif.Set(reflect.ValueOf(time.Now()))
}
args = append(args, eif.Interface())
continue
}
args = append(args, fv.FieldByName(ef.Name).Interface())
}
continue
}
if ff.Name == model.IDField {
if !isNew && fv.IsValid() {
cols = append(cols, ff.ColumnName)
args = append(args, fv.Interface())
}
continue
}
if fv.IsValid() {
cols = append(cols, ff.ColumnName)
args = append(args, fv.Interface())
}
}
for k, fk := range parentFks {
cols = append(cols, k)
args = append(args, fk)
}
var qq string
var qa []any
if doInsert {
osb := sb.Insert(model.TableName)
if len(cols) == 0 {
qq = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", model.TableName, idField.ColumnName)
} else {
osb = osb.Columns(cols...).Values(args...)
qq, qa = osb.Returning(idField.ColumnName).MustSQL()
}
} else {
osb := sb.Update(model.TableName)
for i := range cols {
osb = osb.Set(cols[i], args[i])
}
osb = osb.Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface())
qq, qa = osb.MustSQL()
}
if doInsert {
var nid any
q.engine.logQuery("insert", qq, qa)
row := q.tx.QueryRow(q.ctx, qq, qa...)
err := row.Scan(&nid)
if err != nil {
return nil, fmt.Errorf("insert failed for model %s: %w", model.Name, err)
}
pkField.Set(reflect.ValueOf(nid))
} else {
q.engine.logQuery("update", qq, qa)
_, err := q.tx.Exec(q.ctx, qq, qa...)
if err != nil {
return nil, fmt.Errorf("update failed for model %s: %w", model.Name, err)
}
}
/*if _, ok := q.seenIds[model]; !ok {
q.seenIds[model] = make(map[any]bool)
}
q.seenIds[model][pkField.Interface()] = true*/
for _, rel := range model.Relationships {
if rel.Idx > -1 && rel.Idx < val.NumField() {
fv := val.FieldByName(rel.FieldName)
cm := rel.RelatedModel
pfks := map[string]any{}
if !model.embeddedIsh && rel.Type == HasMany {
{
rm := cm.Relationships[model.Name]
if rm != nil && rm.Type == ManyToOne {
pfks[pascalToSnakeCase(rm.joinField())] = pkField.Interface()
}
}
for j := range fv.Len() {
child := fv.Index(j).Addr().Elem()
if _, err := q.doSave(child, cm, pfks, cm.needsPrimaryKey(child)); err != nil {
return nil, err
}
}
} else if rel.Type == HasOne && cm.embeddedIsh {
if _, err := q.doSave(fv, cm, pfks, cm.needsPrimaryKey(fv)); err != nil {
return nil, err
}
} else if rel.m2mIsh() || rel.Type == ManyToMany || (model.embeddedIsh && cm.embeddedIsh && rel.Type == HasMany) {
if seenJoinTables[rel.ComputeJoinTable()] == nil {
seenJoinTables[rel.ComputeJoinTable()] = make(map[any]bool)
}
if !seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] {
seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] = true
if err := rel.joinDelete(pkField.Interface(), nil, q); err != nil {
return nil, fmt.Errorf("error deleting existing association: %w", err)
}
}
if fv.Kind() == reflect.Slice || fv.Kind() == reflect.Array {
mField := model.Fields[model.IDField]
mpks := map[string]any{}
if !model.embeddedIsh {
mpks[model.TableName+"_"+mField.ColumnName] = pkField.Interface()
}
for i := range fv.Len() {
cur := fv.Index(i)
if _, err := q.doSave(cur, cm, mpks, cm.needsPrimaryKey(cur) && pkField.IsZero()); err != nil {
return nil, err
}
if rel.m2mIsh() || rel.Type == ManyToMany {
if err := rel.joinInsert(cur, q, pkField.Interface()); err != nil {
return nil, fmt.Errorf("failed to insert association for model %s: %w", model.Name, err)
}
}
}
}
}
}
}
return pkField.Interface(), nil
}

417
registry.go Normal file
View File

@ -0,0 +1,417 @@
package orm
import (
context2 "context"
"fmt"
"log"
"reflect"
"sync"
"time"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"golang.org/x/net/context"
)
// Reference stores a typed document reference
type Reference struct {
// owning model name
Model string
// the name of the struct field
FieldName string
// index of field in owning struct
Idx int
// the type of the referenced object
HydratedType reflect.Type
// field kind (struct, slice, ...)
Kind reflect.Kind
exists bool
}
type gridFSReference struct {
BucketName string
FilenameFmt string
LoadType reflect.Type
Idx int
}
type TModelRegistry map[string]*Model
// ModelRegistry - the ModelRegistry stores a map containing
// pointers to Model instances, keyed by an associated
// model name
var ModelRegistry = make(TModelRegistry)
// DB - The mongodb database handle
var DB *mongo.Database
// DBClient - The mongodb client
var DBClient *mongo.Client
// NextStringID - Override this function with your own
// string ID generator!
var NextStringID func() string
var mutex sync.Mutex
func makeGfsRef(tag *structtag.Tag, idx int) gridFSReference {
opts := tag.Options
var ffmt string
if len(opts) < 1 {
ffmt = "%s"
} else {
ffmt = opts[0]
}
var typ reflect.Type
if len(opts) < 2 {
typ = reflect.TypeOf("")
} else {
switch opts[1] {
case "bytes":
typ = reflect.TypeOf([]byte{})
case "string":
typ = reflect.TypeOf("")
default:
typ = reflect.TypeOf("")
}
}
return gridFSReference{
FilenameFmt: ffmt,
BucketName: tag.Name,
LoadType: typ,
Idx: idx,
}
}
func makeRef(idx int, modelName string, fieldName string, ht reflect.Type) Reference {
if modelName != "" {
if ModelRegistry.Index(modelName) != -1 {
return Reference{
Idx: idx,
Model: modelName,
HydratedType: ht,
Kind: ht.Kind(),
exists: true,
FieldName: fieldName,
}
}
return Reference{
Idx: idx,
Model: modelName,
FieldName: fieldName,
HydratedType: ht,
Kind: ht.Kind(),
exists: true,
}
}
panic("model name was empty")
}
type parseResult []string
func (p parseResult) includes(str string) bool {
for _, v := range p {
if v == str {
return true
}
}
return false
}
func parseTags(t reflect.Type, v reflect.Value, lastParsed parseResult, depth int) (map[string][]InternalIndex, map[string]Reference, map[string]gridFSReference, string) {
coll := ""
refs := make(map[string]Reference)
idcs := make(map[string][]InternalIndex)
gfsRefs := make(map[string]gridFSReference)
if depth >= 4 {
return idcs, refs, gfsRefs, coll
}
for i := 0; i < v.NumField(); i++ {
sft := t.Field(i)
ft := sft.Type
tags, err := structtag.Parse(string(sft.Tag))
panik(err)
switch ft.Kind() {
case reflect.Slice:
ft = ft.Elem()
fallthrough
case reflect.Pointer:
if ft.Kind() == reflect.Pointer {
ft = ft.Elem()
}
fallthrough
case reflect.Struct:
if ft.ConvertibleTo(reflect.TypeOf(Document{})) {
collTag, err := tags.Get("coll")
panik(err)
coll = collTag.Name
idxTag, err := tags.Get("idx")
if err == nil {
idcs[sft.Type.Name()] = scanIndex(idxTag.Value())
}
continue
}
if lastParsed.includes(sft.Name) {
continue
}
blip := lastParsed
blip = append(blip, sft.Name)
if ft.Kind() == reflect.Struct && ft != reflect.TypeFor[time.Time]() {
ii2, rr2, gg2, _ := parseTags(ft, reflect.New(ft).Elem(), blip, depth+1)
for k, vv := range ii2 {
idcs[sft.Name+"."+k] = vv
}
for k, vv := range rr2 {
refs[sft.Name+"."+k] = vv
}
for k, vv := range gg2 {
gfsRefs[sft.Name+"."+k] = vv
}
}
if refTag, ok := tags.Get("ref"); ok == nil {
sname := sft.Name
refs[sname] = makeRef(i, refTag.Name, sft.Name, sft.Type)
}
if gtag, ok := tags.Get("gridfs"); ok == nil {
sname := sft.Name + "@" + gtag.Name
gfsRefs[sname] = makeGfsRef(gtag, i)
}
fallthrough
default:
idxTag, err := tags.Get("idx")
if err == nil {
idcs[sft.Name] = scanIndex(idxTag.Value())
}
if gtag, ok := tags.Get("gridfs"); ok == nil {
sname := sft.Name + "@" + gtag.Name
gfsRefs[sname] = makeGfsRef(gtag, i)
}
}
}
return idcs, refs, gfsRefs, coll
}
// Has returns the model typename and Model instance corresponding
// to the argument passed, as well as a boolean indicating whether it
// was found. otherwise returns `"", nil, false`
func (r TModelRegistry) Has(i interface{}) (string, *Model, bool) {
t := reflect.TypeOf(i)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
n := t.Name()
if rT, ok := ModelRegistry[n]; ok {
return n, rT, true
}
return "", nil, false
}
// HasByName functions almost identically to Has,
// except that it takes a string as its argument.
func (r TModelRegistry) HasByName(n string) (string, *Model, bool) {
if t, ok := ModelRegistry[n]; ok {
return n, t, true
}
return "", nil, false
}
// Index returns the index at which the Document struct is embedded
func (r TModelRegistry) Index(n string) int {
if v, ok := ModelRegistry[n]; ok {
return v.idx
}
return -1
}
func (r TModelRegistry) new_(n string) interface{} {
if name, m, ok := ModelRegistry.HasByName(n); ok {
v := reflect.New(m.Type)
av := reflect.New(reflect.PointerTo(m.Type)).Elem()
av.Addr().Elem().Set(v)
df := av.Addr().Elem().Elem().Field(m.idx)
ado := reflect.New(reflect.PointerTo(df.Type())).Elem()
do := reflect.New(df.Type())
ado.Addr().Elem().Set(do)
d := ado.Addr().Elem().Interface().(IDocument)
d.newPopulationMap()
//d := df.Interface().(IDocument)
for k := range m.references {
d.markDepopulated(k)
}
d.setModel(*m)
d.getModel().typeName = name
d.SetSelf(av.Interface())
df.Set(reflect.ValueOf(d).Elem())
return av.Interface()
}
return nil
}
func (r TModelRegistry) newForType(rt reflect.Type) interface{} {
return r.new_(rt.Name())
}
func (r TModelRegistry) Get(name string) *Model {
model, ok := r[name]
if !ok {
return nil
}
return model
}
// Model registers models in the ModelRegistry, where
// they can be accessed via a model's struct name
func (r TModelRegistry) Model(mdl ...any) {
defer mutex.Unlock()
mutex.Lock()
for _, m := range mdl {
t := reflect.TypeOf(m)
v := reflect.ValueOf(m)
vp := v
if vp.Kind() != reflect.Ptr {
vp = reflect.New(v.Type())
vp.Elem().Set(v)
}
id, ok := vp.Interface().(HasID)
if !ok {
panic(fmt.Sprintf("you MUST implement the HasID interface!!! skipping...\n"))
}
switch (id).Id().(type) {
case int, int64, int32, string, bson.ObjectID, uint, uint32, uint64:
break
default:
log.Printf("invalid ID type specified!!! skipping...\n")
}
if t.Kind() == reflect.Ptr {
t = reflect.Indirect(reflect.ValueOf(m)).Type()
v = reflect.ValueOf(m).Elem()
}
n := t.Name()
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("Only structs can be passed to this function, silly! (passed type: %s)", n))
}
idx := -1
for i := 0; i < v.NumField(); i++ {
ft := t.Field(i)
if (ft.Type.ConvertibleTo(reflect.TypeOf(Document{}))) {
idx = i
break
}
}
if idx < 0 {
panic("A model must embed the Document struct!")
}
inds, refs, gfs, coll := parseTags(t, v, make(parseResult, 0), 0)
if coll == "" {
panic(fmt.Sprintf("a Document needs to be given a collection name! (passed type: %s)", n))
}
ModelRegistry[n] = &Model{
idx: idx,
Type: t,
collection: coll,
Indexes: inds,
references: refs,
gridFSReferences: gfs,
typeName: n,
}
}
for k, v := range ModelRegistry {
for k2, v2 := range v.references {
if !v2.exists {
if _, ok := ModelRegistry[v2.FieldName]; ok {
tmp := ModelRegistry[k].references[k2]
ModelRegistry[k].references[k2] = Reference{
Model: k,
Idx: tmp.Idx,
FieldName: tmp.FieldName,
Kind: tmp.Kind,
HydratedType: tmp.HydratedType,
exists: true,
}
}
}
}
}
}
func innerWatch(coll *mongo.Collection) {
sspipeline := mongo.Pipeline{
bson.D{{"$match", bson.D{{"$or",
bson.A{
bson.D{{
"operationType", "insert",
}},
bson.D{{"operationType", "update"}},
},
}},
}},
}
stream, err := coll.Watch(context.TODO(), sspipeline, options.ChangeStream().SetFullDocument(options.UpdateLookup).SetFullDocumentBeforeChange(options.WhenAvailable))
if err != nil {
return
}
defer func(stream *mongo.ChangeStream, ctx context2.Context) {
err := stream.Close(ctx)
panik(err)
}(stream, context.TODO())
for stream.Next(context.TODO()) {
var data bson.M
if err := stream.Decode(&data); err != nil {
log.Fatal(err)
}
var uid any
docKey := data["documentKey"]
switch docKey.(type) {
case bson.M:
uid = docKey.(bson.M)["_id"]
case bson.D:
for _, vv := range docKey.(bson.D) {
if vv.Key == "_id" {
uid = vv.Value
break
}
}
}
//var uid = data["documentKey"].(bson.M)["_id"]
if data["operationType"] == "insert" {
counterColl := DB.Collection(COUNTER_COL)
counterColl.UpdateOne(context.TODO(), bson.M{"collection": coll.Name()}, bson.M{"$set": bson.M{
"current": uid,
}}, options.UpdateOne().SetUpsert(true))
}
}
}
func Connect(uri string, dbName string) {
cli, err := mongo.Connect(options.Client().ApplyURI(uri))
if err != nil {
log.Fatal("failed to open database")
}
panik(err)
DB = cli.Database(dbName)
colls, err := DB.ListCollectionNames(context.TODO(), bson.M{"name": bson.M{"$ne": COUNTER_COL}}, options.ListCollections().SetNameOnly(true))
for _, c := range colls {
if c == COUNTER_COL {
continue
}
go innerWatch(DB.Collection(c))
}
DBClient = cli
}

View File

@ -1,200 +0,0 @@
package orm
import (
"fmt"
sb "github.com/henvic/pgq"
"reflect"
"strings"
)
// RelationshipType - used to distinguish a Relationship between various
// common entity relationship types
type RelationshipType int
const (
HasOne RelationshipType = iota
HasMany
BelongsTo
ManyToOne // the other side of a HasMany relationship
ManyToMany
)
// Relationship - intermediate representation of how two types
// relate to each other. i.e., if struct A embeds struct B,
// a Relationship will be created for those two while parsing the Model for A.
type Relationship struct {
Type RelationshipType // the type of this relationship (see RelationshipType)
JoinTable string // the name of the join table, if specified explicitly via struct tag, otherwise blank
Model *Model // the primary Model which contains this relationship
FieldName string // the name of the struct field with this relationship
Idx int // the index of the struct field with this relationship
RelatedType reflect.Type // the reflect.Type for the struct field named by FieldName
RelatedModel *Model // the Model representing the type of the embedded slice/struct
Kind reflect.Kind // field kind (struct, slice, ...)
m2mInverse *Relationship // the "inverse" side of an explicit ManyToMany relationship
Nullable bool // whether the foreign key for this relationship can have a nullable column
OriginalField reflect.StructField // the original reflect.StructField object associated with this relationship
}
// ComputeJoinTable - computes the name of the join table for ManyToMany relationships.
// will return a snake_cased autogenerated name for unidirectional ManyToMany .
// the "implicit" behavior is invoked upon one of the following conditions being met:
// - the primary Relationship.Model
func (r *Relationship) ComputeJoinTable() string {
if r.JoinTable != "" {
return r.JoinTable
}
otherSide := r.RelatedModel.TableName
if r.Model.embeddedIsh {
otherSide = pascalToSnakeCase(r.FieldName)
}
return r.Model.TableName + "_" + otherSide
}
func (r *Relationship) relatedID() *Field {
return r.RelatedModel.Fields[r.RelatedModel.IDField]
}
func (r *Relationship) primaryID() *Field {
return r.Model.Fields[r.Model.IDField]
}
func (r *Relationship) joinField() string {
if r.Type == ManyToOne {
return r.RelatedModel.Name + "ID"
}
if r.Type == ManyToMany && !r.Model.embeddedIsh {
return r.RelatedModel.Name + "ID"
}
return r.FieldName + "ID"
}
func (r *Relationship) m2mIsh() bool {
needsMany := false
if !r.Model.embeddedIsh && r.RelatedModel.embeddedIsh {
rr, ok := r.RelatedModel.Relationships[r.Model.Name]
if ok && rr.Type != ManyToOne {
needsMany = true
}
}
return ((r.Model.embeddedIsh && !r.RelatedModel.embeddedIsh) || needsMany) &&
r.Type == HasMany
}
func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error {
if r.Type != ManyToMany &&
!r.m2mIsh() {
return nil
}
ichild := v
for ichild.Kind() == reflect.Ptr {
ichild = ichild.Elem()
}
if ichild.Kind() == reflect.Struct {
jtable := r.ComputeJoinTable()
jargs := make([]any, 0)
jcols := make([]string, 0)
jcols = append(jcols, fmt.Sprintf("%s_id",
r.Model.TableName,
))
jargs = append(jargs, pfk)
jcols = append(jcols, r.RelatedModel.TableName+"_id")
jargs = append(jargs, ichild.FieldByName(r.RelatedModel.IDField).Interface())
var ecnt int
e.tx.QueryRow(e.ctx,
fmt.Sprintf("SELECT count(*) from %s where %s = $1 and %s = $2", r.ComputeJoinTable(), jcols[0], jcols[1]), jargs...).Scan(&ecnt)
if ecnt > 0 {
return nil
}
jsql := fmt.Sprintf("INSERT INTO %s (%s) VALUES ($1, $2)", jtable, strings.Join(jcols, ", "))
e.engine.logQuery("insert/join", jsql, jargs)
if !e.engine.dryRun {
_ = e.tx.QueryRow(e.ctx, jsql, jargs...).Scan()
}
}
return nil
}
func (r *Relationship) joinDelete(pk, fk any, q *Query) error {
dq := sb.Delete(r.ComputeJoinTable()).Where(fmt.Sprintf("%s_id = ?", r.Model.TableName), pk)
if fk != nil {
dq = dq.Where(fmt.Sprintf("%s_id = ?", r.RelatedModel.TableName), fk)
}
ds, aa := dq.MustSQL()
q.engine.logQuery("delete/join", ds, aa)
if !q.engine.dryRun {
_, err := q.tx.Exec(q.ctx, ds, aa...)
return err
}
return nil
}
func parseRelationship(field reflect.StructField, modelMap map[string]*Model, outerType reflect.Type, idx int, settings map[string]string) *Relationship {
rel := &Relationship{
Model: modelMap[outerType.Name()],
RelatedModel: modelMap[field.Type.Name()],
RelatedType: field.Type,
Idx: idx,
Kind: field.Type.Kind(),
FieldName: field.Name,
OriginalField: field,
}
if rel.RelatedType.Kind() == reflect.Slice || rel.RelatedType.Kind() == reflect.Array {
rel.RelatedType = rel.RelatedType.Elem()
}
if rel.RelatedModel == nil {
if rel.RelatedType.Name() == "" {
rt := rel.RelatedType
for rt.Kind() == reflect.Ptr || rt.Kind() == reflect.Slice || rt.Kind() == reflect.Array {
rel.Nullable = true
rel.RelatedType = rel.RelatedType.Elem()
rt = rel.RelatedType
}
}
rel.RelatedModel = modelMap[rel.RelatedType.Name()]
if _, ok := modelMap[rel.RelatedType.Name()]; !ok {
rel.RelatedModel = parseModel(reflect.New(rel.RelatedType).Interface())
modelMap[rel.RelatedType.Name()] = rel.RelatedModel
parseModelFields(rel.RelatedModel, modelMap)
rel.RelatedModel.embeddedIsh = true
}
}
switch field.Type.Kind() {
case reflect.Struct:
rel.Type = HasOne
case reflect.Slice, reflect.Array:
rel.Type = HasMany
}
maybeM2m := settings["m2m"]
if maybeM2m == "" {
maybeM2m = settings["manytomany"]
}
if rel.Type == HasMany && maybeM2m != "" {
rel.JoinTable = maybeM2m
}
return rel
}
func addForeignKeyFields(ref *Relationship) {
if !ref.RelatedModel.embeddedIsh && !ref.Model.embeddedIsh {
ref.Type = BelongsTo
} else if !ref.Model.embeddedIsh && ref.RelatedModel.embeddedIsh {
if ref.Type == HasMany {
nr := &Relationship{
RelatedModel: ref.Model,
Model: ref.RelatedModel,
Kind: ref.RelatedModel.Type.Kind(),
Idx: -1,
RelatedType: ref.Model.Type,
}
nr.Type = ManyToOne
nr.FieldName = nr.RelatedModel.Name
ref.RelatedModel.Relationships[nr.FieldName] = nr
} else if ref.Type == HasOne {
ref.Type = BelongsTo
}
} else if ref.Model.embeddedIsh && !ref.RelatedModel.embeddedIsh {
}
}

View File

@ -1,114 +0,0 @@
package orm
import (
"github.com/stretchr/testify/assert"
"testing"
)
func save1(t assert.TestingT, e *Engine) {
u := author(t)
err := e.Model(&user{}).Create(&u)
assert.Nil(t, err)
f := friend(t)
err = e.Model(&user{}).Create(&f)
assert.Nil(t, err)
assert.NotZero(t, f.ID)
oldFavid := u.Favs.ID
u.Favs.Authors = append(u.Favs.Authors, f)
err = e.Model(&user{}).Save(&u)
assert.Nil(t, err)
assert.NotZero(t, u.Favs.ID)
assert.Equal(t, oldFavid, u.Favs.ID)
}
func save2(t assert.TestingT, e *Engine) {
insertBands(t, e)
u := author(t)
err := e.Model(&user{}).Create(&u)
assert.Nil(t, err)
assert.NotZero(t, u.Favs.ID)
s := iti_multi(u)
err = e.Model(&story{}).Save(s)
assert.Nil(t, err)
assert.NotZero(t, s.ID)
checkChapters(t, s)
s.Downloads = s.Downloads + 1
err = e.Model(&story{}).Save(s)
assert.Nil(t, err)
var ns story
err = e.Model(&story{}).Where("ID = ?", s.ID).Find(&ns)
assert.Nil(t, err)
assert.NotZero(t, ns.ID)
assert.NotZero(t, ns.Title)
assert.Equal(t, ns.Downloads, s.Downloads)
}
func updateRaw1(t assert.TestingT, e *Engine) {
insertBands(t, e)
u := author(t)
err := e.Model(&user{}).Create(&u)
assert.Nil(t, err)
s1 := iti_multi(u)
err = e.Model(&story{}).Save(s1)
assert.Nil(t, err)
assert.NotZero(t, s1.ID)
checkChapters(t, s1)
s2 := iti_single(u)
err = e.Model(&story{}).Save(s2)
assert.Nil(t, err)
assert.NotZero(t, s2.ID)
checkChapters(t, s2)
umap := make(map[string]any)
umap["Characters"] = `array_remove(characters, 'Brian Tatler')`
ra, err := e.Model(&chapter{}).WhereRaw("1 = ?", 1).UpdateRaw(umap)
assert.NotZero(t, ra)
assert.Nil(t, err)
stories := make([]story, 0)
err = e.Model(&story{}).Populate(PopulateAll, "Chapters.Bands").Find(&stories)
assert.Nil(t, err)
assert.NotEmpty(t, stories)
for _, ss := range stories {
checkChapters(t, &ss)
for _, c := range ss.Chapters {
steppedInShit := false
for _, b := range c.Characters {
if b == "Brian Tatler" {
steppedInShit = true
break
}
}
assert.False(t, steppedInShit)
}
}
}
func TestSave1(t *testing.T) {
e := initTest(t)
save1(t, e)
e.Disconnect()
}
func TestSave2(t *testing.T) {
e := initTest(t)
save2(t, e)
e.Disconnect()
}
func TestUpdateRaw1(t *testing.T) {
e := initTest(t)
updateRaw1(t, e)
e.Disconnect()
}
func BenchmarkSave(b *testing.B) {
b.Run("Save-1", bench(save1))
b.Run("Save-2", bench(save2))
}
func BenchmarkUpdateRaw(b *testing.B) {
bench(updateRaw1)(b)
}

53
scan.go
View File

@ -1,53 +0,0 @@
package orm
import (
"github.com/jackc/pgx/v5"
"reflect"
"strings"
)
func buildScanDest(val reflect.Value, model *Model, fk *Relationship, cols []string, anonymousCols map[string]map[string]*Field, fkDest any) ([]any, error) {
var dest []any
for _, col := range cols {
bcol := col
if strings.Contains(bcol, ".") {
_, bcol, _ = strings.Cut(bcol, ".")
}
field := model.FieldsByColumnName[bcol]
if field != nil && !field.isAnonymous() {
dest = append(dest, val.FieldByName(field.Name).Addr().Interface())
}
}
for fn, a := range anonymousCols {
iv := val.FieldByName(fn)
for _, field := range a {
dest = append(dest, iv.FieldByName(field.Name).Addr().Interface())
}
}
if fk.Type != BelongsTo {
dest = append(dest, fkDest)
}
return dest, nil
}
func scanRow(row pgx.Row, cols []string, anonymousCols map[string][]string, destVal reflect.Value, m *Model) error {
var scanDest []any
for _, col := range cols {
f := m.FieldsByColumnName[col]
if f != nil && f.ColumnType != "" && !f.isAnonymous() {
scanDest = append(scanDest, destVal.FieldByIndex(f.Original.Index).Addr().Interface())
}
}
for kcol := range anonymousCols {
f := m.FieldsByColumnName[kcol]
if f != nil {
for _, ef := range f.embeddedFields {
scanDest = append(scanDest, destVal.FieldByIndex(f.Original.Index).FieldByName(ef.Name).Addr().Interface())
}
}
}
return row.Scan(scanDest...)
}

View File

@ -1,196 +1,152 @@
package orm
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"math/rand/v2"
"os"
"slices"
"strings"
"testing"
"time"
"github.com/go-loremipsum/loremipsum"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
type chapter struct {
ChapterID int64 `json:"chapterID" d:"pk:t;"`
Title string `json:"chapterTitle" form:"chapterTitle"`
Index int `json:"index" form:"index"`
Words int `json:"words"`
Notes string `json:"notes" form:"notes"`
Genre []string `json:"genre" form:"genre" d:"type:text[]"`
Bands []band `json:"bands" ref:"band,bands"`
Characters []string `json:"characters" form:"characters" d:"type:text[]"`
Relationships [][]string `json:"relationships" form:"relationships" d:"type:jsonb"`
Adult bool `json:"adult" form:"adult"`
Summary string `json:"summary" form:"summary"`
Hidden bool `json:"hidden" form:"hidden"`
LoggedInOnly bool `json:"loggedInOnly" form:"loggedInOnly"`
Posted time.Time `json:"datePosted"`
FileName string `json:"fileName" d:"-"`
Text string `json:"text" d:"column:content" gridfs:"story_text,/stories/{{.ChapterID}}.txt"`
ID bson.ObjectID `bson:"_id" json:"_id"`
Title string `bson:"chapterTitle" json:"chapterTitle" form:"chapterTitle"`
ChapterID int `bson:"id" json:"chapterID" autoinc:"chapters"`
Index int `bson:"index" json:"index" form:"index"`
Words int `bson:"words" json:"words"`
Notes string `bson:"notes" json:"notes" form:"notes"`
Genre []string `bson:"genre" json:"genre" form:"genre"`
Bands []band `json:"bands" bson:"bands" ref:"band,bands"`
Characters []string `bson:"characters" json:"characters" form:"characters"`
Relationships [][]string `bson:"relationships" json:"relationships" form:"relationships"`
Adult bool `bson:"adult" json:"adult" form:"adult"`
Summary string `bson:"summary" json:"summary" form:"summary"`
Hidden bool `bson:"hidden" json:"hidden" form:"hidden"`
LoggedInOnly bool `bson:"loggedInOnly" json:"loggedInOnly" form:"loggedInOnly"`
Posted time.Time `bson:"datePosted,omitempty" json:"datePosted"`
FileName string `json:"fileName" bson:"-"`
Text string `json:"text" bson:"-" gridfs:"story_text,/stories/{{.ChapterID}}.txt"`
}
type band struct {
Document `json:",inline" d:"table:bands"`
ID int64 `json:"_id" d:"pk;"`
Name string `json:"name" form:"name"`
Locked bool `json:"locked" form:"locked"`
Characters []string `json:"characters" form:"characters" d:"type:text[]"`
ID int64 `bson:"_id" json:"_id"`
Document `bson:",inline" json:",inline" coll:"bands"`
Name string `bson:"name" json:"name" form:"name"`
Locked bool `bson:"locked" json:"locked" form:"locked"`
Characters []string `bson:"characters" json:"characters" form:"characters"`
}
type user struct {
Document `json:",inline" d:"table:users"`
ID int64 `json:"_id" d:"pk;"`
Username string `json:"username"`
Favs favs `json:"favs" ref:"user"`
Roles []role `d:"m2m:user_roles"`
}
type role struct {
ID int64 `d:"pk"`
Name string
Users []user `d:"m2m:user_roles"`
}
type favs struct {
ID int64 `d:"pk"`
Stories []story
Authors []user
ID int64 `bson:"_id" json:"_id"`
Document `bson:",inline" json:",inline" coll:"users"`
Username string `bson:"username" json:"username"`
Favs []user `bson:"favs" json:"favs" ref:"user"`
}
type story struct {
Document `json:",inline" d:"table:stories"`
ID int64 `json:"_id" d:"pk;"`
Title string `json:"title" form:"title"`
Author user `json:"author" ref:"user"`
CoAuthor *user `json:"coAuthor" ref:"user"`
Chapters []chapter `json:"chapters"`
Recs int `json:"recs"`
Favs int `json:"favs"`
Views int `json:"views"`
Completed bool `json:"completed" form:"completed"`
Downloads int `json:"downloads"`
ID int64 `bson:"_id" json:"_id"`
Document `bson:",inline" json:",inline" coll:"stories"`
Title string `bson:"title" json:"title" form:"title"`
Author *user `bson:"author" json:"author" ref:"user"`
CoAuthor *user `bson:"coAuthor" json:"coAuthor" ref:"user"`
Chapters []chapter `bson:"chapters" json:"chapters"`
Recs int `bson:"recs" json:"recs"`
Favs int `bson:"favs" json:"favs"`
Views int `bson:"views" json:"views"`
Completed bool `bson:"completed" json:"completed" form:"completed"`
Downloads int `bson:"downloads" json:"downloads"`
}
type somethingWithNestedChapters struct {
ID int64 `json:"_id"`
Document `json:",inline" coll:"nested_stuff"`
Chapters []chapter `json:"chapters"`
NestedText string `json:"text" gridfs:"nested_text,/nested/{{.ID}}.txt"`
ID int64 `bson:"_id" json:"_id"`
Document `bson:",inline" json:",inline" coll:"nested_stuff"`
Chapters []chapter `bson:"chapters" json:"chapters"`
NestedText string `json:"text" bson:"-" gridfs:"nested_text,/nested/{{.ID}}.txt"`
}
func isTestBench(t assert.TestingT) bool {
_, ok := t.(*testing.B)
return ok
func (s *somethingWithNestedChapters) Id() any {
return s.ID
}
func friend(t assert.TestingT) user {
ID := int64(83378)
if isTestBench(t) {
//ID = 0
//ID = rand.Int64N(100000) + 1
}
return user{
Username: "DarQuiel7",
ID: ID,
}
}
func author(t assert.TestingT) user {
ID := int64(85783)
if isTestBench(t) {
//ID = 0
}
return user{
Username: "tablet.exe",
ID: ID,
}
func (s *somethingWithNestedChapters) SetId(id any) {
s.ID = id.(int64)
}
func genChaps(single bool, aceil int) []chapter {
func (s *story) Id() any {
return s.ID
}
func (s *band) Id() any {
return s.ID
}
func (s *user) Id() any {
return s.ID
}
func (s *story) SetId(id any) {
s.ID = id.(int64)
//var t IDocument =s
}
func (s *band) SetId(id any) {
s.ID = id.(int64)
}
func (s *user) SetId(id any) {
s.ID = id.(int64)
}
var author = user{
Username: "tablet.exe",
Favs: []user{
{
Username: "DarQuiel7",
},
},
}
func genChaps(single bool) []chapter {
var ret []chapter
var ceil int
if single {
ceil = 1
} else {
ceil = aceil
ceil = 5
}
relMap := make([][][]string, 0)
bands := make([][]band, 0)
charMap := make([][]string, 0)
for i := range ceil {
curChars := make([]string, 0)
curBands := make([]band, 0)
curBands = append(curBands, diamondHead)
curChars = append(curChars, diamondHead.Characters...)
emptyRel := make([][]string, 0)
emptyRel = append(emptyRel, make([]string, 0))
relMap := [][][]string{
{
randMin := max(i+1, 1)
randMax := min(i+1, randMin+1)
mod1 := max(rand.IntN(randMin), 1)
mod2 := max(rand.IntN(randMax+1), 1)
if (mod1%mod2 == 0 || (mod1%mod2) == 2) && i > 0 {
curBands = append(curBands, bodom)
curChars = append(curChars, bodom.Characters...)
}
}
crel := make([][]string, 0)
numRels := rand.IntN(3)
seenRels := make(map[string]bool)
for len(crel) <= numRels {
arel := make([]string, 0)
randRelChars := rand.IntN(3)
numRelChars := 0
if randRelChars == 1 {
numRelChars = 3
} else if randRelChars == 2 {
numRelChars = 2
}
if numRelChars == 0 {
continue
}
seen := make(map[string]bool)
for len(arel) < numRelChars {
char := diamondHead.Characters[rand.IntN(len(diamondHead.Characters))]
if !seen[char] {
arel = append(arel, char)
seen[char] = true
}
}
slices.Sort(arel)
maybeSeen := strings.Join(arel, "/")
if maybeSeen != "" && !seenRels[maybeSeen] {
seenRels[maybeSeen] = true
crel = append(crel, arel)
}
}
{"Sean Harris", "Brian Tatler"},
},
{
numChars := rand.IntN(len(curChars)-1) + 1
seen := make(map[string]bool)
cchars := make([]string, 0)
for len(cchars) <= numChars {
char := curChars[rand.IntN(len(curChars))]
if !seen[char] {
cchars = append(cchars, char)
seen[char] = true
}
}
charMap = append(charMap, cchars)
}
relMap = append(relMap, crel)
bands = append(bands, curBands)
{"Sean Harris", "Brian Tatler"},
{"Duncan Scott", "Colin Kimberley"},
},
{
{"Duncan Scott", "Colin Kimberley"},
},
emptyRel,
{
{"Sean Harris", "Colin Kimberley", "Brian Tatler"},
},
}
l := loremipsum.New()
for i := range ceil {
for i := 0; i < ceil; i++ {
spf := fmt.Sprintf("%d.md", i+1)
c := chapter{
ret = append(ret, chapter{
ID: bson.NewObjectID(),
Title: fmt.Sprintf("-%d-", i+1),
Index: i + 1,
Words: 50,
Notes: "notenotenote !!!",
Genre: []string{"Slash"},
Bands: bands[i],
Characters: charMap[i],
Bands: []band{diamondHead},
Characters: []string{"Sean Harris", "Brian Tatler", "Duncan Scott", "Colin Kimberley"},
Relationships: relMap[i],
Adult: true,
Summary: l.Paragraph(),
@ -198,10 +154,7 @@ func genChaps(single bool, aceil int) []chapter {
LoggedInOnly: true,
FileName: spf,
Text: strings.Join(l.ParagraphList(10), "\n\n"),
Posted: time.Now().Add(time.Hour * time.Duration(int64(24*7*i))),
}
ret = append(ret, c)
})
}
return ret
@ -210,35 +163,51 @@ func genChaps(single bool, aceil int) []chapter {
func doSomethingWithNested() somethingWithNestedChapters {
l := loremipsum.New()
swnc := somethingWithNestedChapters{
Chapters: genChaps(false, 7),
Chapters: genChaps(false),
NestedText: strings.Join(l.ParagraphList(15), "\n\n"),
}
return swnc
}
func iti_single(a user) *story {
return &story{
func iti_single() story {
return story{
Title: "title",
Completed: true,
Author: a,
Chapters: genChaps(true, 1),
Chapters: genChaps(true),
}
}
func iti_multi(a user) *story {
return &story{
func iti_multi() story {
return story{
Title: "Brian Tatler Fucked and Abused Sean Harris",
Completed: false,
Author: a,
Chapters: genChaps(false, 5),
Chapters: genChaps(false),
}
}
func iti_blank(a user) *story {
t := iti_single(a)
func iti_blank() story {
t := iti_single()
t.Chapters = make([]chapter, 0)
return t
}
func initTest() {
uri := "mongodb://127.0.0.1:27017"
db := "rockfic_ormTest"
ic, _ := mongo.Connect(options.Client().ApplyURI(uri))
ic.Database(db).Drop(context.TODO())
colls, _ := ic.Database(db).ListCollectionNames(context.TODO(), bson.M{})
if len(colls) < 1 {
mdb := ic.Database(db)
mdb.CreateCollection(context.TODO(), "bands")
mdb.CreateCollection(context.TODO(), "stories")
mdb.CreateCollection(context.TODO(), "users")
}
defer ic.Disconnect(context.TODO())
Connect(uri, db)
author.ID = 696969
ModelRegistry.Model(band{}, user{}, story{})
}
var metallica = band{
ID: 1,
Name: "Metallica",
@ -277,66 +246,12 @@ var bodom = band{
},
}
type commonTestFunc func(t assert.TestingT)
var logTime = time.Now()
func initCommonTest(t assert.TestingT) *Engine {
f, err := os.OpenFile(
fmt.Sprintf("test-logs/test-%d.log", logTime.UnixMilli()),
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
func saveDoc(t *testing.T, doc IDocument) {
err := doc.Save()
assert.Nil(t, err)
e, err := Open("postgres://testbed_user:123@localhost/real_test_db", &Config{
LogLevel: LevelQuery,
LogTo: f,
})
assert.Nil(t, err)
e.Models(user{}, story{}, band{}, role{})
return e
}
func deleteAll(e *Engine) {
models := []any{&user{}, &story{}, &band{}, &role{}}
for _, model := range models {
e.Model(model).WhereRaw("true").Delete()
}
}
func initTest(t assert.TestingT) *Engine {
e := initCommonTest(t)
err := e.MigrateDropping()
assert.Nil(t, err)
return e
}
func insertBands(t assert.TestingT, e *Engine) {
toInsert := []*band{&bodom, &diamondHead}
/*if isTestBench(t) {
for i := range toInsert {
toInsert[i].ID = 0
}
}*/
for _, b := range toInsert {
err := e.Model(&band{}).Save(b)
assert.Nil(t, err)
}
}
func checkChapters(t assert.TestingT, s *story) {
for _, c := range s.Chapters {
assert.NotZero(t, c.ChapterID)
assert.NotZero(t, c.Text)
assert.NotZero(t, c.Posted)
}
}
func bench(fn func(assert.TestingT, *Engine)) func(b *testing.B) {
return func(b *testing.B) {
e := initCommonTest(b)
for b.Loop() {
deleteAll(e)
fn(b, e)
}
e.Disconnect()
}
func createAndSave(t *testing.T, doc IDocument) {
mdl := Create(doc).(IDocument)
saveDoc(t, mdl)
}

290
util.go Normal file
View File

@ -0,0 +1,290 @@
package orm
import (
"fmt"
"go.mongodb.org/mongo-driver/v2/bson"
"reflect"
"regexp"
"strconv"
"strings"
)
func panik(err error) {
if err != nil {
panic(err)
}
}
func nameOf(i interface{}) string {
v := reflect.ValueOf(i)
var n string
switch v.Kind() {
case reflect.Slice, reflect.Map:
if v.Type().Elem().Kind() == reflect.Pointer {
n = v.Type().Elem().Elem().Name()
}
case reflect.Pointer:
n = nameOf(reflect.Indirect(v).Interface())
default:
n = v.Type().Name()
}
return n
}
func valueOf(i interface{}) reflect.Value {
v := reflect.ValueOf(i)
if v.Type().Kind() == reflect.Pointer {
v = valueOf(reflect.Indirect(v).Interface())
}
return v
}
func asId(i interface{}) HasID {
v := reflect.ValueOf(i)
var asHasId HasID
var ok bool
switch v.Kind() {
case reflect.Struct:
asHasId, ok = v.Interface().(HasID)
if ok {
return asHasId
}
v = reflect.New(v.Type())
v.Elem().Set(reflect.ValueOf(i))
fallthrough
case reflect.Pointer:
asHasId, ok = v.Interface().(HasID)
if ok {
return asHasId
} else {
panic("value does not implemenet `HasId`!")
}
default:
break
}
return asHasId
}
func coerceInt(input reflect.Value, dst reflect.Value) interface{} {
if input.Type().Kind() == reflect.Pointer {
input = input.Elem()
}
if dst.Type().Kind() == reflect.Pointer {
dst = dst.Elem()
}
if input.Type().ConvertibleTo(dst.Type()) {
return input.Convert(dst.Type()).Interface()
}
return nil
}
var arrRegex, _ = regexp.Compile(`\[(?P<index>\d+)]$`)
func getNested(field string, aValue reflect.Value) (*reflect.Type, *reflect.Value, error) {
if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") {
return nil, nil, fmt.Errorf(errFmtMalformedField, field)
}
value := aValue
if value.Kind() == reflect.Pointer {
value = value.Elem()
}
aft := value.Type()
dots := strings.Split(field, ".")
if value.Kind() != reflect.Struct {
if value.Kind() == reflect.Slice {
st := reflect.MakeSlice(value.Type().Elem(), 0, 0)
for i := 0; i < value.Len(); i++ {
cur := value.Index(i)
if len(dots) > 1 {
_, cv, _ := getNested(strings.Join(dots[1:], "."), cur.FieldByName(dots[0]))
reflect.Append(st, *cv)
//return getNested(, "."), fv)
} else {
reflect.Append(st, cur)
}
}
typ := st.Type().Elem()
return &typ, &st, nil
}
if len(dots) > 1 {
return nil, nil, ErrNotSliceOrStruct
} else {
return &aft, &value, nil
}
}
ref := value
if ref.Kind() == reflect.Pointer {
ref = ref.Elem()
}
var fv = ref.FieldByName(arrRegex.ReplaceAllString(dots[0], ""))
if arrRegex.FindString(dots[0]) != "" && fv.Kind() == reflect.Slice {
matches := arrRegex.FindStringSubmatch(dots[0])
ridx, _ := strconv.Atoi(matches[0])
idx := ridx
fv = fv.Index(idx)
}
ft, _ := ref.Type().FieldByName(arrRegex.ReplaceAllString(dots[0], ""))
if len(dots) > 1 {
return getNested(strings.Join(dots[1:], "."), fv)
} else {
return &ft.Type, &fv, nil
}
}
func makeSettable(rval reflect.Value, value interface{}) reflect.Value {
if !rval.CanSet() {
nv := reflect.New(rval.Type())
nv.Elem().Set(reflect.ValueOf(value))
return nv
}
return rval
}
func incrementInterface(t interface{}) interface{} {
switch pt := t.(type) {
case uint:
t = pt + 1
case uint32:
t = pt + 1
case uint64:
t = pt + 1
case int:
t = pt + 1
case int32:
t = pt + 1
case int64:
t = pt + 1
case string:
t = NextStringID()
case bson.ObjectID:
t = bson.NewObjectID()
default:
panic(ErrUnsupportedID)
}
return t
}
func isValidId(t interface{}) bool {
switch t.(type) {
case uint, uint32, uint64, int, int32, int64, string, bson.ObjectID:
return true
default:
return false
}
}
func isObject(t interface{}) bool {
switch t.(type) {
case bson.M, bson.D:
return true
default:
return false
}
}
func traverseFields(f string, val interface{}) (ret interface{}, remaining string) {
split := strings.Split(f, ".")
rv := reflect.ValueOf(val)
for {
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
} else {
break
}
}
{
asAny := make([]any, 0)
for _, s := range split {
asAny = append(asAny, s)
}
fmt.Println(asAny...)
}
if rv.Kind() == reflect.Slice {
ret = rv.Interface()
remaining = strings.Join(split[1:], ".")
fmt.Println("returning?")
return
}
structField := rv.FieldByName(split[0])
if structField.IsValid() {
fmt.Println(structField.Interface())
if len(split) > 1 {
if structField.Kind() == reflect.Slice {
ret = structField
remaining = strings.Join(split[1:], ".")
return
}
ret, remaining = traverseFields(strings.Join(split[1:], "."), structField.Interface())
fmt.Printf("remaining = %s\n", remaining)
} else {
ret = structField.Interface()
remaining = ""
}
}
return
}
func pull(s reflect.Value, idx int, typ reflect.Type) reflect.Value {
retI := reflect.New(reflect.SliceOf(typ))
for i := 0; i < s.Len(); i++ {
if i == idx {
continue
}
retI.Elem().Set(reflect.Append(retI.Elem(), s.Index(i)))
}
return retI.Elem()
}
func checkStruct(ref reflect.Value) error {
if ref.Kind() == reflect.Slice {
return ErrAppendMultipleDocuments
}
if ref.Kind() != reflect.Struct {
return ErrNotAStruct
}
return nil
}
func checkSlice(ref reflect.Value) error {
if ref.Kind() != reflect.Slice {
return ErrNotASlice
}
return nil
}
func convertSlice[In, Out any](in []In) []Out {
out := make([]Out, 0)
for _, i := range in {
ii, ok := any(i).(Out)
if ok {
out = append(out, ii)
}
}
return out
}
func normalizeSliceToDocumentSlice(in any) *DocumentSlice {
ret := make(DocumentSlice, 0)
val := reflect.ValueOf(in)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() == reflect.Slice {
for i := 0; i < val.Len(); i++ {
if idoc, ok := val.Index(i).Interface().(IDocument); ok {
ret = append(ret, idoc)
}
}
}
return &ret
}
func filterMap[k comparable, v any](input map[k]v, pred func(key k, val v) bool) map[k]v {
ret := make(map[k]v)
for k1, v1 := range input {
if pred(k1, v1) {
ret[k1] = v1
}
}
return ret
}

133
utils.go
View File

@ -1,133 +0,0 @@
package orm
import (
"fmt"
sb "github.com/henvic/pgq"
"reflect"
"regexp"
"strings"
)
var pascalRegex = regexp.MustCompile(`(?P<lowercase>[a-z])(?P<uppercase>[A-Z])`)
var nonWordRegex = regexp.MustCompile(`[^a-zA-Z0-9_]`)
func pascalToSnakeCase(str string) string {
step1 := pascalRegex.ReplaceAllString(str, `${lowercase}_${uppercase}`)
step2 := nonWordRegex.ReplaceAllString(step1, "_")
return strings.ToLower(step2)
}
func canConvertTo[T any](thisType reflect.Type) bool {
return thisType.ConvertibleTo(reflect.TypeFor[T]()) ||
thisType.ConvertibleTo(reflect.TypeFor[*T]()) ||
strings.TrimPrefix(thisType.Name(), "*") == strings.TrimPrefix(reflect.TypeFor[T]().Name(), "*")
}
func parseTags(t string) map[string]string {
tags := strings.Split(t, ";")
m := make(map[string]string)
for _, tag := range tags {
field := strings.Split(tag, ":")
if len(field) < 2 {
m[strings.ToLower(field[0])] = "t"
} else {
m[strings.ToLower(field[0])] = field[1]
}
}
return m
}
func capitalizeFirst(str string) string {
firstChar := strings.ToUpper(string([]byte{str[0]}))
return firstChar + string(str[1:])
}
func serialToRegular(str string) string {
return strings.ReplaceAll(strings.ToLower(str), "serial", "int")
}
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.String:
return v.String() == ""
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Ptr, reflect.Interface:
return v.IsNil()
}
return v.IsZero()
}
func checkInsertable(v reflect.Value) {
}
func reflectSet(f reflect.Value, v any) {
if !f.CanSet() || v == nil {
return
}
switch f.Kind() {
case reflect.Int, reflect.Int64:
switch val := v.(type) {
case int64:
f.SetInt(val)
case int32:
f.SetInt(int64(val))
case int:
f.SetInt(int64(val))
case uint64:
f.SetInt(int64(val))
}
case reflect.String:
if s, ok := v.(string); ok {
f.SetString(s)
}
}
}
func logTrunc(length int, v []any) []any {
if length < 5 {
length = 5
}
trunced := make([]any, 0)
for _, it := range v {
if str, ok := it.(string); ok {
ntrunc := str[:min(length, len(str))]
if len(ntrunc) < len(str) {
ntrunc += "..."
}
trunced = append(trunced, ntrunc)
} else {
trunced = append(trunced, it)
}
}
return trunced
}
func isSliceOfStructs(rv reflect.Value) bool {
return rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() == reflect.Struct
}
// MakePlaceholders - generates a string with `count`
// occurences of a placeholder (`?`), delimited by a
// comma and a space
func MakePlaceholders(count int) string {
if count < 1 {
return ""
}
var ph []string
for range count {
ph = append(ph, "?")
}
return strings.Join(ph, ", ")
}
func wrapQueryIn(s sb.SelectBuilder, idName string) sb.SelectBuilder {
return s.Prefix(
fmt.Sprintf("%s in (",
idName)).Suffix(")")
}