many additions!

- add support for `autoinc` struct tag
- add Pull method to Model
- add utility func to increment a value of an unknown type by 1
- idcounter.go no longer increments the counter value (it is now up to the caller to increment the returned value themselves)
This commit is contained in:
parent b4b18680a8
commit 924b3fc9d2
Signed by: tablet
GPG Key ID: 924A5F6AF051E87C
6 changed files with 303 additions and 156 deletions

@ -39,36 +39,6 @@ func getLastInColl(cname string, id interface{}) interface{} {
Current: id, Current: id,
} }
cnt.Collection = cname cnt.Collection = cname
switch nini := cnt.Current.(type) {
case int:
if nini == 0 {
cnt.Current = nini + 1
}
case int32:
if nini == int32(0) {
cnt.Current = nini + 1
}
case int64:
if nini == int64(0) {
cnt.Current = nini + 1
}
case uint:
if nini == uint(0) {
cnt.Current = nini + 1
}
case uint32:
if nini == uint32(0) {
cnt.Current = nini + 1
}
case uint64:
if nini == uint64(0) {
cnt.Current = nini + 1
}
case string:
cnt.Current = NextStringID()
case primitive.ObjectID:
cnt.Current = primitive.NewObjectID()
}
} }
return cnt.Current return cnt.Current
} }

344
model.go

@ -8,7 +8,6 @@ import (
"github.com/fatih/structtag" "github.com/fatih/structtag"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
) )
@ -64,7 +63,7 @@ type IModel interface {
getIdxs() []*mongo.IndexModel getIdxs() []*mongo.IndexModel
getParsedIdxs() map[string][]InternalIndex getParsedIdxs() map[string][]InternalIndex
Save() error Save() error
serializeToStore() primitive.M serializeToStore() any
setTypeName(str string) setTypeName(str string)
getExists() bool getExists() bool
setExists(n bool) setExists(n bool)
@ -225,6 +224,98 @@ func (m *Model) Delete() error {
} }
} }
func incrementTagged(item interface{}) interface{} {
rv := reflect.ValueOf(item)
rt := reflect.TypeOf(item)
if rv.Kind() != reflect.Pointer {
rv = makeSettable(rv, item)
}
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
if rt.Kind() != reflect.Struct {
if rt.Kind() == reflect.Slice {
for i := 0; i < rv.Elem().Len(); i++ {
incrementTagged(rv.Elem().Index(i).Addr().Interface())
}
} else {
return item
}
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
cur := rv.Elem().Field(i)
tags, err := structtag.Parse(string(structField.Tag))
if err != nil {
continue
}
incTag, err := tags.Get("autoinc")
if err != nil {
continue
}
nid := getLastInColl(incTag.Name, cur.Interface())
if cur.IsZero() {
coerced := coerceInt(reflect.ValueOf(incrementInterface(nid)), cur)
if coerced != nil {
cur.Set(reflect.ValueOf(coerced))
} else {
cur.Set(reflect.ValueOf(incrementInterface(nid)))
}
}
counterColl := DB.Collection(COUNTER_COL)
counterColl.UpdateOne(context.TODO(), bson.M{"collection": incTag.Name}, bson.M{"$set": bson.M{"collection": incTag.Name, "current": cur.Interface()}}, options.Update().SetUpsert(true))
}
return rv.Elem().Interface()
}
func incrementAll(item interface{}) {
if item == nil {
return
}
vp := reflect.ValueOf(item)
el := vp
if vp.Kind() == reflect.Pointer {
el = vp.Elem()
}
if vp.Kind() == reflect.Pointer && vp.IsNil() {
return
}
vt := el.Type()
switch el.Kind() {
case reflect.Struct:
incrementTagged(item)
for i := 0; i < el.NumField(); i++ {
fv := el.Field(i)
fst := vt.Field(i)
if !fst.IsExported() {
continue
}
incrementAll(fv.Interface())
}
case reflect.Slice:
for i := 0; i < el.Len(); i++ {
incd := incrementTagged(el.Index(i).Addr().Interface())
if reflect.ValueOf(incd).Kind() == reflect.Pointer {
el.Index(i).Set(reflect.ValueOf(incd).Elem())
} else {
el.Index(i).Set(reflect.ValueOf(incd))
}
}
}
}
func checkStruct(ref reflect.Value) error {
if ref.Kind() == reflect.Slice {
return fmt.Errorf("Cannot append to multiple documents!")
}
if ref.Kind() != reflect.Struct {
return fmt.Errorf("Current object is not a struct!")
}
return nil
}
// Append appends one or more items to `field`. // Append appends one or more items to `field`.
// will error if this Model contains a reference // will error if this Model contains a reference
// to multiple documents, or if `field` is not a // to multiple documents, or if `field` is not a
@ -237,16 +328,14 @@ func (m *Model) Append(field string, a ...interface{}) error {
selfRef = selfRef.Elem() selfRef = selfRef.Elem()
rt = rt.Elem() rt = rt.Elem()
} }
if selfRef.Kind() == reflect.Slice { if err := checkStruct(selfRef); err != nil {
return fmt.Errorf("Cannot append to multiple documents!") return err
}
if selfRef.Kind() != reflect.Struct {
return fmt.Errorf("Current object is not a struct!")
} }
_, origV, err := getNested(field, selfRef) _, origV, err := getNested(field, selfRef)
if err != nil { if err != nil {
return err return err
} }
origRef := makeSettable(*origV, (*origV).Interface()) origRef := makeSettable(*origV, (*origV).Interface())
fv := origRef fv := origRef
if fv.Kind() == reflect.Pointer { if fv.Kind() == reflect.Pointer {
@ -255,14 +344,49 @@ func (m *Model) Append(field string, a ...interface{}) error {
if fv.Kind() != reflect.Slice { if fv.Kind() != reflect.Slice {
return fmt.Errorf("Current object is not a slice!") return fmt.Errorf("Current object is not a slice!")
} }
for _, b := range a { for _, b := range a {
val := reflect.ValueOf(b) val := reflect.ValueOf(incrementTagged(b))
fv.Set(reflect.Append(fv, val)) fv.Set(reflect.Append(fv, val))
} }
return nil return nil
} }
func (m *Model) Pull(field string, a ...any) error {
rv := reflect.ValueOf(m.self)
selfRef := rv
rt := reflect.TypeOf(m.self)
if selfRef.Kind() == reflect.Pointer {
selfRef = selfRef.Elem()
rt = rt.Elem()
}
if err := checkStruct(selfRef); err != nil {
return err
}
_, origV, err := getNested(field, selfRef)
if err != nil {
return err
}
origRef := makeSettable(*origV, (*origV).Interface())
fv := origRef
if fv.Kind() == reflect.Pointer {
fv = fv.Elem()
}
if fv.Kind() != reflect.Slice {
return fmt.Errorf("Current object is not a slice!")
}
outer:
for _, b := range a {
for i := 0; i < fv.Len(); i++ {
if reflect.DeepEqual(b, fv.Index(i).Interface()) {
fv.Set(pull(fv, i, fv.Index(i).Type()))
break outer
}
}
}
return nil
}
func doSave(c *mongo.Collection, isNew bool, arg interface{}) error { func doSave(c *mongo.Collection, isNew bool, arg interface{}) error {
var err error var err error
m, ok := arg.(IModel) m, ok := arg.(IModel)
@ -291,29 +415,11 @@ func doSave(c *mongo.Collection, isNew bool, arg interface{}) error {
} }
if isNew { if isNew {
nid := getLastInColl(c.Name(), asHasId.Id()) nid := getLastInColl(c.Name(), asHasId.Id())
switch pnid := nid.(type) { pnid := incrementInterface(nid)
case uint:
nid = pnid + 1
case uint32:
nid = pnid + 1
case uint64:
nid = pnid + 1
case int:
nid = pnid + 1
case int32:
nid = pnid + 1
case int64:
nid = pnid + 1
case string:
nid = NextStringID()
case primitive.ObjectID:
nid = primitive.NewObjectID()
default:
panic("unknown or unsupported id type")
}
if reflect.ValueOf(asHasId.Id()).IsZero() { if reflect.ValueOf(asHasId.Id()).IsZero() {
(asHasId).SetId(nid) (asHasId).SetId(pnid)
} }
incrementAll(asHasId)
_, err = c.InsertOne(context.TODO(), m.serializeToStore()) _, err = c.InsertOne(context.TODO(), m.serializeToStore())
if err == nil { if err == nil {
@ -340,115 +446,123 @@ func (m *Model) Save() error {
} }
} }
func (m *Model) serializeToStore() bson.M { func (m *Model) serializeToStore() any {
return serializeIDs((*m).self) return serializeIDs((m).self)
} }
func serializeIDs(input interface{}) bson.M { func serializeIDs(input interface{}) interface{} {
ret := bson.M{}
mv := reflect.ValueOf(input)
mt := reflect.TypeOf(input)
vp := mv vp := reflect.ValueOf(input)
mt := reflect.TypeOf(input)
var ret interface{}
if vp.Kind() != reflect.Ptr { if vp.Kind() != reflect.Ptr {
vp = reflect.New(mv.Type()) if vp.CanAddr() {
vp.Elem().Set(mv) vp = vp.Addr()
} else {
vp = makeSettable(vp, input)
} }
if mv.Kind() == reflect.Pointer {
mv = mv.Elem()
} }
if mt.Kind() == reflect.Pointer { if mt.Kind() == reflect.Pointer {
mt = mt.Elem() mt = mt.Elem()
} }
for i := 0; i < mv.NumField(); i++ { getID := func(bbb interface{}) interface{} {
fv := mv.Field(i) mptr := reflect.ValueOf(bbb)
if mptr.Kind() != reflect.Pointer {
mptr = makeSettable(mptr, bbb)
}
ifc, ok := mptr.Interface().(HasID)
if ok {
return ifc.Id()
} else {
return nil
}
}
/*var itagged interface{}
if reflect.ValueOf(itagged).Kind() != reflect.Pointer {
itagged = incrementTagged(&input)
} else {
itagged = incrementTagged(input)
}
taggedVal := reflect.ValueOf(reflect.ValueOf(itagged).Interface()).Elem()
if vp.Kind() == reflect.Ptr {
tmp := reflect.ValueOf(taggedVal.Interface())
if tmp.Kind() == reflect.Pointer {
vp.Elem().Set(tmp.Elem())
} else {
vp.Elem().Set(tmp)
}
}*/
switch vp.Elem().Kind() {
case reflect.Struct:
ret0 := bson.M{}
for i := 0; i < vp.Elem().NumField(); i++ {
fv := vp.Elem().Field(i)
ft := mt.Field(i) ft := mt.Field(i)
var dr = fv tag, err := structtag.Parse(string(ft.Tag))
tag, err := structtag.Parse(string(mt.Field(i).Tag))
panik(err) panik(err)
bbson, err := tag.Get("bson") bbson, err := tag.Get("bson")
if err != nil || bbson.Name == "-" { if err != nil || bbson.Name == "-" {
continue continue
} }
if bbson.Name == "" {
marsh, _ := bson.Marshal(fv.Interface())
unmarsh := bson.M{}
bson.Unmarshal(marsh, &unmarsh)
for k, v := range unmarsh {
ret0[k] = v
/*if t, ok := v.(primitive.DateTime); ok {
ret0
} else {
}*/
}
} else {
_, terr := tag.Get("ref") _, terr := tag.Get("ref")
switch dr.Type().Kind() { if reflect.ValueOf(fv.Interface()).Type().Kind() != reflect.Pointer {
vp1 := reflect.New(fv.Type())
vp1.Elem().Set(reflect.ValueOf(fv.Interface()))
fv.Set(vp1.Elem())
}
if terr == nil {
ifc, ok := fv.Interface().(HasID)
if fv.Kind() == reflect.Slice {
rarr := bson.A{}
for j := 0; j < fv.Len(); j++ {
rarr = append(rarr, getID(fv.Index(j).Interface()))
}
ret0[bbson.Name] = rarr
/*ret0[bbson.Name] = serializeIDs(fv.Interface())
break*/
} else if !ok {
panic(fmt.Sprintf("referenced model slice at '%s.%s' does not implement HasID", nameOf(input), ft.Name))
} else {
if reflect.ValueOf(ifc).IsNil() {
ret0[bbson.Name] = nil
} else {
ret0[bbson.Name] = ifc.Id()
}
}
} else {
ret0[bbson.Name] = serializeIDs(fv.Interface())
}
}
ret = ret0
}
case reflect.Slice: case reflect.Slice:
rarr := make([]interface{}, 0) ret0 := bson.A{}
intArr := iFaceSlice(fv.Interface()) for i := 0; i < vp.Elem().Len(); i++ {
for _, idHaver := range intArr {
if terr == nil {
if reflect.ValueOf(idHaver).Type().Kind() != reflect.Pointer {
vp := reflect.New(reflect.ValueOf(idHaver).Type())
vp.Elem().Set(reflect.ValueOf(idHaver))
idHaver = vp.Interface()
}
ifc, ok := idHaver.(HasID)
if !ok {
panic(fmt.Sprintf("referenced model slice '%s' does not implement HasID", ft.Name))
}
rarr = append(rarr, ifc.Id())
} else if reflect.ValueOf(idHaver).Kind() == reflect.Struct {
rarr = append(rarr, serializeIDs(idHaver))
} else {
if reflect.ValueOf(idHaver).Kind() == reflect.Slice {
rarr = append(rarr, serializeIDSlice(iFaceSlice(idHaver)))
} else {
rarr = append(rarr, idHaver)
}
}
ret[bbson.Name] = rarr
}
case reflect.Pointer:
dr = fv.Elem()
fallthrough
case reflect.Struct:
if bbson.Name != "" {
if terr == nil {
idHaver := fv.Interface()
ifc, ok := idHaver.(HasID)
if !ok {
panic(fmt.Sprintf("referenced model '%s' does not implement HasID", ft.Name))
}
if !fv.IsNil() {
ret[bbson.Name] = ifc.Id()
} else {
ret[bbson.Name] = nil
}
} else if reflect.TypeOf(fv.Interface()) != reflect.TypeOf(time.Now()) { ret0 = append(ret0, serializeIDs(vp.Elem().Index(i).Addr().Interface()))
ret[bbson.Name] = serializeIDs(fv.Interface())
} else {
ret[bbson.Name] = fv.Interface()
} }
} else { ret = ret0
for k, v := range serializeIDs(fv.Interface()) {
ret[k] = v
}
}
default: default:
ret[bbson.Name] = fv.Interface() ret = vp.Elem().Interface()
}
} }
return ret return ret
} }
func serializeIDSlice(input []interface{}) bson.A {
var a bson.A
for _, in := range input {
fv := reflect.ValueOf(in)
switch fv.Type().Kind() {
case reflect.Slice:
a = append(a, serializeIDSlice(iFaceSlice(fv.Interface())))
case reflect.Struct:
a = append(a, serializeIDs(fv.Interface()))
default:
a = append(a, fv.Interface())
}
}
return a
}
// Create creates a new instance of a given model // Create creates a new instance of a given model
// and returns a pointer to it. // and returns a pointer to it.
func Create(d any) any { func Create(d any) any {

@ -29,6 +29,9 @@ func TestSave(t *testing.T) {
assert.Equal(t, nil, serr) assert.Equal(t, nil, serr)
assert.Less(t, int64(0), storyDoc.ID) assert.Less(t, int64(0), storyDoc.ID)
assert.Less(t, int64(0), lauthor.ID) assert.Less(t, int64(0), lauthor.ID)
for _, c := range storyDoc.Chapters {
assert.NotZero(t, c.ChapterID)
}
} }
func TestPopulate(t *testing.T) { func TestPopulate(t *testing.T) {
@ -52,7 +55,9 @@ func TestPopulate(t *testing.T) {
j, _ := q.JSON() j, _ := q.JSON()
fmt.Printf("%s\n", j) fmt.Printf("%s\n", j)
}) })
assert.NotZero(t, storyDoc.Chapters[0].Bands[0].Name) for _, c := range storyDoc.Chapters {
assert.NotZero(t, c.Bands[0].Name)
}
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
@ -125,3 +130,19 @@ func TestModel_Delete(t *testing.T) {
err := bandDoc.Delete() err := bandDoc.Delete()
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestModel_Pull(t *testing.T) {
initTest()
storyDoc := Create(iti_multi).(*story)
smodel := Create(story{}).(*story)
saveDoc(t, storyDoc)
err := storyDoc.Pull("Chapters", storyDoc.Chapters[4])
assert.Nil(t, err)
assert.NotZero(t, storyDoc.ID)
saveDoc(t, storyDoc)
fin := &story{}
query, err := smodel.FindByID(storyDoc.ID)
assert.Nil(t, err)
query.Exec(fin)
assert.Equal(t, 4, len(fin.Chapters))
}

@ -292,9 +292,6 @@ func rerere(input interface{}, resType reflect.Type) interface{} {
if input == nil { if input == nil {
return nil return nil
} }
if v.CanAddr() && v.IsNil() {
return nil
}
if v.Type().Kind() == reflect.Pointer { if v.Type().Kind() == reflect.Pointer {
v = v.Elem() v = v.Elem()
} }
@ -346,12 +343,14 @@ func rerere(input interface{}, resType reflect.Type) interface{} {
if err != nil { if err != nil {
tmp := rerere(intermediate, ft.Type) tmp := rerere(intermediate, ft.Type)
fuck := reflect.ValueOf(tmp) fuck := reflect.ValueOf(tmp)
if !fuck.IsZero() { if tmp != nil {
if fuck.Type().Kind() == reflect.Pointer { if fuck.Type().Kind() == reflect.Pointer {
fuck = fuck.Elem() fuck = fuck.Elem()
} }
}
fv.Set(fuck) fv.Set(fuck)
} else {
fv.Set(reflect.Zero(ft.Type))
}
shouldBreak = true shouldBreak = true
} else { } else {
tt := ft.Type tt := ft.Type

@ -118,6 +118,7 @@ func genChaps(single bool) []chapter {
for i := 0; i < ceil; i++ { for i := 0; i < ceil; i++ {
spf := fmt.Sprintf("%d.md", i+1) spf := fmt.Sprintf("%d.md", i+1)
ret = append(ret, chapter{ ret = append(ret, chapter{
ID: primitive.NewObjectID(),
Title: fmt.Sprintf("-%d-", i+1), Title: fmt.Sprintf("-%d-", i+1),
Index: int(i + 1), Index: int(i + 1),
Words: 50, Words: 50,
@ -149,6 +150,12 @@ var iti_multi story = story{
Chapters: genChaps(false), Chapters: genChaps(false),
} }
func iti_blank() story {
t := iti_single
t.Chapters = make([]chapter, 0)
return t
}
func initTest() { func initTest() {
uri := "mongodb://127.0.0.1:27017" uri := "mongodb://127.0.0.1:27017"
db := "rockfic_ormTest" db := "rockfic_ormTest"

40
util.go

@ -2,6 +2,7 @@ package orm
import ( import (
"fmt" "fmt"
"go.mongodb.org/mongo-driver/bson/primitive"
"reflect" "reflect"
"strings" "strings"
) )
@ -65,7 +66,7 @@ func coerceInt(input reflect.Value, dst reflect.Value) interface{} {
return nil return nil
} }
func getNested(field string, value reflect.Value) (*reflect.Type, *reflect.Value, error) { func getNested(field string, value reflect.Value) (*reflect.StructField, *reflect.Value, error) {
if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") { if strings.HasPrefix(field, ".") || strings.HasSuffix(field, ".") {
return nil, nil, fmt.Errorf("Malformed field name %s passed", field) return nil, nil, fmt.Errorf("Malformed field name %s passed", field)
} }
@ -78,7 +79,7 @@ func getNested(field string, value reflect.Value) (*reflect.Type, *reflect.Value
ref = ref.Elem() ref = ref.Elem()
} }
fv := ref.FieldByName(dots[0]) fv := ref.FieldByName(dots[0])
ft := fv.Type() ft, _ := ref.Type().FieldByName(dots[0])
if len(dots) > 1 { if len(dots) > 1 {
return getNested(strings.Join(dots[1:], "."), fv) return getNested(strings.Join(dots[1:], "."), fv)
} else { } else {
@ -93,3 +94,38 @@ func makeSettable(rval reflect.Value, value interface{}) reflect.Value {
} }
return rval return rval
} }
func incrementInterface(t interface{}) interface{} {
switch pt := t.(type) {
case uint:
t = pt + 1
case uint32:
t = pt + 1
case uint64:
t = pt + 1
case int:
t = pt + 1
case int32:
t = pt + 1
case int64:
t = pt + 1
case string:
t = NextStringID()
case primitive.ObjectID:
t = primitive.NewObjectID()
default:
panic("unknown or unsupported id type")
}
return t
}
func pull(s reflect.Value, idx int, typ reflect.Type) reflect.Value {
retI := reflect.New(reflect.SliceOf(typ))
for i := 0; i < s.Len(); i++ {
if i == idx {
continue
}
retI.Elem().Set(reflect.Append(retI.Elem(), s.Index(i)))
}
return retI.Elem()
}