175 lines
4.3 KiB
Go
175 lines
4.3 KiB
Go
package orm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
)
|
|
|
|
func fetchJoinTableChildren(e *Engine, ctx context.Context, childRel *Relationship, pid any) (map[any]struct{}, error) {
|
|
rsql := fmt.Sprintf("SELECT %s_%s FROM %s where %s_id = $1",
|
|
childRel.Model.TableName, pascalToSnakeCase(childRel.FieldName),
|
|
childRel.JoinTable(), childRel.RelatedModel.TableName)
|
|
res := make(map[any]struct{})
|
|
if !e.dryRun {
|
|
rows, err := e.conn.Query(ctx, rsql, pid)
|
|
defer rows.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for rows.Next() {
|
|
var id any
|
|
if err = rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
res[id] = struct{}{}
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func fetchChildren(e *Engine, childRel *Relationship, pid any) (map[any]reflect.Value, error) {
|
|
res := make(map[any]reflect.Value)
|
|
qq := e.Model(reflect.New(childRel.RelatedModel.Type).Elem().Interface())
|
|
rrel := childRel.RelatedModel.Relationships[childRel.Model.Name]
|
|
rfield := childRel.RelatedModel.Fields[rrel.FieldName]
|
|
/*if rrel == nil {
|
|
return res, fmt.Errorf("please report this, it shouldn't have happened :(")
|
|
}*/
|
|
rawRows, err := qq.Where(fmt.Sprintf("%s.%s = $1", childRel.RelatedModel.TableName, rfield.ColumnName), pid).Find()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
inter := make([]any, 0)
|
|
rrv := reflect.ValueOf(rawRows)
|
|
for i := range rrv.Len() {
|
|
inter = append(inter, rrv.Index(i).Interface())
|
|
}
|
|
for _, row := range inter {
|
|
v := reflect.ValueOf(row)
|
|
bv := v
|
|
for bv.Kind() == reflect.Ptr {
|
|
bv = bv.Elem()
|
|
}
|
|
id := v.FieldByName(childRel.RelatedModel.IDField).Interface()
|
|
res[id] = v
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func preDiff(e *Engine, value reflect.Value) (*Model, error) {
|
|
ptype := value.Type()
|
|
for ptype.Kind() == reflect.Pointer {
|
|
ptype = ptype.Elem()
|
|
}
|
|
model, ok := e.modelMap.Map[ptype.Name()]
|
|
if !ok {
|
|
return nil, fmt.Errorf("model '%s' not found", ptype.Name())
|
|
}
|
|
return model, nil
|
|
}
|
|
|
|
func diffManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
|
|
model, err := preDiff(e, value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, ppk := model.getPrimaryKey(value)
|
|
dbChildren, err := fetchChildren(e, rel, ppk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
memChildren := make(map[any]reflect.Value)
|
|
fv := value.FieldByName(rel.FieldName)
|
|
for i := range fv.Len() {
|
|
child := fv.Index(i)
|
|
_, cpk := rel.RelatedModel.getPrimaryKey(child)
|
|
if cpk != nil {
|
|
memChildren[cpk] = child
|
|
}
|
|
}
|
|
// deletions //
|
|
for pk := range dbChildren {
|
|
if _, found := memChildren[pk]; !found {
|
|
table := rel.RelatedModel.TableName
|
|
idField := rel.RelatedModel.Fields[rel.RelatedModel.IDField]
|
|
_, err = q.tx.Exec(q.ctx, fmt.Sprintf("DELETE FROM %s where %s = $1", table, idField.ColumnName), pk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
mField := model.Fields[model.IDField]
|
|
mpks := map[string]any{}
|
|
if !model.embeddedIsh {
|
|
mpks[mField.ColumnName] = ppk
|
|
}
|
|
// update || insert //
|
|
for i := range fv.Len() {
|
|
cur := fv.Index(i)
|
|
_, cpk := rel.RelatedModel.getPrimaryKey(cur)
|
|
if cpk == nil || reflect.ValueOf(cpk).IsZero() {
|
|
|
|
_, err = rel.RelatedModel.insert(cur, q, mpks)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
err = rel.RelatedModel.update(cur, q, mpks)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func diffManyToManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
|
|
model, err := preDiff(e, value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, ppk := model.getPrimaryKey(value)
|
|
ids, err := fetchJoinTableChildren(e, q.ctx, rel, ppk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
memIds := make(map[any]reflect.Value)
|
|
fv := value.FieldByName(rel.FieldName)
|
|
for i := range fv.Len() {
|
|
child := fv.Index(i)
|
|
_, cpk := rel.RelatedModel.getPrimaryKey(child)
|
|
if cpk != nil {
|
|
memIds[cpk] = child
|
|
}
|
|
}
|
|
for memId := range memIds {
|
|
if _, found := ids[memId]; !found {
|
|
err = rel.joinInsert(memIds[memId], q, ppk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
for id := range ids {
|
|
if _, found := memIds[id]; !found {
|
|
err = rel.joinDelete(ppk, id, q)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func diffSlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error {
|
|
if rel.Type == ManyToMany || rel.m2mIsh() {
|
|
return diffManyToManySlices(e, q, value, rel)
|
|
}
|
|
if rel.Type == HasMany {
|
|
return diffManySlices(e, q, value, rel)
|
|
}
|
|
return nil
|
|
}
|