refactor some things, add Delete method, improve Save method

This commit is contained in:
parent 3f8fded6e5
commit f721b32e01
Signed by: tablet
GPG Key ID: 924A5F6AF051E87C
5 changed files with 152 additions and 78 deletions

@ -1,3 +1,4 @@
go 1.23.0 go 1.23.0
use . use .
use muck

121
model.go

@ -24,6 +24,22 @@ type Model struct {
exists bool `bson:"-"` exists bool `bson:"-"`
} }
func (m *Model) getCreated() time.Time {
return m.Created
}
func (m *Model) setCreated(Created time.Time) {
m.Created = Created
}
func (m *Model) getModified() time.Time {
return m.Modified
}
func (m *Model) setModified(Modified time.Time) {
m.Modified = Modified
}
// HasID is a simple interface that you must implement // HasID is a simple interface that you must implement
// in your models, using a pointer receiver. // in your models, using a pointer receiver.
// This allows for more flexibility in cases where // This allows for more flexibility in cases where
@ -34,9 +50,11 @@ type HasID interface {
Id() any Id() any
SetId(id any) SetId(id any)
} }
type HasIDSlice []HasID type HasIDSlice []HasID
type IModel interface { type IModel interface {
Append(field string, a ...interface{}) error
Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) Find(query interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error)
FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error) FindAll(query interface{}, opts ...*options.FindOptions) (*Query, error)
FindByID(id interface{}) (*Query, error) FindByID(id interface{}) (*Query, error)
@ -50,12 +68,21 @@ type IModel interface {
setTypeName(str string) setTypeName(str string)
getExists() bool getExists() bool
setExists(n bool) setExists(n bool)
setModified(Modified time.Time)
setCreated(Modified time.Time)
getModified() time.Time
getCreated() time.Time
setSelf(arg interface{})
} }
func (m *Model) setTypeName(str string) { func (m *Model) setTypeName(str string) {
m.typeName = str m.typeName = str
} }
func (m *Model) setSelf(arg interface{}) {
m.self = arg
}
func (m *Model) getExists() bool { func (m *Model) getExists() bool {
return m.exists return m.exists
} }
@ -169,22 +196,32 @@ func (m *Model) FindOne(query interface{}, options ...*options.FindOneOptions) (
m.self = qq.doc m.self = qq.doc
return qq, err return qq, err
} }
func doDelete(m *Model, arg interface{}) error {
self, ok := arg.(HasID)
func (m *Model) DeleteOne() error {
c := m.getColl()
if valueOf(m.self).Kind() == reflect.Slice {
}
id, ok := m.self.(HasID)
if !ok { if !ok {
id2, ok2 := m.self.(HasIDSlice) return fmt.Errorf("Object '%s' does not implement HasID", nameOf(arg))
if !ok2 { }
return fmt.Errorf("model '%s' is not registered", m.typeName) c := m.getColl()
_, err := c.DeleteOne(context.TODO(), bson.M{"_id": self.Id()})
if err == nil {
m.exists = false
}
return err
}
func (m *Model) Delete() error {
var err error
val := valueOf(m.self)
if val.Kind() == reflect.Slice {
for i := 0; i < val.Len(); i++ {
cur := val.Index(i)
if err = doDelete(m, cur.Interface()); err != nil {
return err
}
} }
_, err := c.DeleteOne(context.TODO(), bson.M{"_id": id2[0].Id()}) return nil
return err
} else { } else {
_, err := c.DeleteOne(context.TODO(), bson.M{"_id": id.Id()}) return doDelete(m, m.self)
return err
} }
} }
@ -194,24 +231,24 @@ func (m *Model) DeleteOne() error {
// slice. // slice.
func (m *Model) Append(field string, a ...interface{}) error { func (m *Model) Append(field string, a ...interface{}) error {
rv := reflect.ValueOf(m.self) rv := reflect.ValueOf(m.self)
ref := rv selfRef := rv
rt := reflect.TypeOf(m.self) rt := reflect.TypeOf(m.self)
if ref.Kind() == reflect.Pointer { if selfRef.Kind() == reflect.Pointer {
ref = ref.Elem() selfRef = selfRef.Elem()
rt = rt.Elem() rt = rt.Elem()
} }
if ref.Kind() == reflect.Slice { if selfRef.Kind() == reflect.Slice {
return fmt.Errorf("Cannot append to multiple documents!") return fmt.Errorf("Cannot append to multiple documents!")
} }
if ref.Kind() != reflect.Struct { if selfRef.Kind() != reflect.Struct {
return fmt.Errorf("Current object is not a struct!") return fmt.Errorf("Current object is not a struct!")
} }
_, ofv, err := getNested(field, ref) _, origV, err := getNested(field, selfRef)
if err != nil { if err != nil {
return err return err
} }
oofv := makeSettable(*ofv, (*ofv).Interface()) origRef := makeSettable(*origV, (*origV).Interface())
fv := oofv fv := origRef
if fv.Kind() == reflect.Pointer { if fv.Kind() == reflect.Pointer {
fv = fv.Elem() fv = fv.Elem()
} }
@ -220,29 +257,31 @@ func (m *Model) Append(field string, a ...interface{}) error {
} }
for _, b := range a { for _, b := range a {
va := reflect.ValueOf(b) val := reflect.ValueOf(b)
fv.Set(reflect.Append(fv, va)) fv.Set(reflect.Append(fv, val))
} }
return nil return nil
} }
func (m *Model) Save() error { func doSave(c *mongo.Collection, isNew bool, arg interface{}) error {
var err error var err error
c := m.getColl() m, ok := arg.(IModel)
if !ok {
return fmt.Errorf("type '%s' is not a model", nameOf(arg))
}
m.setSelf(m)
now := time.Now() now := time.Now()
selfo := reflect.ValueOf(m.self) selfo := reflect.ValueOf(m)
vp := selfo vp := selfo
if vp.Kind() != reflect.Ptr { if vp.Kind() != reflect.Ptr {
vp = reflect.New(selfo.Type()) vp = reflect.New(selfo.Type())
vp.Elem().Set(selfo) vp.Elem().Set(selfo)
} }
var asHasId = vp.Interface().(HasID) var asHasId = vp.Interface().(HasID)
(asHasId).Id() if isNew {
isNew := reflect.ValueOf(asHasId.Id()).IsZero() && !m.exists m.setCreated(now)
if isNew || !m.exists {
m.Created = now
} }
m.Modified = now m.setModified(now)
idxs := m.getIdxs() idxs := m.getIdxs()
for _, i := range idxs { for _, i := range idxs {
_, err = c.Indexes().CreateOne(context.TODO(), *i) _, err = c.Indexes().CreateOne(context.TODO(), *i)
@ -250,7 +289,7 @@ func (m *Model) Save() error {
return err return err
} }
} }
if isNew || !m.exists { if isNew {
nid := getLastInColl(c.Name(), asHasId.Id()) nid := getLastInColl(c.Name(), asHasId.Id())
switch pnid := nid.(type) { switch pnid := nid.(type) {
case uint: case uint:
@ -276,17 +315,31 @@ func (m *Model) Save() error {
(asHasId).SetId(nid) (asHasId).SetId(nid)
} }
m.self = asHasId
_, err = c.InsertOne(context.TODO(), m.serializeToStore()) _, err = c.InsertOne(context.TODO(), m.serializeToStore())
if err == nil { if err == nil {
m.exists = true m.setExists(true)
} }
} else { } else {
_, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.self.(HasID).Id()}}, m.serializeToStore()) _, err = c.ReplaceOne(context.TODO(), bson.D{{Key: "_id", Value: m.(HasID).Id()}}, m.serializeToStore())
} }
return err return err
} }
func (m *Model) Save() error {
val := valueOf(m.self)
if val.Kind() == reflect.Slice {
for i := 0; i < val.Len(); i++ {
cur := val.Index(i)
if err := doSave(m.getColl(), !m.exists, cur.Interface()); err != nil {
return err
}
}
return nil
} else {
return doSave(m.getColl(), !m.exists, m.self)
}
}
func (m *Model) serializeToStore() bson.M { func (m *Model) serializeToStore() bson.M {
return serializeIDs((*m).self) return serializeIDs((*m).self)
} }

@ -36,13 +36,11 @@ func TestPopulate(t *testing.T) {
bandDoc := Create(iti_single.Chapters[0].Bands[0]).(*band) bandDoc := Create(iti_single.Chapters[0].Bands[0]).(*band)
storyDoc := Create(iti_single).(*story) storyDoc := Create(iti_single).(*story)
author := Create(author).(*user) mauthor := Create(author).(*user)
saveDoc(t, mauthor)
err := bandDoc.Save() saveDoc(t, bandDoc)
assert.Equal(t, nil, err) storyDoc.Author = mauthor
storyDoc.Author = author saveDoc(t, storyDoc)
err = storyDoc.Save()
assert.Equal(t, nil, err)
assert.Greater(t, storyDoc.ID, int64(0)) assert.Greater(t, storyDoc.ID, int64(0))
smodel := Create(story{}).(*story) smodel := Create(story{}).(*story)
@ -54,32 +52,21 @@ func TestPopulate(t *testing.T) {
j, _ := q.JSON() j, _ := q.JSON()
fmt.Printf("%s\n", j) fmt.Printf("%s\n", j)
}) })
assert.NotZero(t, storyDoc.Chapters[0].Bands[0].Name)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
initTest() initTest()
nb := Create(band{ nb := Create(metallica).(*band)
ID: 1, saveDoc(t, nb)
Name: "Metallica",
Characters: []string{
"James Hetfield",
"Lars Ulrich",
"Kirk Hammett",
"Cliff Burton",
},
Locked: false,
}).(*band)
err := nb.Save()
assert.Equal(t, nil, err)
nb.Locked = true nb.Locked = true
err = nb.Save() saveDoc(t, nb)
assert.Equal(t, nil, err)
foundM := Create(band{}).(*band) foundM := Create(band{}).(*band)
q, err := foundM.FindByID(int64(1)) q, err := foundM.FindByID(int64(1))
assert.Equal(t, nil, err)
found := &band{} found := &band{}
q.Exec(found) q.Exec(found)
assert.Equal(t, nil, err)
assert.Equal(t, int64(1), found.ID) assert.Equal(t, int64(1), found.ID)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, true, found.Locked) assert.Equal(t, true, found.Locked)
@ -87,6 +74,7 @@ func TestUpdate(t *testing.T) {
func TestModel_FindAll(t *testing.T) { func TestModel_FindAll(t *testing.T) {
initTest() initTest()
createAndSave(t, &iti_multi)
smodel := Create(story{}).(*story) smodel := Create(story{}).(*story)
query, err := smodel.FindAll(bson.M{}, options.Find()) query, err := smodel.FindAll(bson.M{}, options.Find())
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
@ -97,30 +85,43 @@ func TestModel_FindAll(t *testing.T) {
func TestModel_PopulateMulti(t *testing.T) { func TestModel_PopulateMulti(t *testing.T) {
initTest() initTest()
bandDoc := Create(iti_single.Chapters[0].Bands[0]).(*band)
saveDoc(t, bandDoc)
createAndSave(t, &iti_multi)
smodel := Create(story{}).(*story) smodel := Create(story{}).(*story)
query, err := smodel.FindAll(bson.M{}, options.Find()) query, err := smodel.FindAll(bson.M{}, options.Find())
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
final := make([]story, 0) final := make([]story, 0)
query.Populate("Author", "Chapters.Bands").Exec(&final) query.Populate("Author", "Chapters.Bands").Exec(&final)
assert.Greater(t, len(final), 0) assert.Greater(t, len(final), 0)
for _, s := range final {
assert.NotZero(t, s.Chapters[0].Bands[0].Name)
}
} }
func TestModel_Append(t *testing.T) { func TestModel_Append(t *testing.T) {
initTest() initTest()
bandDoc := Create(metallica).(*band)
saveDoc(t, bandDoc)
bmodel := Create(band{}).(*band) bmodel := Create(band{}).(*band)
query, err := bmodel.FindByID(int64(1)) query, err := bmodel.FindByID(int64(1))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
fin := &band{} fin := &band{}
query.Exec(fin) query.Exec(fin)
assert.Greater(t, fin.ID, int64(0)) assert.Greater(t, fin.ID, int64(0))
err = bmodel.Append("Characters", "Robert Trujillo") err = bmodel.Append("Characters", "Robert Trujillo")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
err = bmodel.Save() saveDoc(t, bmodel)
assert.Equal(t, nil, err)
fin = &band{} fin = &band{}
query, _ = bmodel.FindByID(int64(1)) query, _ = bmodel.FindByID(int64(1))
query.Exec(fin) query.Exec(fin)
assert.Greater(t, len(fin.Characters), 4) assert.Greater(t, len(fin.Characters), 4)
}
func TestModel_Delete(t *testing.T) {
initTest()
bandDoc := Create(metallica).(*band)
saveDoc(t, bandDoc)
err := bandDoc.Delete()
assert.Nil(t, err)
} }

@ -3,6 +3,8 @@ package orm
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"testing"
"time" "time"
"github.com/go-loremipsum/loremipsum" "github.com/go-loremipsum/loremipsum"
@ -15,7 +17,7 @@ import (
type chapter struct { type chapter struct {
ID primitive.ObjectID `bson:"_id" json:"_id"` ID primitive.ObjectID `bson:"_id" json:"_id"`
Title string `bson:"chapterTitle" json:"chapterTitle" form:"chapterTitle"` Title string `bson:"chapterTitle" json:"chapterTitle" form:"chapterTitle"`
ChapterID int `bson:"id" json:"chapterID"` ChapterID int `bson:"id" json:"chapterID" autoinc:"chapters"`
Index int `bson:"index" json:"index" form:"index"` Index int `bson:"index" json:"index" form:"index"`
Words int `bson:"words" json:"words"` Words int `bson:"words" json:"words"`
Notes string `bson:"notes" json:"notes" form:"notes"` Notes string `bson:"notes" json:"notes" form:"notes"`
@ -112,12 +114,7 @@ func genChaps(single bool) []chapter {
{"Sean Harris", "Colin Kimberley", "Brian Tatler"}, {"Sean Harris", "Colin Kimberley", "Brian Tatler"},
}, },
} }
var b band = band{
Name: "Diamond Head",
Locked: false,
Characters: []string{"Brian Tatler", "Sean Harris", "Duncan Scott", "Colin Kimberley"},
}
b.ID = 503
for i := 0; i < ceil; i++ { for i := 0; i < ceil; i++ {
spf := fmt.Sprintf("%d.md", i+1) spf := fmt.Sprintf("%d.md", i+1)
ret = append(ret, chapter{ ret = append(ret, chapter{
@ -126,7 +123,7 @@ func genChaps(single bool) []chapter {
Words: 50, Words: 50,
Notes: "notenotenote !!!", Notes: "notenotenote !!!",
Genre: []string{"Slash"}, Genre: []string{"Slash"},
Bands: []band{b}, Bands: []band{dh},
Characters: []string{"Sean Harris", "Brian Tatler", "Duncan Scott", "Colin Kimberley"}, Characters: []string{"Sean Harris", "Brian Tatler", "Duncan Scott", "Colin Kimberley"},
Relationships: relMap[i], Relationships: relMap[i],
Adult: true, Adult: true,
@ -156,6 +153,7 @@ func initTest() {
uri := "mongodb://127.0.0.1:27017" uri := "mongodb://127.0.0.1:27017"
db := "rockfic_ormTest" db := "rockfic_ormTest"
ic, _ := mongo.Connect(context.TODO(), options.Client().ApplyURI(uri)) ic, _ := mongo.Connect(context.TODO(), options.Client().ApplyURI(uri))
ic.Database(db).Drop(context.TODO())
colls, _ := ic.Database(db).ListCollectionNames(context.TODO(), bson.M{}) colls, _ := ic.Database(db).ListCollectionNames(context.TODO(), bson.M{})
if len(colls) < 1 { if len(colls) < 1 {
mdb := ic.Database(db) mdb := ic.Database(db)
@ -172,3 +170,32 @@ func after() {
err := DBClient.Disconnect(context.TODO()) err := DBClient.Disconnect(context.TODO())
panik(err) panik(err)
} }
var metallica = band{
ID: 1,
Name: "Metallica",
Characters: []string{
"James Hetfield",
"Lars Ulrich",
"Kirk Hammett",
"Cliff Burton",
},
Locked: false,
}
var dh = band{
ID: 503,
Name: "Diamond Head",
Locked: false,
Characters: []string{"Brian Tatler", "Sean Harris", "Duncan Scott", "Colin Kimberley"},
}
func saveDoc(t *testing.T, doc IModel) {
err := doc.Save()
assert.Nil(t, err)
}
func createAndSave(t *testing.T, doc IModel) {
mdl := Create(doc).(IModel)
saveDoc(t, mdl)
}

10
util.go

@ -30,15 +30,7 @@ func nameOf(i interface{}) string {
func valueOf(i interface{}) reflect.Value { func valueOf(i interface{}) reflect.Value {
v := reflect.ValueOf(i) v := reflect.ValueOf(i)
if v.Type().Kind() == reflect.Slice || v.Type().Kind() == reflect.Map { if v.Type().Kind() == reflect.Pointer {
in := v.Type().Elem()
switch in.Kind() {
case reflect.Pointer:
v = reflect.New(in.Elem()).Elem()
default:
v = reflect.New(in).Elem()
}
} else if v.Type().Kind() == reflect.Pointer {
v = valueOf(reflect.Indirect(v).Interface()) v = valueOf(reflect.Indirect(v).Interface())
} }
return v return v