From d8079551c92ef853df34332663812c7d023b3614 Mon Sep 17 00:00:00 2001 From: Tyler <tystuyfzand@gmail.com> Date: Thu, 15 Apr 2021 01:04:58 -0400 Subject: [PATCH] Add testing, cleanup, rework suffix tree to use nameservers. Parse nameservers from yaml. --- .drone.yml | 5 +- hosts/hosts_bolt.go | 6 +- hosts/hosts_file.go | 16 ++- hosts/hosts_redis.go | 6 +- main.go | 4 +- resolver/nameserver.go | 1 + resolver/resolver.go | 202 ++++++++++++++++++++++---------------- resolver/sfx_tree.go | 20 ++-- resolver/sfx_tree_test.go | 26 ++--- utils/utils_test.go | 16 +-- 10 files changed, 177 insertions(+), 125 deletions(-) diff --git a/.drone.yml b/.drone.yml index 972f4ed..2ed776e 100644 --- a/.drone.yml +++ b/.drone.yml @@ -3,13 +3,16 @@ name: amd64 type: docker steps: + - name: test + image: golang:alpine + commands: + - go test ... - name: build image: golang:alpine volumes: - name: build path: /build commands: - - apk --no-cache add git gcc musl-dev - go build -o /build/godns - name: docker image: plugins/docker diff --git a/hosts/hosts_bolt.go b/hosts/hosts_bolt.go index fe213ee..7a93e2f 100644 --- a/hosts/hosts_bolt.go +++ b/hosts/hosts_bolt.go @@ -2,6 +2,7 @@ package hosts import ( "encoding/json" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" bolt "go.etcd.io/bbolt" "strings" @@ -67,7 +68,10 @@ func (b *BoltHosts) List() (HostMap, error) { } func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) { - log.Debug("Checking bolt provider for %s : %s", queryType, domain) + log.WithFields(log.Fields{ + "queryType": dns.TypeToString[queryType], + "question": domain, + }).Debug("Checking bolt provider") domain = strings.ToLower(domain) diff --git a/hosts/hosts_file.go b/hosts/hosts_file.go index 66db4dd..cf7caaf 100644 --- a/hosts/hosts_file.go +++ b/hosts/hosts_file.go @@ -62,7 +62,10 @@ func (f *FileHosts) List() (HostMap, error) { } func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) { - log.Debug("Checking file provider for %s : %s", queryType, domain) + log.WithFields(log.Fields{ + "queryType": dns.TypeToString[queryType], + "question": domain, + }).Debug("Checking file provider") // Does not support CNAME/TXT/etc if queryType != dns.TypeA && queryType != dns.TypeAAAA { @@ -102,7 +105,10 @@ func (f *FileHosts) Refresh() { buf, err := os.Open(f.file) if err != nil { - log.Warn("Update hosts records from file failed %s", err) + log.WithFields(log.Fields{ + "file": f.file, + "error": err, + }).Warn("Hosts update from file failed") return } @@ -149,7 +155,11 @@ func (f *FileHosts) Refresh() { f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}} } } - log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts)) + + log.WithFields(log.Fields{ + "file": f.file, + "count": len(f.hosts), + }).Debug("Updated hosts records") } func (f *FileHosts) clear() { diff --git a/hosts/hosts_redis.go b/hosts/hosts_redis.go index 43e5327..ec980c0 100644 --- a/hosts/hosts_redis.go +++ b/hosts/hosts_redis.go @@ -3,6 +3,7 @@ package hosts import ( "encoding/json" "github.com/go-redis/redis/v7" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" "strings" ) @@ -46,7 +47,10 @@ func (r *RedisHosts) List() (HostMap, error) { } func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) { - log.Debug("Checking redis provider for %s", domain) + log.WithFields(log.Fields{ + "queryType": dns.TypeToString[queryType], + "question": domain, + }).Debug("Checking redis provider") domain = strings.ToLower(domain) diff --git a/main.go b/main.go index a19aa44..a713aa5 100644 --- a/main.go +++ b/main.go @@ -145,7 +145,7 @@ func main() { func profileCPU() { f, err := os.Create("godns.cprof") if err != nil { - log.Error("%s", err) + log.WithError(err).Error("Unable to profile cpu due to error") return } @@ -161,7 +161,7 @@ func profileMEM() { f, err := os.Create("godns.mprof") if err != nil { - log.Error("%s", err) + log.WithError(err).Error("Unable to profile memory due to error") return } diff --git a/resolver/nameserver.go b/resolver/nameserver.go index 7e63037..92f70e5 100644 --- a/resolver/nameserver.go +++ b/resolver/nameserver.go @@ -2,5 +2,6 @@ package resolver type Nameserver struct { net string + host string address string } diff --git a/resolver/resolver.go b/resolver/resolver.go index 9714eaa..e5325cf 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -1,14 +1,14 @@ package resolver import ( - "bufio" "errors" "fmt" log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "io" "meow.tf/joker/godns/utils" "net" "os" - "strconv" "strings" "sync" "time" @@ -38,6 +38,7 @@ type RResp struct { rtt time.Duration } +// Resolver contains a list of nameservers, domain-specific nameservers, and dns clients type Resolver struct { servers []*Nameserver domainServer *suffixTreeNode @@ -47,6 +48,7 @@ type Resolver struct { clientLock sync.RWMutex } +// NewResolver initializes a resolver from the specified settings func NewResolver(c Settings) *Resolver { r := &Resolver{ servers: make([]*Nameserver, 0), @@ -55,16 +57,18 @@ func NewResolver(c Settings) *Resolver { } if len(c.ServerListFile) > 0 { - r.ReadServerListFile(c.ServerListFile) + err := r.ReadServerListFile(c.ServerListFile) + + if err != nil { + log.WithError(err).Fatalln("Unable to read server list file") + } } if len(c.ResolvFile) > 0 { clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) if err != nil { - log.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) - log.Error("%s", err) - panic(err) + log.WithError(err).Fatalln("not a valid resolv.conf file") } for _, server := range clientConfig.Servers { @@ -75,81 +79,73 @@ func NewResolver(c Settings) *Resolver { return r } -func (r *Resolver) parseServerListFile(buf *os.File) { - scanner := bufio.NewScanner(buf) - - var line string - var idx int - - for scanner.Scan() { - line = strings.TrimSpace(scanner.Text()) - - if !strings.HasPrefix(line, "server") { - continue - } - - idx = strings.Index(line, "=") - - if idx == -1 { - continue - } - - line = strings.TrimSpace(line[idx+1:]) - - if strings.HasPrefix(line, "https://") { - r.servers = append(r.servers, &Nameserver{net: "https", address: line}) - continue - } - - tokens := strings.Split(line, "/") - switch len(tokens) { - case 3: - domain := tokens[1] - ip := tokens[2] - - if !utils.IsDomain(domain) || !utils.IsIP(ip) { - continue - } - - r.domainServer.sinsert(strings.Split(domain, "."), ip) - case 1: - srvPort := strings.Split(line, "#") - - if len(srvPort) > 2 { - continue - } - - ip := "" - - if ip = srvPort[0]; !utils.IsIP(ip) { - continue - } - - port := "53" - - if len(srvPort) == 2 { - if _, err := strconv.Atoi(srvPort[1]); err != nil { - continue - } - - port = srvPort[1] - } - - r.servers = append(r.servers, &Nameserver{address: net.JoinHostPort(ip, port)}) - } - } - +// server is a configuration struct for server lists +type server struct { + // Type is the nameserver type (https, udp, tcp-tls), optional + Type string + Server string + // Optional host for passing to TLS Config + Host string + Domains []string } -func (r *Resolver) ReadServerListFile(files []string) { +// parseServerListFile loads a YAML server list file. +func (r *Resolver) parseServerListFile(buf io.Reader) error { + v := viper.New() + + var err error + + v.SetConfigType("yaml") + + if err = v.ReadConfig(buf); err != nil { + return err + } + + list := make([]server, 0) + + if err = v.UnmarshalKey("servers", &list); err != nil { + return err + } + + for _, server := range list { + nameserver := &Nameserver{ + net: determineNet(server.Type, server.Server), + address: server.Server, + } + + if len(server.Domains) > 0 { + for _, domain := range server.Domains { + r.domainServer.sinsert(strings.Split(domain, "."), nameserver) + } + + continue + } + + r.servers = append(r.servers, nameserver) + } + + return nil +} + +// ReadServerListFile loads a list of server list files. +func (r *Resolver) ReadServerListFile(files []string) error { for _, file := range files { buf, err := os.Open(file) + if err != nil { - panic("Can't open " + file) + return err } - r.parseServerListFile(buf) + + err = r.parseServerListFile(buf) + buf.Close() + + if err != nil { + return err + } } + + return nil } // Lookup will ask each nameserver in top-to-bottom fashion, starting a new request @@ -170,15 +166,18 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error c, err := resolver.resolverFor(net, nameserver) if err != nil { - log.Warn("error:%s", err.Error()) + log.WithError(err).Warn("resolver failed to resolve") return } r, rtt, err := c.Exchange(req, nameserver.address) if err != nil { - log.Warn("%s socket error on %s", qname, nameserver) - log.Warn("error:%s", err.Error()) + log.WithFields(log.Fields{ + "error": err, + "question": qname, + "nameserver": nameserver.address, + }).Warn("Socket error encountered") return } // If SERVFAIL happen, should return immediately and try another upstream resolver. @@ -186,7 +185,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error // that it has been verified no such domain existas and ask other resolvers // would make no sense. See more about #20 if r != nil && r.Rcode != dns.RcodeSuccess { - log.Warn("%s failed to get an valid answer on %s", qname, nameserver) + log.WithFields(log.Fields{ + "question": qname, + "nameserver": nameserver.address, + }).Warn("Nameserver failed to get a valid answer") + if r.Rcode == dns.RcodeServerFailure { return } @@ -208,7 +211,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error // but exit early, if we have an answer select { case re := <-res: - log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver, re.rtt) + log.WithFields(log.Fields{ + "question": utils.UnFqdn(qname), + "nameserver": re.nameserver.address, + "rtt": re.rtt, + }).Debug("Resolve") return re.msg, nil case <-ticker.C: continue @@ -218,7 +225,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error wg.Wait() select { case re := <-res: - log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver.address, re.rtt) + log.WithFields(log.Fields{ + "question": utils.UnFqdn(qname), + "nameserver": re.nameserver.address, + "rtt": re.rtt, + }).Debug("Resolve") return re.msg, nil default: return nil, ResolvError{qname, net, nameservers} @@ -256,10 +267,16 @@ func (r *Resolver) resolverFor(network string, n *Nameserver) (*dns.Client, erro } if n.net == "tcp-tls" { - host, _, err := net.SplitHostPort(n.address) + host := n.host - if err != nil { - host = n.address + if host == "" { + var err error + + host, _, err = net.SplitHostPort(n.address) + + if err != nil { + host = n.address + } } client.TLSConfig = &tls.Config{ @@ -282,12 +299,13 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver { queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.' if v, found := r.domainServer.search(queryKeys); found { - log.Debug("%s found in domain server list, upstream: %v", qname, v) + log.WithFields(log.Fields{ + "question": qname, + "upstream": v.address, + }).Debug("Found in domain server list") //Ensure query the specific upstream nameserver in async Lookup() function. - return []*Nameserver{ - {net: "udp", address: net.JoinHostPort(v, "53")}, - } + return []*Nameserver{v} } return r.servers @@ -296,3 +314,15 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver { func (r *Resolver) Timeout() time.Duration { return time.Duration(r.config.Timeout) * time.Second } + +func determineNet(t, server string) string { + if t != "" { + return t + } + + if strings.HasPrefix(server, "https") { + return "https" + } + + return "udp" +} \ No newline at end of file diff --git a/resolver/sfx_tree.go b/resolver/sfx_tree.go index 41e4cf8..f0f170a 100644 --- a/resolver/sfx_tree.go +++ b/resolver/sfx_tree.go @@ -2,15 +2,15 @@ package resolver type suffixTreeNode struct { key string - value string + value *Nameserver children map[string]*suffixTreeNode } func newSuffixTreeRoot() *suffixTreeNode { - return newSuffixTree("", "") + return newSuffixTree("", nil) } -func newSuffixTree(key string, value string) *suffixTreeNode { +func newSuffixTree(key string, value *Nameserver) *suffixTreeNode { root := &suffixTreeNode{ key: key, value: value, @@ -21,11 +21,11 @@ func newSuffixTree(key string, value string) *suffixTreeNode { func (node *suffixTreeNode) ensureSubTree(key string) { if _, ok := node.children[key]; !ok { - node.children[key] = newSuffixTree(key, "") + node.children[key] = newSuffixTree(key, nil) } } -func (node *suffixTreeNode) insert(key string, value string) { +func (node *suffixTreeNode) insert(key string, value *Nameserver) { if c, ok := node.children[key]; ok { c.value = value } else { @@ -33,7 +33,7 @@ func (node *suffixTreeNode) insert(key string, value string) { } } -func (node *suffixTreeNode) sinsert(keys []string, value string) { +func (node *suffixTreeNode) sinsert(keys []string, value *Nameserver) { if len(keys) == 0 { return } @@ -48,9 +48,9 @@ func (node *suffixTreeNode) sinsert(keys []string, value string) { node.insert(key, value) } -func (node *suffixTreeNode) search(keys []string) (string, bool) { +func (node *suffixTreeNode) search(keys []string) (*Nameserver, bool) { if len(keys) == 0 { - return "", false + return nil, false } key := keys[len(keys)-1] @@ -58,8 +58,8 @@ func (node *suffixTreeNode) search(keys []string) (string, bool) { if nextValue, found := n.search(keys[:len(keys)-1]); found { return nextValue, found } - return n.value, (n.value != "") + return n.value, n.value != nil } - return "", false + return nil, false } diff --git a/resolver/sfx_tree_test.go b/resolver/sfx_tree_test.go index 9ff26b9..12194e4 100644 --- a/resolver/sfx_tree_test.go +++ b/resolver/sfx_tree_test.go @@ -11,43 +11,43 @@ func Test_Suffix_Tree(t *testing.T) { root := newSuffixTreeRoot() Convey("Google should not be found", t, func() { - root.insert("cn", "114.114.114.114") - root.sinsert([]string{"baidu", "cn"}, "166.111.8.28") - root.sinsert([]string{"sina", "cn"}, "114.114.114.114") + root.insert("cn", &Nameserver{address: "114.114.114.114"}) + root.sinsert([]string{"baidu", "cn"}, &Nameserver{address: "166.111.8.28"}) + root.sinsert([]string{"sina", "cn"}, &Nameserver{address: "114.114.114.114"}) v, found := root.search(strings.Split("google.com", ".")) So(found, ShouldEqual, false) v, found = root.search(strings.Split("baidu.cn", ".")) So(found, ShouldEqual, true) - So(v, ShouldEqual, "166.111.8.28") + So(v.address, ShouldEqual, "166.111.8.28") }) Convey("Google should be found", t, func() { - root.sinsert(strings.Split("com", "."), "") - root.sinsert(strings.Split("google.com", "."), "8.8.8.8") - root.sinsert(strings.Split("twitter.com", "."), "8.8.8.8") - root.sinsert(strings.Split("scholar.google.com", "."), "208.67.222.222") + root.sinsert(strings.Split("com", "."), &Nameserver{address: ""}) + root.sinsert(strings.Split("google.com", "."), &Nameserver{address: "8.8.8.8"}) + root.sinsert(strings.Split("twitter.com", "."), &Nameserver{address: "8.8.8.8"}) + root.sinsert(strings.Split("scholar.google.com", "."), &Nameserver{address: "208.67.222.222"}) v, found := root.search(strings.Split("google.com", ".")) So(found, ShouldEqual, true) - So(v, ShouldEqual, "8.8.8.8") + So(v.address, ShouldEqual, "8.8.8.8") v, found = root.search(strings.Split("www.google.com", ".")) So(found, ShouldEqual, true) - So(v, ShouldEqual, "8.8.8.8") + So(v.address, ShouldEqual, "8.8.8.8") v, found = root.search(strings.Split("scholar.google.com", ".")) So(found, ShouldEqual, true) - So(v, ShouldEqual, "208.67.222.222") + So(v.address, ShouldEqual, "208.67.222.222") v, found = root.search(strings.Split("twitter.com", ".")) So(found, ShouldEqual, true) - So(v, ShouldEqual, "8.8.8.8") + So(v.address, ShouldEqual, "8.8.8.8") v, found = root.search(strings.Split("baidu.cn", ".")) So(found, ShouldEqual, true) - So(v, ShouldEqual, "166.111.8.28") + So(v.address, ShouldEqual, "166.111.8.28") }) } diff --git a/utils/utils_test.go b/utils/utils_test.go index 8a9dc5a..ee825b8 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -9,23 +9,23 @@ import ( func TestHostDomainAndIP(t *testing.T) { Convey("Test Host File Domain and IP regex", t, func() { Convey("1.1.1.1 should be IP and not domain", func() { - So(isIP("1.1.1.1"), ShouldEqual, true) - So(isDomain("1.1.1.1"), ShouldEqual, false) + So(IsIP("1.1.1.1"), ShouldEqual, true) + So(IsDomain("1.1.1.1"), ShouldEqual, false) }) Convey("2001:470:20::2 should be IP and not domain", func() { - So(isIP("2001:470:20::2"), ShouldEqual, true) - So(isDomain("2001:470:20::2"), ShouldEqual, false) + So(IsIP("2001:470:20::2"), ShouldEqual, true) + So(IsDomain("2001:470:20::2"), ShouldEqual, false) }) Convey("`host` should not be domain and not IP", func() { - So(isDomain("host"), ShouldEqual, false) - So(isIP("host"), ShouldEqual, false) + So(IsDomain("host"), ShouldEqual, false) + So(IsIP("host"), ShouldEqual, false) }) Convey("`123.test` should be domain and not IP", func() { - So(isDomain("123.test"), ShouldEqual, true) - So(isIP("123.test"), ShouldEqual, false) + So(IsDomain("123.test"), ShouldEqual, true) + So(IsIP("123.test"), ShouldEqual, false) }) })