diamond-orm/gridfs.go

198 lines
4.6 KiB
Go
Raw Permalink Normal View History

package orm
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/gridfs"
"go.mongodb.org/mongo-driver/mongo/options"
"html/template"
"io"
"reflect"
"strings"
)
type GridFSFile struct {
ID primitive.ObjectID `bson:"_id"`
Name string `bson:"filename"`
Length int `bson:"length"`
}
func parseFmt(format string, value any) string {
tmpl, err := template.New("filename").Parse(format)
panik(err)
w := new(strings.Builder)
err = tmpl.Execute(w, value)
panik(err)
return w.String()
}
func bucket(gfsRef GridFSReference) *gridfs.Bucket {
b, _ := gridfs.NewBucket(DB, options.GridFSBucket().SetName(gfsRef.BucketName))
return b
}
func gridFsLoad(val any, g GridFSReference, field string) any {
doc := reflect.ValueOf(val)
rdoc := reflect.ValueOf(val)
if doc.Kind() != reflect.Pointer {
doc = reflect.New(reflect.TypeOf(val))
doc.Elem().Set(reflect.ValueOf(val))
}
var next string
if len(strings.Split(field, ".")) > 1 {
next = strings.Join(strings.Split(field, ".")[1:], ".")
field = strings.Split(field, ".")[0]
} else {
next = field
}
_, rfield, ferr := getNested(field, rdoc)
if ferr != nil {
return nil
}
switch rfield.Kind() {
case reflect.Slice:
for i := 0; i < rfield.Len(); i++ {
cur := rfield.Index(i)
if cur.Kind() != reflect.Pointer {
tmp := reflect.New(cur.Type())
tmp.Elem().Set(cur)
cur = tmp
}
intermediate := gridFsLoad(cur.Interface(), g, next)
if intermediate == nil {
continue
}
ival := reflect.ValueOf(intermediate)
if ival.Kind() == reflect.Pointer {
ival = ival.Elem()
}
rfield.Index(i).Set(ival)
}
case reflect.Struct:
intermediate := gridFsLoad(rfield.Interface(), g, next)
if intermediate != nil {
rfield.Set(reflect.ValueOf(intermediate))
}
default:
b := bucket(g)
var found GridFSFile
cursor, err := b.Find(bson.M{"filename": parseFmt(g.FilenameFmt, val)})
if err != nil {
return nil
}
cursor.Next(context.TODO())
_ = cursor.Decode(&found)
bb := bytes.NewBuffer(nil)
_, err = b.DownloadToStream(found.ID, bb)
if err != nil {
return nil
}
if rfield.Type().AssignableTo(reflect.TypeFor[[]byte]()) {
rfield.Set(reflect.ValueOf(bb.Bytes()))
} else if rfield.Type().AssignableTo(reflect.TypeFor[string]()) {
rfield.Set(reflect.ValueOf(bb.String()))
}
}
if rdoc.Kind() != reflect.Pointer {
return doc.Elem().Interface()
}
return doc.Interface()
}
func gridFsSave(val any, imodel InternalModel) error {
var rerr error
v := reflect.ValueOf(val)
el := v
if v.Kind() == reflect.Pointer {
el = el.Elem()
}
switch el.Kind() {
case reflect.Struct:
for i := 0; i < el.NumField(); i++ {
ft := el.Type().Field(i)
fv := el.Field(i)
if !ft.IsExported() {
continue
}
_, err := structtag.Parse(string(ft.Tag))
panik(err)
var gfsRef *GridFSReference
for kk, vv := range imodel.GridFSReferences {
if strings.HasPrefix(kk, ft.Name) {
gfsRef = &vv
break
}
}
var inner = func(b *gridfs.Bucket, it reflect.Value) error {
filename := parseFmt(gfsRef.FilenameFmt, it.Interface())
contents := GridFSFile{}
curs, err2 := b.Find(bson.M{"filename": filename})
if !errors.Is(err2, mongo.ErrNoDocuments) {
_ = curs.Decode(&contents)
if !reflect.ValueOf(contents).IsZero() {
_ = b.Delete(contents.ID)
}
}
c := it.Field(gfsRef.Idx)
var rdr io.Reader
if c.Type().AssignableTo(reflect.TypeOf([]byte{})) {
rdr = bytes.NewReader(c.Interface().([]byte))
} else if c.Type().AssignableTo(reflect.TypeOf("")) {
rdr = strings.NewReader(c.Interface().(string))
} else {
return fmt.Errorf("gridfs loader type '%s' not supported", c.Type().String())
}
_, err = b.UploadFromStream(filename, rdr)
return err
}
if gfsRef != nil {
b := bucket(*gfsRef)
if fv.Kind() == reflect.Slice {
for j := 0; j < fv.Len(); j++ {
lerr := inner(b, fv.Index(j))
if lerr != nil {
return lerr
}
}
} else if fv.Kind() == reflect.Struct {
lerr := inner(b, fv)
if lerr != nil {
return lerr
}
} else {
lerr := inner(b, el)
if lerr != nil {
return lerr
}
}
}
err = gridFsSave(fv.Interface(), imodel)
if err != nil {
return err
}
}
case reflect.Slice:
for i := 0; i < el.Len(); i++ {
rerr = gridFsSave(el.Index(i).Interface(), imodel)
if rerr != nil {
return rerr
}
}
default:
break
}
return rerr
}