package main import ( "crypto/md5" "fmt" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "meow.tf/joker/godns/cache" "meow.tf/joker/godns/hosts" "meow.tf/joker/godns/resolver" "meow.tf/joker/godns/utils" "net" "time" ) type Handler struct { resolver *resolver.Resolver middleware []MiddlewareFunc cache, negCache cache.Cache hosts hosts.Hosts } type MiddlewareFunc func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg func TsigMiddleware(secretKey string) MiddlewareFunc { return func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg { if r.IsTsig() != nil { if w.TsigStatus() == nil { m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix()) } } return m } } func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler { return &Handler{r, make([]MiddlewareFunc, 0), resolverCache, negCache, h} } // do handles a dns request. // network will decide which network type it is (udp, tcp, https, etc) func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) { switch req.Opcode { case dns.OpcodeQuery: h.query(network, w, req) case dns.OpcodeUpdate: if req.IsTsig() == nil { m := new(dns.Msg) m.SetRcode(req, dns.RcodeBadSig) w.WriteMsg(m) return } packed, err := req.Pack() if err != nil { m := new(dns.Msg) m.SetRcode(req, dns.RcodeBadSig) w.WriteMsg(m) return } if err := dns.TsigVerify(packed, "", "", false); err != nil { m := new(dns.Msg) m.SetRcode(req, dns.RcodeBadSig) w.WriteMsg(m) return } h.update(network, w, req) } } // writeMsg writes a *dns.Msg after passing it through middleware. func (h *Handler) writeMsg(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) { for _, f := range h.middleware { m = f(w, r, m) } w.WriteMsg(m) } // query handles dns queries. func (h *Handler) query(network string, w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]} var remote net.IP switch t := w.RemoteAddr().(type) { case *net.TCPAddr: remote = t.IP case *net.UDPAddr: remote = t.IP default: return } log.WithFields(log.Fields{ "remote": remote, "question": question.String(), }).Debug("Lookup question") key := KeyGen(question) // Check cache first mesg, err := h.cache.Get(key) if err != nil { if mesg, err = h.negCache.Get(key); err != nil { log.WithField("question", question.String()).Debug("no negative cache hit") } else { log.WithField("question", question.String()).Debug("negative cache hit") m := new(dns.Msg) m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) return } } else { log.WithField("question", question.String()).Debug("hit cache") // we need this copy against concurrent modification of Id msg := *mesg msg.Id = req.Id h.writeMsg(w, req, &msg) return } // Query hosts if h.hosts != nil { if host, err := h.hosts.Get(question.Type, question.Name); err == nil { m := new(dns.Msg) m.SetReply(req) hdr := dns.RR_Header{ Name: q.Name, Rrtype: question.Type, Class: dns.ClassINET, Ttl: uint32(host.TTL / time.Second), } switch question.Type { case dns.TypeA: for _, val := range host.Values { m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()}) } case dns.TypeAAAA: for _, val := range host.Values { m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()}) } case dns.TypeCNAME: for _, val := range host.Values { m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val}) } case dns.TypeTXT: m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: host.Values}) } // Insert into cache before using any middleware err = h.cache.Set(key, m) if err != nil { log.WithError(err).Error("Unable to insert hosts entry into cache") } // Write the message h.writeMsg(w, req, m) log.WithField("question", question.Name).Debug("Found entry in hosts file") return } else { log.WithField("question", question.Name).Debug("No entry found in hosts file") } } mesg, err = h.resolver.Lookup(network, req) if err != nil { log.WithError(err).Warn("Query resolution error") m := new(dns.Msg) m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) // cache the failure, too! if err = h.negCache.Set(key, nil); err != nil { log.WithError(err).Warn("Negative cache save failed") } return } // Cache if mesg.Answer length > 0 // This is done BEFORE sending it so we don't modify the request first if len(mesg.Answer) > 0 { err = h.cache.Set(key, mesg) fields := log.Fields{ "question": question.String(), } log.WithFields(fields).Debug("Insert record into cache") if err != nil { fields["error"] = err log.WithFields(fields).Warn("Unable to add to cache") } } // Write message h.writeMsg(w, req, mesg) } func (h *Handler) update(network string, w dns.ResponseWriter, req *dns.Msg) { for _, question := range req.Question { if _, ok := dns.IsDomainName(question.Name); !ok { continue } for _, rr := range req.Ns { hdr := rr.Header() if hdr.Class == dns.ClassANY && hdr.Rdlength == 0 { // Delete record continue } } } } func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) { return func(w dns.ResponseWriter, req *dns.Msg) { h.do(net, w, req) } } func KeyGen(q resolver.Question) string { h := md5.New() h.Write([]byte(q.String())) x := h.Sum(nil) key := fmt.Sprintf("%x", x) return key }