298 lines
7.1 KiB
Go
298 lines
7.1 KiB
Go
package orm
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
type Model struct {
|
|
Name string
|
|
Type reflect.Type
|
|
Relationships map[string]*Relationship
|
|
IDField string
|
|
Fields map[string]*Field
|
|
FieldsByColumnName map[string]*Field
|
|
TableName string
|
|
embeddedIsh bool
|
|
}
|
|
|
|
func (m *Model) addField(field *Field) {
|
|
field.Model = m
|
|
m.Fields[field.Name] = field
|
|
m.FieldsByColumnName[field.ColumnName] = field
|
|
}
|
|
|
|
func (m *Model) getAliasFields() map[string]string {
|
|
fields := make(map[string]string)
|
|
for _, f := range m.Fields {
|
|
if f.fk != nil {
|
|
continue
|
|
}
|
|
fields[f.Name] = fmt.Sprintf("%s.%s as %s_%s", m.TableName, f.ColumnName, strings.ToLower(m.Name), f.ColumnName)
|
|
}
|
|
return fields
|
|
}
|
|
|
|
func (m *Model) getPrimaryKey(val reflect.Value) (string, any) {
|
|
colField := m.Fields[m.IDField]
|
|
if colField == nil {
|
|
return "", nil
|
|
}
|
|
colName := colField.ColumnName
|
|
idField := val.FieldByName(m.IDField)
|
|
if idField.IsValid() {
|
|
return colName, idField.Interface()
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (m *Model) insert(v reflect.Value, e *Query, parentFks map[string]any) (any, error) {
|
|
var isTopLevel bool
|
|
var err error
|
|
var cn *pgxpool.Conn
|
|
if e.tx == nil && !e.engine.dryRun {
|
|
isTopLevel = true
|
|
var tx pgx.Tx
|
|
cn, err = e.engine.conn.Acquire(e.ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tx, err = cn.Begin(e.ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e.tx = tx
|
|
defer func() {
|
|
if err != nil {
|
|
e.tx.Rollback(e.ctx)
|
|
}
|
|
fmt.Printf("[DBG] discarding tx @ %p\n", e.tx)
|
|
fmt.Printf("[DBG] discarding conn @ %p\n", e.tx.Conn())
|
|
fmt.Printf("[DBG] discarding outer conn @ %p\n", cn)
|
|
e.tx = nil
|
|
cn.Release()
|
|
}()
|
|
}
|
|
for v.Kind() == reflect.Pointer {
|
|
v = v.Elem()
|
|
}
|
|
t := v.Type()
|
|
var cols []string
|
|
var placeholders []string
|
|
var args []any
|
|
var returningField *Field
|
|
for k, vv := range parentFks {
|
|
cols = append(cols, k)
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
|
|
args = append(args, vv)
|
|
}
|
|
for _, ff := range m.Fields {
|
|
//ft := t.Field(ff.Index)
|
|
var fv reflect.Value
|
|
if ff.Index > -1 {
|
|
fv = v.Field(ff.Index)
|
|
}
|
|
|
|
mfk := ff.fk
|
|
if mfk == nil {
|
|
mfk = m.Relationships[ff.Name]
|
|
}
|
|
if mfk != nil {
|
|
af := v.FieldByName(mfk.FieldName)
|
|
if af.Kind() == reflect.Struct {
|
|
idField := af.FieldByName(ff.fk.RelatedModel.IDField)
|
|
cols = append(cols, ff.ColumnName)
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
|
|
if !idField.IsValid() {
|
|
var nid any
|
|
nid, err = ff.fk.RelatedModel.insert(af, e, make(map[string]any))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
args = append(args, nid)
|
|
} else {
|
|
args = append(args, idField.Interface())
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
col := ff.ColumnName
|
|
if ff.PrimaryKey {
|
|
returningField = ff
|
|
if fv.IsValid() && !fv.IsZero() {
|
|
cols = append(cols, col)
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
|
|
args = append(args, fv.Interface())
|
|
}
|
|
continue
|
|
}
|
|
if fv.IsValid() {
|
|
cols = append(cols, col)
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", len(placeholders)+1))
|
|
args = append(args, fv.Interface())
|
|
}
|
|
}
|
|
var rfc = "id"
|
|
if returningField != nil {
|
|
rfc = returningField.ColumnName
|
|
}
|
|
scols := fmt.Sprintf("(%s)", strings.Join(cols, ", "))
|
|
svals := fmt.Sprintf("VALUES (%s) ", strings.Join(placeholders, ", "))
|
|
sql := fmt.Sprintf("INSERT INTO %s ",
|
|
m.TableName,
|
|
)
|
|
if len(cols) > 0 {
|
|
sql += scols
|
|
sql += " "
|
|
sql += svals
|
|
} else {
|
|
sql += "DEFAULT VALUES "
|
|
}
|
|
sql += fmt.Sprintf("RETURNING %s", rfc)
|
|
fmt.Printf("[INSERT] %s { %s }\n", sql, logTrunc(args, 200))
|
|
var id any
|
|
{
|
|
if v.FieldByName(m.IDField).IsValid() {
|
|
id = v.FieldByName(m.IDField).Interface()
|
|
}
|
|
}
|
|
if !e.engine.dryRun {
|
|
qr := e.engine.conn.QueryRow(e.ctx, sql, args...)
|
|
err = qr.Scan(&id)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
{
|
|
retField := v.FieldByName(returningField.Name)
|
|
if retField.IsValid() {
|
|
retField.Set(reflect.ValueOf(id))
|
|
}
|
|
}
|
|
for i := range t.NumField() {
|
|
ft := t.Field(i)
|
|
fv := v.Field(i)
|
|
if ft.Type.Kind() == reflect.Slice {
|
|
for j := range fv.Len() {
|
|
child := fv.Index(j).Addr()
|
|
cm := e.engine.modelMap.Map[child.Type().Elem().Name()]
|
|
if cm != nil && cm.embeddedIsh {
|
|
cfk := map[string]any{
|
|
m.TableName + "_id": id,
|
|
}
|
|
if m.Relationships[ft.Name] != nil {
|
|
cfk = map[string]any{
|
|
pascalToSnakeCase(m.Relationships[ft.Name].JoinField()): id,
|
|
}
|
|
}
|
|
_, err = cm.insert(child, e, cfk)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
} else if cm != nil {
|
|
rel := m.Relationships[ft.Name]
|
|
if rel != nil {
|
|
err = rel.joinInsert(child, e, id)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if isTopLevel && !e.engine.dryRun {
|
|
err = e.tx.Commit(e.ctx)
|
|
}
|
|
return id, err
|
|
}
|
|
|
|
func (m *Model) update(val reflect.Value, q *Query, parentFks map[string]any) error {
|
|
var isTopLevel bool
|
|
var err error
|
|
if q.tx == nil && !q.engine.dryRun {
|
|
isTopLevel = true
|
|
var tx pgx.Tx
|
|
tx, err = q.engine.conn.Begin(q.ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
q.tx = tx
|
|
defer func() {
|
|
if err != nil {
|
|
q.tx.Rollback(q.ctx)
|
|
q.tx = nil
|
|
}
|
|
}()
|
|
}
|
|
cnt := 1
|
|
sets := make([]string, 0)
|
|
vals := make([]any, 0)
|
|
for val.Kind() == reflect.Pointer {
|
|
val = val.Elem()
|
|
}
|
|
for _, field := range m.Fields {
|
|
if field.fk != nil || field.ColumnType == "" || field.Name == m.IDField {
|
|
continue
|
|
}
|
|
if field.Index > -1 && field.Index < val.NumField() && field.fk == nil {
|
|
f := val.Field(field.Index)
|
|
sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt))
|
|
vals = append(vals, f.Interface())
|
|
cnt++
|
|
} else if _, ok := parentFks[field.ColumnName]; ok {
|
|
sets = append(sets, fmt.Sprintf("%s = $%d", field.ColumnName, cnt))
|
|
vals = append(vals, parentFks[field.ColumnName])
|
|
cnt++
|
|
}
|
|
}
|
|
mcol, mpk := m.getPrimaryKey(val)
|
|
sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %v", m.TableName, strings.Join(sets, ", "), mcol, mpk)
|
|
fmt.Printf("[UPDATE] %s { %s }\n", sql, logTrunc(vals, 200))
|
|
if !q.engine.dryRun {
|
|
if _, err = q.tx.Exec(q.ctx, sql, vals...); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, rel := range m.Relationships {
|
|
if rel.Idx > -1 && rel.Idx < val.NumField() {
|
|
f := val.Field(rel.Idx)
|
|
if f.Kind() == reflect.Slice ||
|
|
f.Kind() == reflect.Array {
|
|
err = diffSlices(q.engine, q, val, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if rel.RelatedModel.embeddedIsh && !m.embeddedIsh {
|
|
elemish := f.Type()
|
|
for elemish.Kind() == reflect.Ptr {
|
|
elemish = elemish.Elem()
|
|
}
|
|
if elemish.Kind() == reflect.Struct {
|
|
err = rel.RelatedModel.update(f, q, map[string]any{
|
|
pascalToSnakeCase(rel.JoinField()): mpk,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
var finerr error
|
|
if isTopLevel && !q.engine.dryRun {
|
|
finerr = q.tx.Commit(q.ctx)
|
|
}
|
|
return finerr
|
|
}
|
|
|
|
/*func (m *Model) ensure(q *Query, val reflect.Value) error {
|
|
if _, pk := m.getPrimaryKey(val); pk != nil && !reflect.ValueOf(pk).IsZero() {
|
|
return nil
|
|
}
|
|
|
|
}*/
|