package orm import ( "bytes" "context" "errors" "fmt" "github.com/fatih/structtag" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/gridfs" "go.mongodb.org/mongo-driver/mongo/options" "html/template" "io" "reflect" "strings" ) type GridFSFile struct { ID primitive.ObjectID `bson:"_id"` Name string `bson:"filename"` Length int `bson:"length"` } func parseFmt(format string, value any) string { tmpl, err := template.New("filename").Parse(format) panik(err) w := new(strings.Builder) err = tmpl.Execute(w, value) panik(err) return w.String() } func bucket(gfsRef gridFSReference) *gridfs.Bucket { b, _ := gridfs.NewBucket(DB, options.GridFSBucket().SetName(gfsRef.BucketName)) return b } func gridFsLoad(val any, g gridFSReference, field string) any { doc := reflect.ValueOf(val) rdoc := reflect.ValueOf(val) if doc.Kind() != reflect.Pointer { doc = reflect.New(reflect.TypeOf(val)) doc.Elem().Set(reflect.ValueOf(val)) } var next string if len(strings.Split(field, ".")) > 1 { next = strings.Join(strings.Split(field, ".")[1:], ".") field = strings.Split(field, ".")[0] } else { next = field } _, rfield, ferr := getNested(field, rdoc) if ferr != nil { return nil } switch rfield.Kind() { case reflect.Slice: for i := 0; i < rfield.Len(); i++ { cur := rfield.Index(i) if cur.Kind() != reflect.Pointer { tmp := reflect.New(cur.Type()) tmp.Elem().Set(cur) cur = tmp } intermediate := gridFsLoad(cur.Interface(), g, next) if intermediate == nil { continue } ival := reflect.ValueOf(intermediate) if ival.Kind() == reflect.Pointer { ival = ival.Elem() } rfield.Index(i).Set(ival) } case reflect.Struct: intermediate := gridFsLoad(rfield.Interface(), g, next) if intermediate != nil { rfield.Set(reflect.ValueOf(intermediate)) } default: b := bucket(g) var found GridFSFile cursor, err := b.Find(bson.M{"filename": parseFmt(g.FilenameFmt, val)}) if err != nil { return nil } cursor.Next(context.TODO()) _ = cursor.Decode(&found) bb := bytes.NewBuffer(nil) _, err = b.DownloadToStream(found.ID, bb) if err != nil { return nil } if rfield.Type().AssignableTo(reflect.TypeFor[[]byte]()) { rfield.Set(reflect.ValueOf(bb.Bytes())) } else if rfield.Type().AssignableTo(reflect.TypeFor[string]()) { rfield.Set(reflect.ValueOf(bb.String())) } } if rdoc.Kind() != reflect.Pointer { return doc.Elem().Interface() } return doc.Interface() } func gridFsSave(val any, imodel Model) error { var rerr error v := reflect.ValueOf(val) el := v if v.Kind() == reflect.Pointer { el = el.Elem() } switch el.Kind() { case reflect.Struct: for i := 0; i < el.NumField(); i++ { ft := el.Type().Field(i) fv := el.Field(i) if !ft.IsExported() { continue } _, err := structtag.Parse(string(ft.Tag)) panik(err) var gfsRef *gridFSReference for kk, vv := range imodel.gridFSReferences { if strings.HasPrefix(kk, ft.Name) { gfsRef = &vv break } } var inner = func(b *gridfs.Bucket, it reflect.Value) error { filename := parseFmt(gfsRef.FilenameFmt, it.Interface()) contents := GridFSFile{} curs, err2 := b.Find(bson.M{"filename": filename}) if !errors.Is(err2, mongo.ErrNoDocuments) { _ = curs.Decode(&contents) if !reflect.ValueOf(contents).IsZero() { _ = b.Delete(contents.ID) } } c := it.Field(gfsRef.Idx) var rdr io.Reader if c.Type().AssignableTo(reflect.TypeOf([]byte{})) { rdr = bytes.NewReader(c.Interface().([]byte)) } else if c.Type().AssignableTo(reflect.TypeOf("")) { rdr = strings.NewReader(c.Interface().(string)) } else { return fmt.Errorf("gridfs loader type '%s' not supported", c.Type().String()) } _, err = b.UploadFromStream(filename, rdr) return err } if gfsRef != nil { b := bucket(*gfsRef) if fv.Kind() == reflect.Slice { for j := 0; j < fv.Len(); j++ { lerr := inner(b, fv.Index(j)) if lerr != nil { return lerr } } } else if fv.Kind() == reflect.Struct { lerr := inner(b, fv) if lerr != nil { return lerr } } else { lerr := inner(b, el) if lerr != nil { return lerr } } } err = gridFsSave(fv.Interface(), imodel) if err != nil { return err } } case reflect.Slice: for i := 0; i < el.Len(); i++ { rerr = gridFsSave(el.Index(i).Interface(), imodel) if rerr != nil { return rerr } } default: break } return rerr }