diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml
new file mode 100644
index 0000000..5792f72
--- /dev/null
+++ b/.idea/dataSources.xml
@@ -0,0 +1,15 @@
+
+
+
+
+ postgresql
+ true
+ org.postgresql.Driver
+ jdbc:postgresql://localhost:5432/testbed_i_think
+
+
+
+ $ProjectFileDir$
+
+
+
\ No newline at end of file
diff --git a/diamond.go b/diamond.go
new file mode 100644
index 0000000..b3f9337
--- /dev/null
+++ b/diamond.go
@@ -0,0 +1,81 @@
+package orm
+
+import (
+ "context"
+ "github.com/jackc/pgx/v5/pgxpool"
+ "time"
+)
+
+type Engine struct {
+ modelMap *ModelMap
+ conn *pgxpool.Pool
+ m2mSeen map[string]bool
+ dryRun bool
+ cfg *pgxpool.Config
+ ctx context.Context
+}
+
+func (e *Engine) Models(v ...any) {
+ e.modelMap = makeModelMap(v...)
+}
+
+func (e *Engine) Model(val any) *Query {
+ qq := &Query{
+ engine: e,
+ ctx: context.Background(),
+ wheres: make(map[string][]any),
+ joins: make(map[*Relationship][3]string),
+ orders: make([]string, 0),
+ relatedModels: make(map[string]*Model),
+ }
+ return qq.Model(val)
+}
+
+func (e *Engine) Migrate() error {
+ failedMigrations := make(map[string]*Model)
+ var err error
+ for mk, m := range e.modelMap.Map {
+ err = m.migrate(e)
+ if err != nil {
+ failedMigrations[mk] = m
+ }
+ }
+
+ for len(failedMigrations) > 0 {
+ e.m2mSeen = make(map[string]bool)
+ for mk, m := range failedMigrations {
+ err = m.migrate(e)
+ if err == nil {
+ delete(failedMigrations, mk)
+ }
+ }
+ }
+ return err
+}
+
+func Open(connString string) (*Engine, error) {
+ e := &Engine{
+ modelMap: &ModelMap{
+ Map: make(map[string]*Model),
+ },
+ m2mSeen: make(map[string]bool),
+ dryRun: connString == "",
+ ctx: context.Background(),
+ }
+ if connString != "" {
+ var err error
+ e.cfg, err = pgxpool.ParseConfig(connString)
+ e.cfg.MinConns = 5
+ e.cfg.MaxConns = 10
+ e.cfg.MaxConnIdleTime = time.Minute * 2
+ if err != nil {
+ return nil, err
+ }
+ e.conn, err = pgxpool.NewWithConfig(e.ctx, e.cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ }
+ return e, nil
+}
diff --git a/field.go b/field.go
new file mode 100644
index 0000000..85f24e8
--- /dev/null
+++ b/field.go
@@ -0,0 +1,133 @@
+package orm
+
+import (
+ "fmt"
+ "net"
+ "reflect"
+ "time"
+)
+
+type Field struct {
+ Name string
+ ColumnName string
+ ColumnType string
+ Type reflect.Type
+ Original reflect.StructField
+ Model *Model
+ Index int
+ AutoIncrement bool
+ PrimaryKey bool
+ Nullable bool
+ isForeignKey bool
+ fk *Relationship
+}
+
+func (f *Field) alias() (string, string) {
+ columnName := f.Model.Fields[f.Model.IDField].ColumnName
+ if f.ColumnType != "" {
+ columnName = f.ColumnName
+ }
+ return fmt.Sprintf("%s.%s", f.Model.TableName, columnName), fmt.Sprintf("%s_%s", f.Model.TableName, f.ColumnName)
+}
+
+func (f *Field) aliasWith(a string) (string, string) {
+ first := fmt.Sprintf("%s.%s", a, f.ColumnName)
+ second := fmt.Sprintf("%s_%s", a, f.ColumnName)
+ return first, second
+}
+
+func (f *Field) key() string {
+ return fmt.Sprintf("%s.%s", f.Model.Name, f.Name)
+}
+
+func columnType(ty reflect.Type, isPk, isAutoInc bool) string {
+ it := ty
+ switch it.Kind() {
+ case reflect.Ptr:
+ for it.Kind() == reflect.Ptr {
+ it = it.Elem()
+ }
+ case reflect.Int32, reflect.Uint32:
+ if isPk || isAutoInc {
+ return "serial"
+ } else {
+ return "int"
+ }
+ case reflect.Int64, reflect.Uint64, reflect.Int, reflect.Uint:
+ if isPk || isAutoInc {
+ return "bigserial"
+ } else {
+ return "bigint"
+ }
+ case reflect.String:
+ return "text"
+ case reflect.Float32:
+ return "float4"
+ case reflect.Float64:
+ return "double precision"
+ case reflect.Bool:
+ return "boolean"
+ case reflect.Struct:
+ if canConvertTo[time.Time](ty) {
+ return "timestamptz"
+ }
+ if canConvertTo[net.IP](ty) {
+ return "inet"
+ }
+ if canConvertTo[net.IPNet](ty) {
+ return "cidr"
+ }
+
+ default:
+ return ""
+ }
+ return ""
+}
+func parseField(f reflect.StructField, minfo *Model, modelMap map[string]*Model, i int) *Field {
+ field := &Field{
+ Name: f.Name,
+ Original: f,
+ Index: i,
+ }
+ tags := parseTags(f.Tag.Get("d"))
+ if tags["-"] != "" {
+ return nil
+ }
+ field.PrimaryKey = tags["pk"] != "" || tags["primarykey"] != "" || field.Name == "ID"
+ field.AutoIncrement = tags["autoinc"] != ""
+ field.Nullable = tags["nullable"] != ""
+ field.ColumnType = tags["type"]
+ if field.ColumnType == "" {
+ field.ColumnType = columnType(f.Type, field.PrimaryKey, field.AutoIncrement)
+ }
+ field.ColumnName = tags["column"]
+ if field.ColumnName == "" {
+ field.ColumnName = pascalToSnakeCase(field.Name)
+ }
+ if field.PrimaryKey {
+ minfo.IDField = field.Name
+ }
+ elem := f.Type
+ for elem.Kind() == reflect.Ptr {
+ if !field.Nullable {
+ field.Nullable = true
+ }
+ elem = elem.Elem()
+ }
+ field.Type = elem
+
+ switch elem.Kind() {
+ case reflect.Array, reflect.Slice:
+ elem = elem.Elem()
+ fallthrough
+ case reflect.Struct:
+ if canConvertTo[Document](elem) && f.Anonymous {
+ minfo.TableName = tags["table"]
+ return nil
+ } else if field.ColumnType == "" {
+ minfo.Relationships[field.Name] = parseRelationship(f, modelMap, minfo.Type, i)
+ }
+ }
+
+ return field
+}
diff --git a/go.mod b/go.mod
index 25a0f46..0fbb4f4 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/henvic/pgq v0.0.4 // indirect
github.com/huandu/xstrings v1.4.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
diff --git a/go.sum b/go.sum
index ca7aa10..1f0a329 100644
--- a/go.sum
+++ b/go.sum
@@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-loremipsum/loremipsum v1.1.4 h1:RJaJlJwX4y9A2+CMgKIyPcjuFHFKTmaNMhxbL+sI6Vg=
github.com/go-loremipsum/loremipsum v1.1.4/go.mod h1:whNWskGoefTakPnCu2CO23v5Y7RwiG4LMOEtTDaBeOY=
+github.com/henvic/pgq v0.0.4 h1:BgLnxofZJSWWs+9VOf19Gr9uBkSVbHWGiu8wix1nsIY=
+github.com/henvic/pgq v0.0.4/go.mod h1:k0FMvOgmQ45MQ3TgCLe8I3+sDKy9lPAiC2m9gg37pVA=
github.com/huandu/go-assert v1.1.6 h1:oaAfYxq9KNDi9qswn/6aE0EydfxSa+tWZC1KabNitYs=
github.com/huandu/go-assert v1.1.6/go.mod h1:JuIfbmYG9ykwvuxoJ3V8TB5QP+3+ajIA54Y44TmkMxs=
github.com/huandu/go-sqlbuilder v1.35.1 h1:znTuAksxq3T1rYfr3nsD4P0brWDY8qNzdZnI6+vtia4=
diff --git a/model.go b/model.go
new file mode 100644
index 0000000..4500b83
--- /dev/null
+++ b/model.go
@@ -0,0 +1,297 @@
+package orm
+
+import (
+ "fmt"
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgxpool"
+ "reflect"
+ "strings"
+)
+
+type Model struct {
+ Name string
+ Type reflect.Type
+ Relationships map[string]*Relationship
+ IDField string
+ Fields map[string]*Field
+ FieldsByColumnName map[string]*Field
+ TableName string
+ embeddedIsh bool
+}
+
+func (m *Model) addField(field *Field) {
+ field.Model = m
+ m.Fields[field.Name] = field
+ m.FieldsByColumnName[field.ColumnName] = field
+}
+
+func (m *Model) getAliasFields() map[string]string {
+ fields := make(map[string]string)
+ for _, f := range m.Fields {
+ if f.fk != nil {
+ continue
+ }
+ fields[f.Name] = fmt.Sprintf("%s.%s as %s_%s", m.TableName, f.ColumnName, strings.ToLower(m.Name), f.ColumnName)
+ }
+ return fields
+}
+
+func (m *Model) getPrimaryKey(val reflect.Value) (string, any) {
+ colField := m.Fields[m.IDField]
+ if colField == nil {
+ return "", nil
+ }
+ colName := colField.ColumnName
+ idField := val.FieldByName(m.IDField)
+ if idField.IsValid() {
+ return colName, idField.Interface()
+ }
+ return "", nil
+}
+
+func (m *Model) insert(v reflect.Value, e *Query, parentFks map[string]any) (any, error) {
+ var isTopLevel bool
+ var err error
+ var cn *pgxpool.Conn
+ if e.tx == nil && !e.engine.dryRun {
+ isTopLevel = true
+ var tx pgx.Tx
+ cn, err = e.engine.conn.Acquire(e.ctx)
+ if err != nil {
+ return nil, err
+ }
+ tx, err = cn.Begin(e.ctx)
+ if err != nil {
+ return nil, err
+ }
+ e.tx = tx
+ defer func() {
+ if err != nil {
+ e.tx.Rollback(e.ctx)
+ }
+ fmt.Printf("[DBG] discarding tx @ %p\n", e.tx)
+ fmt.Printf("[DBG] discarding conn @ %p\n", e.tx.Conn())
+ fmt.Printf("[DBG] discarding outer conn @ %p\n", cn)
+ e.tx = nil
+ cn.Release()
+ }()
+ }
+ for v.Kind() == reflect.Pointer {
+ v = v.Elem()
+ }
+ t := v.Type()
+ var cols []string
+ var placeholders []string
+ var args []any
+ var returningField *Field
+ for k, vv := range parentFks {
+ cols = append(cols, k)
+ placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
+ args = append(args, vv)
+ }
+ for _, ff := range m.Fields {
+ //ft := t.Field(ff.Index)
+ var fv reflect.Value
+ if ff.Index > -1 {
+ fv = v.Field(ff.Index)
+ }
+
+ mfk := ff.fk
+ if mfk == nil {
+ mfk = m.Relationships[ff.Name]
+ }
+ if mfk != nil {
+ af := v.FieldByName(mfk.FieldName)
+ if af.Kind() == reflect.Struct {
+ idField := af.FieldByName(ff.fk.RelatedModel.IDField)
+ cols = append(cols, ff.ColumnName)
+ placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
+ if !idField.IsValid() {
+ var nid any
+ nid, err = ff.fk.RelatedModel.insert(af, e, make(map[string]any))
+ if err != nil {
+ return 0, err
+ }
+ args = append(args, nid)
+ } else {
+ args = append(args, idField.Interface())
+ }
+ }
+ continue
+ }
+ col := ff.ColumnName
+ if ff.PrimaryKey {
+ returningField = ff
+ if fv.IsValid() && !fv.IsZero() {
+ cols = append(cols, col)
+ placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
+ args = append(args, fv.Interface())
+ }
+ continue
+ }
+ if fv.IsValid() {
+ cols = append(cols, col)
+ placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
+ args = append(args, fv.Interface())
+ }
+ }
+ var rfc = "id"
+ if returningField != nil {
+ rfc = returningField.ColumnName
+ }
+ scols := fmt.Sprintf("(%s)", strings.Join(cols, ", "))
+ svals := fmt.Sprintf("VALUES (%s) ", strings.Join(placeholders, ", "))
+ sql := fmt.Sprintf("INSERT INTO %s ",
+ m.TableName,
+ )
+ if len(cols) > 0 {
+ sql += scols
+ sql += " "
+ sql += svals
+ } else {
+ sql += "DEFAULT VALUES "
+ }
+ sql += fmt.Sprintf("RETURNING %s", rfc)
+ fmt.Printf("[INSERT] %s { %s }\n", sql, logTrunc(args, 200))
+ var id any
+ {
+ if v.FieldByName(m.IDField).IsValid() {
+ id = v.FieldByName(m.IDField).Interface()
+ }
+ }
+ if !e.engine.dryRun {
+ qr := e.engine.conn.QueryRow(e.ctx, sql, args...)
+ err = qr.Scan(&id)
+ if err != nil {
+ return 0, err
+ }
+ }
+ {
+ retField := v.FieldByName(returningField.Name)
+ if retField.IsValid() {
+ retField.Set(reflect.ValueOf(id))
+ }
+ }
+ for i := range t.NumField() {
+ ft := t.Field(i)
+ fv := v.Field(i)
+ if ft.Type.Kind() == reflect.Slice {
+ for j := range fv.Len() {
+ child := fv.Index(j).Addr()
+ cm := e.engine.modelMap.Map[child.Type().Elem().Name()]
+ if cm != nil && cm.embeddedIsh {
+ cfk := map[string]any{
+ m.TableName + "_id": id,
+ }
+ if m.Relationships[ft.Name] != nil {
+ cfk = map[string]any{
+ pascalToSnakeCase(m.Relationships[ft.Name].JoinField()): id,
+ }
+ }
+ _, err = cm.insert(child, e, cfk)
+ if err != nil {
+ return 0, err
+ }
+ } else if cm != nil {
+ rel := m.Relationships[ft.Name]
+ if rel != nil {
+ err = rel.joinInsert(child, e, id)
+ if err != nil {
+ return 0, err
+ }
+ }
+ }
+ }
+ }
+ }
+ if isTopLevel && !e.engine.dryRun {
+ err = e.tx.Commit(e.ctx)
+ }
+ return id, err
+}
+
+func (m *Model) update(val reflect.Value, q *Query, parentFks map[string]any) error {
+ var isTopLevel bool
+ var err error
+ if q.tx == nil && !q.engine.dryRun {
+ isTopLevel = true
+ var tx pgx.Tx
+ tx, err = q.engine.conn.Begin(q.ctx)
+ if err != nil {
+ return err
+ }
+ q.tx = tx
+ defer func() {
+ if err != nil {
+ q.tx.Rollback(q.ctx)
+ q.tx = nil
+ }
+ }()
+ }
+ cnt := 1
+ sets := make([]string, 0)
+ vals := make([]any, 0)
+ for val.Kind() == reflect.Pointer {
+ val = val.Elem()
+ }
+ for _, field := range m.Fields {
+ if field.fk != nil || field.ColumnType == "" || field.Name == m.IDField {
+ continue
+ }
+ if field.Index > -1 && field.Index < val.NumField() && field.fk == nil {
+ f := val.Field(field.Index)
+ sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt))
+ vals = append(vals, f.Interface())
+ cnt++
+ } else if _, ok := parentFks[field.ColumnName]; ok {
+ sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt))
+ vals = append(vals, parentFks[field.ColumnName])
+ cnt++
+ }
+ }
+ mcol, mpk := m.getPrimaryKey(val)
+ sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %v", m.TableName, strings.Join(sets, ", "), mcol, mpk)
+ fmt.Printf("[UPDATE] %s { %s }\n", sql, logTrunc(vals, 200))
+ if !q.engine.dryRun {
+ if _, err = q.tx.Exec(q.ctx, sql, vals...); err != nil {
+ return err
+ }
+ }
+ for _, rel := range m.Relationships {
+ if rel.Idx > -1 && rel.Idx < val.NumField() {
+ f := val.Field(rel.Idx)
+ if f.Kind() == reflect.Slice ||
+ f.Kind() == reflect.Array {
+ err = diffSlices(q.engine, q, val, rel)
+ if err != nil {
+ return err
+ }
+ } else if rel.RelatedModel.embeddedIsh && !m.embeddedIsh {
+ elemish := f.Type()
+ for elemish.Kind() == reflect.Ptr {
+ elemish = elemish.Elem()
+ }
+ if elemish.Kind() == reflect.Struct {
+ err = rel.RelatedModel.update(f, q, map[string]any{
+ pascalToSnakeCase(rel.JoinField()): mpk,
+ })
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ }
+ var finerr error
+ if isTopLevel && !q.engine.dryRun {
+ finerr = q.tx.Commit(q.ctx)
+ }
+ return finerr
+}
+
+/*func (m *Model) ensure(q *Query, val reflect.Value) error {
+ if _, pk := m.getPrimaryKey(val); pk != nil && !reflect.ValueOf(pk).IsZero() {
+ return nil
+ }
+
+}*/
diff --git a/model_internals.go b/model_internals.go
new file mode 100644
index 0000000..67d3f2c
--- /dev/null
+++ b/model_internals.go
@@ -0,0 +1,106 @@
+package orm
+
+import (
+ "reflect"
+ "strings"
+)
+
+func parseModel(model any) *Model {
+ t := reflect.TypeOf(model)
+ for t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ minfo := &Model{
+ Name: t.Name(),
+ Relationships: make(map[string]*Relationship),
+ Fields: make(map[string]*Field),
+ FieldsByColumnName: make(map[string]*Field),
+ Type: t,
+ }
+ for i := range t.NumField() {
+ f := t.Field(i)
+ if !f.IsExported() {
+ continue
+ }
+ //minfo.Fields[f.Name] = parseField(f, minfo, i)
+ }
+ if minfo.TableName == "" {
+ minfo.TableName = pascalToSnakeCase(t.Name())
+ }
+ return minfo
+}
+
+func parseModelFields(model *Model, modelMap map[string]*Model) {
+ t := model.Type
+ for i := range t.NumField() {
+ f := t.Field(i)
+ fi := parseField(f, model, modelMap, i)
+ if fi != nil {
+ model.addField(fi)
+ }
+ }
+}
+func makeModelMap(models ...any) *ModelMap {
+ modelMap := &ModelMap{
+ Map: make(map[string]*Model),
+ }
+ //modelMap := make(map[string]*Model)
+ for _, model := range models {
+ minfo := parseModel(model)
+ // modelMap.Mux.Lock()
+ modelMap.Map[minfo.Name] = minfo
+ // modelMap.Mux.Unlock()
+ }
+ for _, model := range modelMap.Map {
+ // modelMap.Mux.Lock()
+ parseModelFields(model, modelMap.Map)
+ // modelMap.Mux.Unlock()
+ }
+ tagManyToMany(modelMap)
+ for _, model := range modelMap.Map {
+ // modelMap.Mux.Lock()
+ for _, ref := range model.Relationships {
+ if ref.Type != ManyToMany && ref.Idx != -1 {
+ addForeignKeyFields(ref)
+ }
+ }
+ // modelMap.Mux.Unlock()
+ }
+ return modelMap
+}
+
+func tagManyToMany(models *ModelMap) {
+ hasManys := make(map[string]*Relationship)
+ for _, model := range models.Map {
+ for relName := range model.Relationships {
+ hasManys[model.Name+"."+relName] = model.Relationships[relName]
+ }
+ }
+ for _, model := range models.Map {
+ // models.Mux.Lock()
+ for relName := range model.Relationships {
+
+ mb := model.Relationships[relName].RelatedModel
+ var name string
+ for n, reltmp := range hasManys {
+ if !strings.HasPrefix(n, mb.Name) || reltmp.Type != HasMany {
+ continue
+ }
+ if reltmp.RelatedType == model.Type {
+ name = reltmp.FieldName
+ break
+ }
+ }
+ if rel2, ok := mb.Relationships[name]; ok {
+ if name < relName &&
+ rel2.Type == HasMany && model.Relationships[relName].Type == HasMany {
+ mb.Relationships[name].Type = ManyToMany
+ mb.Relationships[name].m2mInverse = model.Relationships[relName]
+ model.Relationships[relName].Type = ManyToMany
+ model.Relationships[relName].m2mInverse = mb.Relationships[name]
+ }
+ }
+ }
+ // models.Mux.Unlock()
+ }
+}
diff --git a/model_map.go b/model_map.go
new file mode 100644
index 0000000..17effd3
--- /dev/null
+++ b/model_map.go
@@ -0,0 +1,10 @@
+package orm
+
+import (
+ "sync"
+)
+
+type ModelMap struct {
+ Map map[string]*Model
+ Mux sync.RWMutex
+}
diff --git a/model_migration.go b/model_migration.go
new file mode 100644
index 0000000..e1612f6
--- /dev/null
+++ b/model_migration.go
@@ -0,0 +1,110 @@
+package orm
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+func (m *Model) createTableSql() string {
+ var fields []string
+ var fks []string
+ for _, field := range m.Fields {
+ isStructOrSliceOfStructs := field.Type.Kind() == reflect.Struct ||
+ ((field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) &&
+ field.Type.Elem().Kind() == reflect.Struct)
+ if field.PrimaryKey {
+ fields = append(fields, fmt.Sprintf("%s %s PRIMARY KEY", field.ColumnName, field.ColumnType))
+ } else if (field.fk != nil && field.fk.Type != HasMany && field.fk.Type != ManyToMany) && field.isForeignKey {
+ colType := serialToRegular(field.ColumnType)
+ if !field.Nullable {
+ colType += " NOT NULL "
+ }
+ ffk := field.fk.RelatedModel.Fields[field.fk.RelatedModel.IDField]
+ if ffk != nil {
+ fks = append(fks, fmt.Sprintf("%s %s REFERENCES %s(%s)",
+ field.ColumnName, colType,
+ field.fk.RelatedModel.TableName,
+ field.fk.RelatedModel.Fields[field.fk.RelatedModel.IDField].ColumnName))
+ }
+ } else if !isStructOrSliceOfStructs || field.ColumnType != "" {
+ lalala := fmt.Sprintf("%s %s", field.ColumnName, field.ColumnType)
+ if !field.Nullable {
+ lalala += " NOT NULL"
+ }
+ fields = append(fields, lalala)
+ }
+ }
+ inter := strings.Join(fields, ", ")
+ if len(fks) > 0 {
+ inter += ", "
+ inter += strings.Join(fks, ", ")
+ }
+ return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);",
+ m.TableName, inter)
+}
+
+func (m *Model) createJoinTableSql(relName string) string {
+ ref, ok := m.Relationships[relName]
+ if !ok {
+ return ""
+ }
+ aTable := m.TableName
+ joinTableName := ref.JoinTable()
+ fct := serialToRegular(ref.Model.Fields[ref.Model.IDField].ColumnType)
+ rct := serialToRegular(ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnType)
+ pkSection := fmt.Sprintf(",\nPRIMARY KEY (%s, %s_id)",
+ fmt.Sprintf("%s_%s",
+ aTable, pascalToSnakeCase(ref.FieldName),
+ ),
+ ref.RelatedModel.TableName,
+ )
+ if ref.Type == HasMany || ref.Type == ManyToMany {
+ pkSection = ""
+ }
+ return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
+%s %s REFERENCES %s(%s),
+%s_id %s REFERENCES %s(%s)%s
+);`,
+ joinTableName,
+ fmt.Sprintf("%s_%s",
+ aTable, pascalToSnakeCase(ref.FieldName),
+ ),
+ fct,
+ ref.Model.TableName, ref.Model.Fields[ref.Model.IDField].ColumnName,
+ ref.RelatedModel.TableName,
+ rct,
+ ref.RelatedModel.TableName, ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnName,
+ pkSection,
+ )
+}
+
+func (m *Model) migrate(engine *Engine) error {
+ sql := m.createTableSql()
+ fmt.Println(sql)
+ if !engine.dryRun {
+ _, err := engine.conn.Exec(engine.ctx, sql)
+ if err != nil {
+ return err
+ }
+ }
+ for relName, rel := range m.Relationships {
+ relkey := rel.Model.Name
+ if (rel.Type == ManyToMany && !engine.m2mSeen[relkey]) ||
+ (rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh && rel.Type == HasMany) {
+ if rel.Type == ManyToMany {
+ engine.m2mSeen[relkey] = true
+ engine.m2mSeen[rel.RelatedModel.Name] = true
+ }
+ jtsql := m.createJoinTableSql(relName)
+ fmt.Println(jtsql)
+ if !engine.dryRun {
+ _, err := engine.conn.Exec(engine.ctx, jtsql)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
diff --git a/model_misc.go b/model_misc.go
new file mode 100644
index 0000000..a6abba4
--- /dev/null
+++ b/model_misc.go
@@ -0,0 +1,174 @@
+package orm
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+)
+
+func fetchJoinTableChildren(e *Engine, ctx context.Context, childRel *Relationship, pid any) (map[any]struct{}, error) {
+ rsql := fmt.Sprintf("SELECT %s_%s FROM %s where %s_id = $1",
+ childRel.Model.TableName, pascalToSnakeCase(childRel.FieldName),
+ childRel.JoinTable(), childRel.RelatedModel.TableName)
+ res := make(map[any]struct{})
+ if !e.dryRun {
+ rows, err := e.conn.Query(ctx, rsql, pid)
+ defer rows.Close()
+ if err != nil {
+ return nil, err
+ }
+ for rows.Next() {
+ var id any
+ if err = rows.Scan(&id); err != nil {
+ return nil, err
+ }
+ res[id] = struct{}{}
+ }
+ }
+ return res, nil
+}
+
+func fetchChildren(e *Engine, childRel *Relationship, pid any) (map[any]reflect.Value, error) {
+ res := make(map[any]reflect.Value)
+ qq := e.Model(reflect.New(childRel.RelatedModel.Type).Elem().Interface())
+ rrel := childRel.RelatedModel.Relationships[childRel.Model.Name]
+ rfield := childRel.RelatedModel.Fields[rrel.FieldName]
+ /*if rrel == nil {
+ return res, fmt.Errorf("please report this, it shouldn't have happened :(")
+ }*/
+ rawRows, err := qq.Where(fmt.Sprintf("%s.%s = $1", childRel.RelatedModel.TableName, rfield.ColumnName), pid).Find()
+ if err != nil {
+ return nil, err
+ }
+ inter := make([]any, 0)
+ rrv := reflect.ValueOf(rawRows)
+ for i := range rrv.Len() {
+ inter = append(inter, rrv.Index(i).Interface())
+ }
+ for _, row := range inter {
+ v := reflect.ValueOf(row)
+ bv := v
+ for bv.Kind() == reflect.Ptr {
+ bv = bv.Elem()
+ }
+ id := v.FieldByName(childRel.RelatedModel.IDField).Interface()
+ res[id] = v
+ }
+
+ return res, nil
+}
+
+func preDiff(e *Engine, value reflect.Value) (*Model, error) {
+ ptype := value.Type()
+ for ptype.Kind() == reflect.Pointer {
+ ptype = ptype.Elem()
+ }
+ model, ok := e.modelMap.Map[ptype.Name()]
+ if !ok {
+ return nil, fmt.Errorf("model '%s' not found", ptype.Name())
+ }
+ return model, nil
+}
+
+func diffManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
+ model, err := preDiff(e, value)
+ if err != nil {
+ return err
+ }
+ _, ppk := model.getPrimaryKey(value)
+ dbChildren, err := fetchChildren(e, rel, ppk)
+ if err != nil {
+ return err
+ }
+ memChildren := make(map[any]reflect.Value)
+ fv := value.FieldByName(rel.FieldName)
+ for i := range fv.Len() {
+ child := fv.Index(i)
+ _, cpk := rel.RelatedModel.getPrimaryKey(child)
+ if cpk != nil {
+ memChildren[cpk] = child
+ }
+ }
+ // deletions //
+ for pk := range dbChildren {
+ if _, found := memChildren[pk]; !found {
+ table := rel.RelatedModel.TableName
+ idField := rel.RelatedModel.Fields[rel.RelatedModel.IDField]
+ _, err = q.tx.Exec(q.ctx, fmt.Sprintf("DELETE FROM %s where %s = $1", table, idField.ColumnName), pk)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ mField := model.Fields[model.IDField]
+ mpks := map[string]any{}
+ if !model.embeddedIsh {
+ mpks[mField.ColumnName] = ppk
+ }
+ // update || insert //
+ for i := range fv.Len() {
+ cur := fv.Index(i)
+ _, cpk := rel.RelatedModel.getPrimaryKey(cur)
+ if cpk == nil || reflect.ValueOf(cpk).IsZero() {
+
+ _, err = rel.RelatedModel.insert(cur, q, mpks)
+ if err != nil {
+ return err
+ }
+ } else {
+ err = rel.RelatedModel.update(cur, q, mpks)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func diffManyToManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
+ model, err := preDiff(e, value)
+ if err != nil {
+ return err
+ }
+ _, ppk := model.getPrimaryKey(value)
+ ids, err := fetchJoinTableChildren(e, q.ctx, rel, ppk)
+ if err != nil {
+ return err
+ }
+ memIds := make(map[any]reflect.Value)
+ fv := value.FieldByName(rel.FieldName)
+ for i := range fv.Len() {
+ child := fv.Index(i)
+ _, cpk := rel.RelatedModel.getPrimaryKey(child)
+ if cpk != nil {
+ memIds[cpk] = child
+ }
+ }
+ for memId := range memIds {
+ if _, found := ids[memId]; !found {
+ err = rel.joinInsert(memIds[memId], q, ppk)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ for id := range ids {
+ if _, found := memIds[id]; !found {
+ err = rel.joinDelete(ppk, id, q)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func diffSlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
+ if rel.Type == ManyToMany || rel.m2mIsh() {
+ return diffManyToManySlices(e, q, value, rel)
+ }
+ if rel.Type == HasMany {
+ return diffManySlices(e, q, value, rel)
+ }
+ return nil
+}
diff --git a/query.go b/query.go
new file mode 100644
index 0000000..fed34ef
--- /dev/null
+++ b/query.go
@@ -0,0 +1,211 @@
+package orm
+
+import (
+ "context"
+ "fmt"
+ sb "github.com/henvic/pgq"
+ "github.com/jackc/pgx/v5"
+ "reflect"
+ "strings"
+)
+
+type Query struct {
+ model *Model
+ relatedModels map[string]*Model
+ wheres map[string][]any
+ orders []string
+ limit int
+ offset int
+ joins map[*Relationship][3]string
+ engine *Engine
+ ctx context.Context
+ tx pgx.Tx
+}
+
+func (q *Query) totalWheres() int {
+ total := 0
+ for _, w := range q.wheres {
+ total += len(w)
+ }
+ return total
+}
+
+func (q *Query) Model(val any) *Query {
+ tt := reflect.TypeOf(val)
+ for tt.Kind() == reflect.Ptr {
+ tt = tt.Elem()
+ }
+ q.model = q.engine.modelMap.Map[tt.Name()]
+ return q
+}
+
+func (q *Query) Where(cond any, args ...any) *Query {
+ switch v := cond.(type) {
+ case string:
+ q.wheres[strings.ReplaceAll(v, "$?", fmt.Sprintf("$%d", q.totalWheres()+1))] = args
+ default:
+ rv := reflect.ValueOf(cond)
+ for rv.Kind() == reflect.Ptr {
+ rv = rv.Elem()
+ }
+ rt := rv.Type()
+ for i := range rv.NumField() {
+ field := rt.Field(i)
+ fieldValue := rv.Field(i)
+ if isZero(fieldValue) {
+ continue
+ }
+ mm, ok := q.engine.modelMap.Map[rv.Type().Name()]
+ if !ok {
+ continue
+ }
+ ff, ok := mm.Fields[field.Name]
+ if !ok || ff.ColumnType == "" {
+ continue
+ }
+ whereClause := fmt.Sprintf("%s.%s = ?", mm.TableName, ff.ColumnName /*, q.totalWheres()+1*/)
+ args = append(args, fieldValue.Interface())
+ q.wheres[whereClause] = args
+ }
+ }
+ return q
+}
+
+func (q *Query) Order(order string) *Query {
+ q.orders = append(q.orders, order)
+ return q
+}
+
+func (q *Query) Limit(limit int) *Query {
+ q.limit = limit
+ return q
+}
+
+func (q *Query) Offset(offset int) *Query {
+ q.offset = offset
+ return q
+}
+
+func (q *Query) buildSelect() sb.SelectBuilder {
+ var fields []string
+
+ for _, f := range q.model.Fields {
+ if f.ColumnType == "" {
+ continue
+ }
+ tn, a := f.alias()
+ fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
+ }
+ seenModels := make(map[string]bool)
+ processField := func(f *Field, m *Model, pfk *Relationship) {
+ if f.ColumnType == "" {
+ if rel, ok := m.Relationships[f.Name]; ok {
+ data, ok2 := q.joins[rel]
+ if ok2 && f.ColumnType != "" {
+ tn, a := f.aliasWith(data[0])
+ fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
+ }
+ }
+ return
+ }
+ {
+ fk := f.fk
+ if fk == nil {
+ fk = m.Relationships[f.Name]
+ }
+ if fk == nil {
+ fk = pfk
+ }
+ if fk != nil && fk.FieldName == f.Name {
+ data, ok2 := q.joins[fk]
+ if ok2 {
+ var (
+ tn, a string
+ )
+ if fk.Type == HasOne {
+ tn, a = fk.RelatedModel.Fields[fk.RelatedModel.IDField].aliasWith(data[0])
+ } else {
+ tn, a = f.aliasWith(data[0])
+ }
+ fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
+ }
+ return
+ }
+ }
+ if f.Name == pfk.FieldName {
+ f.aliasWith(q.joins[pfk][0])
+ } else {
+ tn, a := f.alias()
+ fields = append(fields, fmt.Sprintf("%s AS %s", tn, a))
+ }
+
+ }
+ for r := range q.joins {
+ if !seenModels[r.aliasThingy()] {
+ seenModels[r.aliasThingy()] = true
+ for _, f := range r.Model.Fields {
+ processField(f, r.Model, r)
+ }
+ }
+ if !seenModels[r.relatedAlias()] {
+ seenModels[r.relatedAlias()] = true
+ for _, f := range r.RelatedModel.Fields {
+ processField(f, r.RelatedModel, r)
+ }
+ }
+ }
+ return sb.Select(fields...)
+}
+
+func (q *Query) buildSQL() (string, []any) {
+ sqlb := q.buildSelect().From(q.model.TableName)
+ whereargs := make([]any, 0)
+ if len(q.wheres) > 0 {
+ cnt := 0
+ for w, where := range q.wheres {
+ sqlb = sqlb.Where(w, where...)
+
+ whereargs = append(whereargs, where...)
+ cnt++
+ }
+ }
+ for _, j := range q.joins {
+ sqlb = sqlb.LeftJoin(fmt.Sprintf("%s as %s ON %s", j[1], j[0], j[2]))
+ }
+ool:
+ for _, o := range q.orders {
+ ac, ok := q.model.Fields[o]
+ if !ok {
+ var rel = ac.fk
+ if ac.ColumnType == "" || rel == nil {
+ rel = q.model.Relationships[o]
+ }
+ if rel != nil {
+ if strings.Contains(o, ".") {
+ split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".")
+ cm := rel.Model
+ for i, s := range split {
+ if rel != nil {
+ cm = rel.RelatedModel
+ } else if i == len(split)-1 {
+ break
+ } else {
+ continue ool
+ }
+ rel = cm.Relationships[s]
+ }
+ lf := split[len(split)-1]
+ ac, ok = cm.Fields[lf]
+ if !ok {
+ continue
+ }
+ }
+ }
+ }
+ sqlb = sqlb.OrderBy(ac.ColumnName)
+ }
+ if q.limit > 0 {
+ sqlb = sqlb.Limit(uint64(q.limit))
+ }
+ return sqlb.MustSQL()
+}
diff --git a/query_populate.go b/query_populate.go
new file mode 100644
index 0000000..e6f7e0f
--- /dev/null
+++ b/query_populate.go
@@ -0,0 +1,165 @@
+package orm
+
+import (
+ "fmt"
+ "strings"
+)
+
+const PopulateAll = "~~~ALL~~~"
+
+func join(r *Relationship) (string, string, string) {
+ rtable := r.RelatedModel.TableName
+ field := r.Model.Fields[r.FieldName]
+ var fk, pk, alias string
+ if !r.RelatedModel.embeddedIsh && !r.Model.embeddedIsh {
+ alias = pascalToSnakeCase(field.Name)
+ fk = fmt.Sprintf("%s.%s", alias, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName)
+ pk = fmt.Sprintf("%s.%s", r.Model.TableName, field.ColumnName)
+ alias = pascalToSnakeCase(field.Name)
+ } else if !r.Model.embeddedIsh {
+ alias = pascalToSnakeCase(r.FieldName)
+ sid := strings.TrimSuffix(r.JoinField(), "ID")
+ fk = fmt.Sprintf("%s.%s", alias, r.RelatedModel.Fields[sid].ColumnName)
+ pk = fmt.Sprintf("%s.%s", r.Model.TableName, r.Model.Fields[r.Model.IDField].ColumnName)
+ }
+ return alias, rtable, fmt.Sprintf("%s = %s", fk, pk)
+}
+
+func m2mJoin(r *Relationship) [][3]string {
+ result := make([][3]string, 0)
+ jt := r.JoinTable()
+ first := [3]string{
+ pascalToSnakeCase(r.FieldName),
+ jt,
+ fmt.Sprintf("%s = %s",
+ fmt.Sprintf("%s.%s",
+ jt,
+ r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName,
+ ),
+ fmt.Sprintf("%s.%s",
+ r.Model.TableName,
+ r.Model.Fields[r.Model.IDField].ColumnName,
+ ),
+ ),
+ }
+ second := [3]string{
+ pascalToSnakeCase(r.m2mInverse.FieldName),
+ r.RelatedModel.TableName,
+ fmt.Sprintf("%s = %s",
+ fmt.Sprintf("%s.%s", r.RelatedModel.TableName, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName),
+ fmt.Sprintf("%s.%s", jt, r.m2mInverse.JoinField()),
+ ),
+ }
+
+ /*first := fmt.Sprintf("%s AS %s ON %s = %s",
+ jt,
+ fmt.Sprintf("%s.%s",
+ jt,
+ r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName,
+ ),
+ fmt.Sprintf("%s.%s",
+ r.Model.TableName,
+ r.Model.Fields[r.Model.IDField].ColumnName,
+ ),
+ )
+ second := fmt.Sprintf("%s ON %s = %s",
+ r.RelatedModel.TableName,
+ fmt.Sprintf("%s.%s", r.RelatedModel.TableName, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName),
+ fmt.Sprintf("%s.%s", jt, r.m2mInverse.JoinField()),
+ )*/
+ result = append(result, first, second)
+ return result
+}
+
+func nestedJoin(m *Model, path string) (joins [][3]string, ree []*Relationship) {
+ splitPath := strings.Split(path, ".")
+ prevModel := m
+ for _, f := range splitPath {
+ rel, ok := m.Relationships[f]
+ if !ok {
+ break
+ }
+ var (
+ fk, pk string
+ )
+ if !rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh {
+ pk = prevModel.Fields[rel.FieldName].ColumnName
+ fk = rel.RelatedModel.Fields[rel.RelatedModel.IDField].ColumnName
+ } else if !rel.Model.embeddedIsh {
+ pk = prevModel.Fields[prevModel.IDField].ColumnName
+ fk = rel.RelatedModel.Fields[strings.TrimSuffix(rel.JoinField(), "ID")].ColumnName
+ }
+ ree = append(ree, rel)
+ j2 := [3]string{
+ rel.Model.Fields[rel.FieldName].ColumnName,
+ rel.RelatedModel.TableName,
+ fmt.Sprintf("%s.%s = %s.%s",
+ rel.RelatedModel.TableName, fk,
+ prevModel.TableName, pk),
+ }
+
+ /*j2 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
+ rel.RelatedModel.TableName,
+ rel.Model.Fields[rel.FieldName].ColumnName,
+ rel.RelatedModel.TableName, fk,
+ prevModel.TableName, pk,
+ )*/
+ joins = append(joins, j2)
+ prevModel = rel.RelatedModel
+ }
+ return
+}
+
+func (q *Query) Populate(relation string) *Query {
+ if relation == PopulateAll {
+ for _, rel := range q.model.Relationships {
+ if rel.Type == ManyToMany {
+ mjs := m2mJoin(rel)
+ for _, ajoin := range mjs {
+ q.joins[rel] = [3]string{
+ ajoin[0],
+ ajoin[1],
+ ajoin[2],
+ }
+ }
+ //q.joins = append(q.joins, m2mJoin(rel)...)
+ } else {
+ alias, tn, cond := join(rel)
+ q.joins[rel] = [3]string{
+ alias,
+ tn,
+ cond,
+ }
+ //q.joins = append(q.joins, tn)
+ }
+ q.relatedModels[rel.RelatedModel.Name] = rel.RelatedModel
+ }
+ } else {
+ if strings.Contains(relation, ".") {
+ njs, rmodels := nestedJoin(q.model, relation)
+ for i, m := range rmodels {
+ curTuple := njs[i]
+ q.joins[m] = [3]string{
+ curTuple[0],
+ curTuple[1],
+ curTuple[2],
+ }
+ q.relatedModels[m.Model.Name] = m.Model
+ }
+ //q.joins = append(q.joins, njs...)
+ } else {
+ rel, ok := q.model.Relationships[relation]
+ if ok {
+ alias, tn, j := join(rel)
+ q.joins[rel] = [3]string{
+ alias,
+ tn,
+ j,
+ }
+ //q.joins = append(q.joins, tn)
+ q.relatedModels[rel.RelatedModel.Name] = rel.RelatedModel
+ }
+ }
+ }
+ return q
+}
diff --git a/query_tail.go b/query_tail.go
new file mode 100644
index 0000000..0b33e4c
--- /dev/null
+++ b/query_tail.go
@@ -0,0 +1,42 @@
+package orm
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+)
+
+func (q *Query) Create(val any) (any, error) {
+ _, err := q.model.insert(reflect.ValueOf(val), q, make(map[string]any))
+ return val, err
+}
+
+func (q *Query) Find() (any, error) {
+ sql, args := q.buildSQL()
+ fmt.Printf("[FIND] %s { %+v }\n", sql, args)
+ if !q.engine.dryRun {
+ rows, err := q.engine.conn.Query(q.ctx, sql, args...)
+ defer rows.Close()
+ if err != nil {
+ return nil, err
+ }
+ rmaps, err := rowsToMaps(rows)
+ if err != nil {
+ return nil, err
+ }
+ wtype := q.model.Type
+ for wtype.Kind() == reflect.Pointer {
+ wtype = wtype.Elem()
+ }
+ return fillSlice(rmaps, wtype, q.engine.modelMap), nil
+ }
+ return make([]any, 0), nil
+}
+
+func (q *Query) Update(val any, cond any, args ...any) error {
+ if q.model != nil {
+ return q.model.update(reflect.ValueOf(val), q, make(map[string]any))
+ } else {
+ return errors.New("Please select a model")
+ }
+}
diff --git a/relationship.go b/relationship.go
new file mode 100644
index 0000000..10565ff
--- /dev/null
+++ b/relationship.go
@@ -0,0 +1,209 @@
+package orm
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+type RelationshipType int
+
+const (
+ HasOne RelationshipType = iota
+ HasMany
+ ManyToMany
+)
+
+type Relationship struct {
+ Type RelationshipType
+ Model *Model
+ FieldName string
+ Idx int
+ RelatedType reflect.Type
+ RelatedModel *Model
+ Kind reflect.Kind // field kind (struct, slice, ...)
+ m2mInverse *Relationship
+}
+
+func (r *Relationship) JoinTable() string {
+ return r.Model.TableName + "_" + r.RelatedModel.TableName
+}
+
+func (r *Relationship) JoinField() string {
+ isMany := r.Type == HasMany //|| r.Type == ManyToMany
+ if isMany && r.Model.embeddedIsh {
+ return r.Model.Name + r.FieldName + "ID"
+ } else if isMany && !r.Model.embeddedIsh && r.m2mInverse == nil {
+ return r.Model.Name + "ID"
+ } else if r.Type == ManyToMany && !r.Model.embeddedIsh {
+ return r.RelatedModel.Name + "ID"
+ }
+ return r.FieldName + "ID"
+}
+
+func (r *Relationship) aliasThingy() string {
+ return pascalToSnakeCase(r.Model.Name + "." + r.FieldName)
+}
+
+func (r *Relationship) relatedAlias() string {
+ return pascalToSnakeCase(r.RelatedModel.Name + "." + r.FieldName)
+}
+
+func (r *Relationship) m2mIsh() bool {
+ return r.Model.embeddedIsh && !r.RelatedModel.embeddedIsh && r.Type == HasMany
+}
+
+func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error {
+ if r.Type != ManyToMany &&
+ !r.m2mIsh() {
+ return nil
+ }
+ ichild := v
+ for ichild.Kind() == reflect.Ptr {
+ ichild = ichild.Elem()
+ }
+ if ichild.Kind() == reflect.Struct {
+ jtable := r.JoinTable()
+ jargs := make([]any, 0)
+ jcols := make([]string, 0)
+ jcols = append(jcols, fmt.Sprintf("%s_%s",
+ r.Model.TableName, pascalToSnakeCase(r.FieldName),
+ ))
+ jargs = append(jargs, pfk)
+
+ jcols = append(jcols, r.RelatedModel.TableName+"_id")
+ jargs = append(jargs, ichild.FieldByName(r.RelatedModel.IDField).Interface())
+ var ecnt int
+ e.tx.QueryRow(e.ctx,
+ fmt.Sprintf("SELECT count(*) from %s where %s = $1 and %s = $2", r.JoinTable(), jcols[0], jcols[1]), jargs...).Scan(&ecnt)
+ if ecnt > 0 {
+ return nil
+ }
+ jsql := fmt.Sprintf("INSERT INTO %s (%s) VALUES ($1, $2)", jtable, strings.Join(jcols, ", "))
+ fmt.Printf("[INSERT/JOIN] %s { %s }\n", jsql, logTrunc(jargs, 200))
+ if !e.engine.dryRun {
+ _ = e.tx.QueryRow(e.ctx, jsql, jargs...).Scan()
+ }
+ }
+ return nil
+}
+
+func (r *Relationship) joinDelete(pk, fk any, q *Query) error {
+ jc := fmt.Sprintf("%s_%s", r.Model.TableName, pascalToSnakeCase(r.FieldName))
+ ds := fmt.Sprintf("DELETE FROM %s where %s = $1 and %s = $2",
+ r.JoinTable(), jc, r.RelatedModel.TableName+"_id")
+ fmt.Printf("[DELETE/JOIN] %s { %s }\n", ds, logTrunc([]any{pk, fk}, 200))
+ if !q.engine.dryRun {
+ _, err := q.tx.Exec(q.ctx, ds, pk, fk)
+ return err
+ }
+ return nil
+}
+
+func parseRelationship(field reflect.StructField, modelMap map[string]*Model, outerType reflect.Type, idx int) *Relationship {
+ rel := &Relationship{
+ Model: modelMap[outerType.Name()],
+ RelatedModel: modelMap[field.Type.Name()],
+ RelatedType: field.Type,
+ Idx: idx,
+ Kind: field.Type.Kind(),
+ FieldName: field.Name,
+ }
+ if rel.RelatedType.Kind() == reflect.Slice || rel.RelatedType.Kind() == reflect.Array {
+ rel.RelatedType = rel.RelatedType.Elem()
+ }
+ if rel.RelatedModel == nil {
+ if rel.RelatedType.Name() == "" {
+ rt := rel.RelatedType
+ for rt.Kind() == reflect.Ptr || rt.Kind() == reflect.Slice || rt.Kind() == reflect.Array {
+ rel.RelatedType = rel.RelatedType.Elem()
+ rt = rel.RelatedType
+ }
+ }
+ rel.RelatedModel = modelMap[rel.RelatedType.Name()]
+ if _, ok := modelMap[rel.RelatedType.Name()]; !ok {
+ rel.RelatedModel = parseModel(reflect.New(rel.RelatedType).Interface())
+ modelMap[rel.RelatedType.Name()] = rel.RelatedModel
+ parseModelFields(rel.RelatedModel, modelMap)
+ rel.RelatedModel.embeddedIsh = true
+ }
+ }
+ switch field.Type.Kind() {
+ case reflect.Struct:
+ rel.Type = HasOne
+ case reflect.Slice:
+ rel.Type = HasMany
+ }
+ return rel
+}
+func addForeignKeyFields(ref *Relationship) {
+ rf := ref.RelatedModel.Fields[ref.RelatedModel.IDField]
+ if rf != nil {
+ if !ref.RelatedModel.embeddedIsh && !ref.Model.embeddedIsh {
+ ff := ref.Model.Fields[ref.FieldName]
+ ff.ColumnType = rf.ColumnType
+ ff.ColumnName = pascalToSnakeCase(ref.JoinField())
+ ff.isForeignKey = true
+ ff.fk = ref
+ } else if !ref.Model.embeddedIsh {
+ sid := strings.TrimSuffix(ref.JoinField(), "ID")
+ ref.RelatedModel.Relationships[sid] = &Relationship{
+ FieldName: sid,
+ Type: HasOne,
+ RelatedModel: ref.Model,
+ Model: ref.RelatedModel,
+ Kind: ref.RelatedModel.Type.Kind(),
+ Idx: -1,
+ RelatedType: ref.Model.Type,
+ }
+ ref.RelatedModel.addField(&Field{
+ ColumnType: rf.ColumnType,
+ ColumnName: pascalToSnakeCase(ref.JoinField()),
+ Name: sid,
+ isForeignKey: true,
+ Type: rf.Type,
+ Index: -1,
+ fk: ref.RelatedModel.Relationships[sid],
+ })
+ } else if ref.Model.embeddedIsh && !ref.RelatedModel.embeddedIsh {
+
+ }
+ } else {
+ ref.RelatedModel.addField(&Field{
+ Name: "ID",
+ ColumnName: "id",
+ ColumnType: "bigserial",
+ PrimaryKey: true,
+ Type: ref.RelatedType,
+ Index: -1,
+ AutoIncrement: true,
+ })
+ ff := ref.Model.Fields[ref.FieldName]
+ ff.ColumnType = "bigint"
+ ff.ColumnName = pascalToSnakeCase(ref.RelatedModel.Name + "ID")
+ ff.isForeignKey = true
+ ff.fk = ref
+ ref.RelatedModel.IDField = "ID"
+ /*
+ nn := ref.Model.Name + "ID"
+ ref.RelatedModel.Relationships[ref.Model.Name] = &Relationship{
+ Type: HasOne,
+ RelatedModel: ref.Model,
+ Model: ref.RelatedModel,
+ Idx: 65536,
+ Kind: ref.RelatedModel.Type.Kind(),
+ RelatedType: ref.RelatedModel.Type,
+ FieldName: nn,
+ }
+ ref.RelatedModel.addField(&Field{
+ Name: nn,
+ Type: ref.Model.Type,
+ fk: ref.RelatedModel.Relationships[nn],
+ ColumnName: pascalToSnakeCase(nn),
+ Index: -1,
+ ColumnType: ref.Model.Fields[ref.Model.IDField].ColumnType,
+ isForeignKey: true,
+ })*/
+
+ }
+}
diff --git a/scan.go b/scan.go
new file mode 100644
index 0000000..084f0a1
--- /dev/null
+++ b/scan.go
@@ -0,0 +1,85 @@
+package orm
+
+import (
+ "github.com/jackc/pgx/v5"
+ "reflect"
+)
+
+func rowsToMaps(rows pgx.Rows) ([]map[string]any, error) {
+ var result []map[string]any
+ fieldDescs := rows.FieldDescriptions()
+ for rows.Next() {
+ m := make(map[string]any)
+ scanArgs := make([]any, len(fieldDescs))
+ for i := range fieldDescs {
+ var v any
+ scanArgs[i] = &v
+ }
+ if err := rows.Scan(scanArgs...); err != nil {
+ return nil, err
+ }
+ for i, fd := range fieldDescs {
+ name := fd.Name
+ m[name] = *(scanArgs[i].(*any))
+ }
+ result = append(result, m)
+ }
+ return result, rows.Err()
+}
+
+func fillNested(row map[string]any, t reflect.Type, mm *ModelMap, depth, maxDepth int) any {
+ cm := mm.Map[t.Name()]
+ pp := reflect.New(cm.Type).Elem()
+ for _, field := range cm.Fields {
+ _, alias := field.alias()
+ if v, ok := row[alias]; ok && !field.isForeignKey {
+ reflectSet(pp.Field(field.Index), v)
+ }
+ }
+ for _, rel := range cm.Relationships {
+ if rel.Idx > -1 && rel.Idx < pp.NumField() {
+ relType := rel.RelatedModel.Type
+ for relType.Kind() == reflect.Pointer {
+ relType = relType.Elem()
+ }
+ nv := reflect.New(relType)
+ if rel.Kind == reflect.Struct || rel.Kind == reflect.Pointer {
+ if depth < maxDepth {
+ nv = reflect.ValueOf(fillNested(row, relType, mm, depth+1, maxDepth))
+ }
+ if rel.Kind != reflect.Pointer && nv.Kind() == reflect.Pointer {
+ nv = nv.Elem()
+ }
+ reflectSet(pp.Field(rel.Idx), nv)
+ } else if rel.Kind == reflect.Slice || rel.Kind == reflect.Array {
+ relType2 := relType
+ for relType2.Kind() == reflect.Slice || relType2.Kind() == reflect.Array {
+ relType2 = relType2.Elem()
+ }
+ if depth < maxDepth {
+ nv = reflect.ValueOf(fillNested(row, relType2, mm, depth+1, maxDepth))
+ }
+ if nv.Kind() == reflect.Pointer {
+ nv = nv.Elem()
+ }
+ reflectSet(pp.Field(rel.Idx), reflect.Append(pp.Field(rel.Idx), nv))
+ }
+ }
+ }
+ return pp.Interface()
+}
+
+// fillSlice - note that it's the caller's responsibility to indirect
+// the type in `t`.
+func fillSlice(rows []map[string]any, t reflect.Type, mm *ModelMap) any {
+ pslice := reflect.MakeSlice(reflect.SliceOf(t), 0, 0)
+ rt := t
+ for rt.Kind() == reflect.Ptr {
+ rt = rt.Elem()
+ }
+ for _, row := range rows {
+ pp := fillNested(row, rt, mm, 0, 10)
+ pslice = reflect.Append(pslice, reflect.ValueOf(pp))
+ }
+ return pslice.Interface()
+}
diff --git a/test_main.go b/test_main.go
new file mode 100644
index 0000000..bea347d
--- /dev/null
+++ b/test_main.go
@@ -0,0 +1,55 @@
+package orm
+
+import "fmt"
+
+const do_bootstrap = true
+
+func TestMain() {
+ e, err := Open("postgres://testbed_user:123@localhost/testbed_i_think")
+ if err != nil {
+ panic(err)
+ }
+ e.Models(user{}, story{}, band{}, role{})
+ if do_bootstrap {
+
+ err = e.Migrate()
+ if err != nil {
+ panic(err)
+ }
+ s := iti_multi()
+ u := &author
+ u.Favs.Authors = append(u.Favs.Authors, friend)
+ _, err = e.Model(&user{}).Create(&friend)
+ if err != nil {
+ panic(err)
+ }
+ _, err = e.Model(&user{}).Create(u)
+ if err != nil {
+ panic(err)
+ }
+ _, err = e.Model(&band{}).Create(&bodom)
+ if err != nil {
+ panic(err)
+ }
+ _, err = e.Model(&band{}).Create(&diamondHead)
+ if err != nil {
+ panic(err)
+ }
+ _, err = e.Model(&story{}).Create(s)
+ if err != nil {
+ panic(err)
+ }
+ s.Downloads = s.Downloads + 1
+ err = e.Model(&story{}).Update(s, nil)
+ if err != nil {
+ panic(err)
+ }
+ }
+ ns, err := e.Model(&story{}).Where(&story{
+ ID: 1,
+ }).Populate(PopulateAll).Find()
+ if err != nil {
+ panic(err)
+ }
+ fmt.Printf("%+v", ns)
+}
diff --git a/testing.go b/testing.go
index 3ee8694..401177c 100644
--- a/testing.go
+++ b/testing.go
@@ -1,48 +1,52 @@
package orm
import (
- "context"
"fmt"
- "github.com/stretchr/testify/assert"
+ "math/rand/v2"
"strings"
- "testing"
"time"
"github.com/go-loremipsum/loremipsum"
)
type chapter struct {
- ID bson.ObjectID `json:"_id"`
- Title string `json:"chapterTitle" form:"chapterTitle"`
- ChapterID int `json:"chapterID" autoinc:"chapters"`
- Index int `json:"index" form:"index"`
- Words int `json:"words"`
- Notes string `json:"notes" form:"notes"`
- Genre []string `json:"genre" form:"genre"`
- Bands []band `json:"bands" ref:"band,bands"`
- Characters []string `json:"characters" form:"characters"`
- Relationships [][]string `json:"relationships" form:"relationships"`
- Adult bool `json:"adult" form:"adult"`
- Summary string `json:"summary" form:"summary"`
- Hidden bool `json:"hidden" form:"hidden"`
- LoggedInOnly bool `json:"loggedInOnly" form:"loggedInOnly"`
- Posted time.Time `json:"datePosted"`
- FileName string `json:"fileName"`
- Text string `json:"text" gridfs:"story_text,/stories/{{.ChapterID}}.txt"`
+ ChapterID int64 `json:"chapterID" d:"pk:t;"`
+ Title string `json:"chapterTitle" form:"chapterTitle"`
+ Index int `json:"index" form:"index"`
+ Words int `json:"words"`
+ Notes string `json:"notes" form:"notes"`
+ Genre []string `json:"genre" form:"genre" d:"type:text[]"`
+ Bands []band `json:"bands" ref:"band,bands"`
+ Characters []string `json:"characters" form:"characters" d:"type:text[]"`
+ Relationships [][]string `json:"relationships" form:"relationships" d:"type:text[][]"`
+ Adult bool `json:"adult" form:"adult"`
+ Summary string `json:"summary" form:"summary"`
+ Hidden bool `json:"hidden" form:"hidden"`
+ LoggedInOnly bool `json:"loggedInOnly" form:"loggedInOnly"`
+ Posted time.Time `json:"datePosted"`
+ FileName string `json:"fileName" d:"-"`
+ Text string `json:"text" d:"column:content" gridfs:"story_text,/stories/{{.ChapterID}}.txt"`
}
type band struct {
- ID int64 `json:"_id"`
Document `json:",inline" d:"table:bands"`
+ ID int64 `json:"_id" d:"pk;"`
Name string `json:"name" form:"name"`
Locked bool `json:"locked" form:"locked"`
- Characters []string `json:"characters" form:"characters"`
+ Characters []string `json:"characters" form:"characters" d:"type:text[]"`
}
type user struct {
- ID int64 `json:"_id"`
Document `json:",inline" d:"table:users"`
+ ID int64 `json:"_id" d:"pk;"`
Username string `json:"username"`
Favs favs `json:"favs" ref:"user"`
+ Roles []role
+}
+
+type role struct {
+ ID int64 `d:"pk"`
+ Name string
+ Users []user
}
type favs struct {
@@ -50,10 +54,10 @@ type favs struct {
Authors []user
}
type story struct {
- ID int64 `json:"_id"`
- Document `json:",inline" coll:"stories"`
+ Document `json:",inline" d:"table:stories"`
+ ID int64 `json:"_id" d:"pk;"`
Title string `json:"title" form:"title"`
- Author *user `json:"author" ref:"user"`
+ Author user `json:"author" ref:"user"`
CoAuthor *user `json:"coAuthor" ref:"user"`
Chapters []chapter `json:"chapters"`
Recs int `json:"recs"`
@@ -69,54 +73,23 @@ type somethingWithNestedChapters struct {
NestedText string `json:"text" gridfs:"nested_text,/nested/{{.ID}}.txt"`
}
-func (s *somethingWithNestedChapters) Id() any {
- return s.ID
-}
-
-func (s *somethingWithNestedChapters) SetId(id any) {
- s.ID = id.(int64)
-}
-
-func (s *story) Id() any {
- return s.ID
-}
-
-func (s *band) Id() any {
- return s.ID
-}
-func (s *user) Id() any {
- return s.ID
-}
-
-func (s *story) SetId(id any) {
- s.ID = id.(int64)
- //var t IDocument =s
-}
-
-func (s *band) SetId(id any) {
- s.ID = id.(int64)
-}
-
-func (s *user) SetId(id any) {
- s.ID = id.(int64)
+var friend = user{
+ Username: "DarQuiel7",
+ ID: 83378,
}
var author = user{
Username: "tablet.exe",
- Favs: []user{
- {
- Username: "DarQuiel7",
- },
- },
+ ID: 85783,
}
-func genChaps(single bool) []chapter {
+func genChaps(single bool, aceil int) []chapter {
var ret []chapter
var ceil int
if single {
ceil = 1
} else {
- ceil = 5
+ ceil = aceil
}
emptyRel := make([][]string, 0)
emptyRel = append(emptyRel, make([]string, 0))
@@ -140,8 +113,8 @@ func genChaps(single bool) []chapter {
for i := 0; i < ceil; i++ {
spf := fmt.Sprintf("%d.md", i+1)
- ret = append(ret, chapter{
- ID: bson.NewObjectID(),
+ c := chapter{
+ ChapterID: int64(i + 1),
Title: fmt.Sprintf("-%d-", i+1),
Index: i + 1,
Words: 50,
@@ -156,7 +129,19 @@ func genChaps(single bool) []chapter {
LoggedInOnly: true,
FileName: spf,
Text: strings.Join(l.ParagraphList(10), "\n\n"),
- })
+ Posted: time.Now().Add(time.Hour * time.Duration(int64(24*7*i))),
+ }
+ {
+ randMin := max(i+1, 1)
+ randMax := min(i+1, randMin+1)
+ mod1 := max(rand.IntN(randMin), 1)
+ mod2 := max(rand.IntN(randMax+1), 1)
+ if (mod1%mod2 == 0 || (mod1%mod2) == 2) && i > 0 {
+ c.Bands = append(c.Bands, bodom)
+ }
+ }
+
+ ret = append(ret, c)
}
return ret
@@ -165,51 +150,35 @@ func genChaps(single bool) []chapter {
func doSomethingWithNested() somethingWithNestedChapters {
l := loremipsum.New()
swnc := somethingWithNestedChapters{
- Chapters: genChaps(false),
+ Chapters: genChaps(false, 7),
NestedText: strings.Join(l.ParagraphList(15), "\n\n"),
}
return swnc
}
-func iti_single() story {
- return story{
+func iti_single() *story {
+ return &story{
Title: "title",
Completed: true,
- Chapters: genChaps(true),
+ Author: author,
+ Chapters: genChaps(true, 0),
}
}
-func iti_multi() story {
- return story{
+func iti_multi() *story {
+ return &story{
Title: "Brian Tatler Fucked and Abused Sean Harris",
Completed: false,
- Chapters: genChaps(false),
+ Author: author,
+ Chapters: genChaps(false, 5),
}
}
-func iti_blank() story {
+func iti_blank() *story {
t := iti_single()
t.Chapters = make([]chapter, 0)
return t
}
-func initTest() {
- uri := "mongodb://127.0.0.1:27017"
- db := "rockfic_ormTest"
- ic, _ := mongo.Connect(options.Client().ApplyURI(uri))
- ic.Database(db).Drop(context.TODO())
- colls, _ := ic.Database(db).ListCollectionNames(context.TODO(), bson.M{})
- if len(colls) < 1 {
- mdb := ic.Database(db)
- mdb.CreateCollection(context.TODO(), "bands")
- mdb.CreateCollection(context.TODO(), "stories")
- mdb.CreateCollection(context.TODO(), "users")
- }
- defer ic.Disconnect(context.TODO())
- Connect(uri, db)
- author.ID = 696969
- ModelRegistry.Model(band{}, user{}, story{})
-}
-
var metallica = band{
ID: 1,
Name: "Metallica",
@@ -247,13 +216,3 @@ var bodom = band{
"Alexander Kuoppala",
},
}
-
-func saveDoc(t *testing.T, doc IDocument) {
- err := doc.Save()
- assert.Nil(t, err)
-}
-
-func createAndSave(t *testing.T, doc IDocument) {
- mdl := Create(doc).(IDocument)
- saveDoc(t, mdl)
-}
diff --git a/tests/main.go b/tests/main.go
new file mode 100644
index 0000000..21fde97
--- /dev/null
+++ b/tests/main.go
@@ -0,0 +1,7 @@
+package main
+
+import "rockfic.com/orm"
+
+func main() {
+ orm.TestMain()
+}
diff --git a/utils.go b/utils.go
new file mode 100644
index 0000000..0439454
--- /dev/null
+++ b/utils.go
@@ -0,0 +1,95 @@
+package orm
+
+import (
+ "fmt"
+ "reflect"
+ "regexp"
+ "strings"
+)
+
+var pascalRegex = regexp.MustCompile(`(?P[a-z])(?P[A-Z])`)
+var nonWordRegex = regexp.MustCompile(`[^a-zA-Z0-9_]`)
+
+func pascalToSnakeCase(str string) string {
+ step1 := pascalRegex.ReplaceAllString(str, `${lowercase}_${uppercase}`)
+ step2 := nonWordRegex.ReplaceAllString(step1, "_")
+ return strings.ToLower(step2)
+}
+
+func canConvertTo[T any](thisType reflect.Type) bool {
+ return thisType.ConvertibleTo(reflect.TypeFor[T]()) ||
+ thisType.ConvertibleTo(reflect.TypeFor[*T]()) ||
+ strings.TrimPrefix(thisType.Name(), "*") == strings.TrimPrefix(reflect.TypeFor[T]().Name(), "*")
+}
+
+func parseTags(t string) map[string]string {
+ tags := strings.Split(t, ";")
+ m := make(map[string]string)
+ for _, tag := range tags {
+ field := strings.Split(tag, ":")
+ if len(field) < 2 {
+ m[strings.ToLower(field[0])] = "t"
+ } else {
+ m[strings.ToLower(field[0])] = field[1]
+ }
+ }
+ return m
+}
+
+func capitalizeFirst(str string) string {
+ firstChar := strings.ToUpper(string([]byte{str[0]}))
+ return firstChar + string(str[1:])
+}
+
+func serialToRegular(str string) string {
+ return strings.ReplaceAll(strings.ToLower(str), "serial", "int")
+}
+func isZero(v reflect.Value) bool {
+ switch v.Kind() {
+ case reflect.String:
+ return v.String() == ""
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return v.Int() == 0
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return v.Uint() == 0
+ case reflect.Bool:
+ return !v.Bool()
+ case reflect.Ptr, reflect.Interface:
+ return v.IsNil()
+ }
+ return v.IsZero()
+}
+func reflectSet(f reflect.Value, v any) {
+ if !f.CanSet() || v == nil {
+ return
+ }
+ switch f.Kind() {
+ case reflect.Int, reflect.Int64:
+ switch val := v.(type) {
+ case int64:
+ f.SetInt(val)
+ case int32:
+ f.SetInt(int64(val))
+ case int:
+ f.SetInt(int64(val))
+ case uint64:
+ f.SetInt(int64(val))
+ }
+ case reflect.String:
+ if s, ok := v.(string); ok {
+ f.SetString(s)
+ }
+ }
+}
+
+func logTrunc(v any, length int) string {
+ if length < 5 {
+ length = 5
+ }
+ str := fmt.Sprintf("%+v", v)
+ trunced := str[:min(length, len(str))]
+ if len(trunced) < len(str) {
+ trunced += "..."
+ }
+ return trunced
+}