diff --git a/model.go b/model.go index 954d2d2..7ee7d2f 100644 --- a/model.go +++ b/model.go @@ -59,12 +59,12 @@ type IModel interface { FindByID(id interface{}) (*Query, error) FindOne(query interface{}, options ...*options.FindOneOptions) (*Query, error) FindPaged(query interface{}, page int64, perPage int64, options ...*options.FindOptions) (*Query, error) - getColl() *mongo.Collection - getIdxs() []*mongo.IndexModel - getParsedIdxs() map[string][]InternalIndex Pull(field string, a ...any) error Remove() error Save() error + getColl() *mongo.Collection + getIdxs() []*mongo.IndexModel + getParsedIdxs() map[string][]InternalIndex serializeToStore() any setTypeName(str string) getExists() bool @@ -131,8 +131,9 @@ func (m *Model) FindRaw(query interface{}, opts ...*options.FindOptions) (*mongo // returns a pointer to a Query for further chaining. func (m *Model) Find(query interface{}, opts ...*options.FindOptions) (*Query, error) { qqn := ModelRegistry.new_(m.typeName) - qqv := reflect.New(reflect.SliceOf(reflect.TypeOf(qqn).Elem())) - qqv.Elem().Set(reflect.Zero(qqv.Elem().Type())) + qqt := reflect.SliceOf(reflect.TypeOf(qqn)) + qqv := reflect.New(qqt) + qqv.Elem().Set(reflect.MakeSlice(qqt, 0, 0)) qq := &Query{ Model: *m, Collection: m.getColl(), @@ -190,10 +191,13 @@ func (m *Model) FindOne(query interface{}, options ...*options.FindOneOptions) ( err := rip.Decode(&raw) panik(err) m.exists = true + qqn := ModelRegistry.new_(m.typeName) + v := reflect.New(reflect.TypeOf(qqn)) + v.Elem().Set(reflect.ValueOf(qqn)) qq := &Query{ Collection: m.getColl(), rawDoc: raw, - doc: ModelRegistry.new_(m.typeName), + doc: v.Elem().Interface(), Op: OP_FIND_ONE, Model: *m, } @@ -360,9 +364,7 @@ func (m *Model) serializeToStore() any { return serializeIDs((m).self) } -// Create creates a new instance of a given model -// and returns a pointer to it. -func Create(d any) any { +func createBase(d any) (reflect.Value, int, string) { var n string var ri *InternalModel var ok bool @@ -381,13 +383,21 @@ func Create(d any) any { r.Elem().Set(v) - df := r.Elem().Field(i) - dm := df.Interface().(Model) if reflect.ValueOf(d).Kind() == reflect.Pointer { r.Elem().Set(reflect.ValueOf(d).Elem()) } else { r.Elem().Set(reflect.ValueOf(d)) } + + return r, i, n +} + +// Create creates a new instance of a given model +// and returns a pointer to it. +func Create(d any) any { + r, i, n := createBase(d) + df := r.Elem().Field(i) + dm := df.Interface().(Model) dm.typeName = n what := r.Interface() dm.self = what @@ -395,6 +405,15 @@ func Create(d any) any { return what } +func CreateSlice[T any](d T) []*T { + r, _, _ := createBase(d) + rtype := r.Type() + rslice := reflect.SliceOf(rtype) + newItem := reflect.New(rslice) + newItem.Elem().Set(reflect.MakeSlice(rslice, 0, 0)) + return newItem.Elem().Interface().([]*T) +} + func (m *Model) PrintMe() { fmt.Printf("My name is %s !\n", nameOf(m)) } diff --git a/model_test.go b/model_test.go index 98413a3..bb54da3 100644 --- a/model_test.go +++ b/model_test.go @@ -83,7 +83,7 @@ func TestModel_FindAll(t *testing.T) { smodel := Create(story{}).(*story) query, err := smodel.Find(bson.M{}, options.Find()) assert.Equal(t, nil, err) - final := make([]story, 0) + final := CreateSlice(story{}) query.Exec(&final) assert.Greater(t, len(final), 0) } @@ -92,11 +92,14 @@ func TestModel_PopulateMulti(t *testing.T) { initTest() bandDoc := Create(iti_single.Chapters[0].Bands[0]).(*band) saveDoc(t, bandDoc) + mauthor := Create(author).(*user) + saveDoc(t, mauthor) + iti_multi.Author = mauthor createAndSave(t, &iti_multi) smodel := Create(story{}).(*story) query, err := smodel.Find(bson.M{}, options.Find()) assert.Equal(t, nil, err) - final := make([]story, 0) + final := CreateSlice(story{}) query.Populate("Author", "Chapters.Bands").Exec(&final) assert.Greater(t, len(final), 0) for _, s := range final { @@ -151,6 +154,7 @@ func TestModel_Swap(t *testing.T) { initTest() iti_single.Author = &author storyDoc := Create(iti_single).(*story) + saveDoc(t, storyDoc) storyDoc.Chapters[0].Bands = append(storyDoc.Chapters[0].Bands, metallica) assert.Equal(t, 2, len(storyDoc.Chapters[0].Bands)) err := storyDoc.Swap("Chapters[0].Bands", 0, 1) diff --git a/query.go b/query.go index 7c8f489..308cbac 100644 --- a/query.go +++ b/query.go @@ -226,7 +226,7 @@ func (q *Query) Populate(fields ...string) *Query { _, refColl, _ := ModelRegistry.HasByName(tto.Name()) var tmp1 interface{} if arr, ok := rawDoc.(bson.A); ok { - typ := cm.Type + typ := reflect.PointerTo(cm.Type) slic := reflect.New(reflect.SliceOf(typ)) for i, val2 := range arr { ref := reflect.ValueOf(q.doc) @@ -236,7 +236,7 @@ func (q *Query) Populate(fields ...string) *Query { src := ref.Index(i).Interface() inter := populate(r, refColl.Collection, val2, field, src) if reflect.ValueOf(inter).Kind() == reflect.Pointer { - slic.Elem().Set(reflect.Append(slic.Elem(), reflect.ValueOf(inter).Elem())) + slic.Elem().Set(reflect.Append(slic.Elem(), reflect.ValueOf(inter))) } else { slic.Elem().Set(reflect.Append(slic, reflect.ValueOf(inter))) } @@ -255,35 +255,46 @@ func (q *Query) reOrganize() { var trvo reflect.Value if arr, ok := q.rawDoc.(bson.A); ok { typ := ModelRegistry[q.Model.typeName].Type + if typ.Kind() != reflect.Pointer { + typ = reflect.PointerTo(typ) + } slic := reflect.New(reflect.SliceOf(typ)) for _, v2 := range arr { inter := reflect.ValueOf(rerere(v2, typ)) - if inter.Kind() == reflect.Pointer { + /*if inter.Kind() == reflect.Pointer { inter = inter.Elem() - } + }*/ slic.Elem().Set(reflect.Append(slic.Elem(), inter)) } trvo = slic.Elem() } else { trvo = reflect.ValueOf(rerere(q.rawDoc, reflect.TypeOf(q.doc))) - for { + /*for { if trvo.Kind() == reflect.Pointer { trvo = trvo.Elem() } else { break } - } + }*/ } resV := reflect.ValueOf(q.doc) for { if resV.Kind() == reflect.Pointer { - resV = resV.Elem() + if resV.Elem().Kind() == reflect.Slice { + resV = resV.Elem() + } else { + break + } } else { break } } - resV.Set(trvo) + if resV.CanSet() { + resV.Set(trvo) + } else { + resV.Elem().Set(trvo.Elem()) + } } func rerere(input interface{}, resType reflect.Type) interface{} { @@ -303,7 +314,7 @@ func rerere(input interface{}, resType reflect.Type) interface{} { resType = resType.Elem() } resV := reflect.New(resType) - var rve reflect.Value = resV + var rve = resV if rve.Kind() == reflect.Pointer { rve = resV.Elem() } @@ -472,6 +483,17 @@ func (q *Query) Exec(result interface{}) { if q.done { panic("Exec() has already been called!") } + doc := reflect.ValueOf(q.doc) + if doc.Elem().Kind() == reflect.Slice { + for i := 0; i < doc.Elem().Len(); i++ { + cur := doc.Elem().Index(i) + imodel, ok := cur.Interface().(IModel) + if ok { + imodel.setExists(true) + doc.Elem().Index(i).Set(reflect.ValueOf(imodel)) + } + } + } reflect.ValueOf(result).Elem().Set(reflect.ValueOf(q.doc).Elem()) q.Model.self = q.doc q.done = true diff --git a/registry.go b/registry.go index 34eb399..b0021d3 100644 --- a/registry.go +++ b/registry.go @@ -72,7 +72,7 @@ func getRawTypeFromTag(tagOpt string, slice bool) reflect.Type { var v uint = 0 t = reflect.TypeOf(v) case "string": - var v string = "0" + var v = "0" t = reflect.TypeOf(v) } @@ -207,11 +207,11 @@ func (r TModelRegistry) Index(n string) int { } func (t TModelRegistry) new_(n string) interface{} { - if n, m, ok := ModelRegistry.HasByName(n); ok { + if name, m, ok := ModelRegistry.HasByName(n); ok { v := reflect.New(m.Type) df := v.Elem().Field(m.Idx) d := df.Interface().(Model) - d.typeName = n + d.typeName = name df.Set(reflect.ValueOf(d)) return v.Interface() }