From 3383c5e4f95413b2e72808f189c3e3f64dc59be1 Mon Sep 17 00:00:00 2001 From: Tyler Date: Fri, 7 Feb 2020 22:38:22 -0500 Subject: [PATCH] Changes to allow host "Set" to be standard, providers being able to use other query types, cache all responses, etc. --- etc/godns.conf | 2 +- etc/godns.example.conf | 56 -------------------- handler.go | 114 +++++++++++++++++------------------------ hosts/hosts.go | 38 +++----------- hosts/hosts_file.go | 15 +++++- hosts/hosts_redis.go | 10 +++- main.go | 2 +- resolver/question.go | 6 ++- resolver/resolver.go | 27 ++++++---- resolver/settings.go | 9 ++++ 10 files changed, 107 insertions(+), 172 deletions(-) delete mode 100644 etc/godns.example.conf create mode 100644 resolver/settings.go diff --git a/etc/godns.conf b/etc/godns.conf index 2ba16bb..38fa4c5 100644 --- a/etc/godns.conf +++ b/etc/godns.conf @@ -29,7 +29,7 @@ maxcount = 0 #If set zero. The Sum of cache itmes will be unlimit. # Redis cache backend config [cache.redis] - host = "192.168.1.71" + host = "127.0.0.1" port = 6379 db = 0 password ="" diff --git a/etc/godns.example.conf b/etc/godns.example.conf deleted file mode 100644 index 6b3b8cd..0000000 --- a/etc/godns.example.conf +++ /dev/null @@ -1,56 +0,0 @@ -#Toml config file -title = "GODNS" -Version = "0.2.3" -Author = "kenshin, tystuyfzand" - -debug = false - -[server] -host = "" -port = 53 - -[resolv] -# Domain-specific nameservers configuration, formatting keep compatible with Dnsmasq -# Semicolon separate multiple files. -resolv-file = "/etc/resolv.conf" -timeout = 5 # 5 seconds -# The concurrency interval request upstream recursive server -# Match the PR15, https://github.com/kenshinx/godns/pull/15 -interval = 200 # 200 milliseconds -# When defined, this is preferred over regular DNS. This requires a resolver to be active besides this, only for the initial lookup. -# A hosts file entry will suffice as well. -# dns-over-https = "https://cloudflare-dns.com/dns-query" -setedns0 = false #Support for larger UDP DNS responses - -[redis] -enable = true -host = "127.0.0.1" -port = 6379 -db = 0 -password ="" - -[memcache] -servers = ["127.0.0.1:11211"] - -[log] -stdout = true -file = "./godns.log" -level = "INFO" #DEBUG | INFO |NOTICE | WARN | ERROR - -[cache] -# backend option [memory|memcache|redis] -backend = "memory" -expire = 600 # 10 minutes -maxcount = 0 #If set zero. The Sum of cache items will be unlimit. - -[hosts] -#If set false, will not query hosts file and redis hosts record -enable = true -host-file = "/etc/hosts" -redis-enable = false -redis-key = "godns:hosts" -ttl = 600 -# Refresh interval can be high since we have automatic updating via push and fsnotify -refresh-interval = 300 - - diff --git a/handler.go b/handler.go index 128b053..39ab535 100644 --- a/handler.go +++ b/handler.go @@ -13,12 +13,6 @@ import ( "time" ) -const ( - notIPQuery = 0 - _IP4Query = 4 - _IP6Query = 6 -) - type Handler struct { resolver *resolver.Resolver cache, negCache cache.Cache @@ -29,46 +23,48 @@ func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hos return &Handler{r, resolverCache, negCache, h} } -func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { +func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] - question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: dns.TypeToString[q.Qtype], Class: dns.ClassToString[q.Qclass]} + question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]} var remote net.IP - if Net == "tcp" { - remote = w.RemoteAddr().(*net.TCPAddr).IP - } else { - remote = w.RemoteAddr().(*net.UDPAddr).IP + + switch t := w.RemoteAddr().(type) { + case *net.TCPAddr: + remote = t.IP + case *net.UDPAddr: + remote = t.IP + default: + return } + log.Info("%s lookup %s", remote, question.String()) - IPQuery := h.isIPQuery(q) - // Query hosts - if h.hosts != nil && IPQuery > 0 { - if ips, ttl, ok := h.hosts.Get(question.Name, IPQuery); ok { + if h.hosts != nil { + if vals, ttl, ok := h.hosts.Get(question.Type, question.Name); ok { m := new(dns.Msg) m.SetReply(req) - switch IPQuery { - case _IP4Query: - hdr := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), + hdr := dns.RR_Header{ + Name: q.Name, + Rrtype: question.Type, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + } + + switch question.Type { + case dns.TypeA: + for _, val := range vals { + m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()}) } - for _, ip := range ips { - m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: ip}) + case dns.TypeAAAA: + for _, val := range vals { + m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()}) } - case _IP6Query: - hdr := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), - } - for _, ip := range ips { - m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip}) + case dns.TypeCNAME: + for _, val := range vals { + m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val}) } } @@ -80,29 +76,28 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { } } - // Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN' key := KeyGen(question) - if IPQuery > 0 { - mesg, err := h.cache.Get(key) - if err != nil { - if mesg, err = h.negCache.Get(key); err != nil { - log.Debug("%s didn't hit cache", question.String()) - } else { - log.Debug("%s hit negative cache", question.String()) - dns.HandleFailed(w, req) - return - } + + mesg, err := h.cache.Get(key) + + if err != nil { + if mesg, err = h.negCache.Get(key); err != nil { + log.Debug("%s didn't hit cache", question.String()) } else { - log.Debug("%s hit cache", question.String()) - // we need this copy against concurrent modification of Id - msg := *mesg - msg.Id = req.Id - w.WriteMsg(&msg) + log.Debug("%s hit negative cache", question.String()) + dns.HandleFailed(w, req) return } + } else { + log.Debug("%s hit cache", question.String()) + // we need this copy against concurrent modification of Id + msg := *mesg + msg.Id = req.Id + w.WriteMsg(&msg) + return } - mesg, err := h.resolver.Lookup(Net, req) + mesg, err = h.resolver.Lookup(network, req) if err != nil { log.Warn("Resolve query error %s", err) @@ -117,7 +112,7 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { w.WriteMsg(mesg) - if IPQuery > 0 && len(mesg.Answer) > 0 { + if len(mesg.Answer) > 0 { err = h.cache.Set(key, mesg) if err != nil { log.Warn("Set %s cache failed: %s", question.String(), err.Error()) @@ -132,21 +127,6 @@ func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) { } } -func (h *Handler) isIPQuery(q dns.Question) int { - if q.Qclass != dns.ClassINET { - return notIPQuery - } - - switch q.Qtype { - case dns.TypeA: - return _IP4Query - case dns.TypeAAAA: - return _IP6Query - default: - return notIPQuery - } -} - func KeyGen(q resolver.Question) string { h := md5.New() h.Write([]byte(q.String())) diff --git a/hosts/hosts.go b/hosts/hosts.go index 3fdf9ad..d7ecf66 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -1,22 +1,15 @@ package hosts import ( - "net" "time" ) -const ( - notIPQuery = 0 - _IP4Query = 4 - _IP6Query = 6 -) - var ( zeroDuration = time.Duration(0) ) type Hosts interface { - Get(domain string, family int) ([]net.IP, time.Duration, bool) + Get(queryType uint16, domain string) ([]string, time.Duration, bool) } type ProviderList struct { @@ -24,7 +17,8 @@ type ProviderList struct { } type Provider interface { - Get(domain string) ([]string, time.Duration, bool) + Get(queryType uint16, domain string) ([]string, time.Duration, bool) + Set(t, domain, value string) (bool, error) } func NewHosts(providers []Provider) Hosts { @@ -34,38 +28,22 @@ func NewHosts(providers []Provider) Hosts { /* Match local /etc/hosts file first, remote redis records second */ -func (h *ProviderList) Get(domain string, family int) ([]net.IP, time.Duration, bool) { - var sips []string +func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) { + var vals []string var ok bool - var ip net.IP - var ips []net.IP var ttl time.Duration for _, provider := range h.providers { - sips, ttl, ok = provider.Get(domain) + vals, ttl, ok = provider.Get(queryType, domain) if ok { break } } - if sips == nil { + if vals == nil { return nil, zeroDuration, false } - for _, sip := range sips { - switch family { - case _IP4Query: - ip = net.ParseIP(sip).To4() - case _IP6Query: - ip = net.ParseIP(sip).To16() - default: - continue - } - if ip != nil { - ips = append(ips, ip) - } - } - - return ips, ttl, ips != nil + return vals, ttl, true } \ No newline at end of file diff --git a/hosts/hosts_file.go b/hosts/hosts_file.go index ae3454d..c451de5 100644 --- a/hosts/hosts_file.go +++ b/hosts/hosts_file.go @@ -2,7 +2,9 @@ package hosts import ( "bufio" + "errors" "github.com/fsnotify/fsnotify" + "github.com/miekg/dns" "github.com/ryanuber/go-glob" "meow.tf/joker/godns/log" "meow.tf/joker/godns/utils" @@ -49,8 +51,13 @@ func NewFileProvider(file string, ttl time.Duration) Provider { return fp } -func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) { - log.Debug("Checking file provider for %s", domain) +func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) { + log.Debug("Checking file provider for %s : %s", queryType, domain) + + // Does not support CNAME/TXT/etc + if queryType != dns.TypeA && queryType != dns.TypeAAAA { + return nil, zeroDuration, false + } f.mu.RLock() defer f.mu.RUnlock() @@ -77,6 +84,10 @@ func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) { return nil, time.Duration(0), false } +func (f *FileHosts) Set(t, domain, value string) (bool, error) { + return false, errors.New("file provider does not support setting values") +} + var ( hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$") ) diff --git a/hosts/hosts_redis.go b/hosts/hosts_redis.go index cbf916f..a78df72 100644 --- a/hosts/hosts_redis.go +++ b/hosts/hosts_redis.go @@ -2,6 +2,7 @@ package hosts import ( "github.com/go-redis/redis/v7" + "github.com/miekg/dns" "github.com/ryanuber/go-glob" "meow.tf/joker/godns/log" "strings" @@ -33,9 +34,14 @@ func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider return rh } -func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) { +func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) { log.Debug("Checking redis provider for %s", domain) + // Don't support queries other than A/AAAA for now + if queryType != dns.TypeA || queryType != dns.TypeAAAA { + return nil, zeroDuration, false + } + r.mu.RLock() defer r.mu.RUnlock() @@ -62,7 +68,7 @@ func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) { return nil, time.Duration(0), false } -func (r *RedisHosts) Set(domain, ip string) (bool, error) { +func (r *RedisHosts) Set(t, domain, ip string) (bool, error) { r.mu.Lock() defer r.mu.Unlock() return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result() diff --git a/main.go b/main.go index 5119516..8c79f80 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,7 @@ func main() { } var resolverCache, negCache cache.Cache - r := resolver.NewResolver(settings.ResolvSettings{ + r := resolver.NewResolver(resolver.Settings{ Timeout: viper.GetInt("resolv.timeout"), Interval: viper.GetInt("resolv.interval"), SetEDNS0: viper.GetBool("resolv.edns0"), diff --git a/resolver/question.go b/resolver/question.go index ecab487..ace3337 100644 --- a/resolver/question.go +++ b/resolver/question.go @@ -1,11 +1,13 @@ package resolver +import "github.com/miekg/dns" + type Question struct { Name string - Type string + Type uint16 Class string } func (q *Question) String() string { - return q.Name + " " + q.Class + " " + q.Type + return q.Name + " " + q.Class + " " + dns.TypeToString[q.Type] } diff --git a/resolver/resolver.go b/resolver/resolver.go index 1cdafef..87a2213 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "meow.tf/joker/godns/log" - "meow.tf/joker/godns/settings" "meow.tf/joker/godns/utils" "net" "os" @@ -37,13 +36,13 @@ type RResp struct { type Resolver struct { servers []string domain_server *suffixTreeNode - config *settings.ResolvSettings + config *Settings clients map[string]*dns.Client clientLock sync.RWMutex } -func NewResolver(c settings.ResolvSettings) *Resolver { +func NewResolver(c Settings) *Resolver { r := &Resolver{ servers: []string{}, domain_server: newSuffixTreeRoot(), @@ -109,26 +108,26 @@ func (r *Resolver) parseServerListFile(buf *os.File) { r.domain_server.sinsert(strings.Split(domain, "."), ip) case 1: - srv_port := strings.Split(line, "#") + srvPort := strings.Split(line, "#") - if len(srv_port) > 2 { + if len(srvPort) > 2 { continue } ip := "" - if ip = srv_port[0]; !utils.IsIP(ip) { + if ip = srvPort[0]; !utils.IsIP(ip) { continue } port := "53" - if len(srv_port) == 2 { - if _, err := strconv.Atoi(srv_port[1]); err != nil { + if len(srvPort) == 2 { + if _, err := strconv.Atoi(srvPort[1]); err != nil { continue } - port = srv_port[1] + port = srvPort[1] } r.servers = append(r.servers, net.JoinHostPort(ip, port)) @@ -222,8 +221,14 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error } func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) { + key := net + + if net == "tcp-tls" { + key = net + ":" + nameserver + } + r.clientLock.RLock() - client, exists := r.clients[net] + client, exists := r.clients[key] r.clientLock.RUnlock() if exists { @@ -249,7 +254,7 @@ func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) { } r.clientLock.Lock() - r.clients[net] = client + r.clients[key] = client r.clientLock.Lock() return client, nil diff --git a/resolver/settings.go b/resolver/settings.go new file mode 100644 index 0000000..47d4749 --- /dev/null +++ b/resolver/settings.go @@ -0,0 +1,9 @@ +package resolver + +type Settings struct { + Timeout int `toml:"timeout" env:"RESOLV_TIMEOUT"` + Interval int `toml:"interval" env:"RESOLV_INTERVAL"` + SetEDNS0 bool `toml:"setedns0" env:"RESOLV_EDNS0"` + ServerListFile []string `toml:"server-list-file" env:"SERVER_LIST_FILE"` + ResolvFile string `toml:"resolv-file" env:"RESOLV_FILE"` +}