diamond-orm/model.go

298 lines
7.1 KiB
Go

package orm
import (
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"reflect"
"strings"
)
type Model struct {
Name string
Type reflect.Type
Relationships map[string]*Relationship
IDField string
Fields map[string]*Field
FieldsByColumnName map[string]*Field
TableName string
embeddedIsh bool
}
func (m *Model) addField(field *Field) {
field.Model = m
m.Fields[field.Name] = field
m.FieldsByColumnName[field.ColumnName] = field
}
func (m *Model) getAliasFields() map[string]string {
fields := make(map[string]string)
for _, f := range m.Fields {
if f.fk != nil {
continue
}
fields[f.Name] = fmt.Sprintf("%s.%s as %s_%s", m.TableName, f.ColumnName, strings.ToLower(m.Name), f.ColumnName)
}
return fields
}
func (m *Model) getPrimaryKey(val reflect.Value) (string, any) {
colField := m.Fields[m.IDField]
if colField == nil {
return "", nil
}
colName := colField.ColumnName
idField := val.FieldByName(m.IDField)
if idField.IsValid() {
return colName, idField.Interface()
}
return "", nil
}
func (m *Model) insert(v reflect.Value, e *Query, parentFks map[string]any) (any, error) {
var isTopLevel bool
var err error
var cn *pgxpool.Conn
if e.tx == nil && !e.engine.dryRun {
isTopLevel = true
var tx pgx.Tx
cn, err = e.engine.conn.Acquire(e.ctx)
if err != nil {
return nil, err
}
tx, err = cn.Begin(e.ctx)
if err != nil {
return nil, err
}
e.tx = tx
defer func() {
if err != nil {
e.tx.Rollback(e.ctx)
}
fmt.Printf("[DBG] discarding tx @ %p\n", e.tx)
fmt.Printf("[DBG] discarding conn @ %p\n", e.tx.Conn())
fmt.Printf("[DBG] discarding outer conn @ %p\n", cn)
e.tx = nil
cn.Release()
}()
}
for v.Kind() == reflect.Pointer {
v = v.Elem()
}
t := v.Type()
var cols []string
var placeholders []string
var args []any
var returningField *Field
for k, vv := range parentFks {
cols = append(cols, k)
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
args = append(args, vv)
}
for _, ff := range m.Fields {
//ft := t.Field(ff.Index)
var fv reflect.Value
if ff.Index > -1 {
fv = v.Field(ff.Index)
}
mfk := ff.fk
if mfk == nil {
mfk = m.Relationships[ff.Name]
}
if mfk != nil {
af := v.FieldByName(mfk.FieldName)
if af.Kind() == reflect.Struct {
idField := af.FieldByName(ff.fk.RelatedModel.IDField)
cols = append(cols, ff.ColumnName)
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
if !idField.IsValid() {
var nid any
nid, err = ff.fk.RelatedModel.insert(af, e, make(map[string]any))
if err != nil {
return 0, err
}
args = append(args, nid)
} else {
args = append(args, idField.Interface())
}
}
continue
}
col := ff.ColumnName
if ff.PrimaryKey {
returningField = ff
if fv.IsValid() && !fv.IsZero() {
cols = append(cols, col)
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
args = append(args, fv.Interface())
}
continue
}
if fv.IsValid() {
cols = append(cols, col)
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
args = append(args, fv.Interface())
}
}
var rfc = "id"
if returningField != nil {
rfc = returningField.ColumnName
}
scols := fmt.Sprintf("(%s)", strings.Join(cols, ", "))
svals := fmt.Sprintf("VALUES (%s) ", strings.Join(placeholders, ", "))
sql := fmt.Sprintf("INSERT INTO %s ",
m.TableName,
)
if len(cols) > 0 {
sql += scols
sql += " "
sql += svals
} else {
sql += "DEFAULT VALUES "
}
sql += fmt.Sprintf("RETURNING %s", rfc)
fmt.Printf("[INSERT] %s { %s }\n", sql, logTrunc(args, 200))
var id any
{
if v.FieldByName(m.IDField).IsValid() {
id = v.FieldByName(m.IDField).Interface()
}
}
if !e.engine.dryRun {
qr := e.engine.conn.QueryRow(e.ctx, sql, args...)
err = qr.Scan(&id)
if err != nil {
return 0, err
}
}
{
retField := v.FieldByName(returningField.Name)
if retField.IsValid() {
retField.Set(reflect.ValueOf(id))
}
}
for i := range t.NumField() {
ft := t.Field(i)
fv := v.Field(i)
if ft.Type.Kind() == reflect.Slice {
for j := range fv.Len() {
child := fv.Index(j).Addr()
cm := e.engine.modelMap.Map[child.Type().Elem().Name()]
if cm != nil && cm.embeddedIsh {
cfk := map[string]any{
m.TableName + "_id": id,
}
if m.Relationships[ft.Name] != nil {
cfk = map[string]any{
pascalToSnakeCase(m.Relationships[ft.Name].JoinField()): id,
}
}
_, err = cm.insert(child, e, cfk)
if err != nil {
return 0, err
}
} else if cm != nil {
rel := m.Relationships[ft.Name]
if rel != nil {
err = rel.joinInsert(child, e, id)
if err != nil {
return 0, err
}
}
}
}
}
}
if isTopLevel && !e.engine.dryRun {
err = e.tx.Commit(e.ctx)
}
return id, err
}
func (m *Model) update(val reflect.Value, q *Query, parentFks map[string]any) error {
var isTopLevel bool
var err error
if q.tx == nil && !q.engine.dryRun {
isTopLevel = true
var tx pgx.Tx
tx, err = q.engine.conn.Begin(q.ctx)
if err != nil {
return err
}
q.tx = tx
defer func() {
if err != nil {
q.tx.Rollback(q.ctx)
q.tx = nil
}
}()
}
cnt := 1
sets := make([]string, 0)
vals := make([]any, 0)
for val.Kind() == reflect.Pointer {
val = val.Elem()
}
for _, field := range m.Fields {
if field.fk != nil || field.ColumnType == "" || field.Name == m.IDField {
continue
}
if field.Index > -1 && field.Index < val.NumField() && field.fk == nil {
f := val.Field(field.Index)
sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt))
vals = append(vals, f.Interface())
cnt++
} else if _, ok := parentFks[field.ColumnName]; ok {
sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt))
vals = append(vals, parentFks[field.ColumnName])
cnt++
}
}
mcol, mpk := m.getPrimaryKey(val)
sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %v", m.TableName, strings.Join(sets, ", "), mcol, mpk)
fmt.Printf("[UPDATE] %s { %s }\n", sql, logTrunc(vals, 200))
if !q.engine.dryRun {
if _, err = q.tx.Exec(q.ctx, sql, vals...); err != nil {
return err
}
}
for _, rel := range m.Relationships {
if rel.Idx > -1 && rel.Idx < val.NumField() {
f := val.Field(rel.Idx)
if f.Kind() == reflect.Slice ||
f.Kind() == reflect.Array {
err = diffSlices(q.engine, q, val, rel)
if err != nil {
return err
}
} else if rel.RelatedModel.embeddedIsh && !m.embeddedIsh {
elemish := f.Type()
for elemish.Kind() == reflect.Ptr {
elemish = elemish.Elem()
}
if elemish.Kind() == reflect.Struct {
err = rel.RelatedModel.update(f, q, map[string]any{
pascalToSnakeCase(rel.JoinField()): mpk,
})
if err != nil {
return err
}
}
}
}
}
var finerr error
if isTopLevel && !q.engine.dryRun {
finerr = q.tx.Commit(q.ctx)
}
return finerr
}
/*func (m *Model) ensure(q *Query, val reflect.Value) error {
if _, pk := m.getPrimaryKey(val); pk != nil && !reflect.ValueOf(pk).IsZero() {
return nil
}
}*/