diff --git a/cache.go b/cache.go index 5054c5f..126491e 100644 --- a/cache.go +++ b/cache.go @@ -58,7 +58,7 @@ type MemoryCache struct { Backend map[string]Mesg Expire time.Duration Maxcount int - mu *sync.RWMutex + mu sync.RWMutex } func (c *MemoryCache) Get(key string) (*dns.Msg, error) { @@ -92,9 +92,9 @@ func (c *MemoryCache) Set(key string, msg *dns.Msg) error { } func (c *MemoryCache) Remove(key string) { - c.mu.RLock() + c.mu.Lock() delete(c.Backend, key) - c.mu.RUnlock() + c.mu.Unlock() } func (c *MemoryCache) Exists(key string) bool { diff --git a/handler.go b/handler.go index 310ea16..6dba38d 100644 --- a/handler.go +++ b/handler.go @@ -1,7 +1,6 @@ package main import ( - "sync" "time" "github.com/miekg/dns" @@ -24,18 +23,18 @@ func (q *Question) String() string { } type GODNSHandler struct { - resolver *Resolver - cache Cache - hosts Hosts + resolver *Resolver + cache, negCache Cache + hosts Hosts } func NewHandler() *GODNSHandler { var ( - clientConfig *dns.ClientConfig - cacheConfig CacheSettings - resolver *Resolver - cache Cache + clientConfig *dns.ClientConfig + cacheConfig CacheSettings + resolver *Resolver + cache, negCache Cache ) resolvConfig := settings.ResolvConfig @@ -52,10 +51,14 @@ func NewHandler() *GODNSHandler { switch cacheConfig.Backend { case "memory": cache = &MemoryCache{ - Backend: make(map[string]Mesg), + Backend: make(map[string]Mesg, cacheConfig.Maxcount), Expire: time.Duration(cacheConfig.Expire) * time.Second, Maxcount: cacheConfig.Maxcount, - mu: new(sync.RWMutex), + } + negCache = &MemoryCache{ + Backend: make(map[string]Mesg), + Expire: time.Duration(cacheConfig.Expire) * time.Second / 2, + Maxcount: cacheConfig.Maxcount, } case "redis": // cache = &MemoryCache{ @@ -72,7 +75,7 @@ func NewHandler() *GODNSHandler { hosts := NewHosts(settings.Hosts, settings.Redis) - return &GODNSHandler{resolver, cache, hosts} + return &GODNSHandler{resolver, cache, negCache, hosts} } func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { @@ -123,11 +126,19 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { if IPQuery > 0 { mesg, err := h.cache.Get(key) if err != nil { - Debug("%s didn't hit cache: %s", Q.String(), err) + if mesg, err = h.negCache.Get(key); err != nil { + Debug("%s didn't hit cache: %s", Q.String(), err) + } else { + Debug("%s hit negative cache", Q.String()) + dns.HandleFailed(w, req) + return + } } else { Debug("%s hit cache", Q.String()) - mesg.Id = req.Id - w.WriteMsg(mesg) + // we need this copy against concurrent modification of Id + msg := *mesg + msg.Id = req.Id + w.WriteMsg(&msg) return } } @@ -137,6 +148,11 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { if err != nil { Debug("%s", err) dns.HandleFailed(w, req) + + // cache the failure, too! + if err = h.negCache.Set(key, nil); err != nil { + Debug("Set %s negative cache failed: %v", Q.String(), err) + } return }