diff --git a/.gitignore b/.gitignore index 0741d58..32648f2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ go.work.sum go.work muck/ -/build/ \ No newline at end of file +/build/ +/test-logs/ \ No newline at end of file diff --git a/diamond.go b/diamond.go index b3f9337..f00b4ac 100644 --- a/diamond.go +++ b/diamond.go @@ -2,33 +2,53 @@ package orm import ( "context" + "fmt" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "log/slog" "time" ) +const LevelQuery = slog.Level(-6) +const defaultKey = "default" + type Engine struct { - modelMap *ModelMap + modelMap *internalModelMap conn *pgxpool.Pool m2mSeen map[string]bool dryRun bool cfg *pgxpool.Config ctx context.Context + logger *slog.Logger + levelVar *slog.LevelVar + connStr string } func (e *Engine) Models(v ...any) { - e.modelMap = makeModelMap(v...) + emm := makeModelMap(v...) + for k := range emm.Map { + if _, ok := e.modelMap.Map[k]; !ok { + e.modelMap.Mux.Lock() + e.modelMap.Map[k] = emm.Map[k] + e.modelMap.Mux.Unlock() + } + } } func (e *Engine) Model(val any) *Query { qq := &Query{ - engine: e, - ctx: context.Background(), - wheres: make(map[string][]any), - joins: make(map[*Relationship][3]string), - orders: make([]string, 0), - relatedModels: make(map[string]*Model), + engine: e, + ctx: context.Background(), + wheres: make(map[string][]any), + orders: make([]string, 0), + populationTree: make(map[string]any), + joins: make([]string, 0), } - return qq.Model(val) + return qq.setModel(val) +} + +func (e *Engine) QueryRaw(sql string, args ...any) (pgx.Rows, error) { + return e.conn.Query(e.ctx, sql, args...) } func (e *Engine) Migrate() error { @@ -53,9 +73,31 @@ func (e *Engine) Migrate() error { return err } +func (e *Engine) MigrateDropping() error { + for _, m := range e.modelMap.Map { + sql := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", m.TableName) + if _, err := e.conn.Exec(e.ctx, sql); err != nil { + return err + } + for _, r := range m.Relationships { + if r.m2mIsh() || r.Type == ManyToMany { + jsql := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", r.ComputeJoinTable()) + if _, err := e.conn.Exec(e.ctx, jsql); err != nil { + return err + } + } + } + } + return e.Migrate() +} + +func (e *Engine) Disconnect() { + e.conn.Close() +} + func Open(connString string) (*Engine, error) { e := &Engine{ - modelMap: &ModelMap{ + modelMap: &internalModelMap{ Map: make(map[string]*Model), }, m2mSeen: make(map[string]bool), @@ -63,6 +105,14 @@ func Open(connString string) (*Engine, error) { ctx: context.Background(), } if connString != "" { + engines.Mux.Lock() + if len(engines.Engines) == 0 || engines.Engines[defaultKey] == nil { + engines.Engines[defaultKey] = e + } else { + engines.Engines[connString] = e + } + e.connStr = "" + engines.Mux.Unlock() var err error e.cfg, err = pgxpool.ParseConfig(connString) e.cfg.MinConns = 5 @@ -75,7 +125,6 @@ func Open(connString string) (*Engine, error) { if err != nil { return nil, err } - } return e, nil } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..dbd5866 --- /dev/null +++ b/errors.go @@ -0,0 +1,7 @@ +package orm + +import "fmt" + +var ErrNoConditionOnDeleteOrUpdate = fmt.Errorf("refusing to delete/update with no conditions specified.\n"+ + " (hint: call `.Where(%s)` or `.Where(%s)` to do so anyways)", + `"true"`, `"1 = 1"`) diff --git a/field.go b/field.go index 85f24e8..38e66d5 100644 --- a/field.go +++ b/field.go @@ -1,43 +1,65 @@ package orm import ( - "fmt" "net" "reflect" "time" ) +// Field - represents a field with a valid SQL type in a Model type Field struct { - Name string - ColumnName string - ColumnType string - Type reflect.Type - Original reflect.StructField - Model *Model - Index int - AutoIncrement bool - PrimaryKey bool - Nullable bool - isForeignKey bool - fk *Relationship + Name string + ColumnName string + ColumnType string + Type reflect.Type + Original reflect.StructField + Model *Model + Index int + AutoIncrement bool + PrimaryKey bool + Nullable bool + embeddedFields map[string]*Field } -func (f *Field) alias() (string, string) { - columnName := f.Model.Fields[f.Model.IDField].ColumnName - if f.ColumnType != "" { - columnName = f.ColumnName +func (f *Field) isAnonymous() bool { + return f.Original.Anonymous +} + +func (f *Field) anonymousColumnNames() []string { + cols := make([]string, 0) + if !f.isAnonymous() { + return cols } - return fmt.Sprintf("%s.%s", f.Model.TableName, columnName), fmt.Sprintf("%s_%s", f.Model.TableName, f.ColumnName) + for _, ef := range f.embeddedFields { + cols = append(cols, ef.ColumnName) + } + return cols } -func (f *Field) aliasWith(a string) (string, string) { - first := fmt.Sprintf("%s.%s", a, f.ColumnName) - second := fmt.Sprintf("%s_%s", a, f.ColumnName) - return first, second -} - -func (f *Field) key() string { - return fmt.Sprintf("%s.%s", f.Model.Name, f.Name) +func defaultColumnValue(ty reflect.Type) any { + switch ty.Kind() { + case reflect.Int32, reflect.Uint32, reflect.Int, reflect.Uint, reflect.Int64, reflect.Uint64: + return 0 + case reflect.Bool: + return false + case reflect.String: + return "''" + case reflect.Float32, reflect.Float64: + return 0.0 + case reflect.Struct: + if canConvertTo[time.Time](ty) { + return "now()" + } + if canConvertTo[net.IP](ty) { + return "'0.0.0.0'::INET" + } + if canConvertTo[net.IPNet](ty) { + return "'0.0.0.0/0'::CIDR" + } + case reflect.Slice: + return "'{}'" + } + return "NULL" } func columnType(ty reflect.Type, isPk, isAutoInc bool) string { @@ -47,13 +69,13 @@ func columnType(ty reflect.Type, isPk, isAutoInc bool) string { for it.Kind() == reflect.Ptr { it = it.Elem() } - case reflect.Int32, reflect.Uint32: + case reflect.Int32, reflect.Uint32, reflect.Int, reflect.Uint: if isPk || isAutoInc { return "serial" } else { return "int" } - case reflect.Int64, reflect.Uint64, reflect.Int, reflect.Uint: + case reflect.Int64, reflect.Uint64: if isPk || isAutoInc { return "bigserial" } else { @@ -123,9 +145,13 @@ func parseField(f reflect.StructField, minfo *Model, modelMap map[string]*Model, case reflect.Struct: if canConvertTo[Document](elem) && f.Anonymous { minfo.TableName = tags["table"] - return nil + field.embeddedFields = make(map[string]*Field) + for j := range elem.NumField() { + efield := elem.Field(j) + field.embeddedFields[pascalToSnakeCase(efield.Name)] = parseField(efield, minfo, modelMap, j) + } } else if field.ColumnType == "" { - minfo.Relationships[field.Name] = parseRelationship(f, modelMap, minfo.Type, i) + minfo.Relationships[field.Name] = parseRelationship(f, modelMap, minfo.Type, i, tags) } } diff --git a/model.go b/model.go index 4500b83..1cdd989 100644 --- a/model.go +++ b/model.go @@ -1,11 +1,7 @@ package orm import ( - "fmt" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" "reflect" - "strings" ) type Model struct { @@ -25,15 +21,18 @@ func (m *Model) addField(field *Field) { m.FieldsByColumnName[field.ColumnName] = field } -func (m *Model) getAliasFields() map[string]string { - fields := make(map[string]string) - for _, f := range m.Fields { - if f.fk != nil { - continue - } - fields[f.Name] = fmt.Sprintf("%s.%s as %s_%s", m.TableName, f.ColumnName, strings.ToLower(m.Name), f.ColumnName) - } - return fields +const ( + documentField = "Document" + createdField = "Created" + modifiedField = "Modified" +) + +func (m *Model) docField() *Field { + return m.Fields[documentField] +} + +func (m *Model) idField() *Field { + return m.Fields[m.IDField] } func (m *Model) getPrimaryKey(val reflect.Value) (string, any) { @@ -42,6 +41,17 @@ func (m *Model) getPrimaryKey(val reflect.Value) (string, any) { return "", nil } colName := colField.ColumnName + wasPtr := false + if val.Kind() == reflect.Ptr { + if val.IsNil() { + return "", nil + } + val = val.Elem() + wasPtr = true + } + if val.IsZero() && wasPtr { + return "", nil + } idField := val.FieldByName(m.IDField) if idField.IsValid() { return colName, idField.Interface() @@ -49,249 +59,21 @@ func (m *Model) getPrimaryKey(val reflect.Value) (string, any) { return "", nil } -func (m *Model) insert(v reflect.Value, e *Query, parentFks map[string]any) (any, error) { - var isTopLevel bool - var err error - var cn *pgxpool.Conn - if e.tx == nil && !e.engine.dryRun { - isTopLevel = true - var tx pgx.Tx - cn, err = e.engine.conn.Acquire(e.ctx) - if err != nil { - return nil, err - } - tx, err = cn.Begin(e.ctx) - if err != nil { - return nil, err - } - e.tx = tx - defer func() { - if err != nil { - e.tx.Rollback(e.ctx) - } - fmt.Printf("[DBG] discarding tx @ %p\n", e.tx) - fmt.Printf("[DBG] discarding conn @ %p\n", e.tx.Conn()) - fmt.Printf("[DBG] discarding outer conn @ %p\n", cn) - e.tx = nil - cn.Release() - }() - } - for v.Kind() == reflect.Pointer { - v = v.Elem() - } - t := v.Type() - var cols []string - var placeholders []string - var args []any - var returningField *Field - for k, vv := range parentFks { - cols = append(cols, k) - placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) - args = append(args, vv) - } - for _, ff := range m.Fields { - //ft := t.Field(ff.Index) - var fv reflect.Value - if ff.Index > -1 { - fv = v.Field(ff.Index) - } - - mfk := ff.fk - if mfk == nil { - mfk = m.Relationships[ff.Name] - } - if mfk != nil { - af := v.FieldByName(mfk.FieldName) - if af.Kind() == reflect.Struct { - idField := af.FieldByName(ff.fk.RelatedModel.IDField) - cols = append(cols, ff.ColumnName) - placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) - if !idField.IsValid() { - var nid any - nid, err = ff.fk.RelatedModel.insert(af, e, make(map[string]any)) - if err != nil { - return 0, err - } - args = append(args, nid) - } else { - args = append(args, idField.Interface()) - } - } - continue - } - col := ff.ColumnName - if ff.PrimaryKey { - returningField = ff - if fv.IsValid() && !fv.IsZero() { - cols = append(cols, col) - placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) - args = append(args, fv.Interface()) - } - continue - } - if fv.IsValid() { - cols = append(cols, col) - placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) - args = append(args, fv.Interface()) - } - } - var rfc = "id" - if returningField != nil { - rfc = returningField.ColumnName - } - scols := fmt.Sprintf("(%s)", strings.Join(cols, ", ")) - svals := fmt.Sprintf("VALUES (%s) ", strings.Join(placeholders, ", ")) - sql := fmt.Sprintf("INSERT INTO %s ", - m.TableName, - ) - if len(cols) > 0 { - sql += scols - sql += " " - sql += svals - } else { - sql += "DEFAULT VALUES " - } - sql += fmt.Sprintf("RETURNING %s", rfc) - fmt.Printf("[INSERT] %s { %s }\n", sql, logTrunc(args, 200)) - var id any - { - if v.FieldByName(m.IDField).IsValid() { - id = v.FieldByName(m.IDField).Interface() - } - } - if !e.engine.dryRun { - qr := e.engine.conn.QueryRow(e.ctx, sql, args...) - err = qr.Scan(&id) - if err != nil { - return 0, err - } - } - { - retField := v.FieldByName(returningField.Name) - if retField.IsValid() { - retField.Set(reflect.ValueOf(id)) - } - } - for i := range t.NumField() { - ft := t.Field(i) - fv := v.Field(i) - if ft.Type.Kind() == reflect.Slice { - for j := range fv.Len() { - child := fv.Index(j).Addr() - cm := e.engine.modelMap.Map[child.Type().Elem().Name()] - if cm != nil && cm.embeddedIsh { - cfk := map[string]any{ - m.TableName + "_id": id, - } - if m.Relationships[ft.Name] != nil { - cfk = map[string]any{ - pascalToSnakeCase(m.Relationships[ft.Name].JoinField()): id, - } - } - _, err = cm.insert(child, e, cfk) - if err != nil { - return 0, err - } - } else if cm != nil { - rel := m.Relationships[ft.Name] - if rel != nil { - err = rel.joinInsert(child, e, id) - if err != nil { - return 0, err - } - } - } - } - } - } - if isTopLevel && !e.engine.dryRun { - err = e.tx.Commit(e.ctx) - } - return id, err +func (m *Model) needsPrimaryKey(val reflect.Value) bool { + _, pk := m.getPrimaryKey(val) + return pk == nil || reflect.ValueOf(pk).IsZero() } -func (m *Model) update(val reflect.Value, q *Query, parentFks map[string]any) error { - var isTopLevel bool - var err error - if q.tx == nil && !q.engine.dryRun { - isTopLevel = true - var tx pgx.Tx - tx, err = q.engine.conn.Begin(q.ctx) - if err != nil { - return err - } - q.tx = tx - defer func() { - if err != nil { - q.tx.Rollback(q.ctx) - q.tx = nil - } - }() - } - cnt := 1 - sets := make([]string, 0) - vals := make([]any, 0) - for val.Kind() == reflect.Pointer { - val = val.Elem() - } - for _, field := range m.Fields { - if field.fk != nil || field.ColumnType == "" || field.Name == m.IDField { - continue - } - if field.Index > -1 && field.Index < val.NumField() && field.fk == nil { - f := val.Field(field.Index) - sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt)) - vals = append(vals, f.Interface()) - cnt++ - } else if _, ok := parentFks[field.ColumnName]; ok { - sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt)) - vals = append(vals, parentFks[field.ColumnName]) - cnt++ +func (m *Model) columnsWith(rel *Relationship) (cols []string, err error) { + for _, f := range m.Fields { + if f.ColumnType != "" { + cols = append(cols, f.ColumnName) } } - mcol, mpk := m.getPrimaryKey(val) - sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %v", m.TableName, strings.Join(sets, ", "), mcol, mpk) - fmt.Printf("[UPDATE] %s { %s }\n", sql, logTrunc(vals, 200)) - if !q.engine.dryRun { - if _, err = q.tx.Exec(q.ctx, sql, vals...); err != nil { - return err + for _, r2 := range m.Relationships { + if r2.Type == ManyToOne { + cols = append(cols, pascalToSnakeCase(r2.joinField())) } } - for _, rel := range m.Relationships { - if rel.Idx > -1 && rel.Idx < val.NumField() { - f := val.Field(rel.Idx) - if f.Kind() == reflect.Slice || - f.Kind() == reflect.Array { - err = diffSlices(q.engine, q, val, rel) - if err != nil { - return err - } - } else if rel.RelatedModel.embeddedIsh && !m.embeddedIsh { - elemish := f.Type() - for elemish.Kind() == reflect.Ptr { - elemish = elemish.Elem() - } - if elemish.Kind() == reflect.Struct { - err = rel.RelatedModel.update(f, q, map[string]any{ - pascalToSnakeCase(rel.JoinField()): mpk, - }) - if err != nil { - return err - } - } - } - } - } - var finerr error - if isTopLevel && !q.engine.dryRun { - finerr = q.tx.Commit(q.ctx) - } - return finerr + return } - -/*func (m *Model) ensure(q *Query, val reflect.Value) error { - if _, pk := m.getPrimaryKey(val); pk != nil && !reflect.ValueOf(pk).IsZero() { - return nil - } - -}*/ diff --git a/model_internals.go b/model_internals.go index 67d3f2c..838a253 100644 --- a/model_internals.go +++ b/model_internals.go @@ -22,7 +22,6 @@ func parseModel(model any) *Model { if !f.IsExported() { continue } - //minfo.Fields[f.Name] = parseField(f, minfo, i) } if minfo.TableName == "" { minfo.TableName = pascalToSnakeCase(t.Name()) @@ -35,41 +34,41 @@ func parseModelFields(model *Model, modelMap map[string]*Model) { for i := range t.NumField() { f := t.Field(i) fi := parseField(f, model, modelMap, i) - if fi != nil { + if fi != nil && (fi.ColumnType != "" || fi.isAnonymous()) { model.addField(fi) } } } -func makeModelMap(models ...any) *ModelMap { - modelMap := &ModelMap{ +func makeModelMap(models ...any) *internalModelMap { + modelMap := &internalModelMap{ Map: make(map[string]*Model), } //modelMap := make(map[string]*Model) for _, model := range models { minfo := parseModel(model) - // modelMap.Mux.Lock() + modelMap.Mux.Lock() modelMap.Map[minfo.Name] = minfo - // modelMap.Mux.Unlock() + modelMap.Mux.Unlock() } for _, model := range modelMap.Map { - // modelMap.Mux.Lock() + modelMap.Mux.Lock() parseModelFields(model, modelMap.Map) - // modelMap.Mux.Unlock() + modelMap.Mux.Unlock() } tagManyToMany(modelMap) for _, model := range modelMap.Map { - // modelMap.Mux.Lock() for _, ref := range model.Relationships { if ref.Type != ManyToMany && ref.Idx != -1 { + modelMap.Mux.Lock() addForeignKeyFields(ref) + modelMap.Mux.Unlock() } } - // modelMap.Mux.Unlock() } return modelMap } -func tagManyToMany(models *ModelMap) { +func tagManyToMany(models *internalModelMap) { hasManys := make(map[string]*Relationship) for _, model := range models.Map { for relName := range model.Relationships { @@ -77,7 +76,7 @@ func tagManyToMany(models *ModelMap) { } } for _, model := range models.Map { - // models.Mux.Lock() + models.Mux.Lock() for relName := range model.Relationships { mb := model.Relationships[relName].RelatedModel @@ -101,6 +100,6 @@ func tagManyToMany(models *ModelMap) { } } } - // models.Mux.Unlock() + models.Mux.Unlock() } } diff --git a/model_map.go b/model_map.go index 17effd3..95a16b7 100644 --- a/model_map.go +++ b/model_map.go @@ -4,7 +4,17 @@ import ( "sync" ) -type ModelMap struct { +type internalModelMap struct { Map map[string]*Model Mux sync.RWMutex } + +type engineHolder struct { + Engines map[string]*Engine + Mux sync.RWMutex +} + +var engines = &engineHolder{ + Engines: make(map[string]*Engine), + Mux: sync.RWMutex{}, +} diff --git a/model_migration.go b/model_migration.go index e1612f6..0e31bc3 100644 --- a/model_migration.go +++ b/model_migration.go @@ -6,40 +6,78 @@ import ( "strings" ) +type schemaSnapshotColumn struct { + Document `d:"table:__schemas"` + ID int64 `d:"pk"` + ModelName string + FieldName string + FieldType string + FieldIndex int + IsRelationship bool + IsSynthetic bool +} + +func (m *Model) toSnapshotColumns() (ssc []*schemaSnapshotColumn) { + for _, field := range m.Fields { + ssc = append(ssc, &schemaSnapshotColumn{ + ModelName: m.Name, + FieldName: field.Name, + FieldType: field.Type.String(), + FieldIndex: field.Index, + }) + } + for _, rel := range m.Relationships { + rt := rel.RelatedType + if rel.Kind == reflect.Slice { + rt = reflect.SliceOf(rel.RelatedType) + } + ssc = append(ssc, &schemaSnapshotColumn{ + ModelName: m.Name, + FieldName: rel.FieldName, + FieldType: rt.String(), + FieldIndex: rel.Idx, + IsRelationship: true, + IsSynthetic: rel.Idx < 0, + }) + } + return +} + func (m *Model) createTableSql() string { var fields []string - var fks []string for _, field := range m.Fields { - isStructOrSliceOfStructs := field.Type.Kind() == reflect.Struct || - ((field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) && - field.Type.Elem().Kind() == reflect.Struct) - if field.PrimaryKey { - fields = append(fields, fmt.Sprintf("%s %s PRIMARY KEY", field.ColumnName, field.ColumnType)) - } else if (field.fk != nil && field.fk.Type != HasMany && field.fk.Type != ManyToMany) && field.isForeignKey { - colType := serialToRegular(field.ColumnType) - if !field.Nullable { - colType += " NOT NULL " + if !field.isAnonymous() { + isStructOrSliceOfStructs := field.Type.Kind() == reflect.Struct || + ((field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) && + field.Type.Elem().Kind() == reflect.Struct) + if field.PrimaryKey { + fields = append(fields, fmt.Sprintf("%s %s PRIMARY KEY", field.ColumnName, field.ColumnType)) + } else if !isStructOrSliceOfStructs || field.ColumnType != "" { + lalala := fmt.Sprintf("%s %s", field.ColumnName, field.ColumnType) + if !field.Nullable { + lalala += " NOT NULL" + } + lalala += fmt.Sprintf(" DEFAULT %v", defaultColumnValue(field.Type)) + + fields = append(fields, lalala) } - ffk := field.fk.RelatedModel.Fields[field.fk.RelatedModel.IDField] - if ffk != nil { - fks = append(fks, fmt.Sprintf("%s %s REFERENCES %s(%s)", - field.ColumnName, colType, - field.fk.RelatedModel.TableName, - field.fk.RelatedModel.Fields[field.fk.RelatedModel.IDField].ColumnName)) + } else { + ft := field.Type + for ft.Kind() == reflect.Pointer { + ft = ft.Elem() } - } else if !isStructOrSliceOfStructs || field.ColumnType != "" { - lalala := fmt.Sprintf("%s %s", field.ColumnName, field.ColumnType) - if !field.Nullable { - lalala += " NOT NULL" + for i := range ft.NumField() { + efield := field.Type.Field(i) + ctype := columnType(efield.Type, false, false) + if ctype != "" { + def := fmt.Sprintf("%s %s NOT NULL DEFAULT %v", pascalToSnakeCase(efield.Name), ctype, defaultColumnValue(efield.Type)) + fields = append(fields, def) + } } - fields = append(fields, lalala) } } + inter := strings.Join(fields, ", ") - if len(fks) > 0 { - inter += ", " - inter += strings.Join(fks, ", ") - } return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);", m.TableName, inter) } @@ -50,26 +88,22 @@ func (m *Model) createJoinTableSql(relName string) string { return "" } aTable := m.TableName - joinTableName := ref.JoinTable() - fct := serialToRegular(ref.Model.Fields[ref.Model.IDField].ColumnType) - rct := serialToRegular(ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnType) - pkSection := fmt.Sprintf(",\nPRIMARY KEY (%s, %s_id)", - fmt.Sprintf("%s_%s", - aTable, pascalToSnakeCase(ref.FieldName), - ), + joinTableName := ref.ComputeJoinTable() + fct := serialToRegular(ref.primaryID().ColumnType) + rct := serialToRegular(ref.relatedID().ColumnType) + pkSection := fmt.Sprintf(",\nPRIMARY KEY (%s_id, %s_id)", + aTable, ref.RelatedModel.TableName, ) - if ref.Type == HasMany || ref.Type == ManyToMany { + if ref.m2mIsh() { pkSection = "" } return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( -%s %s REFERENCES %s(%s), -%s_id %s REFERENCES %s(%s)%s +%s_id %s REFERENCES %s(%s) ON DELETE CASCADE, +%s_id %s REFERENCES %s(%s) ON DELETE CASCADE %s );`, joinTableName, - fmt.Sprintf("%s_%s", - aTable, pascalToSnakeCase(ref.FieldName), - ), + ref.Model.TableName, fct, ref.Model.TableName, ref.Model.Fields[ref.Model.IDField].ColumnName, ref.RelatedModel.TableName, @@ -79,6 +113,47 @@ func (m *Model) createJoinTableSql(relName string) string { ) } +func (m *Model) generateConstraints(engine *Engine) error { + for _, rel := range m.Relationships { + field := rel.relatedID() + if rel.Type != ManyToMany && rel.Type != HasMany && !rel.m2mIsh() { + colType := serialToRegular(field.ColumnType) + if !field.Nullable && !rel.Nullable { + colType += " NOT NULL" + } + /*constraint := fmt.Sprintf("%s %s REFERENCES %s(%s)", + pascalToSnakeCase(rel.joinField()), colType, + rel.RelatedModel.TableName, + field.ColumnName) + if rel.Type != ManyToOne && rel.Type != BelongsTo { + constraint += " ON DELETE CASCADE ON UPDATE CASCADE" + }*/ + fk := fmt.Sprintf("fk_%s", pascalToSnakeCase(capitalizeFirst(rel.Model.Name)+rel.FieldName+rel.relatedID().Name)) + q := fmt.Sprintf(`ALTER TABLE %s +ADD COLUMN IF NOT EXISTS %s %s, +ADD CONSTRAINT %s +FOREIGN KEY (%s) REFERENCES %s(%s) +ON DELETE CASCADE +ON UPDATE CASCADE;`, + rel.Model.TableName, + pascalToSnakeCase(rel.joinField()), colType, + fk, + pascalToSnakeCase(rel.joinField()), + rel.RelatedModel.TableName, field.ColumnName, + ) + dq := fmt.Sprintf(`ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s;`, m.TableName, fk) + fmt.Printf("%s\n%s\n", dq, q) + if _, err := engine.conn.Exec(engine.ctx, dq); err != nil { + return err + } + if _, err := engine.conn.Exec(engine.ctx, q); err != nil { + return err + } + } + } + return nil +} + func (m *Model) migrate(engine *Engine) error { sql := m.createTableSql() fmt.Println(sql) @@ -89,11 +164,12 @@ func (m *Model) migrate(engine *Engine) error { } } for relName, rel := range m.Relationships { - relkey := rel.Model.Name + relkey := rel.ComputeJoinTable() if (rel.Type == ManyToMany && !engine.m2mSeen[relkey]) || (rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh && rel.Type == HasMany) { if rel.Type == ManyToMany { engine.m2mSeen[relkey] = true + engine.m2mSeen[rel.Model.Name] = true engine.m2mSeen[rel.RelatedModel.Name] = true } jtsql := m.createJoinTableSql(relName) @@ -106,5 +182,5 @@ func (m *Model) migrate(engine *Engine) error { } } } - return nil + return m.generateConstraints(engine) } diff --git a/model_misc.go b/model_misc.go deleted file mode 100644 index a6abba4..0000000 --- a/model_misc.go +++ /dev/null @@ -1,174 +0,0 @@ -package orm - -import ( - "context" - "fmt" - "reflect" -) - -func fetchJoinTableChildren(e *Engine, ctx context.Context, childRel *Relationship, pid any) (map[any]struct{}, error) { - rsql := fmt.Sprintf("SELECT %s_%s FROM %s where %s_id = $1", - childRel.Model.TableName, pascalToSnakeCase(childRel.FieldName), - childRel.JoinTable(), childRel.RelatedModel.TableName) - res := make(map[any]struct{}) - if !e.dryRun { - rows, err := e.conn.Query(ctx, rsql, pid) - defer rows.Close() - if err != nil { - return nil, err - } - for rows.Next() { - var id any - if err = rows.Scan(&id); err != nil { - return nil, err - } - res[id] = struct{}{} - } - } - return res, nil -} - -func fetchChildren(e *Engine, childRel *Relationship, pid any) (map[any]reflect.Value, error) { - res := make(map[any]reflect.Value) - qq := e.Model(reflect.New(childRel.RelatedModel.Type).Elem().Interface()) - rrel := childRel.RelatedModel.Relationships[childRel.Model.Name] - rfield := childRel.RelatedModel.Fields[rrel.FieldName] - /*if rrel == nil { - return res, fmt.Errorf("please report this, it shouldn't have happened :(") - }*/ - rawRows, err := qq.Where(fmt.Sprintf("%s.%s = $1", childRel.RelatedModel.TableName, rfield.ColumnName), pid).Find() - if err != nil { - return nil, err - } - inter := make([]any, 0) - rrv := reflect.ValueOf(rawRows) - for i := range rrv.Len() { - inter = append(inter, rrv.Index(i).Interface()) - } - for _, row := range inter { - v := reflect.ValueOf(row) - bv := v - for bv.Kind() == reflect.Ptr { - bv = bv.Elem() - } - id := v.FieldByName(childRel.RelatedModel.IDField).Interface() - res[id] = v - } - - return res, nil -} - -func preDiff(e *Engine, value reflect.Value) (*Model, error) { - ptype := value.Type() - for ptype.Kind() == reflect.Pointer { - ptype = ptype.Elem() - } - model, ok := e.modelMap.Map[ptype.Name()] - if !ok { - return nil, fmt.Errorf("model '%s' not found", ptype.Name()) - } - return model, nil -} - -func diffManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { - model, err := preDiff(e, value) - if err != nil { - return err - } - _, ppk := model.getPrimaryKey(value) - dbChildren, err := fetchChildren(e, rel, ppk) - if err != nil { - return err - } - memChildren := make(map[any]reflect.Value) - fv := value.FieldByName(rel.FieldName) - for i := range fv.Len() { - child := fv.Index(i) - _, cpk := rel.RelatedModel.getPrimaryKey(child) - if cpk != nil { - memChildren[cpk] = child - } - } - // deletions // - for pk := range dbChildren { - if _, found := memChildren[pk]; !found { - table := rel.RelatedModel.TableName - idField := rel.RelatedModel.Fields[rel.RelatedModel.IDField] - _, err = q.tx.Exec(q.ctx, fmt.Sprintf("DELETE FROM %s where %s = $1", table, idField.ColumnName), pk) - if err != nil { - return err - } - } - } - mField := model.Fields[model.IDField] - mpks := map[string]any{} - if !model.embeddedIsh { - mpks[mField.ColumnName] = ppk - } - // update || insert // - for i := range fv.Len() { - cur := fv.Index(i) - _, cpk := rel.RelatedModel.getPrimaryKey(cur) - if cpk == nil || reflect.ValueOf(cpk).IsZero() { - - _, err = rel.RelatedModel.insert(cur, q, mpks) - if err != nil { - return err - } - } else { - err = rel.RelatedModel.update(cur, q, mpks) - if err != nil { - return err - } - } - } - return nil -} - -func diffManyToManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { - model, err := preDiff(e, value) - if err != nil { - return err - } - _, ppk := model.getPrimaryKey(value) - ids, err := fetchJoinTableChildren(e, q.ctx, rel, ppk) - if err != nil { - return err - } - memIds := make(map[any]reflect.Value) - fv := value.FieldByName(rel.FieldName) - for i := range fv.Len() { - child := fv.Index(i) - _, cpk := rel.RelatedModel.getPrimaryKey(child) - if cpk != nil { - memIds[cpk] = child - } - } - for memId := range memIds { - if _, found := ids[memId]; !found { - err = rel.joinInsert(memIds[memId], q, ppk) - if err != nil { - return err - } - } - } - for id := range ids { - if _, found := memIds[id]; !found { - err = rel.joinDelete(ppk, id, q) - if err != nil { - return err - } - } - } - return nil -} - -func diffSlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { - if rel.Type == ManyToMany || rel.m2mIsh() { - return diffManyToManySlices(e, q, value, rel) - } - if rel.Type == HasMany { - return diffManySlices(e, q, value, rel) - } - return nil -} diff --git a/query.go b/query.go index fed34ef..31721b0 100644 --- a/query.go +++ b/query.go @@ -10,27 +10,19 @@ import ( ) type Query struct { - model *Model - relatedModels map[string]*Model - wheres map[string][]any - orders []string - limit int - offset int - joins map[*Relationship][3]string - engine *Engine - ctx context.Context - tx pgx.Tx + engine *Engine + model *Model + tx pgx.Tx + ctx context.Context + populationTree map[string]any + wheres map[string][]any + joins []string + orders []string + limit int + offset int } -func (q *Query) totalWheres() int { - total := 0 - for _, w := range q.wheres { - total += len(w) - } - return total -} - -func (q *Query) Model(val any) *Query { +func (q *Query) setModel(val any) *Query { tt := reflect.TypeOf(val) for tt.Kind() == reflect.Ptr { tt = tt.Elem() @@ -39,36 +31,9 @@ func (q *Query) Model(val any) *Query { return q } -func (q *Query) Where(cond any, args ...any) *Query { - switch v := cond.(type) { - case string: - q.wheres[strings.ReplaceAll(v, "$?", fmt.Sprintf("$%d", q.totalWheres()+1))] = args - default: - rv := reflect.ValueOf(cond) - for rv.Kind() == reflect.Ptr { - rv = rv.Elem() - } - rt := rv.Type() - for i := range rv.NumField() { - field := rt.Field(i) - fieldValue := rv.Field(i) - if isZero(fieldValue) { - continue - } - mm, ok := q.engine.modelMap.Map[rv.Type().Name()] - if !ok { - continue - } - ff, ok := mm.Fields[field.Name] - if !ok || ff.ColumnType == "" { - continue - } - whereClause := fmt.Sprintf("%s.%s = ?", mm.TableName, ff.ColumnName /*, q.totalWheres()+1*/) - args = append(args, fieldValue.Interface()) - q.wheres[whereClause] = args - } - } - return q +func (q *Query) cleanupTx() { + q.tx.Rollback(q.ctx) + q.tx = nil } func (q *Query) Order(order string) *Query { @@ -77,109 +42,233 @@ func (q *Query) Order(order string) *Query { } func (q *Query) Limit(limit int) *Query { - q.limit = limit + if limit > -1 { + q.limit = limit + } return q } func (q *Query) Offset(offset int) *Query { - q.offset = offset + if offset > -1 { + q.offset = offset + } return q } -func (q *Query) buildSelect() sb.SelectBuilder { - var fields []string - - for _, f := range q.model.Fields { - if f.ColumnType == "" { - continue - } - tn, a := f.alias() - fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) - } - seenModels := make(map[string]bool) - processField := func(f *Field, m *Model, pfk *Relationship) { - if f.ColumnType == "" { - if rel, ok := m.Relationships[f.Name]; ok { - data, ok2 := q.joins[rel] - if ok2 && f.ColumnType != "" { - tn, a := f.aliasWith(data[0]) - fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) - } - } - return - } - { - fk := f.fk - if fk == nil { - fk = m.Relationships[f.Name] - } - if fk == nil { - fk = pfk - } - if fk != nil && fk.FieldName == f.Name { - data, ok2 := q.joins[fk] - if ok2 { - var ( - tn, a string - ) - if fk.Type == HasOne { - tn, a = fk.RelatedModel.Fields[fk.RelatedModel.IDField].aliasWith(data[0]) - } else { - tn, a = f.aliasWith(data[0]) - } - fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) - } - return - } - } - if f.Name == pfk.FieldName { - f.aliasWith(q.joins[pfk][0]) - } else { - tn, a := f.alias() - fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) - } - - } - for r := range q.joins { - if !seenModels[r.aliasThingy()] { - seenModels[r.aliasThingy()] = true - for _, f := range r.Model.Fields { - processField(f, r.Model, r) - } - } - if !seenModels[r.relatedAlias()] { - seenModels[r.relatedAlias()] = true - for _, f := range r.RelatedModel.Fields { - processField(f, r.RelatedModel, r) - } - } - } - return sb.Select(fields...) +func (q *Query) Where(cond string, args ...any) *Query { + q.processWheres(cond, "eq", args...) + return q } -func (q *Query) buildSQL() (string, []any) { - sqlb := q.buildSelect().From(q.model.TableName) - whereargs := make([]any, 0) - if len(q.wheres) > 0 { - cnt := 0 - for w, where := range q.wheres { - sqlb = sqlb.Where(w, where...) +func (q *Query) WhereRaw(cond string, args ...any) *Query { + q.wheres[cond] = args + return q +} - whereargs = append(whereargs, where...) - cnt++ +func (q *Query) In(cond string, args ...any) *Query { + q.processWheres(cond, "in", args...) + return q +} + +func (q *Query) Join(field string) *Query { + var clauses []string + parts := strings.Split(field, ".") + cur := q.model + found := false + aliasMap := q.getNestedAliases(field) + + for _, part := range parts { + rel, ok := cur.Relationships[part] + if !ok { + found = false + break + } + if rel.FieldName != part { + found = false + break + } + found = true + aliases := aliasMap[rel] + curAlias := aliases[0] + nalias := aliases[1] + if rel.m2mIsh() || rel.Type == ManyToMany { + joinAlias := aliases[2] + jc1 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s_id", + rel.ComputeJoinTable(), joinAlias, + curAlias, cur.idField().ColumnName, + joinAlias, rel.Model.TableName, + ) + jc2 := fmt.Sprintf("%s AS %s ON %s.%s_id = %s.%s", + rel.RelatedModel.TableName, nalias, + joinAlias, rel.RelatedModel.TableName, + nalias, rel.relatedID().ColumnName, + ) + clauses = append(clauses, jc1, jc2) + } + if rel.Type == HasMany || rel.Type == HasOne { + fkr := rel.RelatedModel.Relationships[cur.Name] + if fkr != nil { + jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", + rel.RelatedModel.TableName, nalias, + curAlias, cur.idField().ColumnName, + nalias, pascalToSnakeCase(fkr.joinField()), + ) + clauses = append(clauses, jc) + } + } + if rel.Type == BelongsTo { + jc := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", + rel.RelatedModel.TableName, nalias, + curAlias, pascalToSnakeCase(rel.joinField()), + nalias, rel.RelatedModel.idField().ColumnName, + ) + clauses = append(clauses, jc) + } + curAlias = nalias + cur = rel.RelatedModel + } + if found { + q.joins = append(q.joins, clauses...) + } + return q +} + +func (q *Query) getNestedAliases(field string) (amap map[*Relationship][]string) { + amap = make(map[*Relationship][]string) + parts := strings.Split(field, ".") + cur := q.model + curAlias := q.model.TableName + first := curAlias + found := false + + for _, part := range parts { + rel, ok := cur.Relationships[part] + if !ok { + found = false + break + } + if rel.FieldName != part { + found = false + break + } + found = true + amap[rel] = make([]string, 0) + + nalias := pascalToSnakeCase(part) + if rel.m2mIsh() || rel.Type == ManyToMany { + joinAlias := rel.ComputeJoinTable() + "_joined" + amap[rel] = append(amap[rel], curAlias, nalias, joinAlias) + } else if rel.Type == HasMany || rel.Type == HasOne || rel.Type == BelongsTo { + amap[rel] = append(amap[rel], curAlias, nalias) + } + + curAlias = nalias + cur = rel.RelatedModel + } + if !found { + return + } + amap[nil] = []string{first} + return +} + +func (q *Query) processWheres(cond string, exprKind string, args ...any) { + parts := strings.SplitN(cond, " ", 2) + var translatedColumn string + fieldPath := parts[0] + ncond := "" + if len(parts) > 1 { + ncond = " " + parts[1] + } + pathParts := strings.Split(fieldPath, ".") + if len(pathParts) > 1 { + relPath := pathParts[:len(pathParts)-1] + fieldName := pathParts[len(pathParts)-1] + relPathStr := strings.Join(relPath, ".") + aliasMap := q.getNestedAliases(relPathStr) + for r, a := range aliasMap { + if r == nil { + continue + } + f, ok := r.RelatedModel.Fields[fieldName] + if ok { + translatedColumn = fmt.Sprintf("%s.%s", a[1], f.ColumnName) + } + } + } else if pf := q.model.Fields[pathParts[0]]; pf != nil { + translatedColumn = fmt.Sprintf("%s.%s", q.model.TableName, pf.ColumnName) + } + var tq string + switch strings.ToLower(exprKind) { + case "in": + tq = fmt.Sprintf("%s IN (%s)", translatedColumn, MakePlaceholders(len(args))) + default: + tq = fmt.Sprintf("%s%s", translatedColumn, ncond) + } + q.wheres[tq] = args +} + +func (q *Query) buildSQL() (cols []string, anonymousCols map[string][]string, finalSb sb.SelectBuilder, err error) { + var inParents []any + anonymousCols = make(map[string][]string) + for _, field := range q.model.Fields { + if field.isAnonymous() { + for _, ef := range field.embeddedFields { + anonymousCols[field.ColumnName] = append(anonymousCols[field.ColumnName], ef.ColumnName) + } + continue + } + cols = append(cols, field.ColumnName) + } + finalSb = sb.Select(cols...) + for _, cc := range anonymousCols { + finalSb = finalSb.Columns(cc...) + } + finalSb = finalSb.From(q.model.TableName) + if len(q.joins) > 0 { + idq := sb.Select(fmt.Sprintf("%s.%s", q.model.TableName, q.model.idField().ColumnName)). + Distinct(). + From(q.model.TableName) + for w, arg := range q.wheres { + idq = idq.Where(w, arg...) + } + for _, j := range q.joins { + idq = idq.Join(j) + } + qq, qa := idq.MustSQL() + var rows pgx.Rows + rows, err = q.engine.conn.Query(q.ctx, qq, qa...) + if err != nil { + return + } + defer rows.Close() + for rows.Next() { + var id any + if err = rows.Scan(&id); err != nil { + return + } + inParents = append(inParents, id) + } + if len(inParents) == 0 { + return } } - for _, j := range q.joins { - sqlb = sqlb.LeftJoin(fmt.Sprintf("%s as %s ON %s", j[1], j[0], j[2])) + if len(inParents) > 0 { + finalSb = finalSb.Where( + fmt.Sprintf("%s IN (%s)", + q.model.idField().ColumnName, + MakePlaceholders(len(inParents))), inParents...) + } else if len(q.wheres) > 0 { + for k, vv := range q.wheres { + finalSb = finalSb.Where(k, vv...) + } } ool: for _, o := range q.orders { ac, ok := q.model.Fields[o] if !ok { - var rel = ac.fk - if ac.ColumnType == "" || rel == nil { - rel = q.model.Relationships[o] - } + var rel = q.model.Relationships[o] + if rel != nil { if strings.Contains(o, ".") { split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".") @@ -202,10 +291,13 @@ ool: } } } - sqlb = sqlb.OrderBy(ac.ColumnName) + finalSb = finalSb.OrderBy(ac.ColumnName) } if q.limit > 0 { - sqlb = sqlb.Limit(uint64(q.limit)) + finalSb = finalSb.Limit(uint64(q.limit)) } - return sqlb.MustSQL() + if q.offset > 0 { + finalSb = finalSb.Offset(uint64(q.offset)) + } + return } diff --git a/query_populate.go b/query_populate.go index e6f7e0f..fe80327 100644 --- a/query_populate.go +++ b/query_populate.go @@ -2,164 +2,396 @@ package orm import ( "fmt" + sb "github.com/henvic/pgq" + "github.com/jackc/pgx/v5" + "reflect" "strings" ) const PopulateAll = "~~~ALL~~~" -func join(r *Relationship) (string, string, string) { - rtable := r.RelatedModel.TableName - field := r.Model.Fields[r.FieldName] - var fk, pk, alias string - if !r.RelatedModel.embeddedIsh && !r.Model.embeddedIsh { - alias = pascalToSnakeCase(field.Name) - fk = fmt.Sprintf("%s.%s", alias, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName) - pk = fmt.Sprintf("%s.%s", r.Model.TableName, field.ColumnName) - alias = pascalToSnakeCase(field.Name) - } else if !r.Model.embeddedIsh { - alias = pascalToSnakeCase(r.FieldName) - sid := strings.TrimSuffix(r.JoinField(), "ID") - fk = fmt.Sprintf("%s.%s", alias, r.RelatedModel.Fields[sid].ColumnName) - pk = fmt.Sprintf("%s.%s", r.Model.TableName, r.Model.Fields[r.Model.IDField].ColumnName) +func (q *Query) Populate(fields ...string) *Query { + if q.populationTree == nil { + q.populationTree = make(map[string]any) } - return alias, rtable, fmt.Sprintf("%s = %s", fk, pk) -} - -func m2mJoin(r *Relationship) [][3]string { - result := make([][3]string, 0) - jt := r.JoinTable() - first := [3]string{ - pascalToSnakeCase(r.FieldName), - jt, - fmt.Sprintf("%s = %s", - fmt.Sprintf("%s.%s", - jt, - r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName, - ), - fmt.Sprintf("%s.%s", - r.Model.TableName, - r.Model.Fields[r.Model.IDField].ColumnName, - ), - ), - } - second := [3]string{ - pascalToSnakeCase(r.m2mInverse.FieldName), - r.RelatedModel.TableName, - fmt.Sprintf("%s = %s", - fmt.Sprintf("%s.%s", r.RelatedModel.TableName, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName), - fmt.Sprintf("%s.%s", jt, r.m2mInverse.JoinField()), - ), - } - - /*first := fmt.Sprintf("%s AS %s ON %s = %s", - jt, - fmt.Sprintf("%s.%s", - jt, - r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName, - ), - fmt.Sprintf("%s.%s", - r.Model.TableName, - r.Model.Fields[r.Model.IDField].ColumnName, - ), - ) - second := fmt.Sprintf("%s ON %s = %s", - r.RelatedModel.TableName, - fmt.Sprintf("%s.%s", r.RelatedModel.TableName, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName), - fmt.Sprintf("%s.%s", jt, r.m2mInverse.JoinField()), - )*/ - result = append(result, first, second) - return result -} - -func nestedJoin(m *Model, path string) (joins [][3]string, ree []*Relationship) { - splitPath := strings.Split(path, ".") - prevModel := m - for _, f := range splitPath { - rel, ok := m.Relationships[f] - if !ok { - break - } - var ( - fk, pk string - ) - if !rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh { - pk = prevModel.Fields[rel.FieldName].ColumnName - fk = rel.RelatedModel.Fields[rel.RelatedModel.IDField].ColumnName - } else if !rel.Model.embeddedIsh { - pk = prevModel.Fields[prevModel.IDField].ColumnName - fk = rel.RelatedModel.Fields[strings.TrimSuffix(rel.JoinField(), "ID")].ColumnName - } - ree = append(ree, rel) - j2 := [3]string{ - rel.Model.Fields[rel.FieldName].ColumnName, - rel.RelatedModel.TableName, - fmt.Sprintf("%s.%s = %s.%s", - rel.RelatedModel.TableName, fk, - prevModel.TableName, pk), - } - - /*j2 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", - rel.RelatedModel.TableName, - rel.Model.Fields[rel.FieldName].ColumnName, - rel.RelatedModel.TableName, fk, - prevModel.TableName, pk, - )*/ - joins = append(joins, j2) - prevModel = rel.RelatedModel - } - return -} - -func (q *Query) Populate(relation string) *Query { - if relation == PopulateAll { - for _, rel := range q.model.Relationships { - if rel.Type == ManyToMany { - mjs := m2mJoin(rel) - for _, ajoin := range mjs { - q.joins[rel] = [3]string{ - ajoin[0], - ajoin[1], - ajoin[2], - } + for _, field := range fields { + if field == PopulateAll { + for k := range q.model.Relationships { + if _, ok := q.populationTree[k]; !ok { + q.populationTree[k] = make(map[string]any) } - //q.joins = append(q.joins, m2mJoin(rel)...) - } else { - alias, tn, cond := join(rel) - q.joins[rel] = [3]string{ - alias, - tn, - cond, - } - //q.joins = append(q.joins, tn) } - q.relatedModels[rel.RelatedModel.Name] = rel.RelatedModel + continue } - } else { - if strings.Contains(relation, ".") { - njs, rmodels := nestedJoin(q.model, relation) - for i, m := range rmodels { - curTuple := njs[i] - q.joins[m] = [3]string{ - curTuple[0], - curTuple[1], - curTuple[2], - } - q.relatedModels[m.Model.Name] = m.Model - } - //q.joins = append(q.joins, njs...) - } else { - rel, ok := q.model.Relationships[relation] - if ok { - alias, tn, j := join(rel) - q.joins[rel] = [3]string{ - alias, - tn, - j, - } - //q.joins = append(q.joins, tn) - q.relatedModels[rel.RelatedModel.Name] = rel.RelatedModel + cur := q.populationTree + parts := strings.Split(field, ".") + for _, part := range parts { + if _, ok := cur[part]; !ok { + cur[part] = make(map[string]any) } + cur = cur[part].(map[string]any) } } return q } + +func (q *Query) processPopulate(parent reflect.Value, model *Model, populationTree map[string]any) error { + if parent.Len() == 0 { + return nil + } + pids := make([]any, 0) + var err error + idField := model.IDField + for i := range parent.Len() { + pval := parent.Index(i) + if pval.Kind() == reflect.Pointer { + pval = pval.Elem() + } + pids = append(pids, pval.FieldByName(idField).Interface()) + } + toClose := make([]pgx.Rows, 0) + defer func() { + for _, c := range toClose { + c.Close() + } + }() + for p, nested := range populationTree { + var rel *Relationship + for _, r := range model.Relationships { + if r.FieldName == p { + rel = r + break + } + } + if rel == nil { + return fmt.Errorf("field '%s' not found in model '%s'", p, model.Name) + } + + childSlice := reflect.Value{} + + if (rel.Type == HasMany || rel.Type == HasOne) && !rel.m2mIsh() { + childSlice, err = q.populateHas(rel, parent, pids) + } else if rel.Type == BelongsTo { + childSlice, err = q.populateBelongsTo(rel, parent, pids) + } else if rel.Type == ManyToMany || rel.m2mIsh() { + childSlice, err = q.populateManyToMany(rel, parent, pids) + } + if err != nil { + return fmt.Errorf("failed to populate field at '%s': %w", p, err) + } + ntree, ok := nested.(map[string]any) + if ok && len(ntree) > 0 && childSlice.IsValid() && childSlice.Len() > 0 { + if err = q.processPopulate(childSlice, rel.RelatedModel, ntree); err != nil { + return err + } + } + } + + return nil +} + +func (q *Query) populateHas(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) { + fkf := rel.primaryID() + var fk string + if fkf != nil && fkf.ColumnType != "" { + fk = fkf.ColumnName + } else if rel.relatedID() != nil { + fk = pascalToSnakeCase(rel.RelatedModel.Name + rel.relatedID().Name) + } + + if rel.RelatedModel.embeddedIsh && !rel.Model.embeddedIsh && rel.Type == HasMany { + arel := rel.RelatedModel.Relationships[rel.Model.Name] + fk = pascalToSnakeCase(arel.joinField()) + } + ccols := make([]string, 0) + anonymousCols := make(map[string]map[string]*Field) + for _, f := range rel.RelatedModel.Fields { + if !f.isAnonymous() { + ccols = append(ccols, f.ColumnName) + } + } + for _, f := range rel.RelatedModel.Fields { + if f.isAnonymous() { + ccols = append(ccols, f.anonymousColumnNames()...) + anonymousCols[f.Name] = f.embeddedFields + } + } + for _, r := range rel.RelatedModel.Relationships { + if r.Type != ManyToOne { + continue + } + ccols = append(ccols, pascalToSnakeCase(r.joinField())) + } + /*var tableName string + if rel.Type == HasOne { + tableName = rel.Model.TableName + } + if rel.Type == HasMany { + tableName = rel.RelatedModel.TableName + }*/ + aq, aa := sb.Select(ccols...). + From(rel.RelatedModel.TableName). + Where(fmt.Sprintf("%s IN (%s)", fk, MakePlaceholders(len(parentIds))), parentIds...).MustSQL() + fmt.Printf("[POPULATE] %s %+v\n", aq, aa) + rows, err := q.engine.conn.Query(q.ctx, aq, aa...) + if err != nil { + return reflect.Value{}, err + } + defer rows.Close() + idFieldName := rel.Model.IDField + idField := rel.Model.Fields[idFieldName] + if rel.Type == HasMany { + childMap := reflect.MakeMap(reflect.MapOf( + idField.Type, + reflect.SliceOf(rel.RelatedModel.Type), + )) + for rows.Next() { + child := reflect.New(rel.RelatedModel.Type).Elem() + var fkValue any + scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &fkValue) + if err = rows.Scan(scanDest...); err != nil { + return reflect.Value{}, err + } + fkVal := reflect.ValueOf(fkValue) + childrenOfParent := childMap.MapIndex(fkVal) + if !childrenOfParent.IsValid() { + childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0) + } + childrenOfParent = reflect.Append(childrenOfParent, child) + childMap.SetMapIndex(fkVal, childrenOfParent) + } + for i := range parent.Len() { + ps := parent.Index(i) + if ps.Kind() == reflect.Pointer { + ps = ps.Elem() + } + pid := ps.FieldByName(idFieldName) + c := childMap.MapIndex(pid) + if c.IsValid() { + ps.FieldByName(rel.FieldName).Set(c) + } + } + } else { + childMap := reflect.MakeMap(reflect.MapOf(idField.Type, rel.RelatedModel.Type)) + for rows.Next() { + child := reflect.New(rel.RelatedModel.Type).Elem() + var fkValue any + scanDest, _ := buildScanDest(child, rel.Model, rel, ccols, anonymousCols, &fkValue) + if err = rows.Scan(scanDest...); err != nil { + return reflect.Value{}, err + } + fkVal := reflect.ValueOf(fkValue) + childMap.SetMapIndex(fkVal, child) + } + for i := range parent.Len() { + ps := parent.Index(i) + if ps.Kind() == reflect.Pointer { + ps = ps.Elem() + } + parentID := ps.FieldByName(idFieldName) + if child := childMap.MapIndex(parentID); child.IsValid() { + ps.FieldByName(rel.FieldName).Set(child) + } + } + } + + childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0) + for i := range parent.Len() { + ps := parent.Index(i) + if ps.Kind() == reflect.Ptr { + ps = ps.Elem() + } + childField := ps.FieldByName(rel.FieldName) + if !childField.IsValid() { + continue + } + if rel.Type == HasMany { + for j := range childField.Len() { + childSlice = reflect.Append(childSlice, childField.Index(j).Addr()) + } + } else { + if !childField.IsZero() { + childSlice = reflect.Append(childSlice, childField.Addr()) + } + } + } + return childSlice, nil +} + +func (q *Query) populateManyToMany(rel *Relationship, parent reflect.Value, parentIds []any) (reflect.Value, error) { + inPlaceholders := MakePlaceholders(len(parentIds)) + ccols := make([]string, 0) + anonymousCols := make(map[string]map[string]*Field) + for _, f := range rel.RelatedModel.Fields { + if !f.isAnonymous() { + ccols = append(ccols, "m."+f.ColumnName) + } + } + for _, f := range rel.RelatedModel.Fields { + if f.isAnonymous() { + for ecol := range f.embeddedFields { + ccols = append(ccols, "m."+ecol) + } + anonymousCols[f.Name] = f.embeddedFields + } + } + ccols = append(ccols, fmt.Sprintf("jt.%s_id", rel.Model.TableName)) + + mq, ma := sb.Select(ccols...). + From(fmt.Sprintf("%s AS m", rel.RelatedModel.TableName)). + Join( + fmt.Sprintf("%s AS jt ON m.%s = jt.%s_id", + rel.ComputeJoinTable(), + rel.relatedID().ColumnName, rel.RelatedModel.TableName)). + Where(fmt.Sprintf("jt.%s_id IN (%s)", + rel.Model.TableName, inPlaceholders), parentIds...).MustSQL() + fmt.Printf("[POPULATE/JOIN] %s %+v\n", mq, ma) + rows, err := q.engine.conn.Query(q.ctx, mq, ma...) + if err != nil { + return reflect.Value{}, err + } + defer rows.Close() + idFieldName := rel.Model.IDField + idField := rel.Model.Fields[idFieldName] + childMap := reflect.MakeMap(reflect.MapOf( + idField.Type, + reflect.SliceOf(rel.RelatedModel.Type))) + for rows.Next() { + child := reflect.New(rel.RelatedModel.Type).Elem() + var foreignKeyValue any + scanDest, _ := buildScanDest(child, rel.RelatedModel, rel, ccols, anonymousCols, &foreignKeyValue) + if err = rows.Scan(scanDest...); err != nil { + return reflect.Value{}, err + } + fkVal := reflect.ValueOf(foreignKeyValue) + childrenOfParent := childMap.MapIndex(fkVal) + if !childrenOfParent.IsValid() { + childrenOfParent = reflect.MakeSlice(reflect.SliceOf(rel.RelatedModel.Type), 0, 0) + } + childrenOfParent = reflect.Append(childrenOfParent, child) + childMap.SetMapIndex(fkVal, childrenOfParent) + } + for i := range parent.Len() { + p := parent.Index(i) + if p.Kind() == reflect.Ptr { + p = p.Elem() + } + parentID := p.FieldByName(rel.primaryID().Name) + if children := childMap.MapIndex(parentID); children.IsValid() { + p.FieldByName(rel.FieldName).Set(children) + } + } + childSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.RelatedModel.Type)), 0, 0) + for i := range parent.Len() { + ps := parent.Index(i) + if ps.Kind() == reflect.Ptr { + ps = ps.Elem() + } + childField := ps.FieldByName(rel.FieldName) + if childField.IsValid() { + for j := range childField.Len() { + childSlice = reflect.Append(childSlice, childField.Index(j).Addr()) + } + } + } + return childSlice, nil +} + +func (q *Query) populateBelongsTo(rel *Relationship, childrenSlice reflect.Value, childIDs []any) (reflect.Value, error) { + childIdField := rel.Model.Fields[rel.Model.IDField] + parentIdField := rel.RelatedModel.Fields[rel.RelatedModel.IDField] + fk := pascalToSnakeCase(rel.joinField()) + qs, qa := sb.Select(childIdField.ColumnName, fk). + From(rel.Model.TableName). + Where(fmt.Sprintf("%s IN (%s)", + childIdField.ColumnName, MakePlaceholders(len(childIDs)), + ), childIDs...).MustSQL() + fmt.Printf("[POPULATE/BELONGS-TO] %s %+v\n", qs, qa) + rows, err := q.engine.conn.Query(q.ctx, qs, qa...) + if err != nil { + return reflect.Value{}, err + } + childParentKeyMap := make(map[any]any) + parentKeyValues := make([]any, 0) + parentKeySet := make(map[any]bool) + for rows.Next() { + var cid, pfk any + err = rows.Scan(&cid, &pfk) + if err != nil { + rows.Close() + return reflect.Value{}, err + } + childParentKeyMap[cid] = pfk + if !parentKeySet[pfk] { + parentKeySet[pfk] = true + parentKeyValues = append(parentKeyValues, pfk) + } + } + rows.Close() + if len(parentKeyValues) == 0 { + return reflect.Value{}, nil + } + pcols := make([]string, 0) + anonymousCols := make(map[string]map[string]*Field) + for _, f := range rel.RelatedModel.Fields { + if !f.isAnonymous() { + pcols = append(pcols, f.ColumnName) + } + } + for _, f := range rel.RelatedModel.Fields { + if f.isAnonymous() { + pcols = append(pcols, f.anonymousColumnNames()...) + anonymousCols[f.Name] = f.embeddedFields + } + } + pquery, pqargs := sb.Select(pcols...). + From(rel.RelatedModel.TableName). + Where(fmt.Sprintf("%s IN (%s)", + parentIdField.ColumnName, + MakePlaceholders(len(parentKeyValues))), parentKeyValues...). + MustSQL() + fmt.Printf("[POPULATE/BELONGS-TO->PARENT] %s %+v\n", pquery, pqargs) + parentRows, err := q.engine.conn.Query(q.ctx, pquery, pqargs...) + if err != nil { + return reflect.Value{}, err + } + defer parentRows.Close() + parentMap := reflect.MakeMap(reflect.MapOf( + parentIdField.Type, + rel.RelatedModel.Type, + )) + for parentRows.Next() { + parent := reflect.New(rel.RelatedModel.Type).Elem() + scanDst, _ := buildScanDest(parent, rel.RelatedModel, rel, pcols, anonymousCols, nil) + if err = parentRows.Scan(scanDst...); err != nil { + return reflect.Value{}, err + } + parentId := parent.FieldByName(rel.RelatedModel.IDField) + parentMap.SetMapIndex(parentId, parent) + } + for i := range childrenSlice.Len() { + child := childrenSlice.Index(i) + childID := child.FieldByName(rel.Model.IDField) + if parentKey, ok := childParentKeyMap[childID.Interface()]; ok && parentKey != nil { + if parent := parentMap.MapIndex(reflect.ValueOf(parentKey)); parent.IsValid() { + child.FieldByName(rel.FieldName).Set(parent) + } + } + } + ntype := rel.RelatedModel.Type + if rel.Kind == reflect.Pointer { + ntype = reflect.PointerTo(rel.RelatedModel.Type) + } + parentSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(ntype)), 0, 0) + for i := range childrenSlice.Len() { + ps := childrenSlice.Index(i) + if ps.Kind() == reflect.Ptr { + ps = ps.Elem() + } + childField := ps.FieldByName(rel.FieldName) + if childField.IsValid() { + parentSlice = reflect.Append(parentSlice, childField.Addr()) + } + } + return parentSlice, nil +} diff --git a/query_tail.go b/query_tail.go index 0b33e4c..ea496e4 100644 --- a/query_tail.go +++ b/query_tail.go @@ -1,42 +1,361 @@ package orm import ( - "errors" "fmt" + sb "github.com/henvic/pgq" + "github.com/jackc/pgx/v5" "reflect" + "time" ) -func (q *Query) Create(val any) (any, error) { - _, err := q.model.insert(reflect.ValueOf(val), q, make(map[string]any)) - return val, err -} +func (q *Query) Find(dest any) error { + dstVal := reflect.ValueOf(dest) + if dstVal.Kind() != reflect.Ptr { + return fmt.Errorf("destination must be a pointer, got: %v", dstVal.Kind()) + } + maybeSlice := dstVal.Elem() + cols, acols, sqlb, err := q.buildSQL() + if err != nil { + return err + } + qq, qa := sqlb.MustSQL() + fmt.Printf("[FIND] %s %+v\n", qq, qa) -func (q *Query) Find() (any, error) { - sql, args := q.buildSQL() - fmt.Printf("[FIND] %s { %+v }\n", sql, args) - if !q.engine.dryRun { - rows, err := q.engine.conn.Query(q.ctx, sql, args...) + if maybeSlice.Kind() == reflect.Struct { + row := q.engine.conn.QueryRow(q.ctx, qq, qa...) + if err = scanRow(row, cols, acols, maybeSlice, q.model); err != nil { + return err + } + } else if maybeSlice.Kind() == reflect.Slice || + maybeSlice.Kind() == reflect.Array { + var rows pgx.Rows + rows, err = q.engine.conn.Query(q.ctx, qq, qa...) + if err != nil { + return err + } defer rows.Close() - if err != nil { - return nil, err + + etype := maybeSlice.Type().Elem() + for rows.Next() { + nelem := reflect.New(etype).Elem() + if err = scanRow(rows, cols, acols, nelem, q.model); err != nil { + return err + } + maybeSlice.Set(reflect.Append(maybeSlice, nelem)) } - rmaps, err := rowsToMaps(rows) - if err != nil { - return nil, err - } - wtype := q.model.Type - for wtype.Kind() == reflect.Pointer { - wtype = wtype.Elem() - } - return fillSlice(rmaps, wtype, q.engine.modelMap), nil + } else { + return fmt.Errorf("unsupported destination type: %s", maybeSlice.Kind()) } - return make([]any, 0), nil + if len(q.populationTree) > 0 { + nslice := maybeSlice + var wasPassedStruct bool + if nslice.Kind() == reflect.Struct { + nslice = reflect.MakeSlice(reflect.SliceOf(maybeSlice.Type()), 0, 0) + wasPassedStruct = true + nslice = reflect.Append(nslice, maybeSlice) + } + err = q.processPopulate(nslice, q.model, q.populationTree) + if err == nil && wasPassedStruct { + maybeSlice.Set(nslice.Index(0)) + } + return err + } + return nil } -func (q *Query) Update(val any, cond any, args ...any) error { - if q.model != nil { - return q.model.update(reflect.ValueOf(val), q, make(map[string]any)) - } else { - return errors.New("Please select a model") - } +func (q *Query) Save(val any) error { + return q.saveOrCreate(val, false) +} + +func (q *Query) Create(val any) error { + return q.saveOrCreate(val, true) +} + +// UpdateRaw - takes a mapping of struct field names to +// SQL expressions, updating each field's associated column accordingly +func (q *Query) UpdateRaw(values map[string]any) (int64, error) { + var err error + var subQuery sb.SelectBuilder + stmt := sb.Update(q.model.TableName) + _, _, subQuery, err = q.buildSQL() + if err != nil { + return 0, err + } + subQuery = sb.Select(q.model.idField().ColumnName).FromSelect(subQuery, "subQuery") + stmt = stmt.Where(wrapQueryIn(subQuery, + q.model.idField().ColumnName)) + for k, v := range values { + asString, isString := v.(string) + if f, ok := q.model.Fields[k]; ok { + if isString { + stmt = stmt.Set(f.ColumnName, sb.Expr(asString)) + } else { + stmt = stmt.Set(f.ColumnName, v) + } + } + if _, ok := q.model.FieldsByColumnName[k]; ok { + if isString { + stmt = stmt.Set(k, sb.Expr(asString)) + } else { + stmt = stmt.Set(k, v) + } + } + } + sql, args := stmt.MustSQL() + fmt.Printf("[UPDATE/RAW] %s %+v\n", sql, args) + q.tx, err = q.engine.conn.Begin(q.ctx) + if err != nil { + return 0, err + } + defer q.cleanupTx() + + ctag, err := q.tx.Exec(q.ctx, sql, args...) + if err != nil { + return 0, err + } + return ctag.RowsAffected(), q.tx.Commit(q.ctx) +} + +func (q *Query) Delete() (int64, error) { + var err error + var subQuery sb.SelectBuilder + if len(q.wheres) < 1 { + return 0, ErrNoConditionOnDeleteOrUpdate + } + q.tx, err = q.engine.conn.Begin(q.ctx) + if err != nil { + return 0, err + } + defer q.cleanupTx() + _, _, subQuery, err = q.buildSQL() + if err != nil { + return 0, err + } + + sqlb := sb.Delete(q.model.TableName).Where(subQuery) + sql, sqla := sqlb.MustSQL() + fmt.Printf("[DELETE] %s %+v\n", sql, sqla) + cmdTag, err := q.tx.Exec(q.ctx, sql, sqla...) + if err != nil { + return 0, fmt.Errorf("failed to delete: %w", err) + } + return cmdTag.RowsAffected(), nil +} + +func (q *Query) saveOrCreate(val any, shouldCreate bool) error { + v := reflect.ValueOf(val) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return fmt.Errorf("Save() must be called with a pointer to a struct") + } + var err error + q.tx, err = q.engine.conn.BeginTx(q.ctx, pgx.TxOptions{ + AccessMode: pgx.ReadWrite, + IsoLevel: pgx.ReadUncommitted, + }) + if err != nil { + return err + } + defer q.cleanupTx() + if _, err = q.doSave(v.Elem(), q.engine.modelMap.Map[v.Elem().Type().Name()], nil, shouldCreate); err != nil { + return err + } + return q.tx.Commit(q.ctx) +} + +func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any, shouldInsert bool) (any, error) { + idField := model.Fields[model.IDField] + var pkField reflect.Value + if val.Kind() == reflect.Pointer { + if !val.Elem().IsValid() || val.Elem().IsZero() { + return nil, nil + } + pkField = val.Elem().FieldByName(model.IDField) + } else { + pkField = val.FieldByName(model.IDField) + } + isNew := pkField.IsZero() + var exists bool + if !pkField.IsZero() { + eb := sb.Select("1"). + Prefix("SELECT EXISTS ("). + From(model.TableName). + Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()). + Suffix(")") + ebs, eba := eb.MustSQL() + var ex bool + err := q.tx.QueryRow(q.ctx, ebs, eba...).Scan(&ex) + if err != nil { + q.engine.logger.Warn("error while checking existence", "err", err.Error()) + } + exists = ex + } + /*{ + el, ok := q.seenIds[model] + if !ok { + q.seenIds[model] = make(map[any]bool) + } + if ok && el[pkField.Interface()] { + return pkField.Interface(), nil + } + if !isNew { + q.seenIds[model][pkField.Interface()] = true + } + }*/ + doInsert := isNew || !exists + var cols []string + args := make([]any, 0) + seenJoinTables := make(map[string]map[any]bool) + for _, rel := range model.Relationships { + if rel.Type != BelongsTo { + continue + } + parentVal := val.FieldByName(rel.FieldName) + if parentVal.IsValid() { + nid, err := q.doSave(parentVal, rel.RelatedModel, nil, rel.RelatedModel.needsPrimaryKey(parentVal) && isNew) + if err != nil { + return nil, err + } + cols = append(cols, pascalToSnakeCase(rel.joinField())) + args = append(args, nid) + } else if parentVal.IsValid() { + _, nid := rel.RelatedModel.getPrimaryKey(parentVal) + cols = append(cols, pascalToSnakeCase(rel.joinField())) + args = append(args, nid) + } + } + for _, ff := range model.Fields { + var fv reflect.Value + if ff.Index > -1 && !ff.isAnonymous() { + fv = val.Field(ff.Index) + } else if ff.Index > -1 { + for col, ef := range ff.embeddedFields { + fv = val.Field(ff.Index) + cols = append(cols, col) + eif := fv.FieldByName(ef.Name) + if ff.Name == documentField && canConvertTo[Document](ff.Type) { + asTime, ok := eif.Interface().(time.Time) + shouldCreate := ok && (asTime.IsZero() || eif.IsZero()) + if doInsert && ef.Name == createdField && shouldCreate { + eif.Set(reflect.ValueOf(time.Now())) + } else if ef.Name == modifiedField || shouldCreate { + eif.Set(reflect.ValueOf(time.Now())) + } + args = append(args, eif.Interface()) + continue + } + args = append(args, fv.FieldByName(ef.Name).Interface()) + } + continue + } + if ff.Name == model.IDField { + if !isNew && fv.IsValid() { + cols = append(cols, ff.ColumnName) + args = append(args, fv.Interface()) + } + continue + } + + if fv.IsValid() { + cols = append(cols, ff.ColumnName) + args = append(args, fv.Interface()) + } + } + for k, fk := range parentFks { + cols = append(cols, k) + args = append(args, fk) + } + var qq string + var qa []any + if doInsert { + osb := sb.Insert(model.TableName) + if len(cols) == 0 { + qq = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", model.TableName, idField.ColumnName) + } else { + osb = osb.Columns(cols...).Values(args...) + qq, qa = osb.Returning(idField.ColumnName).MustSQL() + } + } else { + osb := sb.Update(model.TableName) + for i := range cols { + osb = osb.Set(cols[i], args[i]) + } + osb = osb.Where(fmt.Sprintf("%s = ?", idField.ColumnName), pkField.Interface()) + qq, qa = osb.MustSQL() + } + if doInsert { + var nid any + fmt.Printf("[INSERT] %s %+v\n", qq, qa) + row := q.tx.QueryRow(q.ctx, qq, qa...) + err := row.Scan(&nid) + if err != nil { + return nil, fmt.Errorf("insert failed for model %s: %w", model.Name, err) + } + pkField.Set(reflect.ValueOf(nid)) + } else { + fmt.Printf("[UPDATE] %s %+v\n", qq, qa) + _, err := q.tx.Exec(q.ctx, qq, qa...) + if err != nil { + return nil, fmt.Errorf("update failed for model %s: %w", model.Name, err) + } + } + /*if _, ok := q.seenIds[model]; !ok { + q.seenIds[model] = make(map[any]bool) + } + q.seenIds[model][pkField.Interface()] = true*/ + for _, rel := range model.Relationships { + + if rel.Idx > -1 && rel.Idx < val.NumField() { + fv := val.FieldByName(rel.FieldName) + cm := rel.RelatedModel + pfks := map[string]any{} + if !model.embeddedIsh && rel.Type == HasMany { + { + rm := cm.Relationships[model.Name] + if rm != nil && rm.Type == ManyToOne { + pfks[pascalToSnakeCase(rm.joinField())] = pkField.Interface() + } + } + for j := range fv.Len() { + child := fv.Index(j).Addr().Elem() + if _, err := q.doSave(child, cm, pfks, cm.needsPrimaryKey(child)); err != nil { + return nil, err + } + } + } else if rel.Type == HasOne && cm.embeddedIsh { + if _, err := q.doSave(fv, cm, pfks, cm.needsPrimaryKey(fv)); err != nil { + return nil, err + } + } else if rel.m2mIsh() || rel.Type == ManyToMany || (model.embeddedIsh && cm.embeddedIsh && rel.Type == HasMany) { + if seenJoinTables[rel.ComputeJoinTable()] == nil { + seenJoinTables[rel.ComputeJoinTable()] = make(map[any]bool) + } + if !seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] { + seenJoinTables[rel.ComputeJoinTable()][pkField.Interface()] = true + if err := rel.joinDelete(pkField.Interface(), nil, q); err != nil { + return nil, fmt.Errorf("error deleting existing association: %w", err) + } + } + if fv.Kind() == reflect.Slice || fv.Kind() == reflect.Array { + mField := model.Fields[model.IDField] + mpks := map[string]any{} + if !model.embeddedIsh { + mpks[model.TableName+"_"+mField.ColumnName] = pkField.Interface() + } + for i := range fv.Len() { + cur := fv.Index(i) + if _, err := q.doSave(cur, cm, mpks, cm.needsPrimaryKey(cur) && pkField.IsZero()); err != nil { + return nil, err + } + + if rel.m2mIsh() || rel.Type == ManyToMany { + if err := rel.joinInsert(cur, q, pkField.Interface()); err != nil { + return nil, fmt.Errorf("failed to insert association for model %s: %w", model.Name, err) + } + } + } + } + } + } + } + return pkField.Interface(), nil } diff --git a/relationship.go b/relationship.go index 10565ff..a4a3970 100644 --- a/relationship.go +++ b/relationship.go @@ -2,6 +2,7 @@ package orm import ( "fmt" + sb "github.com/henvic/pgq" "reflect" "strings" ) @@ -11,11 +12,14 @@ type RelationshipType int const ( HasOne RelationshipType = iota HasMany + BelongsTo + ManyToOne ManyToMany ) type Relationship struct { Type RelationshipType + JoinTable string Model *Model FieldName string Idx int @@ -23,34 +27,49 @@ type Relationship struct { RelatedModel *Model Kind reflect.Kind // field kind (struct, slice, ...) m2mInverse *Relationship + Nullable bool + OriginalField reflect.StructField } -func (r *Relationship) JoinTable() string { - return r.Model.TableName + "_" + r.RelatedModel.TableName +func (r *Relationship) ComputeJoinTable() string { + if r.JoinTable != "" { + return r.JoinTable + } + otherSide := r.RelatedModel.TableName + if r.Model.embeddedIsh { + otherSide = pascalToSnakeCase(r.FieldName) + } + return r.Model.TableName + "_" + otherSide } -func (r *Relationship) JoinField() string { - isMany := r.Type == HasMany //|| r.Type == ManyToMany - if isMany && r.Model.embeddedIsh { - return r.Model.Name + r.FieldName + "ID" - } else if isMany && !r.Model.embeddedIsh && r.m2mInverse == nil { - return r.Model.Name + "ID" - } else if r.Type == ManyToMany && !r.Model.embeddedIsh { +func (r *Relationship) relatedID() *Field { + return r.RelatedModel.Fields[r.RelatedModel.IDField] +} + +func (r *Relationship) primaryID() *Field { + return r.Model.Fields[r.Model.IDField] +} + +func (r *Relationship) joinField() string { + if r.Type == ManyToOne { + return r.RelatedModel.Name + "ID" + } + if r.Type == ManyToMany && !r.Model.embeddedIsh { return r.RelatedModel.Name + "ID" } return r.FieldName + "ID" } -func (r *Relationship) aliasThingy() string { - return pascalToSnakeCase(r.Model.Name + "." + r.FieldName) -} - -func (r *Relationship) relatedAlias() string { - return pascalToSnakeCase(r.RelatedModel.Name + "." + r.FieldName) -} - func (r *Relationship) m2mIsh() bool { - return r.Model.embeddedIsh && !r.RelatedModel.embeddedIsh && r.Type == HasMany + needsMany := false + if !r.Model.embeddedIsh && r.RelatedModel.embeddedIsh { + rr, ok := r.RelatedModel.Relationships[r.Model.Name] + if ok && rr.Type != ManyToOne { + needsMany = true + } + } + return ((r.Model.embeddedIsh && !r.RelatedModel.embeddedIsh) || needsMany) && + r.Type == HasMany } func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error { @@ -63,11 +82,11 @@ func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error { ichild = ichild.Elem() } if ichild.Kind() == reflect.Struct { - jtable := r.JoinTable() + jtable := r.ComputeJoinTable() jargs := make([]any, 0) jcols := make([]string, 0) - jcols = append(jcols, fmt.Sprintf("%s_%s", - r.Model.TableName, pascalToSnakeCase(r.FieldName), + jcols = append(jcols, fmt.Sprintf("%s_id", + r.Model.TableName, )) jargs = append(jargs, pfk) @@ -75,12 +94,12 @@ func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error { jargs = append(jargs, ichild.FieldByName(r.RelatedModel.IDField).Interface()) var ecnt int e.tx.QueryRow(e.ctx, - fmt.Sprintf("SELECT count(*) from %s where %s = $1 and %s = $2", r.JoinTable(), jcols[0], jcols[1]), jargs...).Scan(&ecnt) + fmt.Sprintf("SELECT count(*) from %s where %s = $1 and %s = $2", r.ComputeJoinTable(), jcols[0], jcols[1]), jargs...).Scan(&ecnt) if ecnt > 0 { return nil } jsql := fmt.Sprintf("INSERT INTO %s (%s) VALUES ($1, $2)", jtable, strings.Join(jcols, ", ")) - fmt.Printf("[INSERT/JOIN] %s { %s }\n", jsql, logTrunc(jargs, 200)) + e.engine.logQuery("insert/join", jsql, jargs) if !e.engine.dryRun { _ = e.tx.QueryRow(e.ctx, jsql, jargs...).Scan() } @@ -89,25 +108,28 @@ func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error { } func (r *Relationship) joinDelete(pk, fk any, q *Query) error { - jc := fmt.Sprintf("%s_%s", r.Model.TableName, pascalToSnakeCase(r.FieldName)) - ds := fmt.Sprintf("DELETE FROM %s where %s = $1 and %s = $2", - r.JoinTable(), jc, r.RelatedModel.TableName+"_id") - fmt.Printf("[DELETE/JOIN] %s { %s }\n", ds, logTrunc([]any{pk, fk}, 200)) + dq := sb.Delete(r.ComputeJoinTable()).Where(fmt.Sprintf("%s_id = ?", r.Model.TableName), pk) + if fk != nil { + dq = dq.Where(fmt.Sprintf("%s_id = ?", r.RelatedModel.TableName), fk) + } + ds, aa := dq.MustSQL() + fmt.Printf("[DELETE/JOIN] %s %+v \n", ds, logTrunc(200, aa)) if !q.engine.dryRun { - _, err := q.tx.Exec(q.ctx, ds, pk, fk) + _, err := q.tx.Exec(q.ctx, ds, aa...) return err } return nil } -func parseRelationship(field reflect.StructField, modelMap map[string]*Model, outerType reflect.Type, idx int) *Relationship { +func parseRelationship(field reflect.StructField, modelMap map[string]*Model, outerType reflect.Type, idx int, settings map[string]string) *Relationship { rel := &Relationship{ - Model: modelMap[outerType.Name()], - RelatedModel: modelMap[field.Type.Name()], - RelatedType: field.Type, - Idx: idx, - Kind: field.Type.Kind(), - FieldName: field.Name, + Model: modelMap[outerType.Name()], + RelatedModel: modelMap[field.Type.Name()], + RelatedType: field.Type, + Idx: idx, + Kind: field.Type.Kind(), + FieldName: field.Name, + OriginalField: field, } if rel.RelatedType.Kind() == reflect.Slice || rel.RelatedType.Kind() == reflect.Array { rel.RelatedType = rel.RelatedType.Elem() @@ -116,6 +138,7 @@ func parseRelationship(field reflect.StructField, modelMap map[string]*Model, ou if rel.RelatedType.Name() == "" { rt := rel.RelatedType for rt.Kind() == reflect.Ptr || rt.Kind() == reflect.Slice || rt.Kind() == reflect.Array { + rel.Nullable = true rel.RelatedType = rel.RelatedType.Elem() rt = rel.RelatedType } @@ -131,79 +154,38 @@ func parseRelationship(field reflect.StructField, modelMap map[string]*Model, ou switch field.Type.Kind() { case reflect.Struct: rel.Type = HasOne - case reflect.Slice: + case reflect.Slice, reflect.Array: rel.Type = HasMany } + maybeM2m := settings["m2m"] + if maybeM2m == "" { + maybeM2m = settings["manytomany"] + } + if rel.Type == HasMany && maybeM2m != "" { + rel.JoinTable = maybeM2m + } return rel } func addForeignKeyFields(ref *Relationship) { - rf := ref.RelatedModel.Fields[ref.RelatedModel.IDField] - if rf != nil { - if !ref.RelatedModel.embeddedIsh && !ref.Model.embeddedIsh { - ff := ref.Model.Fields[ref.FieldName] - ff.ColumnType = rf.ColumnType - ff.ColumnName = pascalToSnakeCase(ref.JoinField()) - ff.isForeignKey = true - ff.fk = ref - } else if !ref.Model.embeddedIsh { - sid := strings.TrimSuffix(ref.JoinField(), "ID") - ref.RelatedModel.Relationships[sid] = &Relationship{ - FieldName: sid, - Type: HasOne, + if !ref.RelatedModel.embeddedIsh && !ref.Model.embeddedIsh { + ref.Type = BelongsTo + } else if !ref.Model.embeddedIsh && ref.RelatedModel.embeddedIsh { + + if ref.Type == HasMany { + nr := &Relationship{ RelatedModel: ref.Model, Model: ref.RelatedModel, Kind: ref.RelatedModel.Type.Kind(), Idx: -1, RelatedType: ref.Model.Type, } - ref.RelatedModel.addField(&Field{ - ColumnType: rf.ColumnType, - ColumnName: pascalToSnakeCase(ref.JoinField()), - Name: sid, - isForeignKey: true, - Type: rf.Type, - Index: -1, - fk: ref.RelatedModel.Relationships[sid], - }) - } else if ref.Model.embeddedIsh && !ref.RelatedModel.embeddedIsh { - + nr.Type = ManyToOne + nr.FieldName = nr.RelatedModel.Name + ref.RelatedModel.Relationships[nr.FieldName] = nr + } else if ref.Type == HasOne { + ref.Type = BelongsTo } - } else { - ref.RelatedModel.addField(&Field{ - Name: "ID", - ColumnName: "id", - ColumnType: "bigserial", - PrimaryKey: true, - Type: ref.RelatedType, - Index: -1, - AutoIncrement: true, - }) - ff := ref.Model.Fields[ref.FieldName] - ff.ColumnType = "bigint" - ff.ColumnName = pascalToSnakeCase(ref.RelatedModel.Name + "ID") - ff.isForeignKey = true - ff.fk = ref - ref.RelatedModel.IDField = "ID" - /* - nn := ref.Model.Name + "ID" - ref.RelatedModel.Relationships[ref.Model.Name] = &Relationship{ - Type: HasOne, - RelatedModel: ref.Model, - Model: ref.RelatedModel, - Idx: 65536, - Kind: ref.RelatedModel.Type.Kind(), - RelatedType: ref.RelatedModel.Type, - FieldName: nn, - } - ref.RelatedModel.addField(&Field{ - Name: nn, - Type: ref.Model.Type, - fk: ref.RelatedModel.Relationships[nn], - ColumnName: pascalToSnakeCase(nn), - Index: -1, - ColumnType: ref.Model.Fields[ref.Model.IDField].ColumnType, - isForeignKey: true, - })*/ + } else if ref.Model.embeddedIsh && !ref.RelatedModel.embeddedIsh { } } diff --git a/scan.go b/scan.go index 084f0a1..aa99be5 100644 --- a/scan.go +++ b/scan.go @@ -3,83 +3,51 @@ package orm import ( "github.com/jackc/pgx/v5" "reflect" + "strings" ) -func rowsToMaps(rows pgx.Rows) ([]map[string]any, error) { - var result []map[string]any - fieldDescs := rows.FieldDescriptions() - for rows.Next() { - m := make(map[string]any) - scanArgs := make([]any, len(fieldDescs)) - for i := range fieldDescs { - var v any - scanArgs[i] = &v - } - if err := rows.Scan(scanArgs...); err != nil { - return nil, err - } - for i, fd := range fieldDescs { - name := fd.Name - m[name] = *(scanArgs[i].(*any)) - } - result = append(result, m) - } - return result, rows.Err() -} +func buildScanDest(val reflect.Value, model *Model, fk *Relationship, cols []string, anonymousCols map[string]map[string]*Field, fkDest any) ([]any, error) { + var dest []any -func fillNested(row map[string]any, t reflect.Type, mm *ModelMap, depth, maxDepth int) any { - cm := mm.Map[t.Name()] - pp := reflect.New(cm.Type).Elem() - for _, field := range cm.Fields { - _, alias := field.alias() - if v, ok := row[alias]; ok && !field.isForeignKey { - reflectSet(pp.Field(field.Index), v) + for _, col := range cols { + bcol := col + if strings.Contains(bcol, ".") { + _, bcol, _ = strings.Cut(bcol, ".") + } + field := model.FieldsByColumnName[bcol] + if field != nil && !field.isAnonymous() { + dest = append(dest, val.FieldByName(field.Name).Addr().Interface()) } } - for _, rel := range cm.Relationships { - if rel.Idx > -1 && rel.Idx < pp.NumField() { - relType := rel.RelatedModel.Type - for relType.Kind() == reflect.Pointer { - relType = relType.Elem() - } - nv := reflect.New(relType) - if rel.Kind == reflect.Struct || rel.Kind == reflect.Pointer { - if depth < maxDepth { - nv = reflect.ValueOf(fillNested(row, relType, mm, depth+1, maxDepth)) - } - if rel.Kind != reflect.Pointer && nv.Kind() == reflect.Pointer { - nv = nv.Elem() - } - reflectSet(pp.Field(rel.Idx), nv) - } else if rel.Kind == reflect.Slice || rel.Kind == reflect.Array { - relType2 := relType - for relType2.Kind() == reflect.Slice || relType2.Kind() == reflect.Array { - relType2 = relType2.Elem() - } - if depth < maxDepth { - nv = reflect.ValueOf(fillNested(row, relType2, mm, depth+1, maxDepth)) - } - if nv.Kind() == reflect.Pointer { - nv = nv.Elem() - } - reflectSet(pp.Field(rel.Idx), reflect.Append(pp.Field(rel.Idx), nv)) + for fn, a := range anonymousCols { + iv := val.FieldByName(fn) + for _, field := range a { + dest = append(dest, iv.FieldByName(field.Name).Addr().Interface()) + } + } + + if fk.Type != BelongsTo { + dest = append(dest, fkDest) + } + + return dest, nil +} +func scanRow(row pgx.Row, cols []string, anonymousCols map[string][]string, destVal reflect.Value, m *Model) error { + var scanDest []any + for _, col := range cols { + f := m.FieldsByColumnName[col] + if f != nil && f.ColumnType != "" && !f.isAnonymous() { + scanDest = append(scanDest, destVal.FieldByIndex(f.Original.Index).Addr().Interface()) + } + } + for kcol := range anonymousCols { + f := m.FieldsByColumnName[kcol] + if f != nil { + for _, ef := range f.embeddedFields { + scanDest = append(scanDest, destVal.FieldByIndex(f.Original.Index).FieldByName(ef.Name).Addr().Interface()) } } } - return pp.Interface() -} -// fillSlice - note that it's the caller's responsibility to indirect -// the type in `t`. -func fillSlice(rows []map[string]any, t reflect.Type, mm *ModelMap) any { - pslice := reflect.MakeSlice(reflect.SliceOf(t), 0, 0) - rt := t - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - for _, row := range rows { - pp := fillNested(row, rt, mm, 0, 10) - pslice = reflect.Append(pslice, reflect.ValueOf(pp)) - } - return pslice.Interface() + return row.Scan(scanDest...) } diff --git a/test_main.go b/test_main.go index bea347d..2535b83 100644 --- a/test_main.go +++ b/test_main.go @@ -2,13 +2,15 @@ package orm import "fmt" -const do_bootstrap = true +const do_bootstrap = false func TestMain() { e, err := Open("postgres://testbed_user:123@localhost/testbed_i_think") if err != nil { panic(err) } + u := author() + s := iti_multi(u) e.Models(user{}, story{}, band{}, role{}) if do_bootstrap { @@ -16,38 +18,45 @@ func TestMain() { if err != nil { panic(err) } - s := iti_multi() - u := &author - u.Favs.Authors = append(u.Favs.Authors, friend) - _, err = e.Model(&user{}).Create(&friend) + f := friend() + err = e.Model(&user{}).Create(&f) if err != nil { panic(err) } - _, err = e.Model(&user{}).Create(u) + err = e.Model(&user{}).Create(&u) if err != nil { panic(err) } - _, err = e.Model(&band{}).Create(&bodom) + u.Favs.Authors = append(u.Favs.Authors, f) + err = e.Model(&user{}).Save(&u) if err != nil { panic(err) } - _, err = e.Model(&band{}).Create(&diamondHead) + err = e.Model(&band{}).Create(&bodom) if err != nil { panic(err) } - _, err = e.Model(&story{}).Create(s) - if err != nil { - panic(err) - } - s.Downloads = s.Downloads + 1 - err = e.Model(&story{}).Update(s, nil) + err = e.Model(&band{}).Create(&diamondHead) if err != nil { panic(err) } } - ns, err := e.Model(&story{}).Where(&story{ - ID: 1, - }).Populate(PopulateAll).Find() + + err = e.Model(&user{}).Where("ID = ?", s.Author.ID).Find(&s.Author) + if err != nil { + panic(err) + } + err = e.Model(&story{}).Save(s) + if err != nil { + panic(err) + } + s.Downloads = s.Downloads + 1 + err = e.Model(&story{}).Save(s) + if err != nil { + panic(err) + } + var ns story + err = e.Model(&story{}).Where("ID = ?", 1).Populate(PopulateAll, "Chapters.Bands").Find(&ns) if err != nil { panic(err) } diff --git a/testing.go b/testing.go index 401177c..c9afd35 100644 --- a/testing.go +++ b/testing.go @@ -2,8 +2,11 @@ package orm import ( "fmt" + "github.com/stretchr/testify/assert" "math/rand/v2" + "slices" "strings" + "testing" "time" "github.com/go-loremipsum/loremipsum" @@ -18,7 +21,7 @@ type chapter struct { Genre []string `json:"genre" form:"genre" d:"type:text[]"` Bands []band `json:"bands" ref:"band,bands"` Characters []string `json:"characters" form:"characters" d:"type:text[]"` - Relationships [][]string `json:"relationships" form:"relationships" d:"type:text[][]"` + Relationships [][]string `json:"relationships" form:"relationships" d:"type:jsonb"` Adult bool `json:"adult" form:"adult"` Summary string `json:"summary" form:"summary"` Hidden bool `json:"hidden" form:"hidden"` @@ -40,16 +43,17 @@ type user struct { ID int64 `json:"_id" d:"pk;"` Username string `json:"username"` Favs favs `json:"favs" ref:"user"` - Roles []role + Roles []role `d:"m2m:user_roles"` } type role struct { ID int64 `d:"pk"` Name string - Users []user + Users []user `d:"m2m:user_roles"` } type favs struct { + ID int64 `d:"pk"` Stories []story Authors []user } @@ -73,14 +77,17 @@ type somethingWithNestedChapters struct { NestedText string `json:"text" gridfs:"nested_text,/nested/{{.ID}}.txt"` } -var friend = user{ - Username: "DarQuiel7", - ID: 83378, +func friend() user { + return user{ + Username: "DarQuiel7", + ID: 83378, + } } - -var author = user{ - Username: "tablet.exe", - ID: 85783, +func author() user { + return user{ + Username: "tablet.exe", + ID: 85783, + } } func genChaps(single bool, aceil int) []chapter { @@ -91,37 +98,84 @@ func genChaps(single bool, aceil int) []chapter { } else { ceil = aceil } - emptyRel := make([][]string, 0) - emptyRel = append(emptyRel, make([]string, 0)) - relMap := [][][]string{ + + relMap := make([][][]string, 0) + bands := make([][]band, 0) + charMap := make([][]string, 0) + for i := range ceil { + curChars := make([]string, 0) + curBands := make([]band, 0) + curBands = append(curBands, diamondHead) + curChars = append(curChars, diamondHead.Characters...) { - {"Sean Harris", "Brian Tatler"}, - }, + randMin := max(i+1, 1) + randMax := min(i+1, randMin+1) + mod1 := max(rand.IntN(randMin), 1) + mod2 := max(rand.IntN(randMax+1), 1) + if (mod1%mod2 == 0 || (mod1%mod2) == 2) && i > 0 { + curBands = append(curBands, bodom) + curChars = append(curChars, bodom.Characters...) + } + } + crel := make([][]string, 0) + numRels := rand.IntN(3) + seenRels := make(map[string]bool) + for len(crel) <= numRels { + arel := make([]string, 0) + randRelChars := rand.IntN(3) + numRelChars := 0 + if randRelChars == 1 { + numRelChars = 3 + } else if randRelChars == 2 { + numRelChars = 2 + } + if numRelChars == 0 { + continue + } + seen := make(map[string]bool) + for len(arel) < numRelChars { + char := diamondHead.Characters[rand.IntN(len(diamondHead.Characters))] + if !seen[char] { + arel = append(arel, char) + seen[char] = true + } + + } + slices.Sort(arel) + maybeSeen := strings.Join(arel, "/") + if maybeSeen != "" && !seenRels[maybeSeen] { + seenRels[maybeSeen] = true + crel = append(crel, arel) + } + } { - {"Sean Harris", "Brian Tatler"}, - {"Duncan Scott", "Colin Kimberley"}, - }, - { - {"Duncan Scott", "Colin Kimberley"}, - }, - emptyRel, - { - {"Sean Harris", "Colin Kimberley", "Brian Tatler"}, - }, + numChars := rand.IntN(len(curChars)-1) + 1 + seen := make(map[string]bool) + cchars := make([]string, 0) + for len(cchars) <= numChars { + char := curChars[rand.IntN(len(curChars))] + if !seen[char] { + cchars = append(cchars, char) + seen[char] = true + } + } + charMap = append(charMap, cchars) + } + relMap = append(relMap, crel) + bands = append(bands, curBands) } l := loremipsum.New() - for i := 0; i < ceil; i++ { + for i := range ceil { spf := fmt.Sprintf("%d.md", i+1) c := chapter{ - ChapterID: int64(i + 1), Title: fmt.Sprintf("-%d-", i+1), Index: i + 1, Words: 50, Notes: "notenotenote !!!", Genre: []string{"Slash"}, - Bands: []band{diamondHead}, - Characters: []string{"Sean Harris", "Brian Tatler", "Duncan Scott", "Colin Kimberley"}, + Bands: bands[i], + Characters: charMap[i], Relationships: relMap[i], Adult: true, Summary: l.Paragraph(), @@ -131,15 +185,6 @@ func genChaps(single bool, aceil int) []chapter { Text: strings.Join(l.ParagraphList(10), "\n\n"), Posted: time.Now().Add(time.Hour * time.Duration(int64(24*7*i))), } - { - randMin := max(i+1, 1) - randMax := min(i+1, randMin+1) - mod1 := max(rand.IntN(randMin), 1) - mod2 := max(rand.IntN(randMax+1), 1) - if (mod1%mod2 == 0 || (mod1%mod2) == 2) && i > 0 { - c.Bands = append(c.Bands, bodom) - } - } ret = append(ret, c) } @@ -155,26 +200,26 @@ func doSomethingWithNested() somethingWithNestedChapters { } return swnc } -func iti_single() *story { +func iti_single(a user) *story { return &story{ Title: "title", Completed: true, - Author: author, - Chapters: genChaps(true, 0), + Author: a, + Chapters: genChaps(true, 1), } } -func iti_multi() *story { +func iti_multi(a user) *story { return &story{ Title: "Brian Tatler Fucked and Abused Sean Harris", Completed: false, - Author: author, + Author: a, Chapters: genChaps(false, 5), } } -func iti_blank() *story { - t := iti_single() +func iti_blank(a user) *story { + t := iti_single(a) t.Chapters = make([]chapter, 0) return t } diff --git a/utils.go b/utils.go index 0439454..f819574 100644 --- a/utils.go +++ b/utils.go @@ -2,6 +2,7 @@ package orm import ( "fmt" + sb "github.com/henvic/pgq" "reflect" "regexp" "strings" @@ -59,6 +60,11 @@ func isZero(v reflect.Value) bool { } return v.IsZero() } + +func checkInsertable(v reflect.Value) { + +} + func reflectSet(f reflect.Value, v any) { if !f.CanSet() || v == nil { return @@ -82,14 +88,46 @@ func reflectSet(f reflect.Value, v any) { } } -func logTrunc(v any, length int) string { +func logTrunc(length int, v []any) []any { if length < 5 { length = 5 } - str := fmt.Sprintf("%+v", v) - trunced := str[:min(length, len(str))] - if len(trunced) < len(str) { - trunced += "..." + trunced := make([]any, 0) + for _, it := range v { + if str, ok := it.(string); ok { + ntrunc := str[:min(length, len(str))] + if len(ntrunc) < len(str) { + ntrunc += "..." + } + trunced = append(trunced, ntrunc) + } else { + trunced = append(trunced, it) + } } + return trunced } + +func isSliceOfStructs(rv reflect.Value) bool { + return rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() == reflect.Struct +} + +// MakePlaceholders - generates a string with `count` +// occurences of a placeholder (`?`), delimited by a +// comma and a space +func MakePlaceholders(count int) string { + if count < 1 { + return "" + } + var ph []string + for range count { + ph = append(ph, "?") + } + return strings.Join(ph, ", ") +} + +func wrapQueryIn(s sb.SelectBuilder, idName string) sb.SelectBuilder { + return s.Prefix( + fmt.Sprintf("%s in (", + idName)).Suffix(")") +}