package orm import ( "fmt" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "reflect" "strings" ) type Model struct { Name string Type reflect.Type Relationships map[string]*Relationship IDField string Fields map[string]*Field FieldsByColumnName map[string]*Field TableName string embeddedIsh bool } func (m *Model) addField(field *Field) { field.Model = m m.Fields[field.Name] = field m.FieldsByColumnName[field.ColumnName] = field } func (m *Model) getAliasFields() map[string]string { fields := make(map[string]string) for _, f := range m.Fields { if f.fk != nil { continue } fields[f.Name] = fmt.Sprintf("%s.%s as %s_%s", m.TableName, f.ColumnName, strings.ToLower(m.Name), f.ColumnName) } return fields } func (m *Model) getPrimaryKey(val reflect.Value) (string, any) { colField := m.Fields[m.IDField] if colField == nil { return "", nil } colName := colField.ColumnName idField := val.FieldByName(m.IDField) if idField.IsValid() { return colName, idField.Interface() } return "", nil } func (m *Model) insert(v reflect.Value, e *Query, parentFks map[string]any) (any, error) { var isTopLevel bool var err error var cn *pgxpool.Conn if e.tx == nil && !e.engine.dryRun { isTopLevel = true var tx pgx.Tx cn, err = e.engine.conn.Acquire(e.ctx) if err != nil { return nil, err } tx, err = cn.Begin(e.ctx) if err != nil { return nil, err } e.tx = tx defer func() { if err != nil { e.tx.Rollback(e.ctx) } fmt.Printf("[DBG] discarding tx @ %p\n", e.tx) fmt.Printf("[DBG] discarding conn @ %p\n", e.tx.Conn()) fmt.Printf("[DBG] discarding outer conn @ %p\n", cn) e.tx = nil cn.Release() }() } for v.Kind() == reflect.Pointer { v = v.Elem() } t := v.Type() var cols []string var placeholders []string var args []any var returningField *Field for k, vv := range parentFks { cols = append(cols, k) placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) args = append(args, vv) } for _, ff := range m.Fields { //ft := t.Field(ff.Index) var fv reflect.Value if ff.Index > -1 { fv = v.Field(ff.Index) } mfk := ff.fk if mfk == nil { mfk = m.Relationships[ff.Name] } if mfk != nil { af := v.FieldByName(mfk.FieldName) if af.Kind() == reflect.Struct { idField := af.FieldByName(ff.fk.RelatedModel.IDField) cols = append(cols, ff.ColumnName) placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) if !idField.IsValid() { var nid any nid, err = ff.fk.RelatedModel.insert(af, e, make(map[string]any)) if err != nil { return 0, err } args = append(args, nid) } else { args = append(args, idField.Interface()) } } continue } col := ff.ColumnName if ff.PrimaryKey { returningField = ff if fv.IsValid() && !fv.IsZero() { cols = append(cols, col) placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) args = append(args, fv.Interface()) } continue } if fv.IsValid() { cols = append(cols, col) placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1)) args = append(args, fv.Interface()) } } var rfc = "id" if returningField != nil { rfc = returningField.ColumnName } scols := fmt.Sprintf("(%s)", strings.Join(cols, ", ")) svals := fmt.Sprintf("VALUES (%s) ", strings.Join(placeholders, ", ")) sql := fmt.Sprintf("INSERT INTO %s ", m.TableName, ) if len(cols) > 0 { sql += scols sql += " " sql += svals } else { sql += "DEFAULT VALUES " } sql += fmt.Sprintf("RETURNING %s", rfc) fmt.Printf("[INSERT] %s { %s }\n", sql, logTrunc(args, 200)) var id any { if v.FieldByName(m.IDField).IsValid() { id = v.FieldByName(m.IDField).Interface() } } if !e.engine.dryRun { qr := e.engine.conn.QueryRow(e.ctx, sql, args...) err = qr.Scan(&id) if err != nil { return 0, err } } { retField := v.FieldByName(returningField.Name) if retField.IsValid() { retField.Set(reflect.ValueOf(id)) } } for i := range t.NumField() { ft := t.Field(i) fv := v.Field(i) if ft.Type.Kind() == reflect.Slice { for j := range fv.Len() { child := fv.Index(j).Addr() cm := e.engine.modelMap.Map[child.Type().Elem().Name()] if cm != nil && cm.embeddedIsh { cfk := map[string]any{ m.TableName + "_id": id, } if m.Relationships[ft.Name] != nil { cfk = map[string]any{ pascalToSnakeCase(m.Relationships[ft.Name].JoinField()): id, } } _, err = cm.insert(child, e, cfk) if err != nil { return 0, err } } else if cm != nil { rel := m.Relationships[ft.Name] if rel != nil { err = rel.joinInsert(child, e, id) if err != nil { return 0, err } } } } } } if isTopLevel && !e.engine.dryRun { err = e.tx.Commit(e.ctx) } return id, err } func (m *Model) update(val reflect.Value, q *Query, parentFks map[string]any) error { var isTopLevel bool var err error if q.tx == nil && !q.engine.dryRun { isTopLevel = true var tx pgx.Tx tx, err = q.engine.conn.Begin(q.ctx) if err != nil { return err } q.tx = tx defer func() { if err != nil { q.tx.Rollback(q.ctx) q.tx = nil } }() } cnt := 1 sets := make([]string, 0) vals := make([]any, 0) for val.Kind() == reflect.Pointer { val = val.Elem() } for _, field := range m.Fields { if field.fk != nil || field.ColumnType == "" || field.Name == m.IDField { continue } if field.Index > -1 && field.Index < val.NumField() && field.fk == nil { f := val.Field(field.Index) sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt)) vals = append(vals, f.Interface()) cnt++ } else if _, ok := parentFks[field.ColumnName]; ok { sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt)) vals = append(vals, parentFks[field.ColumnName]) cnt++ } } mcol, mpk := m.getPrimaryKey(val) sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %v", m.TableName, strings.Join(sets, ", "), mcol, mpk) fmt.Printf("[UPDATE] %s { %s }\n", sql, logTrunc(vals, 200)) if !q.engine.dryRun { if _, err = q.tx.Exec(q.ctx, sql, vals...); err != nil { return err } } for _, rel := range m.Relationships { if rel.Idx > -1 && rel.Idx < val.NumField() { f := val.Field(rel.Idx) if f.Kind() == reflect.Slice || f.Kind() == reflect.Array { err = diffSlices(q.engine, q, val, rel) if err != nil { return err } } else if rel.RelatedModel.embeddedIsh && !m.embeddedIsh { elemish := f.Type() for elemish.Kind() == reflect.Ptr { elemish = elemish.Elem() } if elemish.Kind() == reflect.Struct { err = rel.RelatedModel.update(f, q, map[string]any{ pascalToSnakeCase(rel.JoinField()): mpk, }) if err != nil { return err } } } } } var finerr error if isTopLevel && !q.engine.dryRun { finerr = q.tx.Commit(q.ctx) } return finerr } /*func (m *Model) ensure(q *Query, val reflect.Value) error { if _, pk := m.getPrimaryKey(val); pk != nil && !reflect.ValueOf(pk).IsZero() { return nil } }*/