package main import ( "crypto/md5" "fmt" "github.com/miekg/dns" "meow.tf/joker/godns/cache" "meow.tf/joker/godns/hosts" "meow.tf/joker/godns/log" "meow.tf/joker/godns/resolver" "meow.tf/joker/godns/utils" "net" "time" ) type Handler struct { resolver *resolver.Resolver cache, negCache cache.Cache hosts hosts.Hosts } func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler { return &Handler{r, resolverCache, negCache, h} } func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]} var remote net.IP switch t := w.RemoteAddr().(type) { case *net.TCPAddr: remote = t.IP case *net.UDPAddr: remote = t.IP default: return } log.Info("%s lookup %s", remote, question.String()) // Query hosts if h.hosts != nil { if vals, ttl, ok := h.hosts.Get(question.Type, question.Name); ok { m := new(dns.Msg) m.SetReply(req) hdr := dns.RR_Header{ Name: q.Name, Rrtype: question.Type, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), } switch question.Type { case dns.TypeA: for _, val := range vals { m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()}) } case dns.TypeAAAA: for _, val := range vals { m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()}) } case dns.TypeCNAME: for _, val := range vals { m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val}) } } w.WriteMsg(m) log.Debug("%s found in hosts file", question.Name) return } else { log.Debug("%s didn't found in hosts file", question.Name) } } key := KeyGen(question) mesg, err := h.cache.Get(key) if err != nil { if mesg, err = h.negCache.Get(key); err != nil { log.Debug("%s didn't hit cache", question.String()) } else { log.Debug("%s hit negative cache", question.String()) dns.HandleFailed(w, req) return } } else { log.Debug("%s hit cache", question.String()) // we need this copy against concurrent modification of Id msg := *mesg msg.Id = req.Id w.WriteMsg(&msg) return } mesg, err = h.resolver.Lookup(network, req) if err != nil { log.Warn("Resolve query error %s", err) dns.HandleFailed(w, req) // cache the failure, too! if err = h.negCache.Set(key, nil); err != nil { log.Warn("Set %s negative cache failed: %v", question.String(), err) } return } w.WriteMsg(mesg) if len(mesg.Answer) > 0 { err = h.cache.Set(key, mesg) if err != nil { log.Warn("Set %s cache failed: %s", question.String(), err.Error()) } log.Debug("Insert %s into cache", question.String()) } } func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) { return func(w dns.ResponseWriter, req *dns.Msg) { h.do(net, w, req) } } func KeyGen(q resolver.Question) string { h := md5.New() h.Write([]byte(q.String())) x := h.Sum(nil) key := fmt.Sprintf("%x", x) return key }