diff --git a/cache_store.go b/cache_store.go index f9e1088e..57fe3cde 100644 --- a/cache_store.go +++ b/cache_store.go @@ -3,6 +3,8 @@ package gorm import ( "fmt" "os" + "reflect" + "sort" "strconv" "sync" "time" @@ -20,7 +22,10 @@ type cache struct { size int highWaterMark int enabled bool + idMapMutex sync.RWMutex + idMapping map[modelId][]string database map[string]*cacheItem + mutex sync.RWMutex } func (c *cache) Enable() { @@ -38,43 +43,167 @@ func (c *cache) Enable() { c.size, _ = strconv.Atoi(size) c.highWaterMark, _ = strconv.Atoi(highWaterMark) - c.database = make(map[string]*cacheItem, c.size) + fmt.Println("Cache High Water Mark: ", c.highWaterMark) + fmt.Println("Cache Size: ", c.size) + + c.database = make(map[string]*cacheItem, c.size*2) // Size is larger to allow for temporary bursting + c.idMapping = make(map[modelId][]string, 100) + + ticker := time.NewTicker(5 * time.Second) + + go func() { + for { + select { + case <-ticker.C: + c.Empty() + } + } + }() c.enabled = true } +type KeyValue struct { + Key string + Value *cacheItem +} + +func (c cache) Empty() { + if len(c.database) > c.size { + fmt.Println("Over the limit. Running cleanup") + + var s []KeyValue + c.mutex.RLock() + for k, v := range c.database { + s = append(s, KeyValue{ + Key: k, + Value: v, + }) + } + c.mutex.RUnlock() + + // Sort the results + sort.Slice(s, func(i, j int) bool { + return s[i].Value.accessCount < s[j].Value.accessCount + }) + + // Go through the end of the results list and knock those keys off + c.mutex.Lock() + for _, res := range s[c.highWaterMark : len(s)-1] { + fmt.Println("Cleaned up query " + res.Key + " having only " + strconv.Itoa(int(res.Value.accessCount)) + " accesses.") + delete(c.database, res.Key) + } + c.mutex.Unlock() + } +} + func (c cache) GetItem(key string, offset int64) interface{} { - fmt.Println("Getting item " + key) + fmt.Print("Getting item " + key + " ... ") + c.mutex.RLock() if item, ok := c.database[key]; ok { - item.dataMutex.RLock() item.accessMutex.Lock() - - defer item.dataMutex.RUnlock() - defer item.accessMutex.Unlock() - item.accessCount++ + item.accessMutex.Unlock() - if (item.created+offset < time.Now().Unix()) || offset == -1 { + item.dataMutex.RLock() + defer item.dataMutex.RUnlock() + + if (item.created+offset > time.Now().Unix()) || offset == -1 { + fmt.Print("Found \n") + c.mutex.RUnlock() return item.data } + + fmt.Print("Expired \n") + } else { + fmt.Print("Not found \n") } + c.mutex.RUnlock() return nil } +type modelId struct { + model string + id string +} + func (c *cache) StoreItem(key string, data interface{}) { fmt.Println("Storing item " + key) - if _, ok := c.database[key]; !ok { - c.database[key] = &cacheItem{ - data: data, - created: time.Now().Unix(), + // Affected IDs + affectedIDs := make([]string, 0, 100) + var model string + + // Go through the IDs in the interface and add them and the model to the + switch reflect.TypeOf(data).Kind() { + case reflect.Slice: + // Loop through each of the items and get the primary key or "ID" value + s := reflect.ValueOf(data) + model = reflect.TypeOf(data).Elem().String() + + for i := 0; i < s.Len(); i++ { + affectedIDs = append(affectedIDs, getID(s.Index(i).Interface())) } + + case reflect.Struct: + model = reflect.TypeOf(data).String() + affectedIDs = []string{getID(data)} + } + + if _, ok := c.database[key]; !ok { + c.mutex.Lock() + c.database[key] = &cacheItem{ + created: time.Now().UnixNano(), + accessCount: 1, + data: data, + } + c.mutex.Unlock() } else { + c.mutex.RLock() c.database[key].dataMutex.Lock() c.database[key].data = data - c.database[key].created = time.Now().Unix() + c.database[key].created = time.Now().UnixNano() c.database[key].dataMutex.Unlock() + c.mutex.RUnlock() } + + // Store the query selector agains the relevent IDs + c.idMapMutex.Lock() + for _, id := range affectedIDs { + sel := modelId{model: model, id: id} + + if _, ok := c.idMapping[sel]; !ok { + // We need to create the array + c.idMapping[sel] = []string{key} + } else { + c.idMapping[sel] = append(c.idMapping[sel], key) + } + } + c.idMapMutex.Unlock() +} + +func (c *cache) Expireitem(model, id string) { + // Get the relevent cache items + sel := modelId{model: model, id: id} + c.idMapMutex.Lock() + items := c.idMapping[sel] + delete(c.idMapping, sel) + c.idMapMutex.Unlock() + + // Delete the items from the cache + c.mutex.Lock() + for _, key := range items { + fmt.Println("Expiring item " + key + "(based on " + model + "/" + id) + delete(c.database, key) + } + c.mutex.Unlock() +} + +func getID(data interface{}) string { + d := reflect.ValueOf(data) + idField := d.FieldByName("ID") + + return fmt.Sprint(idField.Interface()) } diff --git a/callback_query.go b/callback_query.go index 3dcc703c..82d85e30 100644 --- a/callback_query.go +++ b/callback_query.go @@ -78,15 +78,19 @@ func queryCallback(scope *Scope) { cacheResults := scope.CacheStore().GetItem(key, *cacheOperation) if cacheResults != nil { results.Set(reflect.ValueOf(cacheResults)) + fmt.Println("Cache HIT") readFromDB = false } else { readFromDB = true + fmt.Println() writeToCache = true } } else { readFromDB = true writeToCache = true } + } else { + fmt.Println("Cache NOT") } if readFromDB { diff --git a/go.mod b/go.mod index b327ea24..43b06f3b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.12 require ( github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-redis/redis v6.15.2+incompatible github.com/go-sql-driver/mysql v1.4.1 github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 diff --git a/main.go b/main.go index a8a13e3f..a0f3aba7 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,7 @@ type DB struct { logMode logModeValue logger logger search *search + cache *cache values sync.Map // global db @@ -161,6 +162,21 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// EnableCache turns on query caching +func (s *DB) EnableCache() *DB { + fmt.Println("Enabling caching...") + s.cache = new(cache) + s.cache.Enable() + + return s +} + +// FlushCacheItem takes a model with its ID field set and searches the caches for it. If found, that cache item is deleted +func (s *DB) FlushCacheItem(model, id string) *DB { + s.cache.Expireitem(model, id) + return s +} + // SetNowFuncOverride set the function to be used when creating a new timestamp func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { s.nowFuncOverride = nowFuncOverride diff --git a/model_struct.go b/model_struct.go index 5234b287..573d9cef 100644 --- a/model_struct.go +++ b/model_struct.go @@ -52,6 +52,18 @@ func (s *ModelStruct) TableName(db *DB) string { return DefaultTableNameHandler(db, s.defaultTableName) } +// TableName returns model's table name +func (s *ModelStruct) Cache(db *DB) *int64 { + if db != nil && s.ModelType != nil { + // Set default table name + if cache, ok := reflect.New(s.ModelType).Interface().(cacher); ok { + return cache.Cache() + } + } + + return nil +} + // StructField model field's struct definition type StructField struct { DBName string diff --git a/scope.go b/scope.go index d20cb9d5..cb2e5d46 100644 --- a/scope.go +++ b/scope.go @@ -24,7 +24,6 @@ type Scope struct { skipLeft bool fields *[]*Field selectAttrs *[]string - cacheStore *cache } // IndirectValue return scope's reflect value's indirect value @@ -48,7 +47,7 @@ func (scope *Scope) DB() *DB { // CacheStore returns scope's cache store func (scope *Scope) CacheStore() *cache { - return scope.cacheStore + return scope.db.parent.cache } // NewDB create a new DB without search information @@ -338,10 +337,12 @@ type dbTabler interface { } func (scope *Scope) Cache() *int64 { - if scope.cacheStore.enabled { - if cacher, ok := scope.Value.(cacher); ok { - return cacher.Cache() + if scope.CacheStore() != nil && scope.CacheStore().enabled { + if cache, ok := scope.Value.(cacher); ok { + return cache.Cache() } + + return scope.GetModelStruct().Cache(scope.db.Model(scope.Value)) } return nil