From fefa6950637812b1d9344f259b839bb6496d55c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=98=99=E2=97=A6=20The=20Tablet=20=E2=9D=80=20GamerGirla?= =?UTF-8?q?ndCo=20=E2=97=A6=E2=9D=A7?= Date: Thu, 5 Sep 2024 13:57:35 -0400 Subject: [PATCH] some changes - move internal model helper funcs into model_internals.go - add Swap method to Model - modify `getNested` util to support indexing slice fields such as `Abc[0].Def` --- model.go | 292 ++++++--------------------------------------- model_internals.go | 263 ++++++++++++++++++++++++++++++++++++++++ model_test.go | 14 +++ util.go | 50 +++++--- 4 files changed, 347 insertions(+), 272 deletions(-) create mode 100644 model_internals.go diff --git a/model.go b/model.go index 56354d6..ba02c2d 100644 --- a/model.go +++ b/model.go @@ -6,7 +6,6 @@ import ( "reflect" "time" - "github.com/fatih/structtag" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" @@ -54,6 +53,7 @@ type HasIDSlice []HasID type IModel interface { Append(field string, a ...interface{}) error + Delete() error Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error) FindByID(id interface{}) (*Query, error) @@ -62,6 +62,8 @@ type IModel interface { getColl() *mongo.Collection getIdxs() []*mongo.IndexModel getParsedIdxs() map[string][]InternalIndex + Pull(field string, a ...any) error + Remove() error Save() error serializeToStore() any setTypeName(str string) @@ -195,19 +197,8 @@ func (m *Model) FindOne(query interface{}, options ...*options.FindOneOptions) ( m.self = qq.doc return qq, err } -func doDelete(m *Model, arg interface{}) error { - self, ok := arg.(HasID) - if !ok { - return fmt.Errorf("Object '%s' does not implement HasID", nameOf(arg)) - } - c := m.getColl() - _, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()}) - if err == nil { - m.exists = false - } - return err -} +// Delete - deletes a model instance from the database func (m *Model) Delete() error { var err error val := valueOf(m.self) @@ -224,96 +215,9 @@ func (m *Model) Delete() error { } } -func incrementTagged(item interface{}) interface{} { - rv := reflect.ValueOf(item) - rt := reflect.TypeOf(item) - if rv.Kind() != reflect.Pointer { - rv = makeSettable(rv, item) - } - if rt.Kind() == reflect.Pointer { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - if rt.Kind() == reflect.Slice { - for i := 0; i < rv.Elem().Len(); i++ { - incrementTagged(rv.Elem().Index(i).Addr().Interface()) - } - } else { - return item - } - } - for i := 0; i < rt.NumField(); i++ { - structField := rt.Field(i) - cur := rv.Elem().Field(i) - tags, err := structtag.Parse(string(structField.Tag)) - if err != nil { - continue - } - incTag, err := tags.Get("autoinc") - if err != nil { - continue - } - - nid := getLastInColl(incTag.Name, cur.Interface()) - if cur.IsZero() { - coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur) - if coerced != nil { - cur.Set(reflect.ValueOf(coerced)) - } else { - cur.Set(reflect.ValueOf(incrementInterface(nid))) - } - } - counterColl := DB.Collection(COUNTER_COL) - counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.Update().SetUpsert(true)) - - } - return rv.Elem().Interface() -} - -func incrementAll(item interface{}) { - if item == nil { - return - } - vp := reflect.ValueOf(item) - el := vp - if vp.Kind() == reflect.Pointer { - el = vp.Elem() - } - if vp.Kind() == reflect.Pointer && vp.IsNil() { - return - } - vt := el.Type() - switch el.Kind() { - case reflect.Struct: - incrementTagged(item) - for i := 0; i < el.NumField(); i++ { - fv := el.Field(i) - fst := vt.Field(i) - if !fst.IsExported() { - continue - } - incrementAll(fv.Interface()) - } - case reflect.Slice: - for i := 0; i < el.Len(); i++ { - incd := incrementTagged(el.Index(i).Addr().Interface()) - if reflect.ValueOf(incd).Kind() == reflect.Pointer { - el.Index(i).Set(reflect.ValueOf(incd).Elem()) - } else { - el.Index(i).Set(reflect.ValueOf(incd)) - } - } - } -} - -func checkStruct(ref reflect.Value) error { - if ref.Kind() == reflect.Slice { - return fmt.Errorf("Cannot append to multiple documents!") - } - if ref.Kind() != reflect.Struct { - return fmt.Errorf("Current object is not a struct!") - } - return nil +// Remove - alias for Delete +func (m *Model) Remove() error { + return m.Delete() } // Append appends one or more items to `field`. @@ -351,6 +255,7 @@ func (m *Model) Append(field string, a ...interface{}) error { return nil } +// Pull - removes elements from the subdocument slice stored in `field`. func (m *Model) Pull(field string, a ...any) error { rv := reflect.ValueOf(m.self) selfRef := rv @@ -387,48 +292,42 @@ outer: return nil } -func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { - var err error - m, ok := arg.(IModel) - if !ok { - return fmt.Errorf("type '%s' is not a model", nameOf(arg)) +// Swap - swaps the elements at indexes `i` and `j` in the +// slice stored at `field` +func (m *Model) Swap(field string, i, j int) error { + rv := reflect.ValueOf(m.self) + selfRef := rv + rt := reflect.TypeOf(m.self) + if selfRef.Kind() == reflect.Pointer { + selfRef = selfRef.Elem() + rt = rt.Elem() } - m.setSelf(m) - now := time.Now() - selfo := reflect.ValueOf(m) - vp := selfo - if vp.Kind() != reflect.Ptr { - vp = reflect.New(selfo.Type()) - vp.Elem().Set(selfo) + if err := checkStruct(selfRef); err != nil { + return err } - var asHasId = vp.Interface().(HasID) - if isNew { - m.setCreated(now) + _, origV, err := getNested(field, selfRef) + if err != nil { + return err } - m.setModified(now) - idxs := m.getIdxs() - for _, i := range idxs { - _, err = c.Indexes().CreateOne(context.TODO(), *i) - if err != nil { - return err - } - } - if isNew { - nid := getLastInColl(c.Name(), asHasId.Id()) - pnid := incrementInterface(nid) - if reflect.ValueOf(asHasId.Id()).IsZero() { - (asHasId).SetId(pnid) - } - incrementAll(asHasId) - _, err = c.InsertOne(context.TODO(), m.serializeToStore()) - if err == nil { - m.setExists(true) - } - } else { - _, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.(HasID).Id()}}, m.serializeToStore()) + origRef := makeSettable(*origV, (*origV).Interface()) + fv := origRef + if fv.Kind() == reflect.Pointer { + fv = fv.Elem() } - return err + if err = checkSlice(fv); err != nil { + return err + } + if i >= fv.Len() || j >= fv.Len() { + return fmt.Errorf("index(es) out of bounds") + } + oi := fv.Index(i).Interface() + oj := fv.Index(j).Interface() + + fv.Index(i).Set(reflect.ValueOf(oj)) + fv.Index(j).Set(reflect.ValueOf(oi)) + + return nil } func (m *Model) Save() error { @@ -450,119 +349,6 @@ func (m *Model) serializeToStore() any { return serializeIDs((m).self) } -func serializeIDs(input interface{}) interface{} { - - vp := reflect.ValueOf(input) - mt := reflect.TypeOf(input) - var ret interface{} - if vp.Kind() != reflect.Ptr { - if vp.CanAddr() { - vp = vp.Addr() - } else { - vp = makeSettable(vp, input) - } - } - - if mt.Kind() == reflect.Pointer { - mt = mt.Elem() - } - getID := func(bbb interface{}) interface{} { - mptr := reflect.ValueOf(bbb) - if mptr.Kind() != reflect.Pointer { - mptr = makeSettable(mptr, bbb) - } - ifc, ok := mptr.Interface().(HasID) - if ok { - return ifc.Id() - } else { - return nil - } - } - /*var itagged interface{} - if reflect.ValueOf(itagged).Kind() != reflect.Pointer { - itagged = incrementTagged(&input) - } else { - itagged = incrementTagged(input) - } - taggedVal := reflect.ValueOf(reflect.ValueOf(itagged).Interface()).Elem() - if vp.Kind() == reflect.Ptr { - tmp := reflect.ValueOf(taggedVal.Interface()) - if tmp.Kind() == reflect.Pointer { - vp.Elem().Set(tmp.Elem()) - } else { - vp.Elem().Set(tmp) - } - }*/ - switch vp.Elem().Kind() { - case reflect.Struct: - ret0 := bson.M{} - for i := 0; i < vp.Elem().NumField(); i++ { - fv := vp.Elem().Field(i) - ft := mt.Field(i) - tag, err := structtag.Parse(string(ft.Tag)) - panik(err) - bbson, err := tag.Get("bson") - if err != nil || bbson.Name == "-" { - continue - } - if bbson.Name == "" { - marsh, _ := bson.Marshal(fv.Interface()) - unmarsh := bson.M{} - bson.Unmarshal(marsh, &unmarsh) - for k, v := range unmarsh { - ret0[k] = v - /*if t, ok := v.(primitive.DateTime); ok { - ret0 - } else { - }*/ - } - } else { - _, terr := tag.Get("ref") - if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer { - vp1 := reflect.New(fv.Type()) - vp1.Elem().Set(reflect.ValueOf(fv.Interface())) - fv.Set(vp1.Elem()) - } - if terr == nil { - ifc, ok := fv.Interface().(HasID) - if fv.Kind() == reflect.Slice { - rarr := bson.A{} - for j := 0; j < fv.Len(); j++ { - rarr = append(rarr, getID(fv.Index(j).Interface())) - } - ret0[bbson.Name] = rarr - /*ret0[bbson.Name] = serializeIDs(fv.Interface()) - break*/ - } else if !ok { - panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name)) - } else { - if reflect.ValueOf(ifc).IsNil() { - ret0[bbson.Name] = nil - } else { - ret0[bbson.Name] = ifc.Id() - } - } - - } else { - ret0[bbson.Name] = serializeIDs(fv.Interface()) - } - } - - ret = ret0 - } - case reflect.Slice: - ret0 := bson.A{} - for i := 0; i < vp.Elem().Len(); i++ { - - ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface())) - } - ret = ret0 - default: - ret = vp.Elem().Interface() - } - return ret -} - // Create creates a new instance of a given model // and returns a pointer to it. func Create(d any) any { diff --git a/model_internals.go b/model_internals.go new file mode 100644 index 0000000..e215a1f --- /dev/null +++ b/model_internals.go @@ -0,0 +1,263 @@ +package orm + +import ( + "context" + "fmt" + "github.com/fatih/structtag" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "reflect" + "time" +) + +func serializeIDs(input interface{}) interface{} { + + vp := reflect.ValueOf(input) + mt := reflect.TypeOf(input) + var ret interface{} + if vp.Kind() != reflect.Ptr { + if vp.CanAddr() { + vp = vp.Addr() + } else { + vp = makeSettable(vp, input) + } + } + + if mt.Kind() == reflect.Pointer { + mt = mt.Elem() + } + getID := func(bbb interface{}) interface{} { + mptr := reflect.ValueOf(bbb) + if mptr.Kind() != reflect.Pointer { + mptr = makeSettable(mptr, bbb) + } + ifc, ok := mptr.Interface().(HasID) + if ok { + return ifc.Id() + } else { + return nil + } + } + /*var itagged interface{} + if reflect.ValueOf(itagged).Kind() != reflect.Pointer { + itagged = incrementTagged(&input) + } else { + itagged = incrementTagged(input) + } + taggedVal := reflect.ValueOf(reflect.ValueOf(itagged).Interface()).Elem() + if vp.Kind() == reflect.Ptr { + tmp := reflect.ValueOf(taggedVal.Interface()) + if tmp.Kind() == reflect.Pointer { + vp.Elem().Set(tmp.Elem()) + } else { + vp.Elem().Set(tmp) + } + }*/ + switch vp.Elem().Kind() { + case reflect.Struct: + ret0 := bson.M{} + for i := 0; i < vp.Elem().NumField(); i++ { + fv := vp.Elem().Field(i) + ft := mt.Field(i) + tag, err := structtag.Parse(string(ft.Tag)) + panik(err) + bbson, err := tag.Get("bson") + if err != nil || bbson.Name == "-" { + continue + } + if bbson.Name == "" { + marsh, _ := bson.Marshal(fv.Interface()) + unmarsh := bson.M{} + bson.Unmarshal(marsh, &unmarsh) + for k, v := range unmarsh { + ret0[k] = v + /*if t, ok := v.(primitive.DateTime); ok { + ret0 + } else { + }*/ + } + } else { + _, terr := tag.Get("ref") + if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer { + vp1 := reflect.New(fv.Type()) + vp1.Elem().Set(reflect.ValueOf(fv.Interface())) + fv.Set(vp1.Elem()) + } + if terr == nil { + ifc, ok := fv.Interface().(HasID) + if fv.Kind() == reflect.Slice { + rarr := bson.A{} + for j := 0; j < fv.Len(); j++ { + rarr = append(rarr, getID(fv.Index(j).Interface())) + } + ret0[bbson.Name] = rarr + /*ret0[bbson.Name] = serializeIDs(fv.Interface()) + break*/ + } else if !ok { + panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name)) + } else { + if reflect.ValueOf(ifc).IsNil() { + ret0[bbson.Name] = nil + } else { + ret0[bbson.Name] = ifc.Id() + } + } + + } else { + ret0[bbson.Name] = serializeIDs(fv.Interface()) + } + } + + ret = ret0 + } + case reflect.Slice: + ret0 := bson.A{} + for i := 0; i < vp.Elem().Len(); i++ { + + ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface())) + } + ret = ret0 + default: + ret = vp.Elem().Interface() + } + return ret +} +func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { + var err error + m, ok := arg.(IModel) + if !ok { + return fmt.Errorf("type '%s' is not a model", nameOf(arg)) + } + m.setSelf(m) + now := time.Now() + selfo := reflect.ValueOf(m) + vp := selfo + if vp.Kind() != reflect.Ptr { + vp = reflect.New(selfo.Type()) + vp.Elem().Set(selfo) + } + var asHasId = vp.Interface().(HasID) + if isNew { + m.setCreated(now) + } + m.setModified(now) + idxs := m.getIdxs() + for _, i := range idxs { + _, err = c.Indexes().CreateOne(context.TODO(), *i) + if err != nil { + return err + } + } + if isNew { + nid := getLastInColl(c.Name(), asHasId.Id()) + pnid := incrementInterface(nid) + if reflect.ValueOf(asHasId.Id()).IsZero() { + (asHasId).SetId(pnid) + } + incrementAll(asHasId) + + _, err = c.InsertOne(context.TODO(), m.serializeToStore()) + if err == nil { + m.setExists(true) + } + } else { + _, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.(HasID).Id()}}, m.serializeToStore()) + } + return err +} +func doDelete(m *Model, arg interface{}) error { + self, ok := arg.(HasID) + + if !ok { + return fmt.Errorf("Object '%s' does not implement HasID", nameOf(arg)) + } + c := m.getColl() + _, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()}) + if err == nil { + m.exists = false + } + return err +} +func incrementTagged(item interface{}) interface{} { + rv := reflect.ValueOf(item) + rt := reflect.TypeOf(item) + if rv.Kind() != reflect.Pointer { + rv = makeSettable(rv, item) + } + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + if rt.Kind() == reflect.Slice { + for i := 0; i < rv.Elem().Len(); i++ { + incrementTagged(rv.Elem().Index(i).Addr().Interface()) + } + } else { + return item + } + } + for i := 0; i < rt.NumField(); i++ { + structField := rt.Field(i) + cur := rv.Elem().Field(i) + tags, err := structtag.Parse(string(structField.Tag)) + if err != nil { + continue + } + incTag, err := tags.Get("autoinc") + if err != nil { + continue + } + + nid := getLastInColl(incTag.Name, cur.Interface()) + if cur.IsZero() { + coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur) + if coerced != nil { + cur.Set(reflect.ValueOf(coerced)) + } else { + cur.Set(reflect.ValueOf(incrementInterface(nid))) + } + } + counterColl := DB.Collection(COUNTER_COL) + counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.Update().SetUpsert(true)) + + } + return rv.Elem().Interface() +} + +func incrementAll(item interface{}) { + if item == nil { + return + } + vp := reflect.ValueOf(item) + el := vp + if vp.Kind() == reflect.Pointer { + el = vp.Elem() + } + if vp.Kind() == reflect.Pointer && vp.IsNil() { + return + } + vt := el.Type() + switch el.Kind() { + case reflect.Struct: + incrementTagged(item) + for i := 0; i < el.NumField(); i++ { + fv := el.Field(i) + fst := vt.Field(i) + if !fst.IsExported() { + continue + } + incrementAll(fv.Interface()) + } + case reflect.Slice: + for i := 0; i < el.Len(); i++ { + incd := incrementTagged(el.Index(i).Addr().Interface()) + if reflect.ValueOf(incd).Kind() == reflect.Pointer { + el.Index(i).Set(reflect.ValueOf(incd).Elem()) + } else { + el.Index(i).Set(reflect.ValueOf(incd)) + } + } + default: + } +} diff --git a/model_test.go b/model_test.go index d6f592a..a17112e 100644 --- a/model_test.go +++ b/model_test.go @@ -146,3 +146,17 @@ func TestModel_Pull(t *testing.T) { query.Exec(fin) assert.Equal(t, 4, len(fin.Chapters)) } + +func TestModel_Swap(t *testing.T) { + initTest() + iti_single.Author = &author + storyDoc := Create(iti_single).(*story) + storyDoc.Chapters[0].Bands = append(storyDoc.Chapters[0].Bands, metallica) + assert.Equal(t, 2, len(storyDoc.Chapters[0].Bands)) + err := storyDoc.Swap("Chapters[0].Bands", 0, 1) + assert.Nil(t, err) + c := storyDoc.Chapters[0].Bands + assert.Equal(t, metallica.ID, c[0].ID) + assert.Equal(t, dh.ID, c[1].ID) + saveDoc(t, storyDoc) +} diff --git a/util.go b/util.go index bf4fbf7..cd0ce1a 100644 --- a/util.go +++ b/util.go @@ -4,6 +4,8 @@ import ( "fmt" "go.mongodb.org/mongo-driver/bson/primitive" "reflect" + "regexp" + "strconv" "strings" ) @@ -37,22 +39,6 @@ func valueOf(i interface{}) reflect.Value { return v } -func iFace(input interface{}) interface{} { - return reflect.ValueOf(input).Interface() -} - -func iFaceSlice(input interface{}) []interface{} { - ret := make([]interface{}, 0) - fv := reflect.ValueOf(input) - if fv.Type().Kind() != reflect.Slice { - return ret - } - for i := 0; i < fv.Len(); i++ { - ret = append(ret, fv.Index(i).Interface()) - } - return ret -} - func coerceInt(input reflect.Value, dst reflect.Value) interface{} { if input.Type().Kind() == reflect.Pointer { input = input.Elem() @@ -66,20 +52,29 @@ func coerceInt(input reflect.Value, dst reflect.Value) interface{} { return nil } +var arrRegex, _ = regexp.Compile(`\[(?P\d+)]$`) + func getNested(field string, value reflect.Value) (*reflect.StructField, *reflect.Value, error) { if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") { return nil, nil, fmt.Errorf("Malformed field name %s passed", field) } dots := strings.Split(field, ".") - if value.Kind() != reflect.Struct { + if value.Kind() != reflect.Struct && arrRegex.FindString(dots[0]) == "" { return nil, nil, fmt.Errorf("This value is not a struct!") } ref := value if ref.Kind() == reflect.Pointer { ref = ref.Elem() } - fv := ref.FieldByName(dots[0]) - ft, _ := ref.Type().FieldByName(dots[0]) + var fv reflect.Value = ref.FieldByName(arrRegex.ReplaceAllString(dots[0], "")) + if arrRegex.FindString(dots[0]) != "" && fv.Kind() == reflect.Slice { + matches := arrRegex.FindStringSubmatch(dots[0]) + ridx, _ := strconv.Atoi(matches[0]) + idx := int(ridx) + fv = fv.Index(idx) + } + + ft, _ := ref.Type().FieldByName(arrRegex.ReplaceAllString(dots[0], "")) if len(dots) > 1 { return getNested(strings.Join(dots[1:], "."), fv) } else { @@ -129,3 +124,20 @@ func pull(s reflect.Value, idx int, typ reflect.Type) reflect.Value { } return retI.Elem() } + +func checkStruct(ref reflect.Value) error { + if ref.Kind() == reflect.Slice { + return fmt.Errorf("Cannot append to multiple documents!") + } + if ref.Kind() != reflect.Struct { + return fmt.Errorf("Current object is not a struct!") + } + return nil +} + +func checkSlice(ref reflect.Value) error { + if ref.Kind() != reflect.Slice { + return fmt.Errorf("Current field is not a slice!") + } + return nil +}