diff --git a/gridfs.go b/gridfs.go new file mode 100644 index 0000000..ee41ecf --- /dev/null +++ b/gridfs.go @@ -0,0 +1,197 @@ +package orm + +import ( + "bytes" + "context" + "errors" + "fmt" + "github.com/fatih/structtag" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/gridfs" + "go.mongodb.org/mongo-driver/mongo/options" + "html/template" + "io" + "reflect" + "strings" +) + +type GridFSFile struct { + ID primitive.ObjectID `bson:"_id"` + Name string `bson:"filename"` + Length int `bson:"length"` +} + +func parseFmt(format string, value any) string { + tmpl, err := template.New("filename").Parse(format) + panik(err) + w := new(strings.Builder) + err = tmpl.Execute(w, value) + panik(err) + return w.String() +} + +func bucket(gfsRef GridFSReference) *gridfs.Bucket { + b, _ := gridfs.NewBucket(DB, options.GridFSBucket().SetName(gfsRef.BucketName)) + return b +} + +func gridFsLoad(val any, g GridFSReference, field string) any { + doc := reflect.ValueOf(val) + rdoc := reflect.ValueOf(val) + if doc.Kind() != reflect.Pointer { + doc = reflect.New(reflect.TypeOf(val)) + doc.Elem().Set(reflect.ValueOf(val)) + } + var next string + if len(strings.Split(field, ".")) > 1 { + next = strings.Join(strings.Split(field, ".")[1:], ".") + field = strings.Split(field, ".")[0] + } else { + next = field + } + _, rfield, ferr := getNested(field, rdoc) + if ferr != nil { + return nil + } + switch rfield.Kind() { + case reflect.Slice: + for i := 0; i < rfield.Len(); i++ { + cur := rfield.Index(i) + if cur.Kind() != reflect.Pointer { + tmp := reflect.New(cur.Type()) + tmp.Elem().Set(cur) + cur = tmp + } + intermediate := gridFsLoad(cur.Interface(), g, next) + if intermediate == nil { + continue + } + ival := reflect.ValueOf(intermediate) + if ival.Kind() == reflect.Pointer { + ival = ival.Elem() + } + rfield.Index(i).Set(ival) + } + case reflect.Struct: + + intermediate := gridFsLoad(rfield.Interface(), g, next) + if intermediate != nil { + rfield.Set(reflect.ValueOf(intermediate)) + } + default: + b := bucket(g) + var found GridFSFile + cursor, err := b.Find(bson.M{"filename": parseFmt(g.FilenameFmt, val)}) + if err != nil { + return nil + } + cursor.Next(context.TODO()) + _ = cursor.Decode(&found) + bb := bytes.NewBuffer(nil) + _, err = b.DownloadToStream(found.ID, bb) + if err != nil { + return nil + } + if rfield.Type().AssignableTo(reflect.TypeFor[[]byte]()) { + rfield.Set(reflect.ValueOf(bb.Bytes())) + } else if rfield.Type().AssignableTo(reflect.TypeFor[string]()) { + rfield.Set(reflect.ValueOf(bb.String())) + } + + } + if rdoc.Kind() != reflect.Pointer { + return doc.Elem().Interface() + } + return doc.Interface() +} + +func gridFsSave(val any, imodel InternalModel) error { + var rerr error + v := reflect.ValueOf(val) + el := v + if v.Kind() == reflect.Pointer { + el = el.Elem() + } + + switch el.Kind() { + case reflect.Struct: + for i := 0; i < el.NumField(); i++ { + ft := el.Type().Field(i) + fv := el.Field(i) + if !ft.IsExported() { + continue + } + _, err := structtag.Parse(string(ft.Tag)) + panik(err) + var gfsRef *GridFSReference + for kk, vv := range imodel.GridFSReferences { + if strings.HasPrefix(kk, ft.Name) { + gfsRef = &vv + break + } + } + var inner = func(b *gridfs.Bucket, it reflect.Value) error { + filename := parseFmt(gfsRef.FilenameFmt, it.Interface()) + contents := GridFSFile{} + curs, err2 := b.Find(bson.M{"filename": filename}) + + if !errors.Is(err2, mongo.ErrNoDocuments) { + _ = curs.Decode(&contents) + if !reflect.ValueOf(contents).IsZero() { + _ = b.Delete(contents.ID) + } + } + c := it.Field(gfsRef.Idx) + var rdr io.Reader + + if c.Type().AssignableTo(reflect.TypeOf([]byte{})) { + rdr = bytes.NewReader(c.Interface().([]byte)) + } else if c.Type().AssignableTo(reflect.TypeOf("")) { + rdr = strings.NewReader(c.Interface().(string)) + } else { + return fmt.Errorf("gridfs loader type '%s' not supported", c.Type().String()) + } + _, err = b.UploadFromStream(filename, rdr) + return err + } + + if gfsRef != nil { + b := bucket(*gfsRef) + if fv.Kind() == reflect.Slice { + for j := 0; j < fv.Len(); j++ { + lerr := inner(b, fv.Index(j)) + if lerr != nil { + return lerr + } + } + } else if fv.Kind() == reflect.Struct { + lerr := inner(b, fv) + if lerr != nil { + return lerr + } + } else { + lerr := inner(b, el) + if lerr != nil { + return lerr + } + } + } + err = gridFsSave(fv.Interface(), imodel) + if err != nil { + return err + } + } + case reflect.Slice: + for i := 0; i < el.Len(); i++ { + rerr = gridFsSave(el.Index(i).Interface(), imodel) + if rerr != nil { + return rerr + } + } + default: + break + } + return rerr +} diff --git a/model_internals.go b/model_internals.go index 95c0532..2cca2f5 100644 --- a/model_internals.go +++ b/model_internals.go @@ -138,6 +138,7 @@ func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { vp.Elem().Set(selfo) } var asHasId = vp.Interface().(HasID) + var asModel = vp.Interface().(IModel) if isNew { m.setCreated(now) } @@ -156,6 +157,8 @@ func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { (asHasId).SetId(pnid) } incrementAll(asHasId) + _, im, _ := ModelRegistry.HasByName(asModel.getTypeName()) + _ = gridFsSave(asHasId, *im) _, err = c.InsertOne(context.TODO(), m.serializeToStore()) if err == nil { diff --git a/model_test.go b/model_test.go index fb77ce0..cd0bf26 100644 --- a/model_test.go +++ b/model_test.go @@ -209,3 +209,65 @@ func TestModel_Swap(t *testing.T) { assert.Equal(t, diamondHead.ID, c[1].ID) saveDoc(t, storyDoc) } + +func TestModel_GridFSLoad(t *testing.T) { + initTest() + ModelRegistry.Model(somethingWithNestedChapters{}) + model := Create(somethingWithNestedChapters{}).(*somethingWithNestedChapters) + thingDoc := Create(doSomethingWithNested()).(*somethingWithNestedChapters) + found := &somethingWithNestedChapters{} + + saveDoc(t, thingDoc) + assert.NotZero(t, thingDoc.ID) + fq, err := model.FindByID(thingDoc.ID) + assert.Nil(t, err) + fq.LoadFile("NestedText", "Chapters.Text").Exec(found) + assert.NotZero(t, found.NestedText) + assert.NotZero(t, len(found.Chapters)) + for _, c := range found.Chapters { + assert.NotZero(t, c.Text) + } +} + +func TestModel_GridFSLoad_Complex(t *testing.T) { + initTest() + model := Create(story{}).(*story) + bandDoc := Create(iti_single.Chapters[0].Bands[0]).(*band) + thingDoc := Create(iti_multi).(*story) + mauthor := Create(author).(*user) + found := &story{} + saveDoc(t, mauthor) + saveDoc(t, bandDoc) + thingDoc.Author = mauthor + saveDoc(t, thingDoc) + assert.NotZero(t, thingDoc.ID) + fq, err := model.FindByID(thingDoc.ID) + assert.Nil(t, err) + fq.Populate("Author", "Chapters.Bands").LoadFile("Chapters.Text").Exec(found) + assert.NotZero(t, len(found.Chapters)) + for _, c := range found.Chapters { + assert.NotZero(t, c.Text) + assert.NotZero(t, c.Bands[0].Name) + } + j, _ := fq.JSON() + fmt.Printf("%s\n", j) +} + +func TestModel_GridFSLoad_Chained(t *testing.T) { + initTest() + ModelRegistry.Model(somethingWithNestedChapters{}) + model := Create(somethingWithNestedChapters{}).(*somethingWithNestedChapters) + thingDoc := Create(doSomethingWithNested()).(*somethingWithNestedChapters) + found := &somethingWithNestedChapters{} + + saveDoc(t, thingDoc) + assert.NotZero(t, thingDoc.ID) + fq, err := model.FindByID(thingDoc.ID) + assert.Nil(t, err) + fq.LoadFile("NestedText").LoadFile("Chapters.Text").Exec(found) + assert.NotZero(t, found.NestedText) + assert.NotZero(t, len(found.Chapters)) + for _, c := range found.Chapters { + assert.NotZero(t, c.Text) + } +} diff --git a/query.go b/query.go index 13aea6f..052f6b1 100644 --- a/query.go +++ b/query.go @@ -184,6 +184,38 @@ func populate(r Reference, rcoll string, rawDoc interface{}, d string, src inter return src } +// LoadFile - loads the contents of one or more files +// stored in gridFS into the fields named by `fields`. +// +// gridFS fields can be either a `string` or `[]byte`, and are +// tagged with `gridfs:"BUCKET,FILE_FORMAT` +// where: +// - `BUCKET` is the name of the bucket where the files are stored +// - `FILE_FORMAT` is any valid go template string that resolves to +// the unique file name. +// all exported values and methods present in the surrounding +// struct can be used in this template. +func (q *Query) LoadFile(fields ...string) *Query { + _, cm, _ := ModelRegistry.HasByName(q.model.typeName) + if cm != nil { + for _, field := range fields { + var r GridFSReference + hasAnnotated := false + for k2, v := range cm.GridFSReferences { + if strings.HasPrefix(k2, field) { + r = v + hasAnnotated = true + break + } + } + if hasAnnotated { + q.doc = gridFsLoad(q.doc, r, field) + } + } + } + return q +} + // Populate populates document references via reflection func (q *Query) Populate(fields ...string) *Query { _, cm, _ := ModelRegistry.HasByName(q.model.typeName) @@ -465,7 +497,7 @@ func handleAnon(raw interface{}, rtype reflect.Type, rval reflect.Value) reflect // JSON - marshals this Query's results into json format func (q *Query) JSON() (string, error) { - res, err := json.MarshalIndent(q.doc, "\n", "\t") + res, err := json.MarshalIndent(q.doc, "", "\t") if err != nil { return "", err } diff --git a/registry.go b/registry.go index b0021d3..56edb18 100644 --- a/registry.go +++ b/registry.go @@ -19,11 +19,12 @@ import ( // InternalModel, as the name suggests, is used // internally by the model registry type InternalModel struct { - Idx int - Type reflect.Type - Collection string - References map[string]Reference - Indexes map[string][]InternalIndex + Idx int + Type reflect.Type + Collection string + References map[string]Reference + Indexes map[string][]InternalIndex + GridFSReferences map[string]GridFSReference } // Reference stores a typed document reference @@ -43,6 +44,13 @@ type Reference struct { Exists bool } +type GridFSReference struct { + BucketName string + FilenameFmt string + LoadType reflect.Type + Idx int +} + type TModelRegistry map[string]*InternalModel // ModelRegistry - the ModelRegistry stores a map containing @@ -82,6 +90,36 @@ func getRawTypeFromTag(tagOpt string, slice bool) reflect.Type { return t } +func makeGfsRef(tag *structtag.Tag, idx int) GridFSReference { + opts := tag.Options + var ffmt string + if len(opts) < 1 { + ffmt = "%s" + } else { + ffmt = opts[0] + } + var typ reflect.Type + if len(opts) < 2 { + typ = reflect.TypeOf("") + } else { + switch opts[1] { + case "bytes": + typ = reflect.TypeOf([]byte{}) + case "string": + typ = reflect.TypeOf("") + default: + typ = reflect.TypeOf("") + } + } + + return GridFSReference{ + FilenameFmt: ffmt, + BucketName: tag.Name, + LoadType: typ, + Idx: idx, + } +} + func makeRef(idx int, modelName string, fieldName string, ht reflect.Type) Reference { if modelName != "" { if ModelRegistry.Index(modelName) != -1 { @@ -106,10 +144,11 @@ func makeRef(idx int, modelName string, fieldName string, ht reflect.Type) Refer panic("model name was empty") } -func parseTags(t reflect.Type, v reflect.Value) (map[string][]InternalIndex, map[string]Reference, string) { +func parseTags(t reflect.Type, v reflect.Value) (map[string][]InternalIndex, map[string]Reference, map[string]GridFSReference, string) { coll := "" - refs := make(map[string]Reference, 0) - idcs := make(map[string][]InternalIndex, 0) + refs := make(map[string]Reference) + idcs := make(map[string][]InternalIndex) + gfsRefs := make(map[string]GridFSReference) for i := 0; i < v.NumField(); i++ { sft := t.Field(i) @@ -126,15 +165,17 @@ func parseTags(t reflect.Type, v reflect.Value) (map[string][]InternalIndex, map ft = ft.Elem() if _, ok := tags.Get("ref"); ok != nil { if ft.Kind() == reflect.Struct { - ii2, rr2, _ := parseTags(ft, reflect.New(ft).Elem()) - for k, v := range ii2 { - idcs[sft.Name+"."+k] = v + ii2, rr2, gg2, _ := parseTags(ft, reflect.New(ft).Elem()) + for k, vv := range ii2 { + idcs[sft.Name+"."+k] = vv } - for k, v := range rr2 { - refs[sft.Name+"."+k] = v + for k, vv := range rr2 { + refs[sft.Name+"."+k] = vv + } + for k, vv := range gg2 { + gfsRefs[sft.Name+"."+k] = vv } } - } continue case reflect.Pointer: @@ -160,18 +201,26 @@ func parseTags(t reflect.Type, v reflect.Value) (map[string][]InternalIndex, map sname := sft.Name + "@" + refTag.Name refs[sname] = makeRef(i, refTag.Name, sft.Name, sft.Type) } + if gtag, ok := tags.Get("gridfs"); ok == nil { + sname := sft.Name + "@" + gtag.Name + gfsRefs[sname] = makeGfsRef(gtag, i) + } fallthrough default: idxTag, err := tags.Get("idx") if err == nil { idcs[sft.Name] = scanIndex(idxTag.Value()) } + if gtag, ok := tags.Get("gridfs"); ok == nil { + sname := sft.Name + "@" + gtag.Name + gfsRefs[sname] = makeGfsRef(gtag, i) + } shouldContinue = false } } } - return idcs, refs, coll + return idcs, refs, gfsRefs, coll } // Has returns the model typename and InternalModel instance corresponding @@ -191,7 +240,7 @@ func (r TModelRegistry) Has(i interface{}) (string, *InternalModel, bool) { // HasByName functions almost identically to Has, // except that it takes a string as its argument. -func (t TModelRegistry) HasByName(n string) (string, *InternalModel, bool) { +func (r TModelRegistry) HasByName(n string) (string, *InternalModel, bool) { if t, ok := ModelRegistry[n]; ok { return n, t, true } @@ -206,7 +255,7 @@ func (r TModelRegistry) Index(n string) int { return -1 } -func (t TModelRegistry) new_(n string) interface{} { +func (r TModelRegistry) new_(n string) interface{} { if name, m, ok := ModelRegistry.HasByName(n); ok { v := reflect.New(m.Type) df := v.Elem().Field(m.Idx) @@ -262,16 +311,17 @@ func (r TModelRegistry) Model(mdl ...any) { if idx < 0 { panic("A model must embed the Model struct!") } - inds, refs, coll := parseTags(t, v) + inds, refs, gfs, coll := parseTags(t, v) if coll == "" { panic(fmt.Sprintf("a model needs to be given a collection name! (passed type: %s)", n)) } ModelRegistry[n] = &InternalModel{ - Idx: idx, - Type: t, - Collection: coll, - Indexes: inds, - References: refs, + Idx: idx, + Type: t, + Collection: coll, + Indexes: inds, + References: refs, + GridFSReferences: gfs, } } for k, v := range ModelRegistry { diff --git a/testing.go b/testing.go index 0807f61..a20da77 100644 --- a/testing.go +++ b/testing.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/stretchr/testify/assert" + "strings" "testing" "time" @@ -31,7 +32,7 @@ type chapter struct { LoggedInOnly bool `bson:"loggedInOnly" json:"loggedInOnly" form:"loggedInOnly"` Posted time.Time `bson:"datePosted,omitempty" json:"datePosted"` FileName string `json:"fileName" bson:"-"` - Text string `json:"text" bson:"-"` + Text string `json:"text" bson:"-" gridfs:"story_text,/stories/{{.ChapterID}}.txt"` } type band struct { @@ -59,6 +60,20 @@ type story struct { Completed bool `bson:"completed" json:"completed" form:"completed"` Downloads int `bson:"downloads" json:"downloads"` } +type somethingWithNestedChapters struct { + ID int64 `bson:"_id" json:"_id"` + Model `bson:",inline" json:",inline" coll:"nested_stuff"` + Chapters []chapter `bson:"chapters" json:"chapters"` + NestedText string `json:"text" bson:"-" gridfs:"nested_text,/nested/{{.ID}}.txt"` +} + +func (s *somethingWithNestedChapters) Id() any { + return s.ID +} + +func (s *somethingWithNestedChapters) SetId(id any) { + s.ID = id.(int64) +} func (s *story) Id() any { return s.ID @@ -114,6 +129,7 @@ func genChaps(single bool) []chapter { {"Sean Harris", "Colin Kimberley", "Brian Tatler"}, }, } + l := loremipsum.New() for i := 0; i < ceil; i++ { spf := fmt.Sprintf("%d.md", i+1) @@ -124,20 +140,30 @@ func genChaps(single bool) []chapter { Words: 50, Notes: "notenotenote !!!", Genre: []string{"Slash"}, - Bands: []band{dh}, + Bands: []band{diamondHead}, Characters: []string{"Sean Harris", "Brian Tatler", "Duncan Scott", "Colin Kimberley"}, Relationships: relMap[i], Adult: true, - Summary: loremipsum.New().Paragraph(), + Summary: l.Paragraph(), Hidden: false, LoggedInOnly: true, FileName: spf, + Text: strings.Join(l.ParagraphList(10), "\n\n"), }) } return ret } +func doSomethingWithNested() somethingWithNestedChapters { + l := loremipsum.New() + swnc := somethingWithNestedChapters{ + Chapters: genChaps(false), + NestedText: strings.Join(l.ParagraphList(15), "\n\n"), + } + return swnc +} + var iti_single = story{ Title: "title", Completed: true,