From 924b3fc9d2c5239bdd3f8670623fe8900d020780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=98=99=E2=97=A6=20The=20Tablet=20=E2=9D=80=20GamerGirla?= =?UTF-8?q?ndCo=20=E2=97=A6=E2=9D=A7?= Date: Wed, 4 Sep 2024 19:43:12 -0400 Subject: [PATCH] many additions! - add support for `autoinc` struct tag - add Pull method to Model - add utility func to increment a value of an unknown type by 1 - idcounter.go no longer increments the counter value (it is now up to the caller to increment the returned value themselves) --- idcounter.go | 30 ----- model.go | 350 +++++++++++++++++++++++++++++++++----------------- model_test.go | 23 +++- query.go | 9 +- testing.go | 7 + util.go | 40 +++++- 6 files changed, 303 insertions(+), 156 deletions(-) diff --git a/idcounter.go b/idcounter.go index d91c523..ca2ab18 100644 --- a/idcounter.go +++ b/idcounter.go @@ -39,36 +39,6 @@ func getLastInColl(cname string, id interface{}) interface{} { Current: id, } cnt.Collection = cname - switch nini := cnt.Current.(type) { - case int: - if nini == 0 { - cnt.Current = nini + 1 - } - case int32: - if nini == int32(0) { - cnt.Current = nini + 1 - } - case int64: - if nini == int64(0) { - cnt.Current = nini + 1 - } - case uint: - if nini == uint(0) { - cnt.Current = nini + 1 - } - case uint32: - if nini == uint32(0) { - cnt.Current = nini + 1 - } - case uint64: - if nini == uint64(0) { - cnt.Current = nini + 1 - } - case string: - cnt.Current = NextStringID() - case primitive.ObjectID: - cnt.Current = primitive.NewObjectID() - } } return cnt.Current } diff --git a/model.go b/model.go index d304d07..56354d6 100644 --- a/model.go +++ b/model.go @@ -8,7 +8,6 @@ import ( "github.com/fatih/structtag" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -64,7 +63,7 @@ type IModel interface { getIdxs() []*mongo.IndexModel getParsedIdxs() map[string][]InternalIndex Save() error - serializeToStore() primitive.M + serializeToStore() any setTypeName(str string) getExists() bool setExists(n bool) @@ -225,6 +224,98 @@ func (m *Model) Delete() error { } } +func incrementTagged(item interface{}) interface{} { + rv := reflect.ValueOf(item) + rt := reflect.TypeOf(item) + if rv.Kind() != reflect.Pointer { + rv = makeSettable(rv, item) + } + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + if rt.Kind() == reflect.Slice { + for i := 0; i < rv.Elem().Len(); i++ { + incrementTagged(rv.Elem().Index(i).Addr().Interface()) + } + } else { + return item + } + } + for i := 0; i < rt.NumField(); i++ { + structField := rt.Field(i) + cur := rv.Elem().Field(i) + tags, err := structtag.Parse(string(structField.Tag)) + if err != nil { + continue + } + incTag, err := tags.Get("autoinc") + if err != nil { + continue + } + + nid := getLastInColl(incTag.Name, cur.Interface()) + if cur.IsZero() { + coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur) + if coerced != nil { + cur.Set(reflect.ValueOf(coerced)) + } else { + cur.Set(reflect.ValueOf(incrementInterface(nid))) + } + } + counterColl := DB.Collection(COUNTER_COL) + counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.Update().SetUpsert(true)) + + } + return rv.Elem().Interface() +} + +func incrementAll(item interface{}) { + if item == nil { + return + } + vp := reflect.ValueOf(item) + el := vp + if vp.Kind() == reflect.Pointer { + el = vp.Elem() + } + if vp.Kind() == reflect.Pointer && vp.IsNil() { + return + } + vt := el.Type() + switch el.Kind() { + case reflect.Struct: + incrementTagged(item) + for i := 0; i < el.NumField(); i++ { + fv := el.Field(i) + fst := vt.Field(i) + if !fst.IsExported() { + continue + } + incrementAll(fv.Interface()) + } + case reflect.Slice: + for i := 0; i < el.Len(); i++ { + incd := incrementTagged(el.Index(i).Addr().Interface()) + if reflect.ValueOf(incd).Kind() == reflect.Pointer { + el.Index(i).Set(reflect.ValueOf(incd).Elem()) + } else { + el.Index(i).Set(reflect.ValueOf(incd)) + } + } + } +} + +func checkStruct(ref reflect.Value) error { + if ref.Kind() == reflect.Slice { + return fmt.Errorf("Cannot append to multiple documents!") + } + if ref.Kind() != reflect.Struct { + return fmt.Errorf("Current object is not a struct!") + } + return nil +} + // Append appends one or more items to `field`. // will error if this Model contains a reference // to multiple documents, or if `field` is not a @@ -237,16 +328,14 @@ func (m *Model) Append(field string, a ...interface{}) error { selfRef = selfRef.Elem() rt = rt.Elem() } - if selfRef.Kind() == reflect.Slice { - return fmt.Errorf("Cannot append to multiple documents!") - } - if selfRef.Kind() != reflect.Struct { - return fmt.Errorf("Current object is not a struct!") + if err := checkStruct(selfRef); err != nil { + return err } _, origV, err := getNested(field, selfRef) if err != nil { return err } + origRef := makeSettable(*origV, (*origV).Interface()) fv := origRef if fv.Kind() == reflect.Pointer { @@ -255,14 +344,49 @@ func (m *Model) Append(field string, a ...interface{}) error { if fv.Kind() != reflect.Slice { return fmt.Errorf("Current object is not a slice!") } - for _, b := range a { - val := reflect.ValueOf(b) + val := reflect.ValueOf(incrementTagged(b)) fv.Set(reflect.Append(fv, val)) } return nil } +func (m *Model) Pull(field string, a ...any) error { + rv := reflect.ValueOf(m.self) + selfRef := rv + rt := reflect.TypeOf(m.self) + if selfRef.Kind() == reflect.Pointer { + selfRef = selfRef.Elem() + rt = rt.Elem() + } + if err := checkStruct(selfRef); err != nil { + return err + } + _, origV, err := getNested(field, selfRef) + if err != nil { + return err + } + + origRef := makeSettable(*origV, (*origV).Interface()) + fv := origRef + if fv.Kind() == reflect.Pointer { + fv = fv.Elem() + } + if fv.Kind() != reflect.Slice { + return fmt.Errorf("Current object is not a slice!") + } +outer: + for _, b := range a { + for i := 0; i < fv.Len(); i++ { + if reflect.DeepEqual(b, fv.Index(i).Interface()) { + fv.Set(pull(fv, i, fv.Index(i).Type())) + break outer + } + } + } + return nil +} + func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { var err error m, ok := arg.(IModel) @@ -291,29 +415,11 @@ func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { } if isNew { nid := getLastInColl(c.Name(), asHasId.Id()) - switch pnid := nid.(type) { - case uint: - nid = pnid + 1 - case uint32: - nid = pnid + 1 - case uint64: - nid = pnid + 1 - case int: - nid = pnid + 1 - case int32: - nid = pnid + 1 - case int64: - nid = pnid + 1 - case string: - nid = NextStringID() - case primitive.ObjectID: - nid = primitive.NewObjectID() - default: - panic("unknown or unsupported id type") - } + pnid := incrementInterface(nid) if reflect.ValueOf(asHasId.Id()).IsZero() { - (asHasId).SetId(nid) + (asHasId).SetId(pnid) } + incrementAll(asHasId) _, err = c.InsertOne(context.TODO(), m.serializeToStore()) if err == nil { @@ -340,115 +446,123 @@ func (m *Model) Save() error { } } -func (m *Model) serializeToStore() bson.M { - return serializeIDs((*m).self) +func (m *Model) serializeToStore() any { + return serializeIDs((m).self) } -func serializeIDs(input interface{}) bson.M { - ret := bson.M{} - mv := reflect.ValueOf(input) - mt := reflect.TypeOf(input) +func serializeIDs(input interface{}) interface{} { - vp := mv + vp := reflect.ValueOf(input) + mt := reflect.TypeOf(input) + var ret interface{} if vp.Kind() != reflect.Ptr { - vp = reflect.New(mv.Type()) - vp.Elem().Set(mv) - } - if mv.Kind() == reflect.Pointer { - mv = mv.Elem() + if vp.CanAddr() { + vp = vp.Addr() + } else { + vp = makeSettable(vp, input) + } } + if mt.Kind() == reflect.Pointer { mt = mt.Elem() } - for i := 0; i < mv.NumField(); i++ { - fv := mv.Field(i) - ft := mt.Field(i) - var dr = fv - tag, err := structtag.Parse(string(mt.Field(i).Tag)) - panik(err) - bbson, err := tag.Get("bson") - if err != nil || bbson.Name == "-" { - continue + getID := func(bbb interface{}) interface{} { + mptr := reflect.ValueOf(bbb) + if mptr.Kind() != reflect.Pointer { + mptr = makeSettable(mptr, bbb) } - _, terr := tag.Get("ref") - switch dr.Type().Kind() { - case reflect.Slice: - rarr := make([]interface{}, 0) - intArr := iFaceSlice(fv.Interface()) - for _, idHaver := range intArr { - if terr == nil { - if reflect.ValueOf(idHaver).Type().Kind() != reflect.Pointer { - vp := reflect.New(reflect.ValueOf(idHaver).Type()) - vp.Elem().Set(reflect.ValueOf(idHaver)) - idHaver = vp.Interface() - } - ifc, ok := idHaver.(HasID) - if !ok { - panic(fmt.Sprintf("referenced model slice '%s' does not implement HasID", ft.Name)) - } - rarr = append(rarr, ifc.Id()) - } else if reflect.ValueOf(idHaver).Kind() == reflect.Struct { - rarr = append(rarr, serializeIDs(idHaver)) - } else { - if reflect.ValueOf(idHaver).Kind() == reflect.Slice { - rarr = append(rarr, serializeIDSlice(iFaceSlice(idHaver))) - } else { - rarr = append(rarr, idHaver) - } - } - ret[bbson.Name] = rarr + ifc, ok := mptr.Interface().(HasID) + if ok { + return ifc.Id() + } else { + return nil + } + } + /*var itagged interface{} + if reflect.ValueOf(itagged).Kind() != reflect.Pointer { + itagged = incrementTagged(&input) + } else { + itagged = incrementTagged(input) + } + taggedVal := reflect.ValueOf(reflect.ValueOf(itagged).Interface()).Elem() + if vp.Kind() == reflect.Ptr { + tmp := reflect.ValueOf(taggedVal.Interface()) + if tmp.Kind() == reflect.Pointer { + vp.Elem().Set(tmp.Elem()) + } else { + vp.Elem().Set(tmp) + } + }*/ + switch vp.Elem().Kind() { + case reflect.Struct: + ret0 := bson.M{} + for i := 0; i < vp.Elem().NumField(); i++ { + fv := vp.Elem().Field(i) + ft := mt.Field(i) + tag, err := structtag.Parse(string(ft.Tag)) + panik(err) + bbson, err := tag.Get("bson") + if err != nil || bbson.Name == "-" { + continue } - case reflect.Pointer: - dr = fv.Elem() - fallthrough - case reflect.Struct: - if bbson.Name != "" { - if terr == nil { - idHaver := fv.Interface() - ifc, ok := idHaver.(HasID) - if !ok { - panic(fmt.Sprintf("referenced model '%s' does not implement HasID", ft.Name)) - } - if !fv.IsNil() { - ret[bbson.Name] = ifc.Id() + if bbson.Name == "" { + marsh, _ := bson.Marshal(fv.Interface()) + unmarsh := bson.M{} + bson.Unmarshal(marsh, &unmarsh) + for k, v := range unmarsh { + ret0[k] = v + /*if t, ok := v.(primitive.DateTime); ok { + ret0 } else { - ret[bbson.Name] = nil - } - - } else if reflect.TypeOf(fv.Interface()) != reflect.TypeOf(time.Now()) { - ret[bbson.Name] = serializeIDs(fv.Interface()) - } else { - ret[bbson.Name] = fv.Interface() + }*/ } } else { - for k, v := range serializeIDs(fv.Interface()) { - ret[k] = v + _, terr := tag.Get("ref") + if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer { + vp1 := reflect.New(fv.Type()) + vp1.Elem().Set(reflect.ValueOf(fv.Interface())) + fv.Set(vp1.Elem()) + } + if terr == nil { + ifc, ok := fv.Interface().(HasID) + if fv.Kind() == reflect.Slice { + rarr := bson.A{} + for j := 0; j < fv.Len(); j++ { + rarr = append(rarr, getID(fv.Index(j).Interface())) + } + ret0[bbson.Name] = rarr + /*ret0[bbson.Name] = serializeIDs(fv.Interface()) + break*/ + } else if !ok { + panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name)) + } else { + if reflect.ValueOf(ifc).IsNil() { + ret0[bbson.Name] = nil + } else { + ret0[bbson.Name] = ifc.Id() + } + } + + } else { + ret0[bbson.Name] = serializeIDs(fv.Interface()) } } - default: - ret[bbson.Name] = fv.Interface() + ret = ret0 } + case reflect.Slice: + ret0 := bson.A{} + for i := 0; i < vp.Elem().Len(); i++ { + + ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface())) + } + ret = ret0 + default: + ret = vp.Elem().Interface() } return ret } -func serializeIDSlice(input []interface{}) bson.A { - var a bson.A - for _, in := range input { - fv := reflect.ValueOf(in) - switch fv.Type().Kind() { - case reflect.Slice: - a = append(a, serializeIDSlice(iFaceSlice(fv.Interface()))) - case reflect.Struct: - a = append(a, serializeIDs(fv.Interface())) - default: - a = append(a, fv.Interface()) - } - } - return a -} - // Create creates a new instance of a given model // and returns a pointer to it. func Create(d any) any { diff --git a/model_test.go b/model_test.go index 4611b68..d6f592a 100644 --- a/model_test.go +++ b/model_test.go @@ -29,6 +29,9 @@ func TestSave(t *testing.T) { assert.Equal(t, nil, serr) assert.Less(t, int64(0), storyDoc.ID) assert.Less(t, int64(0), lauthor.ID) + for _, c := range storyDoc.Chapters { + assert.NotZero(t, c.ChapterID) + } } func TestPopulate(t *testing.T) { @@ -52,7 +55,9 @@ func TestPopulate(t *testing.T) { j, _ := q.JSON() fmt.Printf("%s\n", j) }) - assert.NotZero(t, storyDoc.Chapters[0].Bands[0].Name) + for _, c := range storyDoc.Chapters { + assert.NotZero(t, c.Bands[0].Name) + } } func TestUpdate(t *testing.T) { @@ -125,3 +130,19 @@ func TestModel_Delete(t *testing.T) { err := bandDoc.Delete() assert.Nil(t, err) } + +func TestModel_Pull(t *testing.T) { + initTest() + storyDoc := Create(iti_multi).(*story) + smodel := Create(story{}).(*story) + saveDoc(t, storyDoc) + err := storyDoc.Pull("Chapters", storyDoc.Chapters[4]) + assert.Nil(t, err) + assert.NotZero(t, storyDoc.ID) + saveDoc(t, storyDoc) + fin := &story{} + query, err := smodel.FindByID(storyDoc.ID) + assert.Nil(t, err) + query.Exec(fin) + assert.Equal(t, 4, len(fin.Chapters)) +} diff --git a/query.go b/query.go index b3b9eca..7c8f489 100644 --- a/query.go +++ b/query.go @@ -292,9 +292,6 @@ func rerere(input interface{}, resType reflect.Type) interface{} { if input == nil { return nil } - if v.CanAddr() && v.IsNil() { - return nil - } if v.Type().Kind() == reflect.Pointer { v = v.Elem() } @@ -346,12 +343,14 @@ func rerere(input interface{}, resType reflect.Type) interface{} { if err != nil { tmp := rerere(intermediate, ft.Type) fuck := reflect.ValueOf(tmp) - if !fuck.IsZero() { + if tmp != nil { if fuck.Type().Kind() == reflect.Pointer { fuck = fuck.Elem() } + fv.Set(fuck) + } else { + fv.Set(reflect.Zero(ft.Type)) } - fv.Set(fuck) shouldBreak = true } else { tt := ft.Type diff --git a/testing.go b/testing.go index 818c4da..e92e906 100644 --- a/testing.go +++ b/testing.go @@ -118,6 +118,7 @@ func genChaps(single bool) []chapter { for i := 0; i < ceil; i++ { spf := fmt.Sprintf("%d.md", i+1) ret = append(ret, chapter{ + ID: primitive.NewObjectID(), Title: fmt.Sprintf("-%d-", i+1), Index: int(i + 1), Words: 50, @@ -149,6 +150,12 @@ var iti_multi story = story{ Chapters: genChaps(false), } +func iti_blank() story { + t := iti_single + t.Chapters = make([]chapter, 0) + return t +} + func initTest() { uri := "mongodb://127.0.0.1:27017" db := "rockfic_ormTest" diff --git a/util.go b/util.go index 6e77761..bf4fbf7 100644 --- a/util.go +++ b/util.go @@ -2,6 +2,7 @@ package orm import ( "fmt" + "go.mongodb.org/mongo-driver/bson/primitive" "reflect" "strings" ) @@ -65,7 +66,7 @@ func coerceInt(input reflect.Value, dst reflect.Value) interface{} { return nil } -func getNested(field string, value reflect.Value) (*reflect.Type, *reflect.Value, error) { +func getNested(field string, value reflect.Value) (*reflect.StructField, *reflect.Value, error) { if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") { return nil, nil, fmt.Errorf("Malformed field name %s passed", field) } @@ -78,7 +79,7 @@ func getNested(field string, value reflect.Value) (*reflect.Type, *reflect.Value ref = ref.Elem() } fv := ref.FieldByName(dots[0]) - ft := fv.Type() + ft, _ := ref.Type().FieldByName(dots[0]) if len(dots) > 1 { return getNested(strings.Join(dots[1:], "."), fv) } else { @@ -93,3 +94,38 @@ func makeSettable(rval reflect.Value, value interface{}) reflect.Value { } return rval } + +func incrementInterface(t interface{}) interface{} { + switch pt := t.(type) { + case uint: + t = pt + 1 + case uint32: + t = pt + 1 + case uint64: + t = pt + 1 + case int: + t = pt + 1 + case int32: + t = pt + 1 + case int64: + t = pt + 1 + case string: + t = NextStringID() + case primitive.ObjectID: + t = primitive.NewObjectID() + default: + panic("unknown or unsupported id type") + } + return t +} + +func pull(s reflect.Value, idx int, typ reflect.Type) reflect.Value { + retI := reflect.New(reflect.SliceOf(typ)) + for i := 0; i < s.Len(); i++ { + if i == idx { + continue + } + retI.Elem().Set(reflect.Append(retI.Elem(), s.Index(i))) + } + return retI.Elem() +}