package orm import ( "context" "fmt" "reflect" "time" "github.com/fatih/structtag" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // Model - "base" struct for all queryable models type Model struct { // Created time. updated/added automatically. Created time.Time `bson:"createdAt" json:"createdAt"` // Modified time. updated/added automatically. Modified time.Time `bson:"updatedAt" json:"updatedAt"` typeName string `bson:"-"` self any `bson:"-"` exists bool `bson:"-"` } func (m *Model) getCreated() time.Time { return m.Created } func (m *Model) setCreated(Created time.Time) { m.Created = Created } func (m *Model) getModified() time.Time { return m.Modified } func (m *Model) setModified(Modified time.Time) { m.Modified = Modified } // HasID is a simple interface that you must implement // in your models, using a pointer receiver. // This allows for more flexibility in cases where // your ID isn't an ObjectID (e.g., int, uint, string...). // // and yes, those darn ugly ObjectIDs are supported :) type HasID interface { Id() any SetId(id any) } type HasIDSlice []HasID type IModel interface { Append(field string, a ...interface{}) error Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error) FindByID(id interface{}) (*Query, error) FindOne(query interface{}, options ...*options.FindOneOptions) (*Query, error) FindPaged(query interface{}, page int64, perPage int64, options ...*options.FindOptions) (*Query, error) getColl() *mongo.Collection getIdxs() []*mongo.IndexModel getParsedIdxs() map[string][]InternalIndex Save() error serializeToStore() primitive.M setTypeName(str string) getExists() bool setExists(n bool) setModified(Modified time.Time) setCreated(Modified time.Time) getModified() time.Time getCreated() time.Time setSelf(arg interface{}) } func (m *Model) setTypeName(str string) { m.typeName = str } func (m *Model) setSelf(arg interface{}) { m.self = arg } func (m *Model) getExists() bool { return m.exists } func (m *Model) setExists(n bool) { m.exists = n } func (m *Model) getColl() *mongo.Collection { _, ri, ok := ModelRegistry.HasByName(m.typeName) if !ok { panic(fmt.Sprintf("the model '%s' has not been registered", m.typeName)) } return DB.Collection(ri.Collection) } func (m *Model) getIdxs() []*mongo.IndexModel { mi := make([]*mongo.IndexModel, 0) if mpi := m.getParsedIdxs(); mpi != nil { for _, v := range mpi { for _, i := range v { mi = append(mi, buildIndex(i)) } } return mi } return nil } func (m *Model) getParsedIdxs() map[string][]InternalIndex { _, ri, ok := ModelRegistry.HasByName(m.typeName) if !ok { panic(fmt.Sprintf("model '%s' not registered", m.typeName)) } return ri.Indexes } func (m *Model) Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) { coll := m.getColl() cursor, err := coll.Find(context.TODO(), query, opts...) return cursor, err } func (m *Model) FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error) { qqn := ModelRegistry.new_(m.typeName) qqv := reflect.New(reflect.SliceOf(reflect.TypeOf(qqn).Elem())) qqv.Elem().Set(reflect.Zero(qqv.Elem().Type())) qq := &Query{ Model: *m, Collection: m.getColl(), doc: qqv.Interface(), Op: OP_FIND_ALL, } q, err := m.Find(query, opts...) if err == nil { rawRes := bson.A{} err = q.All(context.TODO(), &rawRes) if err == nil { m.exists = true } qq.rawDoc = rawRes err = q.All(context.TODO(), &qq.doc) if err != nil { qq.reOrganize() err = nil } } return qq, err } func (m *Model) FindPaged(query interface{}, page int64, perPage int64, opts ...*options.FindOptions) (*Query, error) { skipAmt := perPage * (page - 1) if skipAmt < 0 { skipAmt = 0 } if len(opts) > 0 { opts[0].SetSkip(skipAmt).SetLimit(perPage) } else { opts = append(opts, options.Find().SetSkip(skipAmt).SetLimit(perPage)) } q, err := m.FindAll(query, opts...) q.Op = OP_FIND_PAGED return q, err } func (m *Model) FindByID(id interface{}) (*Query, error) { return m.FindOne(bson.D{{"_id", id}}) } func (m *Model) FindOne(query interface{}, options ...*options.FindOneOptions) (*Query, error) { coll := m.getColl() rip := coll.FindOne(context.TODO(), query, options...) raw := bson.M{} err := rip.Decode(&raw) panik(err) m.exists = true qq := &Query{ Collection: m.getColl(), rawDoc: raw, doc: ModelRegistry.new_(m.typeName), Op: OP_FIND_ONE, Model: *m, } qq.rawDoc = raw err = rip.Decode(qq.doc) if err != nil { qq.reOrganize() err = nil } m.self = qq.doc return qq, err } func doDelete(m *Model, arg interface{}) error { self, ok := arg.(HasID) if !ok { return fmt.Errorf("Object '%s' does not implement HasID", nameOf(arg)) } c := m.getColl() _, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()}) if err == nil { m.exists = false } return err } func (m *Model) Delete() error { var err error val := valueOf(m.self) if val.Kind() == reflect.Slice { for i := 0; i < val.Len(); i++ { cur := val.Index(i) if err = doDelete(m, cur.Interface()); err != nil { return err } } return nil } else { return doDelete(m, m.self) } } // Append appends one or more items to `field`. // will error if this Model contains a reference // to multiple documents, or if `field` is not a // slice. func (m *Model) Append(field string, a ...interface{}) error { rv := reflect.ValueOf(m.self) selfRef := rv rt := reflect.TypeOf(m.self) if selfRef.Kind() == reflect.Pointer { selfRef = selfRef.Elem() rt = rt.Elem() } if selfRef.Kind() == reflect.Slice { return fmt.Errorf("Cannot append to multiple documents!") } if selfRef.Kind() != reflect.Struct { return fmt.Errorf("Current object is not a struct!") } _, origV, err := getNested(field, selfRef) if err != nil { return err } origRef := makeSettable(*origV, (*origV).Interface()) fv := origRef if fv.Kind() == reflect.Pointer { fv = fv.Elem() } if fv.Kind() != reflect.Slice { return fmt.Errorf("Current object is not a slice!") } for _, b := range a { val := reflect.ValueOf(b) fv.Set(reflect.Append(fv, val)) } return nil } func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { var err error m, ok := arg.(IModel) if !ok { return fmt.Errorf("type '%s' is not a model", nameOf(arg)) } m.setSelf(m) now := time.Now() selfo := reflect.ValueOf(m) vp := selfo if vp.Kind() != reflect.Ptr { vp = reflect.New(selfo.Type()) vp.Elem().Set(selfo) } var asHasId = vp.Interface().(HasID) if isNew { m.setCreated(now) } m.setModified(now) idxs := m.getIdxs() for _, i := range idxs { _, err = c.Indexes().CreateOne(context.TODO(), *i) if err != nil { return err } } if isNew { nid := getLastInColl(c.Name(), asHasId.Id()) switch pnid := nid.(type) { case uint: nid = pnid + 1 case uint32: nid = pnid + 1 case uint64: nid = pnid + 1 case int: nid = pnid + 1 case int32: nid = pnid + 1 case int64: nid = pnid + 1 case string: nid = NextStringID() case primitive.ObjectID: nid = primitive.NewObjectID() default: panic("unknown or unsupported id type") } if reflect.ValueOf(asHasId.Id()).IsZero() { (asHasId).SetId(nid) } _, err = c.InsertOne(context.TODO(), m.serializeToStore()) if err == nil { m.setExists(true) } } else { _, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.(HasID).Id()}}, m.serializeToStore()) } return err } func (m *Model) Save() error { val := valueOf(m.self) if val.Kind() == reflect.Slice { for i := 0; i < val.Len(); i++ { cur := val.Index(i) if err := doSave(m.getColl(), !m.exists, cur.Interface()); err != nil { return err } } return nil } else { return doSave(m.getColl(), !m.exists, m.self) } } func (m *Model) serializeToStore() bson.M { return serializeIDs((*m).self) } func serializeIDs(input interface{}) bson.M { ret := bson.M{} mv := reflect.ValueOf(input) mt := reflect.TypeOf(input) vp := mv if vp.Kind() != reflect.Ptr { vp = reflect.New(mv.Type()) vp.Elem().Set(mv) } if mv.Kind() == reflect.Pointer { mv = mv.Elem() } if mt.Kind() == reflect.Pointer { mt = mt.Elem() } for i := 0; i < mv.NumField(); i++ { fv := mv.Field(i) ft := mt.Field(i) var dr = fv tag, err := structtag.Parse(string(mt.Field(i).Tag)) panik(err) bbson, err := tag.Get("bson") if err != nil || bbson.Name == "-" { continue } _, terr := tag.Get("ref") switch dr.Type().Kind() { case reflect.Slice: rarr := make([]interface{}, 0) intArr := iFaceSlice(fv.Interface()) for _, idHaver := range intArr { if terr == nil { if reflect.ValueOf(idHaver).Type().Kind() != reflect.Pointer { vp := reflect.New(reflect.ValueOf(idHaver).Type()) vp.Elem().Set(reflect.ValueOf(idHaver)) idHaver = vp.Interface() } ifc, ok := idHaver.(HasID) if !ok { panic(fmt.Sprintf("referenced model slice '%s' does not implement HasID", ft.Name)) } rarr = append(rarr, ifc.Id()) } else if reflect.ValueOf(idHaver).Kind() == reflect.Struct { rarr = append(rarr, serializeIDs(idHaver)) } else { if reflect.ValueOf(idHaver).Kind() == reflect.Slice { rarr = append(rarr, serializeIDSlice(iFaceSlice(idHaver))) } else { rarr = append(rarr, idHaver) } } ret[bbson.Name] = rarr } case reflect.Pointer: dr = fv.Elem() fallthrough case reflect.Struct: if bbson.Name != "" { if terr == nil { idHaver := fv.Interface() ifc, ok := idHaver.(HasID) if !ok { panic(fmt.Sprintf("referenced model '%s' does not implement HasID", ft.Name)) } if !fv.IsNil() { ret[bbson.Name] = ifc.Id() } else { ret[bbson.Name] = nil } } else if reflect.TypeOf(fv.Interface()) != reflect.TypeOf(time.Now()) { ret[bbson.Name] = serializeIDs(fv.Interface()) } else { ret[bbson.Name] = fv.Interface() } } else { for k, v := range serializeIDs(fv.Interface()) { ret[k] = v } } default: ret[bbson.Name] = fv.Interface() } } return ret } func serializeIDSlice(input []interface{}) bson.A { var a bson.A for _, in := range input { fv := reflect.ValueOf(in) switch fv.Type().Kind() { case reflect.Slice: a = append(a, serializeIDSlice(iFaceSlice(fv.Interface()))) case reflect.Struct: a = append(a, serializeIDs(fv.Interface())) default: a = append(a, fv.Interface()) } } return a } // Create creates a new instance of a given model // and returns a pointer to it. func Create(d any) any { var n string var ri *InternalModel var ok bool n, ri, ok = ModelRegistry.HasByName(nameOf(d)) if !ok { ModelRegistry.Model(d) n, ri, _ = ModelRegistry.Has(d) } t := ri.Type v := valueOf(d) i := ModelRegistry.Index(n) r := reflect.New(t) r.Elem().Set(v) df := r.Elem().Field(i) dm := df.Interface().(Model) if reflect.ValueOf(d).Kind() == reflect.Pointer { r.Elem().Set(reflect.ValueOf(d).Elem()) } else { r.Elem().Set(reflect.ValueOf(d)) } dm.typeName = n what := r.Interface() dm.self = what df.Set(reflect.ValueOf(dm)) return what } func (m *Model) PrintMe() { fmt.Printf("My name is %s !\n", nameOf(m)) }