diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml new file mode 100644 index 0000000..5792f72 --- /dev/null +++ b/.idea/dataSources.xml @@ -0,0 +1,15 @@ + + + + + postgresql + true + org.postgresql.Driver + jdbc:postgresql://localhost:5432/testbed_i_think + + + + $ProjectFileDir$ + + + \ No newline at end of file diff --git a/diamond.go b/diamond.go new file mode 100644 index 0000000..b3f9337 --- /dev/null +++ b/diamond.go @@ -0,0 +1,81 @@ +package orm + +import ( + "context" + "github.com/jackc/pgx/v5/pgxpool" + "time" +) + +type Engine struct { + modelMap *ModelMap + conn *pgxpool.Pool + m2mSeen map[string]bool + dryRun bool + cfg *pgxpool.Config + ctx context.Context +} + +func (e *Engine) Models(v ...any) { + e.modelMap = makeModelMap(v...) +} + +func (e *Engine) Model(val any) *Query { + qq := &Query{ + engine: e, + ctx: context.Background(), + wheres: make(map[string][]any), + joins: make(map[*Relationship][3]string), + orders: make([]string, 0), + relatedModels: make(map[string]*Model), + } + return qq.Model(val) +} + +func (e *Engine) Migrate() error { + failedMigrations := make(map[string]*Model) + var err error + for mk, m := range e.modelMap.Map { + err = m.migrate(e) + if err != nil { + failedMigrations[mk] = m + } + } + + for len(failedMigrations) > 0 { + e.m2mSeen = make(map[string]bool) + for mk, m := range failedMigrations { + err = m.migrate(e) + if err == nil { + delete(failedMigrations, mk) + } + } + } + return err +} + +func Open(connString string) (*Engine, error) { + e := &Engine{ + modelMap: &ModelMap{ + Map: make(map[string]*Model), + }, + m2mSeen: make(map[string]bool), + dryRun: connString == "", + ctx: context.Background(), + } + if connString != "" { + var err error + e.cfg, err = pgxpool.ParseConfig(connString) + e.cfg.MinConns = 5 + e.cfg.MaxConns = 10 + e.cfg.MaxConnIdleTime = time.Minute * 2 + if err != nil { + return nil, err + } + e.conn, err = pgxpool.NewWithConfig(e.ctx, e.cfg) + if err != nil { + return nil, err + } + + } + return e, nil +} diff --git a/field.go b/field.go new file mode 100644 index 0000000..85f24e8 --- /dev/null +++ b/field.go @@ -0,0 +1,133 @@ +package orm + +import ( + "fmt" + "net" + "reflect" + "time" +) + +type Field struct { + Name string + ColumnName string + ColumnType string + Type reflect.Type + Original reflect.StructField + Model *Model + Index int + AutoIncrement bool + PrimaryKey bool + Nullable bool + isForeignKey bool + fk *Relationship +} + +func (f *Field) alias() (string, string) { + columnName := f.Model.Fields[f.Model.IDField].ColumnName + if f.ColumnType != "" { + columnName = f.ColumnName + } + return fmt.Sprintf("%s.%s", f.Model.TableName, columnName), fmt.Sprintf("%s_%s", f.Model.TableName, f.ColumnName) +} + +func (f *Field) aliasWith(a string) (string, string) { + first := fmt.Sprintf("%s.%s", a, f.ColumnName) + second := fmt.Sprintf("%s_%s", a, f.ColumnName) + return first, second +} + +func (f *Field) key() string { + return fmt.Sprintf("%s.%s", f.Model.Name, f.Name) +} + +func columnType(ty reflect.Type, isPk, isAutoInc bool) string { + it := ty + switch it.Kind() { + case reflect.Ptr: + for it.Kind() == reflect.Ptr { + it = it.Elem() + } + case reflect.Int32, reflect.Uint32: + if isPk || isAutoInc { + return "serial" + } else { + return "int" + } + case reflect.Int64, reflect.Uint64, reflect.Int, reflect.Uint: + if isPk || isAutoInc { + return "bigserial" + } else { + return "bigint" + } + case reflect.String: + return "text" + case reflect.Float32: + return "float4" + case reflect.Float64: + return "double precision" + case reflect.Bool: + return "boolean" + case reflect.Struct: + if canConvertTo[time.Time](ty) { + return "timestamptz" + } + if canConvertTo[net.IP](ty) { + return "inet" + } + if canConvertTo[net.IPNet](ty) { + return "cidr" + } + + default: + return "" + } + return "" +} +func parseField(f reflect.StructField, minfo *Model, modelMap map[string]*Model, i int) *Field { + field := &Field{ + Name: f.Name, + Original: f, + Index: i, + } + tags := parseTags(f.Tag.Get("d")) + if tags["-"] != "" { + return nil + } + field.PrimaryKey = tags["pk"] != "" || tags["primarykey"] != "" || field.Name == "ID" + field.AutoIncrement = tags["autoinc"] != "" + field.Nullable = tags["nullable"] != "" + field.ColumnType = tags["type"] + if field.ColumnType == "" { + field.ColumnType = columnType(f.Type, field.PrimaryKey, field.AutoIncrement) + } + field.ColumnName = tags["column"] + if field.ColumnName == "" { + field.ColumnName = pascalToSnakeCase(field.Name) + } + if field.PrimaryKey { + minfo.IDField = field.Name + } + elem := f.Type + for elem.Kind() == reflect.Ptr { + if !field.Nullable { + field.Nullable = true + } + elem = elem.Elem() + } + field.Type = elem + + switch elem.Kind() { + case reflect.Array, reflect.Slice: + elem = elem.Elem() + fallthrough + case reflect.Struct: + if canConvertTo[Document](elem) && f.Anonymous { + minfo.TableName = tags["table"] + return nil + } else if field.ColumnType == "" { + minfo.Relationships[field.Name] = parseRelationship(f, modelMap, minfo.Type, i) + } + } + + return field +} diff --git a/go.mod b/go.mod index 25a0f46..0fbb4f4 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/henvic/pgq v0.0.4 // indirect github.com/huandu/xstrings v1.4.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/go.sum b/go.sum index ca7aa10..1f0a329 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-loremipsum/loremipsum v1.1.4 h1:RJaJlJwX4y9A2+CMgKIyPcjuFHFKTmaNMhxbL+sI6Vg= github.com/go-loremipsum/loremipsum v1.1.4/go.mod h1:whNWskGoefTakPnCu2CO23v5Y7RwiG4LMOEtTDaBeOY= +github.com/henvic/pgq v0.0.4 h1:BgLnxofZJSWWs+9VOf19Gr9uBkSVbHWGiu8wix1nsIY= +github.com/henvic/pgq v0.0.4/go.mod h1:k0FMvOgmQ45MQ3TgCLe8I3+sDKy9lPAiC2m9gg37pVA= github.com/huandu/go-assert v1.1.6 h1:oaAfYxq9KNDi9qswn/6aE0EydfxSa+tWZC1KabNitYs= github.com/huandu/go-assert v1.1.6/go.mod h1:JuIfbmYG9ykwvuxoJ3V8TB5QP+3+ajIA54Y44TmkMxs= github.com/huandu/go-sqlbuilder v1.35.1 h1:znTuAksxq3T1rYfr3nsD4P0brWDY8qNzdZnI6+vtia4= diff --git a/model.go b/model.go new file mode 100644 index 0000000..4500b83 --- /dev/null +++ b/model.go @@ -0,0 +1,297 @@ +package orm + +import ( + "fmt" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "reflect" + "strings" +) + +type Model struct { + Name string + Type reflect.Type + Relationships map[string]*Relationship + IDField string + Fields map[string]*Field + FieldsByColumnName map[string]*Field + TableName string + embeddedIsh bool +} + +func (m *Model) addField(field *Field) { + field.Model = m + m.Fields[field.Name] = field + m.FieldsByColumnName[field.ColumnName] = field +} + +func (m *Model) getAliasFields() map[string]string { + fields := make(map[string]string) + for _, f := range m.Fields { + if f.fk != nil { + continue + } + fields[f.Name] = fmt.Sprintf("%s.%s as %s_%s", m.TableName, f.ColumnName, strings.ToLower(m.Name), f.ColumnName) + } + return fields +} + +func (m *Model) getPrimaryKey(val reflect.Value) (string, any) { + colField := m.Fields[m.IDField] + if colField == nil { + return "", nil + } + colName := colField.ColumnName + idField := val.FieldByName(m.IDField) + if idField.IsValid() { + return colName, idField.Interface() + } + return "", nil +} + +func (m *Model) insert(v reflect.Value, e *Query, parentFks map[string]any) (any, error) { + var isTopLevel bool + var err error + var cn *pgxpool.Conn + if e.tx == nil && !e.engine.dryRun { + isTopLevel = true + var tx pgx.Tx + cn, err = e.engine.conn.Acquire(e.ctx) + if err != nil { + return nil, err + } + tx, err = cn.Begin(e.ctx) + if err != nil { + return nil, err + } + e.tx = tx + defer func() { + if err != nil { + e.tx.Rollback(e.ctx) + } + fmt.Printf("[DBG] discarding tx @ %p\n", e.tx) + fmt.Printf("[DBG] discarding conn @ %p\n", e.tx.Conn()) + fmt.Printf("[DBG] discarding outer conn @ %p\n", cn) + e.tx = nil + cn.Release() + }() + } + for v.Kind() == reflect.Pointer { + v = v.Elem() + } + t := v.Type() + var cols []string + var placeholders []string + var args []any + var returningField *Field + for k, vv := range parentFks { + cols = append(cols, k) + placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) + args = append(args, vv) + } + for _, ff := range m.Fields { + //ft := t.Field(ff.Index) + var fv reflect.Value + if ff.Index > -1 { + fv = v.Field(ff.Index) + } + + mfk := ff.fk + if mfk == nil { + mfk = m.Relationships[ff.Name] + } + if mfk != nil { + af := v.FieldByName(mfk.FieldName) + if af.Kind() == reflect.Struct { + idField := af.FieldByName(ff.fk.RelatedModel.IDField) + cols = append(cols, ff.ColumnName) + placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) + if !idField.IsValid() { + var nid any + nid, err = ff.fk.RelatedModel.insert(af, e, make(map[string]any)) + if err != nil { + return 0, err + } + args = append(args, nid) + } else { + args = append(args, idField.Interface()) + } + } + continue + } + col := ff.ColumnName + if ff.PrimaryKey { + returningField = ff + if fv.IsValid() && !fv.IsZero() { + cols = append(cols, col) + placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) + args = append(args, fv.Interface()) + } + continue + } + if fv.IsValid() { + cols = append(cols, col) + placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) + args = append(args, fv.Interface()) + } + } + var rfc = "id" + if returningField != nil { + rfc = returningField.ColumnName + } + scols := fmt.Sprintf("(%s)", strings.Join(cols, ", ")) + svals := fmt.Sprintf("VALUES (%s) ", strings.Join(placeholders, ", ")) + sql := fmt.Sprintf("INSERT INTO %s ", + m.TableName, + ) + if len(cols) > 0 { + sql += scols + sql += " " + sql += svals + } else { + sql += "DEFAULT VALUES " + } + sql += fmt.Sprintf("RETURNING %s", rfc) + fmt.Printf("[INSERT] %s { %s }\n", sql, logTrunc(args, 200)) + var id any + { + if v.FieldByName(m.IDField).IsValid() { + id = v.FieldByName(m.IDField).Interface() + } + } + if !e.engine.dryRun { + qr := e.engine.conn.QueryRow(e.ctx, sql, args...) + err = qr.Scan(&id) + if err != nil { + return 0, err + } + } + { + retField := v.FieldByName(returningField.Name) + if retField.IsValid() { + retField.Set(reflect.ValueOf(id)) + } + } + for i := range t.NumField() { + ft := t.Field(i) + fv := v.Field(i) + if ft.Type.Kind() == reflect.Slice { + for j := range fv.Len() { + child := fv.Index(j).Addr() + cm := e.engine.modelMap.Map[child.Type().Elem().Name()] + if cm != nil && cm.embeddedIsh { + cfk := map[string]any{ + m.TableName + "_id": id, + } + if m.Relationships[ft.Name] != nil { + cfk = map[string]any{ + pascalToSnakeCase(m.Relationships[ft.Name].JoinField()): id, + } + } + _, err = cm.insert(child, e, cfk) + if err != nil { + return 0, err + } + } else if cm != nil { + rel := m.Relationships[ft.Name] + if rel != nil { + err = rel.joinInsert(child, e, id) + if err != nil { + return 0, err + } + } + } + } + } + } + if isTopLevel && !e.engine.dryRun { + err = e.tx.Commit(e.ctx) + } + return id, err +} + +func (m *Model) update(val reflect.Value, q *Query, parentFks map[string]any) error { + var isTopLevel bool + var err error + if q.tx == nil && !q.engine.dryRun { + isTopLevel = true + var tx pgx.Tx + tx, err = q.engine.conn.Begin(q.ctx) + if err != nil { + return err + } + q.tx = tx + defer func() { + if err != nil { + q.tx.Rollback(q.ctx) + q.tx = nil + } + }() + } + cnt := 1 + sets := make([]string, 0) + vals := make([]any, 0) + for val.Kind() == reflect.Pointer { + val = val.Elem() + } + for _, field := range m.Fields { + if field.fk != nil || field.ColumnType == "" || field.Name == m.IDField { + continue + } + if field.Index > -1 && field.Index < val.NumField() && field.fk == nil { + f := val.Field(field.Index) + sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt)) + vals = append(vals, f.Interface()) + cnt++ + } else if _, ok := parentFks[field.ColumnName]; ok { + sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt)) + vals = append(vals, parentFks[field.ColumnName]) + cnt++ + } + } + mcol, mpk := m.getPrimaryKey(val) + sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %v", m.TableName, strings.Join(sets, ", "), mcol, mpk) + fmt.Printf("[UPDATE] %s { %s }\n", sql, logTrunc(vals, 200)) + if !q.engine.dryRun { + if _, err = q.tx.Exec(q.ctx, sql, vals...); err != nil { + return err + } + } + for _, rel := range m.Relationships { + if rel.Idx > -1 && rel.Idx < val.NumField() { + f := val.Field(rel.Idx) + if f.Kind() == reflect.Slice || + f.Kind() == reflect.Array { + err = diffSlices(q.engine, q, val, rel) + if err != nil { + return err + } + } else if rel.RelatedModel.embeddedIsh && !m.embeddedIsh { + elemish := f.Type() + for elemish.Kind() == reflect.Ptr { + elemish = elemish.Elem() + } + if elemish.Kind() == reflect.Struct { + err = rel.RelatedModel.update(f, q, map[string]any{ + pascalToSnakeCase(rel.JoinField()): mpk, + }) + if err != nil { + return err + } + } + } + } + } + var finerr error + if isTopLevel && !q.engine.dryRun { + finerr = q.tx.Commit(q.ctx) + } + return finerr +} + +/*func (m *Model) ensure(q *Query, val reflect.Value) error { + if _, pk := m.getPrimaryKey(val); pk != nil && !reflect.ValueOf(pk).IsZero() { + return nil + } + +}*/ diff --git a/model_internals.go b/model_internals.go new file mode 100644 index 0000000..67d3f2c --- /dev/null +++ b/model_internals.go @@ -0,0 +1,106 @@ +package orm + +import ( + "reflect" + "strings" +) + +func parseModel(model any) *Model { + t := reflect.TypeOf(model) + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + minfo := &Model{ + Name: t.Name(), + Relationships: make(map[string]*Relationship), + Fields: make(map[string]*Field), + FieldsByColumnName: make(map[string]*Field), + Type: t, + } + for i := range t.NumField() { + f := t.Field(i) + if !f.IsExported() { + continue + } + //minfo.Fields[f.Name] = parseField(f, minfo, i) + } + if minfo.TableName == "" { + minfo.TableName = pascalToSnakeCase(t.Name()) + } + return minfo +} + +func parseModelFields(model *Model, modelMap map[string]*Model) { + t := model.Type + for i := range t.NumField() { + f := t.Field(i) + fi := parseField(f, model, modelMap, i) + if fi != nil { + model.addField(fi) + } + } +} +func makeModelMap(models ...any) *ModelMap { + modelMap := &ModelMap{ + Map: make(map[string]*Model), + } + //modelMap := make(map[string]*Model) + for _, model := range models { + minfo := parseModel(model) + // modelMap.Mux.Lock() + modelMap.Map[minfo.Name] = minfo + // modelMap.Mux.Unlock() + } + for _, model := range modelMap.Map { + // modelMap.Mux.Lock() + parseModelFields(model, modelMap.Map) + // modelMap.Mux.Unlock() + } + tagManyToMany(modelMap) + for _, model := range modelMap.Map { + // modelMap.Mux.Lock() + for _, ref := range model.Relationships { + if ref.Type != ManyToMany && ref.Idx != -1 { + addForeignKeyFields(ref) + } + } + // modelMap.Mux.Unlock() + } + return modelMap +} + +func tagManyToMany(models *ModelMap) { + hasManys := make(map[string]*Relationship) + for _, model := range models.Map { + for relName := range model.Relationships { + hasManys[model.Name+"."+relName] = model.Relationships[relName] + } + } + for _, model := range models.Map { + // models.Mux.Lock() + for relName := range model.Relationships { + + mb := model.Relationships[relName].RelatedModel + var name string + for n, reltmp := range hasManys { + if !strings.HasPrefix(n, mb.Name) || reltmp.Type != HasMany { + continue + } + if reltmp.RelatedType == model.Type { + name = reltmp.FieldName + break + } + } + if rel2, ok := mb.Relationships[name]; ok { + if name < relName && + rel2.Type == HasMany && model.Relationships[relName].Type == HasMany { + mb.Relationships[name].Type = ManyToMany + mb.Relationships[name].m2mInverse = model.Relationships[relName] + model.Relationships[relName].Type = ManyToMany + model.Relationships[relName].m2mInverse = mb.Relationships[name] + } + } + } + // models.Mux.Unlock() + } +} diff --git a/model_map.go b/model_map.go new file mode 100644 index 0000000..17effd3 --- /dev/null +++ b/model_map.go @@ -0,0 +1,10 @@ +package orm + +import ( + "sync" +) + +type ModelMap struct { + Map map[string]*Model + Mux sync.RWMutex +} diff --git a/model_migration.go b/model_migration.go new file mode 100644 index 0000000..e1612f6 --- /dev/null +++ b/model_migration.go @@ -0,0 +1,110 @@ +package orm + +import ( + "fmt" + "reflect" + "strings" +) + +func (m *Model) createTableSql() string { + var fields []string + var fks []string + for _, field := range m.Fields { + isStructOrSliceOfStructs := field.Type.Kind() == reflect.Struct || + ((field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) && + field.Type.Elem().Kind() == reflect.Struct) + if field.PrimaryKey { + fields = append(fields, fmt.Sprintf("%s %s PRIMARY KEY", field.ColumnName, field.ColumnType)) + } else if (field.fk != nil && field.fk.Type != HasMany && field.fk.Type != ManyToMany) && field.isForeignKey { + colType := serialToRegular(field.ColumnType) + if !field.Nullable { + colType += " NOT NULL " + } + ffk := field.fk.RelatedModel.Fields[field.fk.RelatedModel.IDField] + if ffk != nil { + fks = append(fks, fmt.Sprintf("%s %s REFERENCES %s(%s)", + field.ColumnName, colType, + field.fk.RelatedModel.TableName, + field.fk.RelatedModel.Fields[field.fk.RelatedModel.IDField].ColumnName)) + } + } else if !isStructOrSliceOfStructs || field.ColumnType != "" { + lalala := fmt.Sprintf("%s %s", field.ColumnName, field.ColumnType) + if !field.Nullable { + lalala += " NOT NULL" + } + fields = append(fields, lalala) + } + } + inter := strings.Join(fields, ", ") + if len(fks) > 0 { + inter += ", " + inter += strings.Join(fks, ", ") + } + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);", + m.TableName, inter) +} + +func (m *Model) createJoinTableSql(relName string) string { + ref, ok := m.Relationships[relName] + if !ok { + return "" + } + aTable := m.TableName + joinTableName := ref.JoinTable() + fct := serialToRegular(ref.Model.Fields[ref.Model.IDField].ColumnType) + rct := serialToRegular(ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnType) + pkSection := fmt.Sprintf(",\nPRIMARY KEY (%s, %s_id)", + fmt.Sprintf("%s_%s", + aTable, pascalToSnakeCase(ref.FieldName), + ), + ref.RelatedModel.TableName, + ) + if ref.Type == HasMany || ref.Type == ManyToMany { + pkSection = "" + } + return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( +%s %s REFERENCES %s(%s), +%s_id %s REFERENCES %s(%s)%s +);`, + joinTableName, + fmt.Sprintf("%s_%s", + aTable, pascalToSnakeCase(ref.FieldName), + ), + fct, + ref.Model.TableName, ref.Model.Fields[ref.Model.IDField].ColumnName, + ref.RelatedModel.TableName, + rct, + ref.RelatedModel.TableName, ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnName, + pkSection, + ) +} + +func (m *Model) migrate(engine *Engine) error { + sql := m.createTableSql() + fmt.Println(sql) + if !engine.dryRun { + _, err := engine.conn.Exec(engine.ctx, sql) + if err != nil { + return err + } + } + for relName, rel := range m.Relationships { + relkey := rel.Model.Name + if (rel.Type == ManyToMany && !engine.m2mSeen[relkey]) || + (rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh && rel.Type == HasMany) { + if rel.Type == ManyToMany { + engine.m2mSeen[relkey] = true + engine.m2mSeen[rel.RelatedModel.Name] = true + } + jtsql := m.createJoinTableSql(relName) + fmt.Println(jtsql) + if !engine.dryRun { + _, err := engine.conn.Exec(engine.ctx, jtsql) + if err != nil { + return err + } + } + } + } + return nil +} diff --git a/model_misc.go b/model_misc.go new file mode 100644 index 0000000..a6abba4 --- /dev/null +++ b/model_misc.go @@ -0,0 +1,174 @@ +package orm + +import ( + "context" + "fmt" + "reflect" +) + +func fetchJoinTableChildren(e *Engine, ctx context.Context, childRel *Relationship, pid any) (map[any]struct{}, error) { + rsql := fmt.Sprintf("SELECT %s_%s FROM %s where %s_id = $1", + childRel.Model.TableName, pascalToSnakeCase(childRel.FieldName), + childRel.JoinTable(), childRel.RelatedModel.TableName) + res := make(map[any]struct{}) + if !e.dryRun { + rows, err := e.conn.Query(ctx, rsql, pid) + defer rows.Close() + if err != nil { + return nil, err + } + for rows.Next() { + var id any + if err = rows.Scan(&id); err != nil { + return nil, err + } + res[id] = struct{}{} + } + } + return res, nil +} + +func fetchChildren(e *Engine, childRel *Relationship, pid any) (map[any]reflect.Value, error) { + res := make(map[any]reflect.Value) + qq := e.Model(reflect.New(childRel.RelatedModel.Type).Elem().Interface()) + rrel := childRel.RelatedModel.Relationships[childRel.Model.Name] + rfield := childRel.RelatedModel.Fields[rrel.FieldName] + /*if rrel == nil { + return res, fmt.Errorf("please report this, it shouldn't have happened :(") + }*/ + rawRows, err := qq.Where(fmt.Sprintf("%s.%s = $1", childRel.RelatedModel.TableName, rfield.ColumnName), pid).Find() + if err != nil { + return nil, err + } + inter := make([]any, 0) + rrv := reflect.ValueOf(rawRows) + for i := range rrv.Len() { + inter = append(inter, rrv.Index(i).Interface()) + } + for _, row := range inter { + v := reflect.ValueOf(row) + bv := v + for bv.Kind() == reflect.Ptr { + bv = bv.Elem() + } + id := v.FieldByName(childRel.RelatedModel.IDField).Interface() + res[id] = v + } + + return res, nil +} + +func preDiff(e *Engine, value reflect.Value) (*Model, error) { + ptype := value.Type() + for ptype.Kind() == reflect.Pointer { + ptype = ptype.Elem() + } + model, ok := e.modelMap.Map[ptype.Name()] + if !ok { + return nil, fmt.Errorf("model '%s' not found", ptype.Name()) + } + return model, nil +} + +func diffManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { + model, err := preDiff(e, value) + if err != nil { + return err + } + _, ppk := model.getPrimaryKey(value) + dbChildren, err := fetchChildren(e, rel, ppk) + if err != nil { + return err + } + memChildren := make(map[any]reflect.Value) + fv := value.FieldByName(rel.FieldName) + for i := range fv.Len() { + child := fv.Index(i) + _, cpk := rel.RelatedModel.getPrimaryKey(child) + if cpk != nil { + memChildren[cpk] = child + } + } + // deletions // + for pk := range dbChildren { + if _, found := memChildren[pk]; !found { + table := rel.RelatedModel.TableName + idField := rel.RelatedModel.Fields[rel.RelatedModel.IDField] + _, err = q.tx.Exec(q.ctx, fmt.Sprintf("DELETE FROM %s where %s = $1", table, idField.ColumnName), pk) + if err != nil { + return err + } + } + } + mField := model.Fields[model.IDField] + mpks := map[string]any{} + if !model.embeddedIsh { + mpks[mField.ColumnName] = ppk + } + // update || insert // + for i := range fv.Len() { + cur := fv.Index(i) + _, cpk := rel.RelatedModel.getPrimaryKey(cur) + if cpk == nil || reflect.ValueOf(cpk).IsZero() { + + _, err = rel.RelatedModel.insert(cur, q, mpks) + if err != nil { + return err + } + } else { + err = rel.RelatedModel.update(cur, q, mpks) + if err != nil { + return err + } + } + } + return nil +} + +func diffManyToManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { + model, err := preDiff(e, value) + if err != nil { + return err + } + _, ppk := model.getPrimaryKey(value) + ids, err := fetchJoinTableChildren(e, q.ctx, rel, ppk) + if err != nil { + return err + } + memIds := make(map[any]reflect.Value) + fv := value.FieldByName(rel.FieldName) + for i := range fv.Len() { + child := fv.Index(i) + _, cpk := rel.RelatedModel.getPrimaryKey(child) + if cpk != nil { + memIds[cpk] = child + } + } + for memId := range memIds { + if _, found := ids[memId]; !found { + err = rel.joinInsert(memIds[memId], q, ppk) + if err != nil { + return err + } + } + } + for id := range ids { + if _, found := memIds[id]; !found { + err = rel.joinDelete(ppk, id, q) + if err != nil { + return err + } + } + } + return nil +} + +func diffSlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { + if rel.Type == ManyToMany || rel.m2mIsh() { + return diffManyToManySlices(e, q, value, rel) + } + if rel.Type == HasMany { + return diffManySlices(e, q, value, rel) + } + return nil +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..fed34ef --- /dev/null +++ b/query.go @@ -0,0 +1,211 @@ +package orm + +import ( + "context" + "fmt" + sb "github.com/henvic/pgq" + "github.com/jackc/pgx/v5" + "reflect" + "strings" +) + +type Query struct { + model *Model + relatedModels map[string]*Model + wheres map[string][]any + orders []string + limit int + offset int + joins map[*Relationship][3]string + engine *Engine + ctx context.Context + tx pgx.Tx +} + +func (q *Query) totalWheres() int { + total := 0 + for _, w := range q.wheres { + total += len(w) + } + return total +} + +func (q *Query) Model(val any) *Query { + tt := reflect.TypeOf(val) + for tt.Kind() == reflect.Ptr { + tt = tt.Elem() + } + q.model = q.engine.modelMap.Map[tt.Name()] + return q +} + +func (q *Query) Where(cond any, args ...any) *Query { + switch v := cond.(type) { + case string: + q.wheres[strings.ReplaceAll(v, "$?", fmt.Sprintf("$%d", q.totalWheres()+1))] = args + default: + rv := reflect.ValueOf(cond) + for rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + rt := rv.Type() + for i := range rv.NumField() { + field := rt.Field(i) + fieldValue := rv.Field(i) + if isZero(fieldValue) { + continue + } + mm, ok := q.engine.modelMap.Map[rv.Type().Name()] + if !ok { + continue + } + ff, ok := mm.Fields[field.Name] + if !ok || ff.ColumnType == "" { + continue + } + whereClause := fmt.Sprintf("%s.%s = ?", mm.TableName, ff.ColumnName /*, q.totalWheres()+1*/) + args = append(args, fieldValue.Interface()) + q.wheres[whereClause] = args + } + } + return q +} + +func (q *Query) Order(order string) *Query { + q.orders = append(q.orders, order) + return q +} + +func (q *Query) Limit(limit int) *Query { + q.limit = limit + return q +} + +func (q *Query) Offset(offset int) *Query { + q.offset = offset + return q +} + +func (q *Query) buildSelect() sb.SelectBuilder { + var fields []string + + for _, f := range q.model.Fields { + if f.ColumnType == "" { + continue + } + tn, a := f.alias() + fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) + } + seenModels := make(map[string]bool) + processField := func(f *Field, m *Model, pfk *Relationship) { + if f.ColumnType == "" { + if rel, ok := m.Relationships[f.Name]; ok { + data, ok2 := q.joins[rel] + if ok2 && f.ColumnType != "" { + tn, a := f.aliasWith(data[0]) + fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) + } + } + return + } + { + fk := f.fk + if fk == nil { + fk = m.Relationships[f.Name] + } + if fk == nil { + fk = pfk + } + if fk != nil && fk.FieldName == f.Name { + data, ok2 := q.joins[fk] + if ok2 { + var ( + tn, a string + ) + if fk.Type == HasOne { + tn, a = fk.RelatedModel.Fields[fk.RelatedModel.IDField].aliasWith(data[0]) + } else { + tn, a = f.aliasWith(data[0]) + } + fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) + } + return + } + } + if f.Name == pfk.FieldName { + f.aliasWith(q.joins[pfk][0]) + } else { + tn, a := f.alias() + fields = append(fields, fmt.Sprintf("%s AS %s", tn, a)) + } + + } + for r := range q.joins { + if !seenModels[r.aliasThingy()] { + seenModels[r.aliasThingy()] = true + for _, f := range r.Model.Fields { + processField(f, r.Model, r) + } + } + if !seenModels[r.relatedAlias()] { + seenModels[r.relatedAlias()] = true + for _, f := range r.RelatedModel.Fields { + processField(f, r.RelatedModel, r) + } + } + } + return sb.Select(fields...) +} + +func (q *Query) buildSQL() (string, []any) { + sqlb := q.buildSelect().From(q.model.TableName) + whereargs := make([]any, 0) + if len(q.wheres) > 0 { + cnt := 0 + for w, where := range q.wheres { + sqlb = sqlb.Where(w, where...) + + whereargs = append(whereargs, where...) + cnt++ + } + } + for _, j := range q.joins { + sqlb = sqlb.LeftJoin(fmt.Sprintf("%s as %s ON %s", j[1], j[0], j[2])) + } +ool: + for _, o := range q.orders { + ac, ok := q.model.Fields[o] + if !ok { + var rel = ac.fk + if ac.ColumnType == "" || rel == nil { + rel = q.model.Relationships[o] + } + if rel != nil { + if strings.Contains(o, ".") { + split := strings.Split(strings.TrimSuffix(strings.TrimPrefix(o, "."), "."), ".") + cm := rel.Model + for i, s := range split { + if rel != nil { + cm = rel.RelatedModel + } else if i == len(split)-1 { + break + } else { + continue ool + } + rel = cm.Relationships[s] + } + lf := split[len(split)-1] + ac, ok = cm.Fields[lf] + if !ok { + continue + } + } + } + } + sqlb = sqlb.OrderBy(ac.ColumnName) + } + if q.limit > 0 { + sqlb = sqlb.Limit(uint64(q.limit)) + } + return sqlb.MustSQL() +} diff --git a/query_populate.go b/query_populate.go new file mode 100644 index 0000000..e6f7e0f --- /dev/null +++ b/query_populate.go @@ -0,0 +1,165 @@ +package orm + +import ( + "fmt" + "strings" +) + +const PopulateAll = "~~~ALL~~~" + +func join(r *Relationship) (string, string, string) { + rtable := r.RelatedModel.TableName + field := r.Model.Fields[r.FieldName] + var fk, pk, alias string + if !r.RelatedModel.embeddedIsh && !r.Model.embeddedIsh { + alias = pascalToSnakeCase(field.Name) + fk = fmt.Sprintf("%s.%s", alias, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName) + pk = fmt.Sprintf("%s.%s", r.Model.TableName, field.ColumnName) + alias = pascalToSnakeCase(field.Name) + } else if !r.Model.embeddedIsh { + alias = pascalToSnakeCase(r.FieldName) + sid := strings.TrimSuffix(r.JoinField(), "ID") + fk = fmt.Sprintf("%s.%s", alias, r.RelatedModel.Fields[sid].ColumnName) + pk = fmt.Sprintf("%s.%s", r.Model.TableName, r.Model.Fields[r.Model.IDField].ColumnName) + } + return alias, rtable, fmt.Sprintf("%s = %s", fk, pk) +} + +func m2mJoin(r *Relationship) [][3]string { + result := make([][3]string, 0) + jt := r.JoinTable() + first := [3]string{ + pascalToSnakeCase(r.FieldName), + jt, + fmt.Sprintf("%s = %s", + fmt.Sprintf("%s.%s", + jt, + r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName, + ), + fmt.Sprintf("%s.%s", + r.Model.TableName, + r.Model.Fields[r.Model.IDField].ColumnName, + ), + ), + } + second := [3]string{ + pascalToSnakeCase(r.m2mInverse.FieldName), + r.RelatedModel.TableName, + fmt.Sprintf("%s = %s", + fmt.Sprintf("%s.%s", r.RelatedModel.TableName, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName), + fmt.Sprintf("%s.%s", jt, r.m2mInverse.JoinField()), + ), + } + + /*first := fmt.Sprintf("%s AS %s ON %s = %s", + jt, + fmt.Sprintf("%s.%s", + jt, + r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName, + ), + fmt.Sprintf("%s.%s", + r.Model.TableName, + r.Model.Fields[r.Model.IDField].ColumnName, + ), + ) + second := fmt.Sprintf("%s ON %s = %s", + r.RelatedModel.TableName, + fmt.Sprintf("%s.%s", r.RelatedModel.TableName, r.RelatedModel.Fields[r.RelatedModel.IDField].ColumnName), + fmt.Sprintf("%s.%s", jt, r.m2mInverse.JoinField()), + )*/ + result = append(result, first, second) + return result +} + +func nestedJoin(m *Model, path string) (joins [][3]string, ree []*Relationship) { + splitPath := strings.Split(path, ".") + prevModel := m + for _, f := range splitPath { + rel, ok := m.Relationships[f] + if !ok { + break + } + var ( + fk, pk string + ) + if !rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh { + pk = prevModel.Fields[rel.FieldName].ColumnName + fk = rel.RelatedModel.Fields[rel.RelatedModel.IDField].ColumnName + } else if !rel.Model.embeddedIsh { + pk = prevModel.Fields[prevModel.IDField].ColumnName + fk = rel.RelatedModel.Fields[strings.TrimSuffix(rel.JoinField(), "ID")].ColumnName + } + ree = append(ree, rel) + j2 := [3]string{ + rel.Model.Fields[rel.FieldName].ColumnName, + rel.RelatedModel.TableName, + fmt.Sprintf("%s.%s = %s.%s", + rel.RelatedModel.TableName, fk, + prevModel.TableName, pk), + } + + /*j2 := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", + rel.RelatedModel.TableName, + rel.Model.Fields[rel.FieldName].ColumnName, + rel.RelatedModel.TableName, fk, + prevModel.TableName, pk, + )*/ + joins = append(joins, j2) + prevModel = rel.RelatedModel + } + return +} + +func (q *Query) Populate(relation string) *Query { + if relation == PopulateAll { + for _, rel := range q.model.Relationships { + if rel.Type == ManyToMany { + mjs := m2mJoin(rel) + for _, ajoin := range mjs { + q.joins[rel] = [3]string{ + ajoin[0], + ajoin[1], + ajoin[2], + } + } + //q.joins = append(q.joins, m2mJoin(rel)...) + } else { + alias, tn, cond := join(rel) + q.joins[rel] = [3]string{ + alias, + tn, + cond, + } + //q.joins = append(q.joins, tn) + } + q.relatedModels[rel.RelatedModel.Name] = rel.RelatedModel + } + } else { + if strings.Contains(relation, ".") { + njs, rmodels := nestedJoin(q.model, relation) + for i, m := range rmodels { + curTuple := njs[i] + q.joins[m] = [3]string{ + curTuple[0], + curTuple[1], + curTuple[2], + } + q.relatedModels[m.Model.Name] = m.Model + } + //q.joins = append(q.joins, njs...) + } else { + rel, ok := q.model.Relationships[relation] + if ok { + alias, tn, j := join(rel) + q.joins[rel] = [3]string{ + alias, + tn, + j, + } + //q.joins = append(q.joins, tn) + q.relatedModels[rel.RelatedModel.Name] = rel.RelatedModel + } + } + } + return q +} diff --git a/query_tail.go b/query_tail.go new file mode 100644 index 0000000..0b33e4c --- /dev/null +++ b/query_tail.go @@ -0,0 +1,42 @@ +package orm + +import ( + "errors" + "fmt" + "reflect" +) + +func (q *Query) Create(val any) (any, error) { + _, err := q.model.insert(reflect.ValueOf(val), q, make(map[string]any)) + return val, err +} + +func (q *Query) Find() (any, error) { + sql, args := q.buildSQL() + fmt.Printf("[FIND] %s { %+v }\n", sql, args) + if !q.engine.dryRun { + rows, err := q.engine.conn.Query(q.ctx, sql, args...) + defer rows.Close() + if err != nil { + return nil, err + } + rmaps, err := rowsToMaps(rows) + if err != nil { + return nil, err + } + wtype := q.model.Type + for wtype.Kind() == reflect.Pointer { + wtype = wtype.Elem() + } + return fillSlice(rmaps, wtype, q.engine.modelMap), nil + } + return make([]any, 0), nil +} + +func (q *Query) Update(val any, cond any, args ...any) error { + if q.model != nil { + return q.model.update(reflect.ValueOf(val), q, make(map[string]any)) + } else { + return errors.New("Please select a model") + } +} diff --git a/relationship.go b/relationship.go new file mode 100644 index 0000000..10565ff --- /dev/null +++ b/relationship.go @@ -0,0 +1,209 @@ +package orm + +import ( + "fmt" + "reflect" + "strings" +) + +type RelationshipType int + +const ( + HasOne RelationshipType = iota + HasMany + ManyToMany +) + +type Relationship struct { + Type RelationshipType + Model *Model + FieldName string + Idx int + RelatedType reflect.Type + RelatedModel *Model + Kind reflect.Kind // field kind (struct, slice, ...) + m2mInverse *Relationship +} + +func (r *Relationship) JoinTable() string { + return r.Model.TableName + "_" + r.RelatedModel.TableName +} + +func (r *Relationship) JoinField() string { + isMany := r.Type == HasMany //|| r.Type == ManyToMany + if isMany && r.Model.embeddedIsh { + return r.Model.Name + r.FieldName + "ID" + } else if isMany && !r.Model.embeddedIsh && r.m2mInverse == nil { + return r.Model.Name + "ID" + } else if r.Type == ManyToMany && !r.Model.embeddedIsh { + return r.RelatedModel.Name + "ID" + } + return r.FieldName + "ID" +} + +func (r *Relationship) aliasThingy() string { + return pascalToSnakeCase(r.Model.Name + "." + r.FieldName) +} + +func (r *Relationship) relatedAlias() string { + return pascalToSnakeCase(r.RelatedModel.Name + "." + r.FieldName) +} + +func (r *Relationship) m2mIsh() bool { + return r.Model.embeddedIsh && !r.RelatedModel.embeddedIsh && r.Type == HasMany +} + +func (r *Relationship) joinInsert(v reflect.Value, e *Query, pfk any) error { + if r.Type != ManyToMany && + !r.m2mIsh() { + return nil + } + ichild := v + for ichild.Kind() == reflect.Ptr { + ichild = ichild.Elem() + } + if ichild.Kind() == reflect.Struct { + jtable := r.JoinTable() + jargs := make([]any, 0) + jcols := make([]string, 0) + jcols = append(jcols, fmt.Sprintf("%s_%s", + r.Model.TableName, pascalToSnakeCase(r.FieldName), + )) + jargs = append(jargs, pfk) + + jcols = append(jcols, r.RelatedModel.TableName+"_id") + jargs = append(jargs, ichild.FieldByName(r.RelatedModel.IDField).Interface()) + var ecnt int + e.tx.QueryRow(e.ctx, + fmt.Sprintf("SELECT count(*) from %s where %s = $1 and %s = $2", r.JoinTable(), jcols[0], jcols[1]), jargs...).Scan(&ecnt) + if ecnt > 0 { + return nil + } + jsql := fmt.Sprintf("INSERT INTO %s (%s) VALUES ($1, $2)", jtable, strings.Join(jcols, ", ")) + fmt.Printf("[INSERT/JOIN] %s { %s }\n", jsql, logTrunc(jargs, 200)) + if !e.engine.dryRun { + _ = e.tx.QueryRow(e.ctx, jsql, jargs...).Scan() + } + } + return nil +} + +func (r *Relationship) joinDelete(pk, fk any, q *Query) error { + jc := fmt.Sprintf("%s_%s", r.Model.TableName, pascalToSnakeCase(r.FieldName)) + ds := fmt.Sprintf("DELETE FROM %s where %s = $1 and %s = $2", + r.JoinTable(), jc, r.RelatedModel.TableName+"_id") + fmt.Printf("[DELETE/JOIN] %s { %s }\n", ds, logTrunc([]any{pk, fk}, 200)) + if !q.engine.dryRun { + _, err := q.tx.Exec(q.ctx, ds, pk, fk) + return err + } + return nil +} + +func parseRelationship(field reflect.StructField, modelMap map[string]*Model, outerType reflect.Type, idx int) *Relationship { + rel := &Relationship{ + Model: modelMap[outerType.Name()], + RelatedModel: modelMap[field.Type.Name()], + RelatedType: field.Type, + Idx: idx, + Kind: field.Type.Kind(), + FieldName: field.Name, + } + if rel.RelatedType.Kind() == reflect.Slice || rel.RelatedType.Kind() == reflect.Array { + rel.RelatedType = rel.RelatedType.Elem() + } + if rel.RelatedModel == nil { + if rel.RelatedType.Name() == "" { + rt := rel.RelatedType + for rt.Kind() == reflect.Ptr || rt.Kind() == reflect.Slice || rt.Kind() == reflect.Array { + rel.RelatedType = rel.RelatedType.Elem() + rt = rel.RelatedType + } + } + rel.RelatedModel = modelMap[rel.RelatedType.Name()] + if _, ok := modelMap[rel.RelatedType.Name()]; !ok { + rel.RelatedModel = parseModel(reflect.New(rel.RelatedType).Interface()) + modelMap[rel.RelatedType.Name()] = rel.RelatedModel + parseModelFields(rel.RelatedModel, modelMap) + rel.RelatedModel.embeddedIsh = true + } + } + switch field.Type.Kind() { + case reflect.Struct: + rel.Type = HasOne + case reflect.Slice: + rel.Type = HasMany + } + return rel +} +func addForeignKeyFields(ref *Relationship) { + rf := ref.RelatedModel.Fields[ref.RelatedModel.IDField] + if rf != nil { + if !ref.RelatedModel.embeddedIsh && !ref.Model.embeddedIsh { + ff := ref.Model.Fields[ref.FieldName] + ff.ColumnType = rf.ColumnType + ff.ColumnName = pascalToSnakeCase(ref.JoinField()) + ff.isForeignKey = true + ff.fk = ref + } else if !ref.Model.embeddedIsh { + sid := strings.TrimSuffix(ref.JoinField(), "ID") + ref.RelatedModel.Relationships[sid] = &Relationship{ + FieldName: sid, + Type: HasOne, + RelatedModel: ref.Model, + Model: ref.RelatedModel, + Kind: ref.RelatedModel.Type.Kind(), + Idx: -1, + RelatedType: ref.Model.Type, + } + ref.RelatedModel.addField(&Field{ + ColumnType: rf.ColumnType, + ColumnName: pascalToSnakeCase(ref.JoinField()), + Name: sid, + isForeignKey: true, + Type: rf.Type, + Index: -1, + fk: ref.RelatedModel.Relationships[sid], + }) + } else if ref.Model.embeddedIsh && !ref.RelatedModel.embeddedIsh { + + } + } else { + ref.RelatedModel.addField(&Field{ + Name: "ID", + ColumnName: "id", + ColumnType: "bigserial", + PrimaryKey: true, + Type: ref.RelatedType, + Index: -1, + AutoIncrement: true, + }) + ff := ref.Model.Fields[ref.FieldName] + ff.ColumnType = "bigint" + ff.ColumnName = pascalToSnakeCase(ref.RelatedModel.Name + "ID") + ff.isForeignKey = true + ff.fk = ref + ref.RelatedModel.IDField = "ID" + /* + nn := ref.Model.Name + "ID" + ref.RelatedModel.Relationships[ref.Model.Name] = &Relationship{ + Type: HasOne, + RelatedModel: ref.Model, + Model: ref.RelatedModel, + Idx: 65536, + Kind: ref.RelatedModel.Type.Kind(), + RelatedType: ref.RelatedModel.Type, + FieldName: nn, + } + ref.RelatedModel.addField(&Field{ + Name: nn, + Type: ref.Model.Type, + fk: ref.RelatedModel.Relationships[nn], + ColumnName: pascalToSnakeCase(nn), + Index: -1, + ColumnType: ref.Model.Fields[ref.Model.IDField].ColumnType, + isForeignKey: true, + })*/ + + } +} diff --git a/scan.go b/scan.go new file mode 100644 index 0000000..084f0a1 --- /dev/null +++ b/scan.go @@ -0,0 +1,85 @@ +package orm + +import ( + "github.com/jackc/pgx/v5" + "reflect" +) + +func rowsToMaps(rows pgx.Rows) ([]map[string]any, error) { + var result []map[string]any + fieldDescs := rows.FieldDescriptions() + for rows.Next() { + m := make(map[string]any) + scanArgs := make([]any, len(fieldDescs)) + for i := range fieldDescs { + var v any + scanArgs[i] = &v + } + if err := rows.Scan(scanArgs...); err != nil { + return nil, err + } + for i, fd := range fieldDescs { + name := fd.Name + m[name] = *(scanArgs[i].(*any)) + } + result = append(result, m) + } + return result, rows.Err() +} + +func fillNested(row map[string]any, t reflect.Type, mm *ModelMap, depth, maxDepth int) any { + cm := mm.Map[t.Name()] + pp := reflect.New(cm.Type).Elem() + for _, field := range cm.Fields { + _, alias := field.alias() + if v, ok := row[alias]; ok && !field.isForeignKey { + reflectSet(pp.Field(field.Index), v) + } + } + for _, rel := range cm.Relationships { + if rel.Idx > -1 && rel.Idx < pp.NumField() { + relType := rel.RelatedModel.Type + for relType.Kind() == reflect.Pointer { + relType = relType.Elem() + } + nv := reflect.New(relType) + if rel.Kind == reflect.Struct || rel.Kind == reflect.Pointer { + if depth < maxDepth { + nv = reflect.ValueOf(fillNested(row, relType, mm, depth+1, maxDepth)) + } + if rel.Kind != reflect.Pointer && nv.Kind() == reflect.Pointer { + nv = nv.Elem() + } + reflectSet(pp.Field(rel.Idx), nv) + } else if rel.Kind == reflect.Slice || rel.Kind == reflect.Array { + relType2 := relType + for relType2.Kind() == reflect.Slice || relType2.Kind() == reflect.Array { + relType2 = relType2.Elem() + } + if depth < maxDepth { + nv = reflect.ValueOf(fillNested(row, relType2, mm, depth+1, maxDepth)) + } + if nv.Kind() == reflect.Pointer { + nv = nv.Elem() + } + reflectSet(pp.Field(rel.Idx), reflect.Append(pp.Field(rel.Idx), nv)) + } + } + } + return pp.Interface() +} + +// fillSlice - note that it's the caller's responsibility to indirect +// the type in `t`. +func fillSlice(rows []map[string]any, t reflect.Type, mm *ModelMap) any { + pslice := reflect.MakeSlice(reflect.SliceOf(t), 0, 0) + rt := t + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + for _, row := range rows { + pp := fillNested(row, rt, mm, 0, 10) + pslice = reflect.Append(pslice, reflect.ValueOf(pp)) + } + return pslice.Interface() +} diff --git a/test_main.go b/test_main.go new file mode 100644 index 0000000..bea347d --- /dev/null +++ b/test_main.go @@ -0,0 +1,55 @@ +package orm + +import "fmt" + +const do_bootstrap = true + +func TestMain() { + e, err := Open("postgres://testbed_user:123@localhost/testbed_i_think") + if err != nil { + panic(err) + } + e.Models(user{}, story{}, band{}, role{}) + if do_bootstrap { + + err = e.Migrate() + if err != nil { + panic(err) + } + s := iti_multi() + u := &author + u.Favs.Authors = append(u.Favs.Authors, friend) + _, err = e.Model(&user{}).Create(&friend) + if err != nil { + panic(err) + } + _, err = e.Model(&user{}).Create(u) + if err != nil { + panic(err) + } + _, err = e.Model(&band{}).Create(&bodom) + if err != nil { + panic(err) + } + _, err = e.Model(&band{}).Create(&diamondHead) + if err != nil { + panic(err) + } + _, err = e.Model(&story{}).Create(s) + if err != nil { + panic(err) + } + s.Downloads = s.Downloads + 1 + err = e.Model(&story{}).Update(s, nil) + if err != nil { + panic(err) + } + } + ns, err := e.Model(&story{}).Where(&story{ + ID: 1, + }).Populate(PopulateAll).Find() + if err != nil { + panic(err) + } + fmt.Printf("%+v", ns) +} diff --git a/testing.go b/testing.go index 3ee8694..401177c 100644 --- a/testing.go +++ b/testing.go @@ -1,48 +1,52 @@ package orm import ( - "context" "fmt" - "github.com/stretchr/testify/assert" + "math/rand/v2" "strings" - "testing" "time" "github.com/go-loremipsum/loremipsum" ) type chapter struct { - ID bson.ObjectID `json:"_id"` - Title string `json:"chapterTitle" form:"chapterTitle"` - ChapterID int `json:"chapterID" autoinc:"chapters"` - Index int `json:"index" form:"index"` - Words int `json:"words"` - Notes string `json:"notes" form:"notes"` - Genre []string `json:"genre" form:"genre"` - Bands []band `json:"bands" ref:"band,bands"` - Characters []string `json:"characters" form:"characters"` - Relationships [][]string `json:"relationships" form:"relationships"` - Adult bool `json:"adult" form:"adult"` - Summary string `json:"summary" form:"summary"` - Hidden bool `json:"hidden" form:"hidden"` - LoggedInOnly bool `json:"loggedInOnly" form:"loggedInOnly"` - Posted time.Time `json:"datePosted"` - FileName string `json:"fileName"` - Text string `json:"text" gridfs:"story_text,/stories/{{.ChapterID}}.txt"` + ChapterID int64 `json:"chapterID" d:"pk:t;"` + Title string `json:"chapterTitle" form:"chapterTitle"` + Index int `json:"index" form:"index"` + Words int `json:"words"` + Notes string `json:"notes" form:"notes"` + Genre []string `json:"genre" form:"genre" d:"type:text[]"` + Bands []band `json:"bands" ref:"band,bands"` + Characters []string `json:"characters" form:"characters" d:"type:text[]"` + Relationships [][]string `json:"relationships" form:"relationships" d:"type:text[][]"` + Adult bool `json:"adult" form:"adult"` + Summary string `json:"summary" form:"summary"` + Hidden bool `json:"hidden" form:"hidden"` + LoggedInOnly bool `json:"loggedInOnly" form:"loggedInOnly"` + Posted time.Time `json:"datePosted"` + FileName string `json:"fileName" d:"-"` + Text string `json:"text" d:"column:content" gridfs:"story_text,/stories/{{.ChapterID}}.txt"` } type band struct { - ID int64 `json:"_id"` Document `json:",inline" d:"table:bands"` + ID int64 `json:"_id" d:"pk;"` Name string `json:"name" form:"name"` Locked bool `json:"locked" form:"locked"` - Characters []string `json:"characters" form:"characters"` + Characters []string `json:"characters" form:"characters" d:"type:text[]"` } type user struct { - ID int64 `json:"_id"` Document `json:",inline" d:"table:users"` + ID int64 `json:"_id" d:"pk;"` Username string `json:"username"` Favs favs `json:"favs" ref:"user"` + Roles []role +} + +type role struct { + ID int64 `d:"pk"` + Name string + Users []user } type favs struct { @@ -50,10 +54,10 @@ type favs struct { Authors []user } type story struct { - ID int64 `json:"_id"` - Document `json:",inline" coll:"stories"` + Document `json:",inline" d:"table:stories"` + ID int64 `json:"_id" d:"pk;"` Title string `json:"title" form:"title"` - Author *user `json:"author" ref:"user"` + Author user `json:"author" ref:"user"` CoAuthor *user `json:"coAuthor" ref:"user"` Chapters []chapter `json:"chapters"` Recs int `json:"recs"` @@ -69,54 +73,23 @@ type somethingWithNestedChapters struct { NestedText string `json:"text" gridfs:"nested_text,/nested/{{.ID}}.txt"` } -func (s *somethingWithNestedChapters) Id() any { - return s.ID -} - -func (s *somethingWithNestedChapters) SetId(id any) { - s.ID = id.(int64) -} - -func (s *story) Id() any { - return s.ID -} - -func (s *band) Id() any { - return s.ID -} -func (s *user) Id() any { - return s.ID -} - -func (s *story) SetId(id any) { - s.ID = id.(int64) - //var t IDocument =s -} - -func (s *band) SetId(id any) { - s.ID = id.(int64) -} - -func (s *user) SetId(id any) { - s.ID = id.(int64) +var friend = user{ + Username: "DarQuiel7", + ID: 83378, } var author = user{ Username: "tablet.exe", - Favs: []user{ - { - Username: "DarQuiel7", - }, - }, + ID: 85783, } -func genChaps(single bool) []chapter { +func genChaps(single bool, aceil int) []chapter { var ret []chapter var ceil int if single { ceil = 1 } else { - ceil = 5 + ceil = aceil } emptyRel := make([][]string, 0) emptyRel = append(emptyRel, make([]string, 0)) @@ -140,8 +113,8 @@ func genChaps(single bool) []chapter { for i := 0; i < ceil; i++ { spf := fmt.Sprintf("%d.md", i+1) - ret = append(ret, chapter{ - ID: bson.NewObjectID(), + c := chapter{ + ChapterID: int64(i + 1), Title: fmt.Sprintf("-%d-", i+1), Index: i + 1, Words: 50, @@ -156,7 +129,19 @@ func genChaps(single bool) []chapter { LoggedInOnly: true, FileName: spf, Text: strings.Join(l.ParagraphList(10), "\n\n"), - }) + Posted: time.Now().Add(time.Hour * time.Duration(int64(24*7*i))), + } + { + randMin := max(i+1, 1) + randMax := min(i+1, randMin+1) + mod1 := max(rand.IntN(randMin), 1) + mod2 := max(rand.IntN(randMax+1), 1) + if (mod1%mod2 == 0 || (mod1%mod2) == 2) && i > 0 { + c.Bands = append(c.Bands, bodom) + } + } + + ret = append(ret, c) } return ret @@ -165,51 +150,35 @@ func genChaps(single bool) []chapter { func doSomethingWithNested() somethingWithNestedChapters { l := loremipsum.New() swnc := somethingWithNestedChapters{ - Chapters: genChaps(false), + Chapters: genChaps(false, 7), NestedText: strings.Join(l.ParagraphList(15), "\n\n"), } return swnc } -func iti_single() story { - return story{ +func iti_single() *story { + return &story{ Title: "title", Completed: true, - Chapters: genChaps(true), + Author: author, + Chapters: genChaps(true, 0), } } -func iti_multi() story { - return story{ +func iti_multi() *story { + return &story{ Title: "Brian Tatler Fucked and Abused Sean Harris", Completed: false, - Chapters: genChaps(false), + Author: author, + Chapters: genChaps(false, 5), } } -func iti_blank() story { +func iti_blank() *story { t := iti_single() t.Chapters = make([]chapter, 0) return t } -func initTest() { - uri := "mongodb://127.0.0.1:27017" - db := "rockfic_ormTest" - ic, _ := mongo.Connect(options.Client().ApplyURI(uri)) - ic.Database(db).Drop(context.TODO()) - colls, _ := ic.Database(db).ListCollectionNames(context.TODO(), bson.M{}) - if len(colls) < 1 { - mdb := ic.Database(db) - mdb.CreateCollection(context.TODO(), "bands") - mdb.CreateCollection(context.TODO(), "stories") - mdb.CreateCollection(context.TODO(), "users") - } - defer ic.Disconnect(context.TODO()) - Connect(uri, db) - author.ID = 696969 - ModelRegistry.Model(band{}, user{}, story{}) -} - var metallica = band{ ID: 1, Name: "Metallica", @@ -247,13 +216,3 @@ var bodom = band{ "Alexander Kuoppala", }, } - -func saveDoc(t *testing.T, doc IDocument) { - err := doc.Save() - assert.Nil(t, err) -} - -func createAndSave(t *testing.T, doc IDocument) { - mdl := Create(doc).(IDocument) - saveDoc(t, mdl) -} diff --git a/tests/main.go b/tests/main.go new file mode 100644 index 0000000..21fde97 --- /dev/null +++ b/tests/main.go @@ -0,0 +1,7 @@ +package main + +import "rockfic.com/orm" + +func main() { + orm.TestMain() +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..0439454 --- /dev/null +++ b/utils.go @@ -0,0 +1,95 @@ +package orm + +import ( + "fmt" + "reflect" + "regexp" + "strings" +) + +var pascalRegex = regexp.MustCompile(`(?P[a-z])(?P[A-Z])`) +var nonWordRegex = regexp.MustCompile(`[^a-zA-Z0-9_]`) + +func pascalToSnakeCase(str string) string { + step1 := pascalRegex.ReplaceAllString(str, `${lowercase}_${uppercase}`) + step2 := nonWordRegex.ReplaceAllString(step1, "_") + return strings.ToLower(step2) +} + +func canConvertTo[T any](thisType reflect.Type) bool { + return thisType.ConvertibleTo(reflect.TypeFor[T]()) || + thisType.ConvertibleTo(reflect.TypeFor[*T]()) || + strings.TrimPrefix(thisType.Name(), "*") == strings.TrimPrefix(reflect.TypeFor[T]().Name(), "*") +} + +func parseTags(t string) map[string]string { + tags := strings.Split(t, ";") + m := make(map[string]string) + for _, tag := range tags { + field := strings.Split(tag, ":") + if len(field) < 2 { + m[strings.ToLower(field[0])] = "t" + } else { + m[strings.ToLower(field[0])] = field[1] + } + } + return m +} + +func capitalizeFirst(str string) string { + firstChar := strings.ToUpper(string([]byte{str[0]})) + return firstChar + string(str[1:]) +} + +func serialToRegular(str string) string { + return strings.ReplaceAll(strings.ToLower(str), "serial", "int") +} +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return v.String() == "" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Ptr, reflect.Interface: + return v.IsNil() + } + return v.IsZero() +} +func reflectSet(f reflect.Value, v any) { + if !f.CanSet() || v == nil { + return + } + switch f.Kind() { + case reflect.Int, reflect.Int64: + switch val := v.(type) { + case int64: + f.SetInt(val) + case int32: + f.SetInt(int64(val)) + case int: + f.SetInt(int64(val)) + case uint64: + f.SetInt(int64(val)) + } + case reflect.String: + if s, ok := v.(string); ok { + f.SetString(s) + } + } +} + +func logTrunc(v any, length int) string { + if length < 5 { + length = 5 + } + str := fmt.Sprintf("%+v", v) + trunced := str[:min(length, len(str))] + if len(trunced) < len(str) { + trunced += "..." + } + return trunced +}