cacheinterface/driver/memory/memory.go

200 lines
3.2 KiB
Go
Raw Normal View History

package memory
import (
"errors"
"github.com/patrickmn/go-cache"
"meow.tf/go/cacheinterface/v2/encoder"
"reflect"
"time"
)
var (
ErrNotExist = errors.New("item does not exist")
)
type Options struct {
2023-02-05 01:31:09 +00:00
Encoder encoder.Encoder `default:"msgpack"`
DefaultExpiration time.Duration `query:"defaultExpiration" default:"1m"`
CleanupTime time.Duration `query:"cleanupTime" default:"5m"`
}
type Cache struct {
options Options
c *cache.Cache
}
func New(options Options) (*Cache, error) {
2023-02-05 01:31:09 +00:00
c := cache.New(options.DefaultExpiration, options.CleanupTime)
return &Cache{
options: options,
c: c,
}, nil
}
func (mc *Cache) Has(key string) bool {
_, exists := mc.c.Get(key)
return exists
}
func (mc *Cache) Get(key string, dst any) error {
item, exists := mc.c.Get(key)
if !exists {
return ErrNotExist
}
return CacheGet(item, dst)
}
func (mc *Cache) GetBytes(key string) ([]byte, error) {
item, exists := mc.c.Get(key)
if !exists {
return nil, ErrNotExist
}
return CacheGetBytes(mc.options.Encoder, item)
}
func (mc *Cache) Set(key string, val any, ttl time.Duration) error {
mc.c.Set(key, val, ttl)
return nil
}
func (mc *Cache) Del(key string) error {
mc.c.Delete(key)
return nil
}
func CacheGetBytes(encoder encoder.Encoder, item any) ([]byte, error) {
switch item.(type) {
case string:
return []byte(item.(string)), nil
case []byte:
return item.([]byte), nil
}
return encoder.Marshal(item)
}
func CacheGet(item any, v any) error {
switch v := v.(type) {
case *string:
if v != nil {
*v = item.(string)
return nil
}
case *[]byte:
if v != nil {
*v = item.([]byte)
return nil
}
case *int:
if v != nil {
*v = item.(int)
return nil
}
case *int8:
if v != nil {
*v = item.(int8)
return nil
}
case *int16:
if v != nil {
*v = item.(int16)
return nil
}
case *int32:
if v != nil {
*v = item.(int32)
return nil
}
case *int64:
if v != nil {
*v = item.(int64)
return nil
}
case *uint:
if v != nil {
*v = item.(uint)
return nil
}
case *uint8:
if v != nil {
*v = item.(uint8)
return nil
}
case *uint16:
if v != nil {
*v = item.(uint16)
return nil
}
case *uint32:
if v != nil {
*v = item.(uint32)
return nil
}
case *uint64:
if v != nil {
*v = item.(uint64)
return nil
}
case *bool:
if v != nil {
*v = item.(bool)
return nil
}
case *float32:
if v != nil {
*v = item.(float32)
return nil
}
case *float64:
if v != nil {
*v = item.(float64)
return nil
}
case *[]string:
*v = item.([]string)
return nil
case *map[string]string:
*v = item.(map[string]string)
return nil
case *map[string]any:
*v = item.(map[string]any)
return nil
case *time.Duration:
if v != nil {
*v = item.(time.Duration)
return nil
}
case *time.Time:
if v != nil {
*v = item.(time.Time)
return nil
}
}
vv := reflect.ValueOf(v)
if !vv.IsValid() {
return errors.New("dst pointer is not valid")
}
if vv.Kind() != reflect.Ptr {
return errors.New("dst pointer is not a pointer")
}
vv = vv.Elem()
if !vv.IsValid() {
return errors.New("dst pointer is not a valid element")
}
vv.Set(reflect.ValueOf(item))
return nil
}