package main import ( "crypto/md5" "fmt" "meow.tf/joker/godns/cache" "meow.tf/joker/godns/hosts" "meow.tf/joker/godns/log" "meow.tf/joker/godns/resolver" "meow.tf/joker/godns/settings" "meow.tf/joker/godns/utils" "net" "time" "github.com/miekg/dns" ) const ( notIPQuery = 0 _IP4Query = 4 _IP6Query = 6 ) type GODNSHandler struct { resolver *resolver.Resolver cache, negCache cache.Cache hosts hosts.Hosts } func NewHandler() *GODNSHandler { var ( cacheConfig settings.CacheSettings r *resolver.Resolver resolverCache, negCache cache.Cache ) r = resolver.NewResolver(settings.Resolver()) cacheConfig = settings.Cache() switch cacheConfig.Backend { case "memory": cacheDuration := time.Duration(cacheConfig.Expire) * time.Second negCache = cache.NewMemoryCache(cacheDuration/2, cacheConfig.Maxcount) resolverCache = cache.NewMemoryCache(time.Duration(cacheConfig.Expire)*time.Second, cacheConfig.Maxcount) case "memcache": resolverCache = cache.NewMemcachedCache( settings.Memcache().Servers, int32(cacheConfig.Expire)) negCache = cache.NewMemcachedCache( settings.Memcache().Servers, int32(cacheConfig.Expire/2)) case "redis": resolverCache = cache.NewRedisCache( settings.Redis(), int32(cacheConfig.Expire)) negCache = cache.NewRedisCache( settings.Redis(), int32(cacheConfig.Expire/2)) default: log.Error("Invalid cache backend %s", cacheConfig.Backend) panic("Invalid cache backend") } h := hosts.NewHosts(settings.Hosts(), settings.Redis()) return &GODNSHandler{r, resolverCache, negCache, h} } func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: dns.TypeToString[q.Qtype], Class: dns.ClassToString[q.Qclass]} var remote net.IP if Net == "tcp" { remote = w.RemoteAddr().(*net.TCPAddr).IP } else { remote = w.RemoteAddr().(*net.UDPAddr).IP } log.Info("%s lookup %s", remote, question.String()) IPQuery := h.isIPQuery(q) // Query hosts if h.hosts != nil && IPQuery > 0 { if ips, ok := h.hosts.Get(question.Name, IPQuery); ok { m := new(dns.Msg) m.SetReply(req) switch IPQuery { case _IP4Query: hdr := dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: h.hosts.TTL(), } for _, ip := range ips { m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: ip}) } case _IP6Query: hdr := dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: h.hosts.TTL(), } for _, ip := range ips { m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip}) } } w.WriteMsg(m) log.Debug("%s found in hosts file", question.Name) return } else { log.Debug("%s didn't found in hosts file", question.Name) } } // Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN' key := KeyGen(question) if IPQuery > 0 { mesg, err := h.cache.Get(key) if err != nil { if mesg, err = h.negCache.Get(key); err != nil { log.Debug("%s didn't hit cache", question.String()) } else { log.Debug("%s hit negative cache", question.String()) dns.HandleFailed(w, req) return } } else { log.Debug("%s hit cache", question.String()) // we need this copy against concurrent modification of Id msg := *mesg msg.Id = req.Id w.WriteMsg(&msg) return } } mesg, err := h.resolver.Lookup(Net, req) if err != nil { log.Warn("Resolve query error %s", err) dns.HandleFailed(w, req) // cache the failure, too! if err = h.negCache.Set(key, nil); err != nil { log.Warn("Set %s negative cache failed: %v", question.String(), err) } return } w.WriteMsg(mesg) if IPQuery > 0 && len(mesg.Answer) > 0 { err = h.cache.Set(key, mesg) if err != nil { log.Warn("Set %s cache failed: %s", question.String(), err.Error()) } log.Debug("Insert %s into cache", question.String()) } } func (h *GODNSHandler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) { return func(w dns.ResponseWriter, req *dns.Msg) { h.do(net, w, req) } } func (h *GODNSHandler) isIPQuery(q dns.Question) int { if q.Qclass != dns.ClassINET { return notIPQuery } switch q.Qtype { case dns.TypeA: return _IP4Query case dns.TypeAAAA: return _IP6Query default: return notIPQuery } } func KeyGen(q resolver.Question) string { h := md5.New() h.Write([]byte(q.String())) x := h.Sum(nil) key := fmt.Sprintf("%x", x) return key }