From d8079551c92ef853df34332663812c7d023b3614 Mon Sep 17 00:00:00 2001
From: Tyler <tystuyfzand@gmail.com>
Date: Thu, 15 Apr 2021 01:04:58 -0400
Subject: [PATCH] Add testing, cleanup, rework suffix tree to use nameservers.
 Parse nameservers from yaml.

---
 .drone.yml                |   5 +-
 hosts/hosts_bolt.go       |   6 +-
 hosts/hosts_file.go       |  16 ++-
 hosts/hosts_redis.go      |   6 +-
 main.go                   |   4 +-
 resolver/nameserver.go    |   1 +
 resolver/resolver.go      | 202 ++++++++++++++++++++++----------------
 resolver/sfx_tree.go      |  20 ++--
 resolver/sfx_tree_test.go |  26 ++---
 utils/utils_test.go       |  16 +--
 10 files changed, 177 insertions(+), 125 deletions(-)

diff --git a/.drone.yml b/.drone.yml
index 972f4ed..2ed776e 100644
--- a/.drone.yml
+++ b/.drone.yml
@@ -3,13 +3,16 @@ name: amd64
 type: docker
 
 steps:
+  - name: test
+    image: golang:alpine
+    commands:
+      - go test ...
   - name: build
     image: golang:alpine
     volumes:
       - name: build
         path: /build
     commands:
-      - apk --no-cache add git gcc musl-dev
       - go build -o /build/godns
   - name: docker
     image: plugins/docker
diff --git a/hosts/hosts_bolt.go b/hosts/hosts_bolt.go
index fe213ee..7a93e2f 100644
--- a/hosts/hosts_bolt.go
+++ b/hosts/hosts_bolt.go
@@ -2,6 +2,7 @@ package hosts
 
 import (
 	"encoding/json"
+	"github.com/miekg/dns"
 	log "github.com/sirupsen/logrus"
 	bolt "go.etcd.io/bbolt"
 	"strings"
@@ -67,7 +68,10 @@ func (b *BoltHosts) List() (HostMap, error) {
 }
 
 func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
-	log.Debug("Checking bolt provider for %s : %s", queryType, domain)
+	log.WithFields(log.Fields{
+		"queryType": dns.TypeToString[queryType],
+		"question": domain,
+	}).Debug("Checking bolt provider")
 
 	domain = strings.ToLower(domain)
 
diff --git a/hosts/hosts_file.go b/hosts/hosts_file.go
index 66db4dd..cf7caaf 100644
--- a/hosts/hosts_file.go
+++ b/hosts/hosts_file.go
@@ -62,7 +62,10 @@ func (f *FileHosts) List() (HostMap, error) {
 }
 
 func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
-	log.Debug("Checking file provider for %s : %s", queryType, domain)
+	log.WithFields(log.Fields{
+		"queryType": dns.TypeToString[queryType],
+		"question": domain,
+	}).Debug("Checking file provider")
 
 	// Does not support CNAME/TXT/etc
 	if queryType != dns.TypeA && queryType != dns.TypeAAAA {
@@ -102,7 +105,10 @@ func (f *FileHosts) Refresh() {
 	buf, err := os.Open(f.file)
 
 	if err != nil {
-		log.Warn("Update hosts records from file failed %s", err)
+		log.WithFields(log.Fields{
+			"file": f.file,
+			"error": err,
+		}).Warn("Hosts update from file failed")
 		return
 	}
 
@@ -149,7 +155,11 @@ func (f *FileHosts) Refresh() {
 			f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}}
 		}
 	}
-	log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts))
+
+	log.WithFields(log.Fields{
+		"file": f.file,
+		"count": len(f.hosts),
+	}).Debug("Updated hosts records")
 }
 
 func (f *FileHosts) clear() {
diff --git a/hosts/hosts_redis.go b/hosts/hosts_redis.go
index 43e5327..ec980c0 100644
--- a/hosts/hosts_redis.go
+++ b/hosts/hosts_redis.go
@@ -3,6 +3,7 @@ package hosts
 import (
 	"encoding/json"
 	"github.com/go-redis/redis/v7"
+	"github.com/miekg/dns"
 	log "github.com/sirupsen/logrus"
 	"strings"
 )
@@ -46,7 +47,10 @@ func (r *RedisHosts) List() (HostMap, error) {
 }
 
 func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
-	log.Debug("Checking redis provider for %s", domain)
+	log.WithFields(log.Fields{
+		"queryType": dns.TypeToString[queryType],
+		"question": domain,
+	}).Debug("Checking redis provider")
 
 	domain = strings.ToLower(domain)
 
diff --git a/main.go b/main.go
index a19aa44..a713aa5 100644
--- a/main.go
+++ b/main.go
@@ -145,7 +145,7 @@ func main() {
 func profileCPU() {
 	f, err := os.Create("godns.cprof")
 	if err != nil {
-		log.Error("%s", err)
+		log.WithError(err).Error("Unable to profile cpu due to error")
 		return
 	}
 
@@ -161,7 +161,7 @@ func profileMEM() {
 	f, err := os.Create("godns.mprof")
 
 	if err != nil {
-		log.Error("%s", err)
+		log.WithError(err).Error("Unable to profile memory due to error")
 		return
 	}
 
diff --git a/resolver/nameserver.go b/resolver/nameserver.go
index 7e63037..92f70e5 100644
--- a/resolver/nameserver.go
+++ b/resolver/nameserver.go
@@ -2,5 +2,6 @@ package resolver
 
 type Nameserver struct {
 	net string
+	host string
 	address string
 }
diff --git a/resolver/resolver.go b/resolver/resolver.go
index 9714eaa..e5325cf 100644
--- a/resolver/resolver.go
+++ b/resolver/resolver.go
@@ -1,14 +1,14 @@
 package resolver
 
 import (
-	"bufio"
 	"errors"
 	"fmt"
 	log "github.com/sirupsen/logrus"
+	"github.com/spf13/viper"
+	"io"
 	"meow.tf/joker/godns/utils"
 	"net"
 	"os"
-	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -38,6 +38,7 @@ type RResp struct {
 	rtt        time.Duration
 }
 
+// Resolver contains a list of nameservers, domain-specific nameservers, and dns clients
 type Resolver struct {
 	servers      []*Nameserver
 	domainServer *suffixTreeNode
@@ -47,6 +48,7 @@ type Resolver struct {
 	clientLock sync.RWMutex
 }
 
+// NewResolver initializes a resolver from the specified settings
 func NewResolver(c Settings) *Resolver {
 	r := &Resolver{
 		servers:      make([]*Nameserver, 0),
@@ -55,16 +57,18 @@ func NewResolver(c Settings) *Resolver {
 	}
 
 	if len(c.ServerListFile) > 0 {
-		r.ReadServerListFile(c.ServerListFile)
+		err := r.ReadServerListFile(c.ServerListFile)
+
+		if err != nil {
+			log.WithError(err).Fatalln("Unable to read server list file")
+		}
 	}
 
 	if len(c.ResolvFile) > 0 {
 		clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
 
 		if err != nil {
-			log.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
-			log.Error("%s", err)
-			panic(err)
+			log.WithError(err).Fatalln("not a valid resolv.conf file")
 		}
 
 		for _, server := range clientConfig.Servers {
@@ -75,81 +79,73 @@ func NewResolver(c Settings) *Resolver {
 	return r
 }
 
-func (r *Resolver) parseServerListFile(buf *os.File) {
-	scanner := bufio.NewScanner(buf)
-
-	var line string
-	var idx int
-
-	for scanner.Scan() {
-		line = strings.TrimSpace(scanner.Text())
-
-		if !strings.HasPrefix(line, "server") {
-			continue
-		}
-
-		idx = strings.Index(line, "=")
-
-		if idx == -1 {
-			continue
-		}
-
-		line = strings.TrimSpace(line[idx+1:])
-
-		if strings.HasPrefix(line, "https://") {
-			r.servers = append(r.servers, &Nameserver{net: "https", address: line})
-			continue
-		}
-
-		tokens := strings.Split(line, "/")
-		switch len(tokens) {
-		case 3:
-			domain := tokens[1]
-			ip := tokens[2]
-
-			if !utils.IsDomain(domain) || !utils.IsIP(ip) {
-				continue
-			}
-
-			r.domainServer.sinsert(strings.Split(domain, "."), ip)
-		case 1:
-			srvPort := strings.Split(line, "#")
-
-			if len(srvPort) > 2 {
-				continue
-			}
-
-			ip := ""
-
-			if ip = srvPort[0]; !utils.IsIP(ip) {
-				continue
-			}
-
-			port := "53"
-
-			if len(srvPort) == 2 {
-				if _, err := strconv.Atoi(srvPort[1]); err != nil {
-					continue
-				}
-
-				port = srvPort[1]
-			}
-
-			r.servers = append(r.servers, &Nameserver{address: net.JoinHostPort(ip, port)})
-		}
-	}
-
+// server is a configuration struct for server lists
+type server struct {
+	// Type is the nameserver type (https, udp, tcp-tls), optional
+	Type string
+	Server string
+	// Optional host for passing to TLS Config
+	Host string
+	Domains []string
 }
 
-func (r *Resolver) ReadServerListFile(files []string) {
+// parseServerListFile loads a YAML server list file.
+func (r *Resolver) parseServerListFile(buf io.Reader) error {
+	v := viper.New()
+
+	var err error
+
+	v.SetConfigType("yaml")
+
+	if err = v.ReadConfig(buf); err != nil {
+		return err
+	}
+
+	list := make([]server, 0)
+
+	if err = v.UnmarshalKey("servers", &list); err != nil {
+		return err
+	}
+
+	for _, server := range list {
+		nameserver := &Nameserver{
+			net: determineNet(server.Type, server.Server),
+			address: server.Server,
+		}
+
+		if len(server.Domains) > 0 {
+			for _, domain := range server.Domains {
+				r.domainServer.sinsert(strings.Split(domain, "."), nameserver)
+			}
+
+			continue
+		}
+
+		r.servers = append(r.servers, nameserver)
+	}
+
+	return nil
+}
+
+// ReadServerListFile loads a list of server list files.
+func (r *Resolver) ReadServerListFile(files []string) error {
 	for _, file := range files {
 		buf, err := os.Open(file)
+
 		if err != nil {
-			panic("Can't open " + file)
+			return err
 		}
-		r.parseServerListFile(buf)
+
+		err = r.parseServerListFile(buf)
+
 		buf.Close()
+
+		if err != nil {
+			return err
+		}
 	}
+
+	return nil
 }
 
 // Lookup will ask each nameserver in top-to-bottom fashion, starting a new request
@@ -170,15 +166,18 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 		c, err := resolver.resolverFor(net, nameserver)
 
 		if err != nil {
-			log.Warn("error:%s", err.Error())
+			log.WithError(err).Warn("resolver failed to resolve")
 			return
 		}
 
 		r, rtt, err := c.Exchange(req, nameserver.address)
 
 		if err != nil {
-			log.Warn("%s socket error on %s", qname, nameserver)
-			log.Warn("error:%s", err.Error())
+			log.WithFields(log.Fields{
+				"error": err,
+				"question": qname,
+				"nameserver": nameserver.address,
+			}).Warn("Socket error encountered")
 			return
 		}
 		// If SERVFAIL happen, should return immediately and try another upstream resolver.
@@ -186,7 +185,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 		// that it has been verified no such domain existas and ask other resolvers
 		// would make no sense. See more about #20
 		if r != nil && r.Rcode != dns.RcodeSuccess {
-			log.Warn("%s failed to get an valid answer on %s", qname, nameserver)
+			log.WithFields(log.Fields{
+				"question": qname,
+				"nameserver": nameserver.address,
+			}).Warn("Nameserver failed to get a valid answer")
+
 			if r.Rcode == dns.RcodeServerFailure {
 				return
 			}
@@ -208,7 +211,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 		// but exit early, if we have an answer
 		select {
 		case re := <-res:
-			log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver, re.rtt)
+			log.WithFields(log.Fields{
+				"question": utils.UnFqdn(qname),
+				"nameserver": re.nameserver.address,
+				"rtt": re.rtt,
+			}).Debug("Resolve")
 			return re.msg, nil
 		case <-ticker.C:
 			continue
@@ -218,7 +225,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 	wg.Wait()
 	select {
 	case re := <-res:
-		log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver.address, re.rtt)
+		log.WithFields(log.Fields{
+			"question": utils.UnFqdn(qname),
+			"nameserver": re.nameserver.address,
+			"rtt": re.rtt,
+		}).Debug("Resolve")
 		return re.msg, nil
 	default:
 		return nil, ResolvError{qname, net, nameservers}
@@ -256,10 +267,16 @@ func (r *Resolver) resolverFor(network string, n *Nameserver) (*dns.Client, erro
 	}
 
 	if n.net == "tcp-tls" {
-		host, _, err := net.SplitHostPort(n.address)
+		host := n.host
 
-		if err != nil {
-			host = n.address
+		if host == "" {
+			var err error
+
+			host, _, err = net.SplitHostPort(n.address)
+
+			if err != nil {
+				host = n.address
+			}
 		}
 
 		client.TLSConfig = &tls.Config{
@@ -282,12 +299,13 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver {
 	queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.'
 
 	if v, found := r.domainServer.search(queryKeys); found {
-		log.Debug("%s found in domain server list, upstream: %v", qname, v)
+		log.WithFields(log.Fields{
+			"question": qname,
+			"upstream": v.address,
+		}).Debug("Found in domain server list")
 
 		//Ensure query the specific upstream nameserver in async Lookup() function.
-		return []*Nameserver{
-			{net: "udp", address: net.JoinHostPort(v, "53")},
-		}
+		return []*Nameserver{v}
 	}
 
 	return r.servers
@@ -296,3 +314,15 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver {
 func (r *Resolver) Timeout() time.Duration {
 	return time.Duration(r.config.Timeout) * time.Second
 }
+
+func determineNet(t, server string) string {
+	if t != "" {
+		return t
+	}
+
+	if strings.HasPrefix(server, "https") {
+		return "https"
+	}
+
+	return "udp"
+}
\ No newline at end of file
diff --git a/resolver/sfx_tree.go b/resolver/sfx_tree.go
index 41e4cf8..f0f170a 100644
--- a/resolver/sfx_tree.go
+++ b/resolver/sfx_tree.go
@@ -2,15 +2,15 @@ package resolver
 
 type suffixTreeNode struct {
 	key      string
-	value    string
+	value    *Nameserver
 	children map[string]*suffixTreeNode
 }
 
 func newSuffixTreeRoot() *suffixTreeNode {
-	return newSuffixTree("", "")
+	return newSuffixTree("", nil)
 }
 
-func newSuffixTree(key string, value string) *suffixTreeNode {
+func newSuffixTree(key string, value *Nameserver) *suffixTreeNode {
 	root := &suffixTreeNode{
 		key:      key,
 		value:    value,
@@ -21,11 +21,11 @@ func newSuffixTree(key string, value string) *suffixTreeNode {
 
 func (node *suffixTreeNode) ensureSubTree(key string) {
 	if _, ok := node.children[key]; !ok {
-		node.children[key] = newSuffixTree(key, "")
+		node.children[key] = newSuffixTree(key, nil)
 	}
 }
 
-func (node *suffixTreeNode) insert(key string, value string) {
+func (node *suffixTreeNode) insert(key string, value *Nameserver) {
 	if c, ok := node.children[key]; ok {
 		c.value = value
 	} else {
@@ -33,7 +33,7 @@ func (node *suffixTreeNode) insert(key string, value string) {
 	}
 }
 
-func (node *suffixTreeNode) sinsert(keys []string, value string) {
+func (node *suffixTreeNode) sinsert(keys []string, value *Nameserver) {
 	if len(keys) == 0 {
 		return
 	}
@@ -48,9 +48,9 @@ func (node *suffixTreeNode) sinsert(keys []string, value string) {
 	node.insert(key, value)
 }
 
-func (node *suffixTreeNode) search(keys []string) (string, bool) {
+func (node *suffixTreeNode) search(keys []string) (*Nameserver, bool) {
 	if len(keys) == 0 {
-		return "", false
+		return nil, false
 	}
 
 	key := keys[len(keys)-1]
@@ -58,8 +58,8 @@ func (node *suffixTreeNode) search(keys []string) (string, bool) {
 		if nextValue, found := n.search(keys[:len(keys)-1]); found {
 			return nextValue, found
 		}
-		return n.value, (n.value != "")
+		return n.value, n.value != nil
 	}
 
-	return "", false
+	return nil, false
 }
diff --git a/resolver/sfx_tree_test.go b/resolver/sfx_tree_test.go
index 9ff26b9..12194e4 100644
--- a/resolver/sfx_tree_test.go
+++ b/resolver/sfx_tree_test.go
@@ -11,43 +11,43 @@ func Test_Suffix_Tree(t *testing.T) {
 	root := newSuffixTreeRoot()
 
 	Convey("Google should not be found", t, func() {
-		root.insert("cn", "114.114.114.114")
-		root.sinsert([]string{"baidu", "cn"}, "166.111.8.28")
-		root.sinsert([]string{"sina", "cn"}, "114.114.114.114")
+		root.insert("cn", &Nameserver{address: "114.114.114.114"})
+		root.sinsert([]string{"baidu", "cn"}, &Nameserver{address: "166.111.8.28"})
+		root.sinsert([]string{"sina", "cn"}, &Nameserver{address: "114.114.114.114"})
 
 		v, found := root.search(strings.Split("google.com", "."))
 		So(found, ShouldEqual, false)
 
 		v, found = root.search(strings.Split("baidu.cn", "."))
 		So(found, ShouldEqual, true)
-		So(v, ShouldEqual, "166.111.8.28")
+		So(v.address, ShouldEqual, "166.111.8.28")
 	})
 
 	Convey("Google should be found", t, func() {
-		root.sinsert(strings.Split("com", "."), "")
-		root.sinsert(strings.Split("google.com", "."), "8.8.8.8")
-		root.sinsert(strings.Split("twitter.com", "."), "8.8.8.8")
-		root.sinsert(strings.Split("scholar.google.com", "."), "208.67.222.222")
+		root.sinsert(strings.Split("com", "."), &Nameserver{address: ""})
+		root.sinsert(strings.Split("google.com", "."), &Nameserver{address: "8.8.8.8"})
+		root.sinsert(strings.Split("twitter.com", "."), &Nameserver{address: "8.8.8.8"})
+		root.sinsert(strings.Split("scholar.google.com", "."), &Nameserver{address: "208.67.222.222"})
 
 		v, found := root.search(strings.Split("google.com", "."))
 		So(found, ShouldEqual, true)
-		So(v, ShouldEqual, "8.8.8.8")
+		So(v.address, ShouldEqual, "8.8.8.8")
 
 		v, found = root.search(strings.Split("www.google.com", "."))
 		So(found, ShouldEqual, true)
-		So(v, ShouldEqual, "8.8.8.8")
+		So(v.address, ShouldEqual, "8.8.8.8")
 
 		v, found = root.search(strings.Split("scholar.google.com", "."))
 		So(found, ShouldEqual, true)
-		So(v, ShouldEqual, "208.67.222.222")
+		So(v.address, ShouldEqual, "208.67.222.222")
 
 		v, found = root.search(strings.Split("twitter.com", "."))
 		So(found, ShouldEqual, true)
-		So(v, ShouldEqual, "8.8.8.8")
+		So(v.address, ShouldEqual, "8.8.8.8")
 
 		v, found = root.search(strings.Split("baidu.cn", "."))
 		So(found, ShouldEqual, true)
-		So(v, ShouldEqual, "166.111.8.28")
+		So(v.address, ShouldEqual, "166.111.8.28")
 	})
 
 }
diff --git a/utils/utils_test.go b/utils/utils_test.go
index 8a9dc5a..ee825b8 100644
--- a/utils/utils_test.go
+++ b/utils/utils_test.go
@@ -9,23 +9,23 @@ import (
 func TestHostDomainAndIP(t *testing.T) {
 	Convey("Test Host File Domain and IP regex", t, func() {
 		Convey("1.1.1.1 should be IP and not domain", func() {
-			So(isIP("1.1.1.1"), ShouldEqual, true)
-			So(isDomain("1.1.1.1"), ShouldEqual, false)
+			So(IsIP("1.1.1.1"), ShouldEqual, true)
+			So(IsDomain("1.1.1.1"), ShouldEqual, false)
 		})
 
 		Convey("2001:470:20::2 should be IP and not domain", func() {
-			So(isIP("2001:470:20::2"), ShouldEqual, true)
-			So(isDomain("2001:470:20::2"), ShouldEqual, false)
+			So(IsIP("2001:470:20::2"), ShouldEqual, true)
+			So(IsDomain("2001:470:20::2"), ShouldEqual, false)
 		})
 
 		Convey("`host` should not be domain and not IP", func() {
-			So(isDomain("host"), ShouldEqual, false)
-			So(isIP("host"), ShouldEqual, false)
+			So(IsDomain("host"), ShouldEqual, false)
+			So(IsIP("host"), ShouldEqual, false)
 		})
 
 		Convey("`123.test` should be domain and not IP", func() {
-			So(isDomain("123.test"), ShouldEqual, true)
-			So(isIP("123.test"), ShouldEqual, false)
+			So(IsDomain("123.test"), ShouldEqual, true)
+			So(IsIP("123.test"), ShouldEqual, false)
 		})
 
 	})