diamond-orm/model.go

372 lines
8.9 KiB
Go
Raw Normal View History

2024-09-01 16:17:48 -04:00
package orm
import (
"context"
"fmt"
"reflect"
"time"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// Model - "base" struct for all queryable models
type Model struct {
// Created time. updated/added automatically.
Created time.Time `bson:"createdAt" json:"createdAt"`
// Modified time. updated/added automatically.
Modified time.Time `bson:"updatedAt" json:"updatedAt"`
typeName string `bson:"-" json:"-"`
self any `bson:"-" json:"-"`
exists bool `bson:"-"`
}
// HasID is a simple interface that you must implement.
// This allows for more flexibility if your ID isn't an
// ObjectID (e.g., int, uint, string...).
//
// and yes, those darn ugly ObjectIDs are supported :)
type HasID interface {
Id() any
SetId(id any)
}
type IModel interface {
Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error)
FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error)
FindByID(id interface{}) (*Query, error)
FindOne(query interface{}, options ...*options.FindOneOptions) (*Query, error)
FindPaged(query interface{}, page int64, perPage int64, options ...*options.FindOptions) (*Query, error)
getColl() *mongo.Collection
getIdxs() []*mongo.IndexModel
getParsedIdxs() map[string][]InternalIndex
Save() error
serializeToStore() primitive.M
setTypeName(str string)
}
func (m *Model) setTypeName(str string) {
m.typeName = str
}
func (m *Model) getColl() *mongo.Collection {
_, ri, ok := ModelRegistry.HasByName(m.typeName)
if !ok {
panic(fmt.Sprintf("the model '%s' has not been registered", m.typeName))
}
return DB.Collection(ri.Collection)
}
func (m *Model) getIdxs() []*mongo.IndexModel {
mi := []*mongo.IndexModel{}
if mpi := m.getParsedIdxs(); mpi != nil {
for _, v := range mpi {
for _, i := range v {
mi = append(mi, BuildIndex(i))
}
}
return mi
}
return nil
}
func (m *Model) getParsedIdxs() map[string][]InternalIndex {
_, ri, ok := ModelRegistry.HasByName(m.typeName)
if !ok {
panic(fmt.Sprintf("model '%s' not registered", m.typeName))
}
return ri.Indexes
}
func (m *Model) Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) {
coll := m.getColl()
cursor, err := coll.Find(context.TODO(), query, opts...)
return cursor, err
}
func (m *Model) FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error) {
qqn := ModelRegistry.new_(m.typeName)
qqv := reflect.New(reflect.SliceOf(reflect.TypeOf(qqn).Elem()))
qqv.Elem().Set(reflect.Zero(qqv.Elem().Type()))
qq := &Query{
Model: *m,
Collection: m.getColl(),
doc: qqv.Interface(),
Op: OP_FIND_ALL,
}
q, err := m.Find(query, opts...)
if err == nil {
rawRes := bson.A{}
err = q.All(context.TODO(), &rawRes)
if err == nil {
m.exists = true
}
qq.rawDoc = rawRes
err = q.All(context.TODO(), &qq.doc)
if err != nil {
qq.reOrganize()
err = nil
}
}
return qq, err
}
func (m *Model) FindPaged(query interface{}, page int64, perPage int64, options ...*options.FindOptions) (*Query, error) {
skipAmt := perPage * (page - 1)
if skipAmt < 0 {
skipAmt = 0
}
if len(options) > 0 {
options[0].SetSkip(skipAmt).SetLimit(perPage)
}
q, err := m.FindAll(query, options...)
q.Op = OP_FIND_PAGED
return q, err
}
func (m *Model) FindByID(id interface{}) (*Query, error) {
return m.FindOne(bson.D{{"_id", id}})
}
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// ^ ^
// ^ ^
func (m *Model) FindOne(query interface{}, options ...*options.FindOneOptions) (*Query, error) {
coll := m.getColl()
rip := coll.FindOne(context.TODO(), query, options...)
raw := bson.M{}
err := rip.Decode(&raw)
panik(err)
m.exists = true
qq := &Query{
Collection: m.getColl(),
rawDoc: raw,
doc: ModelRegistry.new_(m.typeName),
Op: OP_FIND_ONE,
Model: *m,
}
qq.rawDoc = raw
err = rip.Decode(qq.doc)
if err != nil {
qq.reOrganize()
err = nil
}
m.self = qq.doc
return qq, err
}
func (m *Model) Save() error {
var err error
c := m.getColl()
now := time.Now()
selfo := reflect.ValueOf(m.self)
vp := selfo
if vp.Kind() != reflect.Ptr {
vp = reflect.New(selfo.Type())
vp.Elem().Set(selfo)
}
var asHasId = vp.Interface().(HasID)
(asHasId).Id()
isNew := reflect.ValueOf(asHasId.Id()).IsZero() || !m.exists
if isNew {
m.Created = now
}
m.Modified = now
idxs := m.getIdxs()
for _, i := range idxs {
_, err = c.Indexes().CreateOne(context.TODO(), *i)
if err != nil {
return err
}
}
if isNew {
nid := getLastInColl(c.Name(), asHasId.Id())
switch pnid := nid.(type) {
case uint:
nid = pnid + 1
case uint32:
nid = pnid + 1
case uint64:
nid = pnid + 1
case int:
nid = pnid + 1
case int32:
nid = pnid + 1
case int64:
nid = pnid + 1
case string:
nid = NextStringID()
case primitive.ObjectID:
nid = primitive.NewObjectID()
default:
panic("unknown or unsupported id type")
}
if reflect.ValueOf(asHasId.Id()).IsZero() {
(asHasId).SetId(nid)
}
m.self = asHasId
c.InsertOne(context.TODO(), m.serializeToStore())
m.exists = true
} else {
_, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.self.(HasID).Id()}}, m.serializeToStore())
}
return err
}
func (m *Model) serializeToStore() bson.M {
return serializeIDs((*m).self)
}
func serializeIDs(input interface{}) bson.M {
ret := bson.M{}
mv := reflect.ValueOf(input)
mt := reflect.TypeOf(input)
vp := mv
if vp.Kind() != reflect.Ptr {
vp = reflect.New(mv.Type())
vp.Elem().Set(mv)
}
if mv.Kind() == reflect.Pointer {
mv = mv.Elem()
}
if mt.Kind() == reflect.Pointer {
mt = mt.Elem()
}
for i := 0; i < mv.NumField(); i++ {
fv := mv.Field(i)
ft := mt.Field(i)
var dr reflect.Value = fv
tag, err := structtag.Parse(string(mt.Field(i).Tag))
panik(err)
bbson, err := tag.Get("bson")
if err != nil || bbson.Name == "-" {
continue
}
_, terr := tag.Get("ref")
switch dr.Type().Kind() {
case reflect.Slice:
rarr := make([]interface{}, 0)
intArr := iFaceSlice(fv.Interface())
for _, idHaver := range intArr {
if terr == nil {
if reflect.ValueOf(idHaver).Type().Kind() != reflect.Pointer {
vp := reflect.New(reflect.ValueOf(idHaver).Type())
vp.Elem().Set(reflect.ValueOf(idHaver))
idHaver = vp.Interface()
}
ifc, ok := idHaver.(HasID)
if !ok {
panic(fmt.Sprintf("referenced model slice '%s' does not implement HasID", ft.Name))
}
rarr = append(rarr, ifc.Id())
} else if reflect.ValueOf(idHaver).Kind() == reflect.Struct {
rarr = append(rarr, serializeIDs(idHaver))
} else {
if reflect.ValueOf(idHaver).Kind() == reflect.Slice {
rarr = append(rarr, serializeIDSlice(iFaceSlice(idHaver)))
} else {
rarr = append(rarr, idHaver)
}
}
ret[bbson.Name] = rarr
}
case reflect.Pointer:
dr = fv.Elem()
fallthrough
case reflect.Struct:
if bbson.Name != "" {
if terr == nil {
idHaver := fv.Interface()
ifc, ok := idHaver.(HasID)
if !ok {
panic(fmt.Sprintf("referenced model '%s' does not implement HasID", ft.Name))
}
if !fv.IsNil() {
ret[bbson.Name] = ifc.Id()
} else {
ret[bbson.Name] = nil
}
} else if reflect.TypeOf(fv.Interface()) != reflect.TypeOf(time.Now()) {
ret[bbson.Name] = serializeIDs(fv.Interface())
} else {
ret[bbson.Name] = fv.Interface()
}
} else {
for k, v := range serializeIDs(fv.Interface()) {
ret[k] = v
}
}
default:
ret[bbson.Name] = fv.Interface()
}
}
return ret
}
func serializeIDSlice(input []interface{}) bson.A {
var a bson.A
for _, in := range input {
fv := reflect.ValueOf(in)
switch fv.Type().Kind() {
case reflect.Slice:
a = append(a, serializeIDSlice(iFaceSlice(fv.Interface())))
case reflect.Struct:
a = append(a, serializeIDs(fv.Interface()))
default:
a = append(a, fv.Interface())
}
}
return a
}
// Create creates a new instance of a given model.
// returns a pointer to the newly created model.
func Create(d any) any {
var n string
var ri *InternalModel
var ok bool
n, ri, ok = ModelRegistry.HasByName(nameOf(d))
if !ok {
ModelRegistry.Model(d)
n, ri, _ = ModelRegistry.Has(d)
}
t := ri.Type
v := valueOf(d)
i := ModelRegistry.Index(n)
r := reflect.New(t)
r.Elem().Set(v)
df := r.Elem().Field(i)
dm := df.Interface().(Model)
if reflect.ValueOf(d).Kind() == reflect.Pointer {
r.Elem().Set(reflect.ValueOf(d).Elem())
} else {
r.Elem().Set(reflect.ValueOf(d))
}
dm.typeName = n
what := r.Interface()
dm.self = what
df.Set(reflect.ValueOf(dm))
return what
}
func (m *Model) PrintMe() {
fmt.Printf("My name is %s !\n", nameOf(m))
}