Add testing, cleanup, rework suffix tree to use nameservers. Parse nameservers from yaml.
This commit is contained in:
		@ -2,5 +2,6 @@ package resolver
 | 
			
		||||
 | 
			
		||||
type Nameserver struct {
 | 
			
		||||
	net string
 | 
			
		||||
	host string
 | 
			
		||||
	address string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,14 +1,14 @@
 | 
			
		||||
package resolver
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"github.com/spf13/viper"
 | 
			
		||||
	"io"
 | 
			
		||||
	"meow.tf/joker/godns/utils"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
@ -38,6 +38,7 @@ type RResp struct {
 | 
			
		||||
	rtt        time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Resolver contains a list of nameservers, domain-specific nameservers, and dns clients
 | 
			
		||||
type Resolver struct {
 | 
			
		||||
	servers      []*Nameserver
 | 
			
		||||
	domainServer *suffixTreeNode
 | 
			
		||||
@ -47,6 +48,7 @@ type Resolver struct {
 | 
			
		||||
	clientLock sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewResolver initializes a resolver from the specified settings
 | 
			
		||||
func NewResolver(c Settings) *Resolver {
 | 
			
		||||
	r := &Resolver{
 | 
			
		||||
		servers:      make([]*Nameserver, 0),
 | 
			
		||||
@ -55,16 +57,18 @@ func NewResolver(c Settings) *Resolver {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(c.ServerListFile) > 0 {
 | 
			
		||||
		r.ReadServerListFile(c.ServerListFile)
 | 
			
		||||
		err := r.ReadServerListFile(c.ServerListFile)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).Fatalln("Unable to read server list file")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(c.ResolvFile) > 0 {
 | 
			
		||||
		clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
 | 
			
		||||
			log.Error("%s", err)
 | 
			
		||||
			panic(err)
 | 
			
		||||
			log.WithError(err).Fatalln("not a valid resolv.conf file")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, server := range clientConfig.Servers {
 | 
			
		||||
@ -75,81 +79,73 @@ func NewResolver(c Settings) *Resolver {
 | 
			
		||||
	return r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Resolver) parseServerListFile(buf *os.File) {
 | 
			
		||||
	scanner := bufio.NewScanner(buf)
 | 
			
		||||
 | 
			
		||||
	var line string
 | 
			
		||||
	var idx int
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line = strings.TrimSpace(scanner.Text())
 | 
			
		||||
 | 
			
		||||
		if !strings.HasPrefix(line, "server") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		idx = strings.Index(line, "=")
 | 
			
		||||
 | 
			
		||||
		if idx == -1 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		line = strings.TrimSpace(line[idx+1:])
 | 
			
		||||
 | 
			
		||||
		if strings.HasPrefix(line, "https://") {
 | 
			
		||||
			r.servers = append(r.servers, &Nameserver{net: "https", address: line})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tokens := strings.Split(line, "/")
 | 
			
		||||
		switch len(tokens) {
 | 
			
		||||
		case 3:
 | 
			
		||||
			domain := tokens[1]
 | 
			
		||||
			ip := tokens[2]
 | 
			
		||||
 | 
			
		||||
			if !utils.IsDomain(domain) || !utils.IsIP(ip) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			r.domainServer.sinsert(strings.Split(domain, "."), ip)
 | 
			
		||||
		case 1:
 | 
			
		||||
			srvPort := strings.Split(line, "#")
 | 
			
		||||
 | 
			
		||||
			if len(srvPort) > 2 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ip := ""
 | 
			
		||||
 | 
			
		||||
			if ip = srvPort[0]; !utils.IsIP(ip) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			port := "53"
 | 
			
		||||
 | 
			
		||||
			if len(srvPort) == 2 {
 | 
			
		||||
				if _, err := strconv.Atoi(srvPort[1]); err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				port = srvPort[1]
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			r.servers = append(r.servers, &Nameserver{address: net.JoinHostPort(ip, port)})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
// server is a configuration struct for server lists
 | 
			
		||||
type server struct {
 | 
			
		||||
	// Type is the nameserver type (https, udp, tcp-tls), optional
 | 
			
		||||
	Type string
 | 
			
		||||
	Server string
 | 
			
		||||
	// Optional host for passing to TLS Config
 | 
			
		||||
	Host string
 | 
			
		||||
	Domains []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Resolver) ReadServerListFile(files []string) {
 | 
			
		||||
// parseServerListFile loads a YAML server list file.
 | 
			
		||||
func (r *Resolver) parseServerListFile(buf io.Reader) error {
 | 
			
		||||
	v := viper.New()
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	v.SetConfigType("yaml")
 | 
			
		||||
 | 
			
		||||
	if err = v.ReadConfig(buf); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	list := make([]server, 0)
 | 
			
		||||
 | 
			
		||||
	if err = v.UnmarshalKey("servers", &list); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, server := range list {
 | 
			
		||||
		nameserver := &Nameserver{
 | 
			
		||||
			net: determineNet(server.Type, server.Server),
 | 
			
		||||
			address: server.Server,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(server.Domains) > 0 {
 | 
			
		||||
			for _, domain := range server.Domains {
 | 
			
		||||
				r.domainServer.sinsert(strings.Split(domain, "."), nameserver)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		r.servers = append(r.servers, nameserver)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadServerListFile loads a list of server list files.
 | 
			
		||||
func (r *Resolver) ReadServerListFile(files []string) error {
 | 
			
		||||
	for _, file := range files {
 | 
			
		||||
		buf, err := os.Open(file)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			panic("Can't open " + file)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		r.parseServerListFile(buf)
 | 
			
		||||
 | 
			
		||||
		err = r.parseServerListFile(buf)
 | 
			
		||||
 | 
			
		||||
		buf.Close()
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Lookup will ask each nameserver in top-to-bottom fashion, starting a new request
 | 
			
		||||
@ -170,15 +166,18 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 | 
			
		||||
		c, err := resolver.resolverFor(net, nameserver)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Warn("error:%s", err.Error())
 | 
			
		||||
			log.WithError(err).Warn("resolver failed to resolve")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		r, rtt, err := c.Exchange(req, nameserver.address)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Warn("%s socket error on %s", qname, nameserver)
 | 
			
		||||
			log.Warn("error:%s", err.Error())
 | 
			
		||||
			log.WithFields(log.Fields{
 | 
			
		||||
				"error": err,
 | 
			
		||||
				"question": qname,
 | 
			
		||||
				"nameserver": nameserver.address,
 | 
			
		||||
			}).Warn("Socket error encountered")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// If SERVFAIL happen, should return immediately and try another upstream resolver.
 | 
			
		||||
@ -186,7 +185,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 | 
			
		||||
		// that it has been verified no such domain existas and ask other resolvers
 | 
			
		||||
		// would make no sense. See more about #20
 | 
			
		||||
		if r != nil && r.Rcode != dns.RcodeSuccess {
 | 
			
		||||
			log.Warn("%s failed to get an valid answer on %s", qname, nameserver)
 | 
			
		||||
			log.WithFields(log.Fields{
 | 
			
		||||
				"question": qname,
 | 
			
		||||
				"nameserver": nameserver.address,
 | 
			
		||||
			}).Warn("Nameserver failed to get a valid answer")
 | 
			
		||||
 | 
			
		||||
			if r.Rcode == dns.RcodeServerFailure {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
@ -208,7 +211,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 | 
			
		||||
		// but exit early, if we have an answer
 | 
			
		||||
		select {
 | 
			
		||||
		case re := <-res:
 | 
			
		||||
			log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver, re.rtt)
 | 
			
		||||
			log.WithFields(log.Fields{
 | 
			
		||||
				"question": utils.UnFqdn(qname),
 | 
			
		||||
				"nameserver": re.nameserver.address,
 | 
			
		||||
				"rtt": re.rtt,
 | 
			
		||||
			}).Debug("Resolve")
 | 
			
		||||
			return re.msg, nil
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			continue
 | 
			
		||||
@ -218,7 +225,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	select {
 | 
			
		||||
	case re := <-res:
 | 
			
		||||
		log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver.address, re.rtt)
 | 
			
		||||
		log.WithFields(log.Fields{
 | 
			
		||||
			"question": utils.UnFqdn(qname),
 | 
			
		||||
			"nameserver": re.nameserver.address,
 | 
			
		||||
			"rtt": re.rtt,
 | 
			
		||||
		}).Debug("Resolve")
 | 
			
		||||
		return re.msg, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, ResolvError{qname, net, nameservers}
 | 
			
		||||
@ -256,10 +267,16 @@ func (r *Resolver) resolverFor(network string, n *Nameserver) (*dns.Client, erro
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if n.net == "tcp-tls" {
 | 
			
		||||
		host, _, err := net.SplitHostPort(n.address)
 | 
			
		||||
		host := n.host
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			host = n.address
 | 
			
		||||
		if host == "" {
 | 
			
		||||
			var err error
 | 
			
		||||
 | 
			
		||||
			host, _, err = net.SplitHostPort(n.address)
 | 
			
		||||
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				host = n.address
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		client.TLSConfig = &tls.Config{
 | 
			
		||||
@ -282,12 +299,13 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver {
 | 
			
		||||
	queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.'
 | 
			
		||||
 | 
			
		||||
	if v, found := r.domainServer.search(queryKeys); found {
 | 
			
		||||
		log.Debug("%s found in domain server list, upstream: %v", qname, v)
 | 
			
		||||
		log.WithFields(log.Fields{
 | 
			
		||||
			"question": qname,
 | 
			
		||||
			"upstream": v.address,
 | 
			
		||||
		}).Debug("Found in domain server list")
 | 
			
		||||
 | 
			
		||||
		//Ensure query the specific upstream nameserver in async Lookup() function.
 | 
			
		||||
		return []*Nameserver{
 | 
			
		||||
			{net: "udp", address: net.JoinHostPort(v, "53")},
 | 
			
		||||
		}
 | 
			
		||||
		return []*Nameserver{v}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return r.servers
 | 
			
		||||
@ -296,3 +314,15 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver {
 | 
			
		||||
func (r *Resolver) Timeout() time.Duration {
 | 
			
		||||
	return time.Duration(r.config.Timeout) * time.Second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func determineNet(t, server string) string {
 | 
			
		||||
	if t != "" {
 | 
			
		||||
		return t
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(server, "https") {
 | 
			
		||||
		return "https"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return "udp"
 | 
			
		||||
}
 | 
			
		||||
@ -2,15 +2,15 @@ package resolver
 | 
			
		||||
 | 
			
		||||
type suffixTreeNode struct {
 | 
			
		||||
	key      string
 | 
			
		||||
	value    string
 | 
			
		||||
	value    *Nameserver
 | 
			
		||||
	children map[string]*suffixTreeNode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSuffixTreeRoot() *suffixTreeNode {
 | 
			
		||||
	return newSuffixTree("", "")
 | 
			
		||||
	return newSuffixTree("", nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSuffixTree(key string, value string) *suffixTreeNode {
 | 
			
		||||
func newSuffixTree(key string, value *Nameserver) *suffixTreeNode {
 | 
			
		||||
	root := &suffixTreeNode{
 | 
			
		||||
		key:      key,
 | 
			
		||||
		value:    value,
 | 
			
		||||
@ -21,11 +21,11 @@ func newSuffixTree(key string, value string) *suffixTreeNode {
 | 
			
		||||
 | 
			
		||||
func (node *suffixTreeNode) ensureSubTree(key string) {
 | 
			
		||||
	if _, ok := node.children[key]; !ok {
 | 
			
		||||
		node.children[key] = newSuffixTree(key, "")
 | 
			
		||||
		node.children[key] = newSuffixTree(key, nil)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *suffixTreeNode) insert(key string, value string) {
 | 
			
		||||
func (node *suffixTreeNode) insert(key string, value *Nameserver) {
 | 
			
		||||
	if c, ok := node.children[key]; ok {
 | 
			
		||||
		c.value = value
 | 
			
		||||
	} else {
 | 
			
		||||
@ -33,7 +33,7 @@ func (node *suffixTreeNode) insert(key string, value string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *suffixTreeNode) sinsert(keys []string, value string) {
 | 
			
		||||
func (node *suffixTreeNode) sinsert(keys []string, value *Nameserver) {
 | 
			
		||||
	if len(keys) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@ -48,9 +48,9 @@ func (node *suffixTreeNode) sinsert(keys []string, value string) {
 | 
			
		||||
	node.insert(key, value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *suffixTreeNode) search(keys []string) (string, bool) {
 | 
			
		||||
func (node *suffixTreeNode) search(keys []string) (*Nameserver, bool) {
 | 
			
		||||
	if len(keys) == 0 {
 | 
			
		||||
		return "", false
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	key := keys[len(keys)-1]
 | 
			
		||||
@ -58,8 +58,8 @@ func (node *suffixTreeNode) search(keys []string) (string, bool) {
 | 
			
		||||
		if nextValue, found := n.search(keys[:len(keys)-1]); found {
 | 
			
		||||
			return nextValue, found
 | 
			
		||||
		}
 | 
			
		||||
		return n.value, (n.value != "")
 | 
			
		||||
		return n.value, n.value != nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return "", false
 | 
			
		||||
	return nil, false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -11,43 +11,43 @@ func Test_Suffix_Tree(t *testing.T) {
 | 
			
		||||
	root := newSuffixTreeRoot()
 | 
			
		||||
 | 
			
		||||
	Convey("Google should not be found", t, func() {
 | 
			
		||||
		root.insert("cn", "114.114.114.114")
 | 
			
		||||
		root.sinsert([]string{"baidu", "cn"}, "166.111.8.28")
 | 
			
		||||
		root.sinsert([]string{"sina", "cn"}, "114.114.114.114")
 | 
			
		||||
		root.insert("cn", &Nameserver{address: "114.114.114.114"})
 | 
			
		||||
		root.sinsert([]string{"baidu", "cn"}, &Nameserver{address: "166.111.8.28"})
 | 
			
		||||
		root.sinsert([]string{"sina", "cn"}, &Nameserver{address: "114.114.114.114"})
 | 
			
		||||
 | 
			
		||||
		v, found := root.search(strings.Split("google.com", "."))
 | 
			
		||||
		So(found, ShouldEqual, false)
 | 
			
		||||
 | 
			
		||||
		v, found = root.search(strings.Split("baidu.cn", "."))
 | 
			
		||||
		So(found, ShouldEqual, true)
 | 
			
		||||
		So(v, ShouldEqual, "166.111.8.28")
 | 
			
		||||
		So(v.address, ShouldEqual, "166.111.8.28")
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	Convey("Google should be found", t, func() {
 | 
			
		||||
		root.sinsert(strings.Split("com", "."), "")
 | 
			
		||||
		root.sinsert(strings.Split("google.com", "."), "8.8.8.8")
 | 
			
		||||
		root.sinsert(strings.Split("twitter.com", "."), "8.8.8.8")
 | 
			
		||||
		root.sinsert(strings.Split("scholar.google.com", "."), "208.67.222.222")
 | 
			
		||||
		root.sinsert(strings.Split("com", "."), &Nameserver{address: ""})
 | 
			
		||||
		root.sinsert(strings.Split("google.com", "."), &Nameserver{address: "8.8.8.8"})
 | 
			
		||||
		root.sinsert(strings.Split("twitter.com", "."), &Nameserver{address: "8.8.8.8"})
 | 
			
		||||
		root.sinsert(strings.Split("scholar.google.com", "."), &Nameserver{address: "208.67.222.222"})
 | 
			
		||||
 | 
			
		||||
		v, found := root.search(strings.Split("google.com", "."))
 | 
			
		||||
		So(found, ShouldEqual, true)
 | 
			
		||||
		So(v, ShouldEqual, "8.8.8.8")
 | 
			
		||||
		So(v.address, ShouldEqual, "8.8.8.8")
 | 
			
		||||
 | 
			
		||||
		v, found = root.search(strings.Split("www.google.com", "."))
 | 
			
		||||
		So(found, ShouldEqual, true)
 | 
			
		||||
		So(v, ShouldEqual, "8.8.8.8")
 | 
			
		||||
		So(v.address, ShouldEqual, "8.8.8.8")
 | 
			
		||||
 | 
			
		||||
		v, found = root.search(strings.Split("scholar.google.com", "."))
 | 
			
		||||
		So(found, ShouldEqual, true)
 | 
			
		||||
		So(v, ShouldEqual, "208.67.222.222")
 | 
			
		||||
		So(v.address, ShouldEqual, "208.67.222.222")
 | 
			
		||||
 | 
			
		||||
		v, found = root.search(strings.Split("twitter.com", "."))
 | 
			
		||||
		So(found, ShouldEqual, true)
 | 
			
		||||
		So(v, ShouldEqual, "8.8.8.8")
 | 
			
		||||
		So(v.address, ShouldEqual, "8.8.8.8")
 | 
			
		||||
 | 
			
		||||
		v, found = root.search(strings.Split("baidu.cn", "."))
 | 
			
		||||
		So(found, ShouldEqual, true)
 | 
			
		||||
		So(v, ShouldEqual, "166.111.8.28")
 | 
			
		||||
		So(v.address, ShouldEqual, "166.111.8.28")
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user