package orm import ( "context" "fmt" "reflect" ) func fetchJoinTableChildren(e *Engine, ctx context.Context, childRel *Relationship, pid any) (map[any]struct{}, error) { rsql := fmt.Sprintf("SELECT %s_%s FROM %s where %s_id = $1", childRel.Model.TableName, pascalToSnakeCase(childRel.FieldName), childRel.JoinTable(), childRel.RelatedModel.TableName) res := make(map[any]struct{}) if !e.dryRun { rows, err := e.conn.Query(ctx, rsql, pid) defer rows.Close() if err != nil { return nil, err } for rows.Next() { var id any if err = rows.Scan(&id); err != nil { return nil, err } res[id] = struct{}{} } } return res, nil } func fetchChildren(e *Engine, childRel *Relationship, pid any) (map[any]reflect.Value, error) { res := make(map[any]reflect.Value) qq := e.Model(reflect.New(childRel.RelatedModel.Type).Elem().Interface()) rrel := childRel.RelatedModel.Relationships[childRel.Model.Name] rfield := childRel.RelatedModel.Fields[rrel.FieldName] /*if rrel == nil { return res, fmt.Errorf("please report this, it shouldn't have happened :(") }*/ rawRows, err := qq.Where(fmt.Sprintf("%s.%s = $1", childRel.RelatedModel.TableName, rfield.ColumnName), pid).Find() if err != nil { return nil, err } inter := make([]any, 0) rrv := reflect.ValueOf(rawRows) for i := range rrv.Len() { inter = append(inter, rrv.Index(i).Interface()) } for _, row := range inter { v := reflect.ValueOf(row) bv := v for bv.Kind() == reflect.Ptr { bv = bv.Elem() } id := v.FieldByName(childRel.RelatedModel.IDField).Interface() res[id] = v } return res, nil } func preDiff(e *Engine, value reflect.Value) (*Model, error) { ptype := value.Type() for ptype.Kind() == reflect.Pointer { ptype = ptype.Elem() } model, ok := e.modelMap.Map[ptype.Name()] if !ok { return nil, fmt.Errorf("model '%s' not found", ptype.Name()) } return model, nil } func diffManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { model, err := preDiff(e, value) if err != nil { return err } _, ppk := model.getPrimaryKey(value) dbChildren, err := fetchChildren(e, rel, ppk) if err != nil { return err } memChildren := make(map[any]reflect.Value) fv := value.FieldByName(rel.FieldName) for i := range fv.Len() { child := fv.Index(i) _, cpk := rel.RelatedModel.getPrimaryKey(child) if cpk != nil { memChildren[cpk] = child } } // deletions // for pk := range dbChildren { if _, found := memChildren[pk]; !found { table := rel.RelatedModel.TableName idField := rel.RelatedModel.Fields[rel.RelatedModel.IDField] _, err = q.tx.Exec(q.ctx, fmt.Sprintf("DELETE FROM %s where %s = $1", table, idField.ColumnName), pk) if err != nil { return err } } } mField := model.Fields[model.IDField] mpks := map[string]any{} if !model.embeddedIsh { mpks[mField.ColumnName] = ppk } // update || insert // for i := range fv.Len() { cur := fv.Index(i) _, cpk := rel.RelatedModel.getPrimaryKey(cur) if cpk == nil || reflect.ValueOf(cpk).IsZero() { _, err = rel.RelatedModel.insert(cur, q, mpks) if err != nil { return err } } else { err = rel.RelatedModel.update(cur, q, mpks) if err != nil { return err } } } return nil } func diffManyToManySlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { model, err := preDiff(e, value) if err != nil { return err } _, ppk := model.getPrimaryKey(value) ids, err := fetchJoinTableChildren(e, q.ctx, rel, ppk) if err != nil { return err } memIds := make(map[any]reflect.Value) fv := value.FieldByName(rel.FieldName) for i := range fv.Len() { child := fv.Index(i) _, cpk := rel.RelatedModel.getPrimaryKey(child) if cpk != nil { memIds[cpk] = child } } for memId := range memIds { if _, found := ids[memId]; !found { err = rel.joinInsert(memIds[memId], q, ppk) if err != nil { return err } } } for id := range ids { if _, found := memIds[id]; !found { err = rel.joinDelete(ppk, id, q) if err != nil { return err } } } return nil } func diffSlices(e *Engine, q *Query, value reflect.Value, rel *Relationship) error { if rel.Type == ManyToMany || rel.m2mIsh() { return diffManyToManySlices(e, q, value, rel) } if rel.Type == HasMany { return diffManySlices(e, q, value, rel) } return nil }