diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5e56e04 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/bin diff --git a/bin/godns b/bin/godns deleted file mode 100755 index a1806f8..0000000 Binary files a/bin/godns and /dev/null differ diff --git a/handler.go b/handler.go index 41dc174..8585670 100644 --- a/handler.go +++ b/handler.go @@ -1,10 +1,10 @@ package main import ( - "github.com/miekg/dns" - "net" "sync" "time" + + "github.com/miekg/dns" ) type Question struct { @@ -13,6 +13,12 @@ type Question struct { qclass string } +const ( + notIPQuery = 0 + _IP4Query = 4 + _IP6Query = 6 +) + func (q *Question) String() string { return q.qname + " " + q.qclass + " " + q.qtype } @@ -76,24 +82,47 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { Debug("Question: %s", Q.String()) + IPQuery := h.isIPQuery(q) + // Query hosts - if settings.Hosts.Enable && h.isIPQuery(q) { - if ip, ok := h.hosts.Get(Q.qname); ok { + if settings.Hosts.Enable && IPQuery > 0 { + if ip, ok := h.hosts.Get(Q.qname, IPQuery); ok { m := new(dns.Msg) m.SetReply(req) - rr_header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: settings.Hosts.TTL} - a := &dns.A{rr_header, net.ParseIP(ip)} - m.Answer = append(m.Answer, a) + + switch IPQuery { + case _IP4Query: + rr_header := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: settings.Hosts.TTL, + } + a := &dns.A{rr_header, ip} + m.Answer = append(m.Answer, a) + case _IP6Query: + rr_header := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: settings.Hosts.TTL, + } + aaaa := &dns.AAAA{rr_header, ip} + m.Answer = append(m.Answer, aaaa) + } + w.WriteMsg(m) - Debug("%s found in hosts", Q.qname) + Debug("%s found in hosts file", Q.qname) return + } else { + Debug("%s didn't found in hosts file", Q.qname) } } // Only query cache when qtype == 'A' , qclass == 'IN' key := KeyGen(Q) - if h.isIPQuery(q) { + if IPQuery > 0 { mesg, err := h.cache.Get(key) if err != nil { Debug("%s didn't hit cache: %s", Q.String(), err) @@ -118,7 +147,7 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { w.WriteMsg(mesg) - if h.isIPQuery(q) { + if IPQuery > 0 { err = h.cache.Set(key, mesg) if err != nil { @@ -138,8 +167,19 @@ func (h *GODNSHandler) DoUDP(w dns.ResponseWriter, req *dns.Msg) { h.do("udp", w, req) } -func (h *GODNSHandler) isIPQuery(q dns.Question) bool { - return q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET +func (h *GODNSHandler) isIPQuery(q dns.Question) int { + if q.Qclass != dns.ClassINET { + return notIPQuery + } + + switch q.Qtype { + case dns.TypeA: + return _IP4Query + case dns.TypeAAAA: + return _IP6Query + default: + return notIPQuery + } } func UnFqdn(s string) string { diff --git a/hosts.go b/hosts.go index 5856910..17d6445 100644 --- a/hosts.go +++ b/hosts.go @@ -2,10 +2,12 @@ package main import ( "bufio" - "github.com/hoisie/redis" + "net" "os" "regexp" "strings" + + "github.com/hoisie/redis" ) type Hosts struct { @@ -29,14 +31,24 @@ func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { 3. Match local /etc/hosts file first, remote redis records second */ -func (h *Hosts) Get(domain string) (ip string, ok bool) { - if ip, ok = h.FileHosts[domain]; ok { - return +func (h *Hosts) Get(domain string, family int) (ip net.IP, ok bool) { + var sip string + + if sip, ok = h.FileHosts[domain]; !ok { + if sip, ok = h.RedisHosts.Get(domain); !ok { + return nil, false + } } - if ip, ok = h.RedisHosts.Get(domain); ok { - return + + switch family { + case _IP4Query: + ip = net.ParseIP(sip).To4() + return ip, (ip != nil) + case _IP6Query: + ip = net.ParseIP(sip).To16() + return ip, (ip != nil) } - return "", false + return nil, false } func (h *Hosts) GetAll() map[string]string { @@ -114,11 +126,13 @@ func (f *FileHosts) GetAll() map[string]string { } func (f *FileHosts) isDomain(domain string) bool { - match, _ := regexp.MatchString("^[a-z]", domain) + if f.isIP(domain) { + return false + } + match, _ := regexp.MatchString("^[a-zA-Z0-9][a-zA-Z0-9-]", domain) return match } func (f *FileHosts) isIP(ip string) bool { - match, _ := regexp.MatchString("^[1-9]", ip) - return match + return (net.ParseIP(ip) != nil) } diff --git a/hosts_test.go b/hosts_test.go new file mode 100644 index 0000000..af70747 --- /dev/null +++ b/hosts_test.go @@ -0,0 +1,34 @@ +package main + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestHostDomainAndIP(t *testing.T) { + Convey("Test Host File Domain and IP regex", t, func() { + f := &FileHosts{} + + Convey("1.1.1.1 should be IP and not domain", func() { + So(f.isIP("1.1.1.1"), ShouldEqual, true) + So(f.isDomain("1.1.1.1"), ShouldEqual, false) + }) + + Convey("2001:470:20::2 should be IP and not domain", func() { + So(f.isIP("2001:470:20::2"), ShouldEqual, true) + So(f.isDomain("2001:470:20::2"), ShouldEqual, false) + }) + + Convey("`host` should be domain and not IP", func() { + So(f.isDomain("host"), ShouldEqual, true) + So(f.isIP("host"), ShouldEqual, false) + }) + + Convey("`123.test` should be domain and not IP", func() { + So(f.isDomain("123.test"), ShouldEqual, true) + So(f.isIP("123.test"), ShouldEqual, false) + }) + + }) +}