diamond-orm/model_misc.go

175 lines
4.3 KiB
Go

package orm
import (
"context"
"fmt"
"reflect"
)
func fetchJoinTableChildren(e *Engine, ctx context.Context, childRel *Relationship, pid any) (map[any]struct{}, error) {
rsql := fmt.Sprintf("SELECT %s_%s FROM %s where %s_id = $1",
childRel.Model.TableName, pascalToSnakeCase(childRel.FieldName),
childRel.JoinTable(), childRel.RelatedModel.TableName)
res := make(map[any]struct{})
if !e.dryRun {
rows, err := e.conn.Query(ctx, rsql, pid)
defer rows.Close()
if err != nil {
return nil, err
}
for rows.Next() {
var id any
if err = rows.Scan(&id); err != nil {
return nil, err
}
res[id] = struct{}{}
}
}
return res, nil
}
func fetchChildren(e *Engine, childRel *Relationship, pid any) (map[any]reflect.Value, error) {
res := make(map[any]reflect.Value)
qq := e.Model(reflect.New(childRel.RelatedModel.Type).Elem().Interface())
rrel := childRel.RelatedModel.Relationships[childRel.Model.Name]
rfield := childRel.RelatedModel.Fields[rrel.FieldName]
/*if rrel == nil {
return res, fmt.Errorf("please report this, it shouldn't have happened :(")
}*/
rawRows, err := qq.Where(fmt.Sprintf("%s.%s = $1", childRel.RelatedModel.TableName, rfield.ColumnName), pid).Find()
if err != nil {
return nil, err
}
inter := make([]any, 0)
rrv := reflect.ValueOf(rawRows)
for i := range rrv.Len() {
inter = append(inter, rrv.Index(i).Interface())
}
for _, row := range inter {
v := reflect.ValueOf(row)
bv := v
for bv.Kind() == reflect.Ptr {
bv = bv.Elem()
}
id := v.FieldByName(childRel.RelatedModel.IDField).Interface()
res[id] = v
}
return res, nil
}
func preDiff(e *Engine, value reflect.Value) (*Model, error) {
ptype := value.Type()
for ptype.Kind() == reflect.Pointer {
ptype = ptype.Elem()
}
model, ok := e.modelMap.Map[ptype.Name()]
if !ok {
return nil, fmt.Errorf("model '%s' not found", ptype.Name())
}
return model, nil
}
func diffManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
model, err := preDiff(e, value)
if err != nil {
return err
}
_, ppk := model.getPrimaryKey(value)
dbChildren, err := fetchChildren(e, rel, ppk)
if err != nil {
return err
}
memChildren := make(map[any]reflect.Value)
fv := value.FieldByName(rel.FieldName)
for i := range fv.Len() {
child := fv.Index(i)
_, cpk := rel.RelatedModel.getPrimaryKey(child)
if cpk != nil {
memChildren[cpk] = child
}
}
// deletions //
for pk := range dbChildren {
if _, found := memChildren[pk]; !found {
table := rel.RelatedModel.TableName
idField := rel.RelatedModel.Fields[rel.RelatedModel.IDField]
_, err = q.tx.Exec(q.ctx, fmt.Sprintf("DELETE FROM %s where %s = $1", table, idField.ColumnName), pk)
if err != nil {
return err
}
}
}
mField := model.Fields[model.IDField]
mpks := map[string]any{}
if !model.embeddedIsh {
mpks[mField.ColumnName] = ppk
}
// update || insert //
for i := range fv.Len() {
cur := fv.Index(i)
_, cpk := rel.RelatedModel.getPrimaryKey(cur)
if cpk == nil || reflect.ValueOf(cpk).IsZero() {
_, err = rel.RelatedModel.insert(cur, q, mpks)
if err != nil {
return err
}
} else {
err = rel.RelatedModel.update(cur, q, mpks)
if err != nil {
return err
}
}
}
return nil
}
func diffManyToManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
model, err := preDiff(e, value)
if err != nil {
return err
}
_, ppk := model.getPrimaryKey(value)
ids, err := fetchJoinTableChildren(e, q.ctx, rel, ppk)
if err != nil {
return err
}
memIds := make(map[any]reflect.Value)
fv := value.FieldByName(rel.FieldName)
for i := range fv.Len() {
child := fv.Index(i)
_, cpk := rel.RelatedModel.getPrimaryKey(child)
if cpk != nil {
memIds[cpk] = child
}
}
for memId := range memIds {
if _, found := ids[memId]; !found {
err = rel.joinInsert(memIds[memId], q, ppk)
if err != nil {
return err
}
}
}
for id := range ids {
if _, found := memIds[id]; !found {
err = rel.joinDelete(ppk, id, q)
if err != nil {
return err
}
}
}
return nil
}
func diffSlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
if rel.Type == ManyToMany || rel.m2mIsh() {
return diffManyToManySlices(e, q, value, rel)
}
if rel.Type == HasMany {
return diffManySlices(e, q, value, rel)
}
return nil
}