godns/handler.go

248 lines
5.6 KiB
Go

package main
import (
"crypto/md5"
"fmt"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"meow.tf/joker/godns/cache"
"meow.tf/joker/godns/hosts"
"meow.tf/joker/godns/resolver"
"meow.tf/joker/godns/utils"
"net"
"time"
)
type Handler struct {
resolver *resolver.Resolver
middleware []MiddlewareFunc
cache, negCache cache.Cache
hosts hosts.Hosts
}
type MiddlewareFunc func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg
func TsigMiddleware(secretKey string) MiddlewareFunc {
return func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg {
if r.IsTsig() != nil {
if w.TsigStatus() == nil {
m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.TSIG).Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix())
}
}
return m
}
}
func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler {
return &Handler{r, make([]MiddlewareFunc, 0), resolverCache, negCache, h}
}
// do handles a dns request.
// network will decide which network type it is (udp, tcp, https, etc)
func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) {
switch req.Opcode {
case dns.OpcodeQuery:
h.query(network, w, req)
case dns.OpcodeUpdate:
if req.IsTsig() == nil {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeBadSig)
w.WriteMsg(m)
return
}
packed, err := req.Pack()
if err != nil {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeBadSig)
w.WriteMsg(m)
return
}
if err := dns.TsigVerify(packed, "", "", false); err != nil {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeBadSig)
w.WriteMsg(m)
return
}
h.update(network, w, req)
}
}
// writeMsg writes a *dns.Msg after passing it through middleware.
func (h *Handler) writeMsg(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) {
for _, f := range h.middleware {
m = f(w, r, m)
}
w.WriteMsg(m)
}
// query handles dns queries.
func (h *Handler) query(network string, w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]}
var remote net.IP
switch t := w.RemoteAddr().(type) {
case *net.TCPAddr:
remote = t.IP
case *net.UDPAddr:
remote = t.IP
default:
return
}
log.WithFields(log.Fields{
"remote": remote,
"question": question.String(),
}).Debug("Lookup question")
key := KeyGen(question)
// Check cache first
mesg, err := h.cache.Get(key)
if err != nil {
if mesg, err = h.negCache.Get(key); err != nil {
log.WithField("question", question.String()).Debug("no negative cache hit")
} else {
log.WithField("question", question.String()).Debug("negative cache hit")
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeServerFailure)
w.WriteMsg(m)
return
}
} else {
log.WithField("question", question.String()).Debug("hit cache")
// we need this copy against concurrent modification of Id
msg := *mesg
msg.Id = req.Id
h.writeMsg(w, req, &msg)
return
}
// Query hosts
if h.hosts != nil {
if host, err := h.hosts.Get(question.Type, question.Name); err == nil {
m := new(dns.Msg)
m.SetReply(req)
hdr := dns.RR_Header{
Name: q.Name,
Rrtype: question.Type,
Class: dns.ClassINET,
Ttl: uint32(host.TTL / time.Second),
}
switch question.Type {
case dns.TypeA:
for _, val := range host.Values {
m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()})
}
case dns.TypeAAAA:
for _, val := range host.Values {
m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()})
}
case dns.TypeCNAME:
for _, val := range host.Values {
m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val})
}
case dns.TypeTXT:
m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: host.Values})
}
// Insert into cache before using any middleware
err = h.cache.Set(key, m)
if err != nil {
log.WithError(err).Error("Unable to insert hosts entry into cache")
}
// Write the message
h.writeMsg(w, req, m)
log.WithField("question", question.Name).Debug("Found entry in hosts file")
return
} else {
log.WithField("question", question.Name).Debug("No entry found in hosts file")
}
}
mesg, err = h.resolver.Lookup(network, req)
if err != nil {
log.WithError(err).Warn("Query resolution error")
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeServerFailure)
w.WriteMsg(m)
// cache the failure, too!
if err = h.negCache.Set(key, nil); err != nil {
log.WithError(err).Warn("Negative cache save failed")
}
return
}
// Cache if mesg.Answer length > 0
// This is done BEFORE sending it so we don't modify the request first
if len(mesg.Answer) > 0 {
err = h.cache.Set(key, mesg)
fields := log.Fields{
"question": question.String(),
}
log.WithFields(fields).Debug("Insert record into cache")
if err != nil {
fields["error"] = err
log.WithFields(fields).Warn("Unable to add to cache")
}
}
// Write message
h.writeMsg(w, req, mesg)
}
func (h *Handler) update(network string, w dns.ResponseWriter, req *dns.Msg) {
for _, question := range req.Question {
if _, ok := dns.IsDomainName(question.Name); !ok {
continue
}
for _, rr := range req.Ns {
hdr := rr.Header()
if hdr.Class == dns.ClassANY && hdr.Rdlength == 0 {
// Delete record
continue
}
}
}
}
func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
h.do(net, w, req)
}
}
func KeyGen(q resolver.Question) string {
h := md5.New()
h.Write([]byte(q.String()))
x := h.Sum(nil)
key := fmt.Sprintf("%x", x)
return key
}