some changes

- move internal model helper funcs into model_internals.go
- add Swap method to Model
- modify `getNested` util to support indexing slice fields such as `Abc[0].Def`
This commit is contained in:
parent 924b3fc9d2
commit fefa695063
Signed by: tablet
GPG Key ID: 924A5F6AF051E87C
4 changed files with 347 additions and 272 deletions

286
model.go

@ -6,7 +6,6 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
@ -54,6 +53,7 @@ type HasIDSlice []HasID
type IModel interface { type IModel interface {
Append(field string, a ...interface{}) error Append(field string, a ...interface{}) error
Delete() error
Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error)
FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error) FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error)
FindByID(id interface{}) (*Query, error) FindByID(id interface{}) (*Query, error)
@ -62,6 +62,8 @@ type IModel interface {
getColl() *mongo.Collection getColl() *mongo.Collection
getIdxs() []*mongo.IndexModel getIdxs() []*mongo.IndexModel
getParsedIdxs() map[string][]InternalIndex getParsedIdxs() map[string][]InternalIndex
Pull(field string, a ...any) error
Remove() error
Save() error Save() error
serializeToStore() any serializeToStore() any
setTypeName(str string) setTypeName(str string)
@ -195,19 +197,8 @@ func (m *Model) FindOne(query interface{}, options ...*options.FindOneOptions) (
m.self = qq.doc m.self = qq.doc
return qq, err return qq, err
} }
func doDelete(m *Model, arg interface{}) error {
self, ok := arg.(HasID)
if !ok { // Delete - deletes a model instance from the database
return fmt.Errorf("Object '%s' does not implement HasID", nameOf(arg))
}
c := m.getColl()
_, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()})
if err == nil {
m.exists = false
}
return err
}
func (m *Model) Delete() error { func (m *Model) Delete() error {
var err error var err error
val := valueOf(m.self) val := valueOf(m.self)
@ -224,96 +215,9 @@ func (m *Model) Delete() error {
} }
} }
func incrementTagged(item interface{}) interface{} { // Remove - alias for Delete
rv := reflect.ValueOf(item) func (m *Model) Remove() error {
rt := reflect.TypeOf(item) return m.Delete()
if rv.Kind() != reflect.Pointer {
rv = makeSettable(rv, item)
}
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
if rt.Kind() != reflect.Struct {
if rt.Kind() == reflect.Slice {
for i := 0; i < rv.Elem().Len(); i++ {
incrementTagged(rv.Elem().Index(i).Addr().Interface())
}
} else {
return item
}
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
cur := rv.Elem().Field(i)
tags, err := structtag.Parse(string(structField.Tag))
if err != nil {
continue
}
incTag, err := tags.Get("autoinc")
if err != nil {
continue
}
nid := getLastInColl(incTag.Name, cur.Interface())
if cur.IsZero() {
coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur)
if coerced != nil {
cur.Set(reflect.ValueOf(coerced))
} else {
cur.Set(reflect.ValueOf(incrementInterface(nid)))
}
}
counterColl := DB.Collection(COUNTER_COL)
counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.Update().SetUpsert(true))
}
return rv.Elem().Interface()
}
func incrementAll(item interface{}) {
if item == nil {
return
}
vp := reflect.ValueOf(item)
el := vp
if vp.Kind() == reflect.Pointer {
el = vp.Elem()
}
if vp.Kind() == reflect.Pointer && vp.IsNil() {
return
}
vt := el.Type()
switch el.Kind() {
case reflect.Struct:
incrementTagged(item)
for i := 0; i < el.NumField(); i++ {
fv := el.Field(i)
fst := vt.Field(i)
if !fst.IsExported() {
continue
}
incrementAll(fv.Interface())
}
case reflect.Slice:
for i := 0; i < el.Len(); i++ {
incd := incrementTagged(el.Index(i).Addr().Interface())
if reflect.ValueOf(incd).Kind() == reflect.Pointer {
el.Index(i).Set(reflect.ValueOf(incd).Elem())
} else {
el.Index(i).Set(reflect.ValueOf(incd))
}
}
}
}
func checkStruct(ref reflect.Value) error {
if ref.Kind() == reflect.Slice {
return fmt.Errorf("Cannot append to multiple documents!")
}
if ref.Kind() != reflect.Struct {
return fmt.Errorf("Current object is not a struct!")
}
return nil
} }
// Append appends one or more items to `field`. // Append appends one or more items to `field`.
@ -351,6 +255,7 @@ func (m *Model) Append(field string, a ...interface{}) error {
return nil return nil
} }
// Pull - removes elements from the subdocument slice stored in `field`.
func (m *Model) Pull(field string, a ...any) error { func (m *Model) Pull(field string, a ...any) error {
rv := reflect.ValueOf(m.self) rv := reflect.ValueOf(m.self)
selfRef := rv selfRef := rv
@ -387,48 +292,42 @@ outer:
return nil return nil
} }
func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { // Swap - swaps the elements at indexes `i` and `j` in the
var err error // slice stored at `field`
m, ok := arg.(IModel) func (m *Model) Swap(field string, i, j int) error {
if !ok { rv := reflect.ValueOf(m.self)
return fmt.Errorf("type '%s' is not a model", nameOf(arg)) selfRef := rv
rt := reflect.TypeOf(m.self)
if selfRef.Kind() == reflect.Pointer {
selfRef = selfRef.Elem()
rt = rt.Elem()
} }
m.setSelf(m) if err := checkStruct(selfRef); err != nil {
now := time.Now() return err
selfo := reflect.ValueOf(m)
vp := selfo
if vp.Kind() != reflect.Ptr {
vp = reflect.New(selfo.Type())
vp.Elem().Set(selfo)
} }
var asHasId = vp.Interface().(HasID) _, origV, err := getNested(field, selfRef)
if isNew {
m.setCreated(now)
}
m.setModified(now)
idxs := m.getIdxs()
for _, i := range idxs {
_, err = c.Indexes().CreateOne(context.TODO(), *i)
if err != nil { if err != nil {
return err return err
} }
}
if isNew {
nid := getLastInColl(c.Name(), asHasId.Id())
pnid := incrementInterface(nid)
if reflect.ValueOf(asHasId.Id()).IsZero() {
(asHasId).SetId(pnid)
}
incrementAll(asHasId)
_, err = c.InsertOne(context.TODO(), m.serializeToStore()) origRef := makeSettable(*origV, (*origV).Interface())
if err == nil { fv := origRef
m.setExists(true) if fv.Kind() == reflect.Pointer {
} fv = fv.Elem()
} else {
_, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.(HasID).Id()}}, m.serializeToStore())
} }
if err = checkSlice(fv); err != nil {
return err return err
}
if i >= fv.Len() || j >= fv.Len() {
return fmt.Errorf("index(es) out of bounds")
}
oi := fv.Index(i).Interface()
oj := fv.Index(j).Interface()
fv.Index(i).Set(reflect.ValueOf(oj))
fv.Index(j).Set(reflect.ValueOf(oi))
return nil
} }
func (m *Model) Save() error { func (m *Model) Save() error {
@ -450,119 +349,6 @@ func (m *Model) serializeToStore() any {
return serializeIDs((m).self) return serializeIDs((m).self)
} }
func serializeIDs(input interface{}) interface{} {
vp := reflect.ValueOf(input)
mt := reflect.TypeOf(input)
var ret interface{}
if vp.Kind() != reflect.Ptr {
if vp.CanAddr() {
vp = vp.Addr()
} else {
vp = makeSettable(vp, input)
}
}
if mt.Kind() == reflect.Pointer {
mt = mt.Elem()
}
getID := func(bbb interface{}) interface{} {
mptr := reflect.ValueOf(bbb)
if mptr.Kind() != reflect.Pointer {
mptr = makeSettable(mptr, bbb)
}
ifc, ok := mptr.Interface().(HasID)
if ok {
return ifc.Id()
} else {
return nil
}
}
/*var itagged interface{}
if reflect.ValueOf(itagged).Kind() != reflect.Pointer {
itagged = incrementTagged(&input)
} else {
itagged = incrementTagged(input)
}
taggedVal := reflect.ValueOf(reflect.ValueOf(itagged).Interface()).Elem()
if vp.Kind() == reflect.Ptr {
tmp := reflect.ValueOf(taggedVal.Interface())
if tmp.Kind() == reflect.Pointer {
vp.Elem().Set(tmp.Elem())
} else {
vp.Elem().Set(tmp)
}
}*/
switch vp.Elem().Kind() {
case reflect.Struct:
ret0 := bson.M{}
for i := 0; i < vp.Elem().NumField(); i++ {
fv := vp.Elem().Field(i)
ft := mt.Field(i)
tag, err := structtag.Parse(string(ft.Tag))
panik(err)
bbson, err := tag.Get("bson")
if err != nil || bbson.Name == "-" {
continue
}
if bbson.Name == "" {
marsh, _ := bson.Marshal(fv.Interface())
unmarsh := bson.M{}
bson.Unmarshal(marsh, &unmarsh)
for k, v := range unmarsh {
ret0[k] = v
/*if t, ok := v.(primitive.DateTime); ok {
ret0
} else {
}*/
}
} else {
_, terr := tag.Get("ref")
if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer {
vp1 := reflect.New(fv.Type())
vp1.Elem().Set(reflect.ValueOf(fv.Interface()))
fv.Set(vp1.Elem())
}
if terr == nil {
ifc, ok := fv.Interface().(HasID)
if fv.Kind() == reflect.Slice {
rarr := bson.A{}
for j := 0; j < fv.Len(); j++ {
rarr = append(rarr, getID(fv.Index(j).Interface()))
}
ret0[bbson.Name] = rarr
/*ret0[bbson.Name] = serializeIDs(fv.Interface())
break*/
} else if !ok {
panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name))
} else {
if reflect.ValueOf(ifc).IsNil() {
ret0[bbson.Name] = nil
} else {
ret0[bbson.Name] = ifc.Id()
}
}
} else {
ret0[bbson.Name] = serializeIDs(fv.Interface())
}
}
ret = ret0
}
case reflect.Slice:
ret0 := bson.A{}
for i := 0; i < vp.Elem().Len(); i++ {
ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface()))
}
ret = ret0
default:
ret = vp.Elem().Interface()
}
return ret
}
// Create creates a new instance of a given model // Create creates a new instance of a given model
// and returns a pointer to it. // and returns a pointer to it.
func Create(d any) any { func Create(d any) any {

263
model_internals.go Normal file

@ -0,0 +1,263 @@
package orm
import (
"context"
"fmt"
"github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"reflect"
"time"
)
func serializeIDs(input interface{}) interface{} {
vp := reflect.ValueOf(input)
mt := reflect.TypeOf(input)
var ret interface{}
if vp.Kind() != reflect.Ptr {
if vp.CanAddr() {
vp = vp.Addr()
} else {
vp = makeSettable(vp, input)
}
}
if mt.Kind() == reflect.Pointer {
mt = mt.Elem()
}
getID := func(bbb interface{}) interface{} {
mptr := reflect.ValueOf(bbb)
if mptr.Kind() != reflect.Pointer {
mptr = makeSettable(mptr, bbb)
}
ifc, ok := mptr.Interface().(HasID)
if ok {
return ifc.Id()
} else {
return nil
}
}
/*var itagged interface{}
if reflect.ValueOf(itagged).Kind() != reflect.Pointer {
itagged = incrementTagged(&input)
} else {
itagged = incrementTagged(input)
}
taggedVal := reflect.ValueOf(reflect.ValueOf(itagged).Interface()).Elem()
if vp.Kind() == reflect.Ptr {
tmp := reflect.ValueOf(taggedVal.Interface())
if tmp.Kind() == reflect.Pointer {
vp.Elem().Set(tmp.Elem())
} else {
vp.Elem().Set(tmp)
}
}*/
switch vp.Elem().Kind() {
case reflect.Struct:
ret0 := bson.M{}
for i := 0; i < vp.Elem().NumField(); i++ {
fv := vp.Elem().Field(i)
ft := mt.Field(i)
tag, err := structtag.Parse(string(ft.Tag))
panik(err)
bbson, err := tag.Get("bson")
if err != nil || bbson.Name == "-" {
continue
}
if bbson.Name == "" {
marsh, _ := bson.Marshal(fv.Interface())
unmarsh := bson.M{}
bson.Unmarshal(marsh, &unmarsh)
for k, v := range unmarsh {
ret0[k] = v
/*if t, ok := v.(primitive.DateTime); ok {
ret0
} else {
}*/
}
} else {
_, terr := tag.Get("ref")
if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer {
vp1 := reflect.New(fv.Type())
vp1.Elem().Set(reflect.ValueOf(fv.Interface()))
fv.Set(vp1.Elem())
}
if terr == nil {
ifc, ok := fv.Interface().(HasID)
if fv.Kind() == reflect.Slice {
rarr := bson.A{}
for j := 0; j < fv.Len(); j++ {
rarr = append(rarr, getID(fv.Index(j).Interface()))
}
ret0[bbson.Name] = rarr
/*ret0[bbson.Name] = serializeIDs(fv.Interface())
break*/
} else if !ok {
panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name))
} else {
if reflect.ValueOf(ifc).IsNil() {
ret0[bbson.Name] = nil
} else {
ret0[bbson.Name] = ifc.Id()
}
}
} else {
ret0[bbson.Name] = serializeIDs(fv.Interface())
}
}
ret = ret0
}
case reflect.Slice:
ret0 := bson.A{}
for i := 0; i < vp.Elem().Len(); i++ {
ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface()))
}
ret = ret0
default:
ret = vp.Elem().Interface()
}
return ret
}
func doSave(c *mongo.Collection, isNew bool, arg interface{}) error {
var err error
m, ok := arg.(IModel)
if !ok {
return fmt.Errorf("type '%s' is not a model", nameOf(arg))
}
m.setSelf(m)
now := time.Now()
selfo := reflect.ValueOf(m)
vp := selfo
if vp.Kind() != reflect.Ptr {
vp = reflect.New(selfo.Type())
vp.Elem().Set(selfo)
}
var asHasId = vp.Interface().(HasID)
if isNew {
m.setCreated(now)
}
m.setModified(now)
idxs := m.getIdxs()
for _, i := range idxs {
_, err = c.Indexes().CreateOne(context.TODO(), *i)
if err != nil {
return err
}
}
if isNew {
nid := getLastInColl(c.Name(), asHasId.Id())
pnid := incrementInterface(nid)
if reflect.ValueOf(asHasId.Id()).IsZero() {
(asHasId).SetId(pnid)
}
incrementAll(asHasId)
_, err = c.InsertOne(context.TODO(), m.serializeToStore())
if err == nil {
m.setExists(true)
}
} else {
_, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.(HasID).Id()}}, m.serializeToStore())
}
return err
}
func doDelete(m *Model, arg interface{}) error {
self, ok := arg.(HasID)
if !ok {
return fmt.Errorf("Object '%s' does not implement HasID", nameOf(arg))
}
c := m.getColl()
_, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()})
if err == nil {
m.exists = false
}
return err
}
func incrementTagged(item interface{}) interface{} {
rv := reflect.ValueOf(item)
rt := reflect.TypeOf(item)
if rv.Kind() != reflect.Pointer {
rv = makeSettable(rv, item)
}
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
if rt.Kind() != reflect.Struct {
if rt.Kind() == reflect.Slice {
for i := 0; i < rv.Elem().Len(); i++ {
incrementTagged(rv.Elem().Index(i).Addr().Interface())
}
} else {
return item
}
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
cur := rv.Elem().Field(i)
tags, err := structtag.Parse(string(structField.Tag))
if err != nil {
continue
}
incTag, err := tags.Get("autoinc")
if err != nil {
continue
}
nid := getLastInColl(incTag.Name, cur.Interface())
if cur.IsZero() {
coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur)
if coerced != nil {
cur.Set(reflect.ValueOf(coerced))
} else {
cur.Set(reflect.ValueOf(incrementInterface(nid)))
}
}
counterColl := DB.Collection(COUNTER_COL)
counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.Update().SetUpsert(true))
}
return rv.Elem().Interface()
}
func incrementAll(item interface{}) {
if item == nil {
return
}
vp := reflect.ValueOf(item)
el := vp
if vp.Kind() == reflect.Pointer {
el = vp.Elem()
}
if vp.Kind() == reflect.Pointer && vp.IsNil() {
return
}
vt := el.Type()
switch el.Kind() {
case reflect.Struct:
incrementTagged(item)
for i := 0; i < el.NumField(); i++ {
fv := el.Field(i)
fst := vt.Field(i)
if !fst.IsExported() {
continue
}
incrementAll(fv.Interface())
}
case reflect.Slice:
for i := 0; i < el.Len(); i++ {
incd := incrementTagged(el.Index(i).Addr().Interface())
if reflect.ValueOf(incd).Kind() == reflect.Pointer {
el.Index(i).Set(reflect.ValueOf(incd).Elem())
} else {
el.Index(i).Set(reflect.ValueOf(incd))
}
}
default:
}
}

@ -146,3 +146,17 @@ func TestModel_Pull(t *testing.T) {
query.Exec(fin) query.Exec(fin)
assert.Equal(t, 4, len(fin.Chapters)) assert.Equal(t, 4, len(fin.Chapters))
} }
func TestModel_Swap(t *testing.T) {
initTest()
iti_single.Author = &author
storyDoc := Create(iti_single).(*story)
storyDoc.Chapters[0].Bands = append(storyDoc.Chapters[0].Bands, metallica)
assert.Equal(t, 2, len(storyDoc.Chapters[0].Bands))
err := storyDoc.Swap("Chapters[0].Bands", 0, 1)
assert.Nil(t, err)
c := storyDoc.Chapters[0].Bands
assert.Equal(t, metallica.ID, c[0].ID)
assert.Equal(t, dh.ID, c[1].ID)
saveDoc(t, storyDoc)
}

50
util.go

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"reflect" "reflect"
"regexp"
"strconv"
"strings" "strings"
) )
@ -37,22 +39,6 @@ func valueOf(i interface{}) reflect.Value {
return v return v
} }
func iFace(input interface{}) interface{} {
return reflect.ValueOf(input).Interface()
}
func iFaceSlice(input interface{}) []interface{} {
ret := make([]interface{}, 0)
fv := reflect.ValueOf(input)
if fv.Type().Kind() != reflect.Slice {
return ret
}
for i := 0; i < fv.Len(); i++ {
ret = append(ret, fv.Index(i).Interface())
}
return ret
}
func coerceInt(input reflect.Value, dst reflect.Value) interface{} { func coerceInt(input reflect.Value, dst reflect.Value) interface{} {
if input.Type().Kind() == reflect.Pointer { if input.Type().Kind() == reflect.Pointer {
input = input.Elem() input = input.Elem()
@ -66,20 +52,29 @@ func coerceInt(input reflect.Value, dst reflect.Value) interface{} {
return nil return nil
} }
var arrRegex, _ = regexp.Compile(`\[(?P<index>\d+)]$`)
func getNested(field string, value reflect.Value) (*reflect.StructField, *reflect.Value, error) { func getNested(field string, value reflect.Value) (*reflect.StructField, *reflect.Value, error) {
if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") { if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") {
return nil, nil, fmt.Errorf("Malformed field name %s passed", field) return nil, nil, fmt.Errorf("Malformed field name %s passed", field)
} }
dots := strings.Split(field, ".") dots := strings.Split(field, ".")
if value.Kind() != reflect.Struct { if value.Kind() != reflect.Struct && arrRegex.FindString(dots[0]) == "" {
return nil, nil, fmt.Errorf("This value is not a struct!") return nil, nil, fmt.Errorf("This value is not a struct!")
} }
ref := value ref := value
if ref.Kind() == reflect.Pointer { if ref.Kind() == reflect.Pointer {
ref = ref.Elem() ref = ref.Elem()
} }
fv := ref.FieldByName(dots[0]) var fv reflect.Value = ref.FieldByName(arrRegex.ReplaceAllString(dots[0], ""))
ft, _ := ref.Type().FieldByName(dots[0]) if arrRegex.FindString(dots[0]) != "" && fv.Kind() == reflect.Slice {
matches := arrRegex.FindStringSubmatch(dots[0])
ridx, _ := strconv.Atoi(matches[0])
idx := int(ridx)
fv = fv.Index(idx)
}
ft, _ := ref.Type().FieldByName(arrRegex.ReplaceAllString(dots[0], ""))
if len(dots) > 1 { if len(dots) > 1 {
return getNested(strings.Join(dots[1:], "."), fv) return getNested(strings.Join(dots[1:], "."), fv)
} else { } else {
@ -129,3 +124,20 @@ func pull(s reflect.Value, idx int, typ reflect.Type) reflect.Value {
} }
return retI.Elem() return retI.Elem()
} }
func checkStruct(ref reflect.Value) error {
if ref.Kind() == reflect.Slice {
return fmt.Errorf("Cannot append to multiple documents!")
}
if ref.Kind() != reflect.Struct {
return fmt.Errorf("Current object is not a struct!")
}
return nil
}
func checkSlice(ref reflect.Value) error {
if ref.Kind() != reflect.Slice {
return fmt.Errorf("Current field is not a slice!")
}
return nil
}