godns/handler.go

137 lines
3.0 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crypto/md5"
"fmt"
"github.com/miekg/dns"
"meow.tf/joker/godns/cache"
"meow.tf/joker/godns/hosts"
"meow.tf/joker/godns/log"
"meow.tf/joker/godns/resolver"
"meow.tf/joker/godns/utils"
"net"
"time"
)
type Handler struct {
resolver *resolver.Resolver
cache, negCache cache.Cache
hosts hosts.Hosts
}
func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler {
return &Handler{r, resolverCache, negCache, h}
}
func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]}
var remote net.IP
switch t := w.RemoteAddr().(type) {
case *net.TCPAddr:
remote = t.IP
case *net.UDPAddr:
remote = t.IP
default:
return
}
log.Info("%s lookup %s", remote, question.String())
// Query hosts
if h.hosts != nil {
if vals, ttl, ok := h.hosts.Get(question.Type, question.Name); ok {
m := new(dns.Msg)
m.SetReply(req)
hdr := dns.RR_Header{
Name: q.Name,
Rrtype: question.Type,
Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second),
}
switch question.Type {
case dns.TypeA:
for _, val := range vals {
m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()})
}
case dns.TypeAAAA:
for _, val := range vals {
m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()})
}
case dns.TypeCNAME:
for _, val := range vals {
m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val})
}
}
w.WriteMsg(m)
log.Debug("%s found in hosts file", question.Name)
return
} else {
log.Debug("%s didn't found in hosts file", question.Name)
}
}
key := KeyGen(question)
mesg, err := h.cache.Get(key)
if err != nil {
if mesg, err = h.negCache.Get(key); err != nil {
log.Debug("%s didn't hit cache", question.String())
} else {
log.Debug("%s hit negative cache", question.String())
dns.HandleFailed(w, req)
return
}
} else {
log.Debug("%s hit cache", question.String())
// we need this copy against concurrent modification of Id
msg := *mesg
msg.Id = req.Id
w.WriteMsg(&msg)
return
}
mesg, err = h.resolver.Lookup(network, req)
if err != nil {
log.Warn("Resolve query error %s", err)
dns.HandleFailed(w, req)
// cache the failure, too!
if err = h.negCache.Set(key, nil); err != nil {
log.Warn("Set %s negative cache failed: %v", question.String(), err)
}
return
}
w.WriteMsg(mesg)
if len(mesg.Answer) > 0 {
err = h.cache.Set(key, mesg)
if err != nil {
log.Warn("Set %s cache failed: %s", question.String(), err.Error())
}
log.Debug("Insert %s into cache", question.String())
}
}
func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
h.do(net, w, req)
}
}
func KeyGen(q resolver.Question) string {
h := md5.New()
h.Write([]byte(q.String()))
x := h.Sum(nil)
key := fmt.Sprintf("%x", x)
return key
}