Add testing, cleanup, rework suffix tree to use nameservers. Parse nameservers from yaml.
continuous-integration/drone/push Build is failing Details

This commit is contained in:
Tyler 2021-04-15 01:04:58 -04:00
parent b6efd0df0c
commit d8079551c9
10 changed files with 177 additions and 125 deletions

View File

@ -3,13 +3,16 @@ name: amd64
type: docker type: docker
steps: steps:
- name: test
image: golang:alpine
commands:
- go test ...
- name: build - name: build
image: golang:alpine image: golang:alpine
volumes: volumes:
- name: build - name: build
path: /build path: /build
commands: commands:
- apk --no-cache add git gcc musl-dev
- go build -o /build/godns - go build -o /build/godns
- name: docker - name: docker
image: plugins/docker image: plugins/docker

View File

@ -2,6 +2,7 @@ package hosts
import ( import (
"encoding/json" "encoding/json"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
"strings" "strings"
@ -67,7 +68,10 @@ func (b *BoltHosts) List() (HostMap, error) {
} }
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) { func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking bolt provider for %s : %s", queryType, domain) log.WithFields(log.Fields{
"queryType": dns.TypeToString[queryType],
"question": domain,
}).Debug("Checking bolt provider")
domain = strings.ToLower(domain) domain = strings.ToLower(domain)

View File

@ -62,7 +62,10 @@ func (f *FileHosts) List() (HostMap, error) {
} }
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) { func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking file provider for %s : %s", queryType, domain) log.WithFields(log.Fields{
"queryType": dns.TypeToString[queryType],
"question": domain,
}).Debug("Checking file provider")
// Does not support CNAME/TXT/etc // Does not support CNAME/TXT/etc
if queryType != dns.TypeA && queryType != dns.TypeAAAA { if queryType != dns.TypeA && queryType != dns.TypeAAAA {
@ -102,7 +105,10 @@ func (f *FileHosts) Refresh() {
buf, err := os.Open(f.file) buf, err := os.Open(f.file)
if err != nil { if err != nil {
log.Warn("Update hosts records from file failed %s", err) log.WithFields(log.Fields{
"file": f.file,
"error": err,
}).Warn("Hosts update from file failed")
return return
} }
@ -149,7 +155,11 @@ func (f *FileHosts) Refresh() {
f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}} f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}}
} }
} }
log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts))
log.WithFields(log.Fields{
"file": f.file,
"count": len(f.hosts),
}).Debug("Updated hosts records")
} }
func (f *FileHosts) clear() { func (f *FileHosts) clear() {

View File

@ -3,6 +3,7 @@ package hosts
import ( import (
"encoding/json" "encoding/json"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v7"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"strings" "strings"
) )
@ -46,7 +47,10 @@ func (r *RedisHosts) List() (HostMap, error) {
} }
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) { func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking redis provider for %s", domain) log.WithFields(log.Fields{
"queryType": dns.TypeToString[queryType],
"question": domain,
}).Debug("Checking redis provider")
domain = strings.ToLower(domain) domain = strings.ToLower(domain)

View File

@ -145,7 +145,7 @@ func main() {
func profileCPU() { func profileCPU() {
f, err := os.Create("godns.cprof") f, err := os.Create("godns.cprof")
if err != nil { if err != nil {
log.Error("%s", err) log.WithError(err).Error("Unable to profile cpu due to error")
return return
} }
@ -161,7 +161,7 @@ func profileMEM() {
f, err := os.Create("godns.mprof") f, err := os.Create("godns.mprof")
if err != nil { if err != nil {
log.Error("%s", err) log.WithError(err).Error("Unable to profile memory due to error")
return return
} }

View File

@ -2,5 +2,6 @@ package resolver
type Nameserver struct { type Nameserver struct {
net string net string
host string
address string address string
} }

View File

@ -1,14 +1,14 @@
package resolver package resolver
import ( import (
"bufio"
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"io"
"meow.tf/joker/godns/utils" "meow.tf/joker/godns/utils"
"net" "net"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -38,6 +38,7 @@ type RResp struct {
rtt time.Duration rtt time.Duration
} }
// Resolver contains a list of nameservers, domain-specific nameservers, and dns clients
type Resolver struct { type Resolver struct {
servers []*Nameserver servers []*Nameserver
domainServer *suffixTreeNode domainServer *suffixTreeNode
@ -47,6 +48,7 @@ type Resolver struct {
clientLock sync.RWMutex clientLock sync.RWMutex
} }
// NewResolver initializes a resolver from the specified settings
func NewResolver(c Settings) *Resolver { func NewResolver(c Settings) *Resolver {
r := &Resolver{ r := &Resolver{
servers: make([]*Nameserver, 0), servers: make([]*Nameserver, 0),
@ -55,16 +57,18 @@ func NewResolver(c Settings) *Resolver {
} }
if len(c.ServerListFile) > 0 { if len(c.ServerListFile) > 0 {
r.ReadServerListFile(c.ServerListFile) err := r.ReadServerListFile(c.ServerListFile)
if err != nil {
log.WithError(err).Fatalln("Unable to read server list file")
}
} }
if len(c.ResolvFile) > 0 { if len(c.ResolvFile) > 0 {
clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
if err != nil { if err != nil {
log.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) log.WithError(err).Fatalln("not a valid resolv.conf file")
log.Error("%s", err)
panic(err)
} }
for _, server := range clientConfig.Servers { for _, server := range clientConfig.Servers {
@ -75,81 +79,73 @@ func NewResolver(c Settings) *Resolver {
return r return r
} }
func (r *Resolver) parseServerListFile(buf *os.File) { // server is a configuration struct for server lists
scanner := bufio.NewScanner(buf) type server struct {
// Type is the nameserver type (https, udp, tcp-tls), optional
var line string Type string
var idx int Server string
// Optional host for passing to TLS Config
for scanner.Scan() { Host string
line = strings.TrimSpace(scanner.Text()) Domains []string
if !strings.HasPrefix(line, "server") {
continue
}
idx = strings.Index(line, "=")
if idx == -1 {
continue
}
line = strings.TrimSpace(line[idx+1:])
if strings.HasPrefix(line, "https://") {
r.servers = append(r.servers, &Nameserver{net: "https", address: line})
continue
}
tokens := strings.Split(line, "/")
switch len(tokens) {
case 3:
domain := tokens[1]
ip := tokens[2]
if !utils.IsDomain(domain) || !utils.IsIP(ip) {
continue
}
r.domainServer.sinsert(strings.Split(domain, "."), ip)
case 1:
srvPort := strings.Split(line, "#")
if len(srvPort) > 2 {
continue
}
ip := ""
if ip = srvPort[0]; !utils.IsIP(ip) {
continue
}
port := "53"
if len(srvPort) == 2 {
if _, err := strconv.Atoi(srvPort[1]); err != nil {
continue
}
port = srvPort[1]
}
r.servers = append(r.servers, &Nameserver{address: net.JoinHostPort(ip, port)})
}
}
} }
func (r *Resolver) ReadServerListFile(files []string) { // parseServerListFile loads a YAML server list file.
func (r *Resolver) parseServerListFile(buf io.Reader) error {
v := viper.New()
var err error
v.SetConfigType("yaml")
if err = v.ReadConfig(buf); err != nil {
return err
}
list := make([]server, 0)
if err = v.UnmarshalKey("servers", &list); err != nil {
return err
}
for _, server := range list {
nameserver := &Nameserver{
net: determineNet(server.Type, server.Server),
address: server.Server,
}
if len(server.Domains) > 0 {
for _, domain := range server.Domains {
r.domainServer.sinsert(strings.Split(domain, "."), nameserver)
}
continue
}
r.servers = append(r.servers, nameserver)
}
return nil
}
// ReadServerListFile loads a list of server list files.
func (r *Resolver) ReadServerListFile(files []string) error {
for _, file := range files { for _, file := range files {
buf, err := os.Open(file) buf, err := os.Open(file)
if err != nil { if err != nil {
panic("Can't open " + file) return err
} }
r.parseServerListFile(buf)
err = r.parseServerListFile(buf)
buf.Close() buf.Close()
if err != nil {
return err
}
} }
return nil
} }
// Lookup will ask each nameserver in top-to-bottom fashion, starting a new request // Lookup will ask each nameserver in top-to-bottom fashion, starting a new request
@ -170,15 +166,18 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
c, err := resolver.resolverFor(net, nameserver) c, err := resolver.resolverFor(net, nameserver)
if err != nil { if err != nil {
log.Warn("error:%s", err.Error()) log.WithError(err).Warn("resolver failed to resolve")
return return
} }
r, rtt, err := c.Exchange(req, nameserver.address) r, rtt, err := c.Exchange(req, nameserver.address)
if err != nil { if err != nil {
log.Warn("%s socket error on %s", qname, nameserver) log.WithFields(log.Fields{
log.Warn("error:%s", err.Error()) "error": err,
"question": qname,
"nameserver": nameserver.address,
}).Warn("Socket error encountered")
return return
} }
// If SERVFAIL happen, should return immediately and try another upstream resolver. // If SERVFAIL happen, should return immediately and try another upstream resolver.
@ -186,7 +185,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
// that it has been verified no such domain existas and ask other resolvers // that it has been verified no such domain existas and ask other resolvers
// would make no sense. See more about #20 // would make no sense. See more about #20
if r != nil && r.Rcode != dns.RcodeSuccess { if r != nil && r.Rcode != dns.RcodeSuccess {
log.Warn("%s failed to get an valid answer on %s", qname, nameserver) log.WithFields(log.Fields{
"question": qname,
"nameserver": nameserver.address,
}).Warn("Nameserver failed to get a valid answer")
if r.Rcode == dns.RcodeServerFailure { if r.Rcode == dns.RcodeServerFailure {
return return
} }
@ -208,7 +211,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
// but exit early, if we have an answer // but exit early, if we have an answer
select { select {
case re := <-res: case re := <-res:
log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver, re.rtt) log.WithFields(log.Fields{
"question": utils.UnFqdn(qname),
"nameserver": re.nameserver.address,
"rtt": re.rtt,
}).Debug("Resolve")
return re.msg, nil return re.msg, nil
case <-ticker.C: case <-ticker.C:
continue continue
@ -218,7 +225,11 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
wg.Wait() wg.Wait()
select { select {
case re := <-res: case re := <-res:
log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver.address, re.rtt) log.WithFields(log.Fields{
"question": utils.UnFqdn(qname),
"nameserver": re.nameserver.address,
"rtt": re.rtt,
}).Debug("Resolve")
return re.msg, nil return re.msg, nil
default: default:
return nil, ResolvError{qname, net, nameservers} return nil, ResolvError{qname, net, nameservers}
@ -256,10 +267,16 @@ func (r *Resolver) resolverFor(network string, n *Nameserver) (*dns.Client, erro
} }
if n.net == "tcp-tls" { if n.net == "tcp-tls" {
host, _, err := net.SplitHostPort(n.address) host := n.host
if err != nil { if host == "" {
host = n.address var err error
host, _, err = net.SplitHostPort(n.address)
if err != nil {
host = n.address
}
} }
client.TLSConfig = &tls.Config{ client.TLSConfig = &tls.Config{
@ -282,12 +299,13 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver {
queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.' queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.'
if v, found := r.domainServer.search(queryKeys); found { if v, found := r.domainServer.search(queryKeys); found {
log.Debug("%s found in domain server list, upstream: %v", qname, v) log.WithFields(log.Fields{
"question": qname,
"upstream": v.address,
}).Debug("Found in domain server list")
//Ensure query the specific upstream nameserver in async Lookup() function. //Ensure query the specific upstream nameserver in async Lookup() function.
return []*Nameserver{ return []*Nameserver{v}
{net: "udp", address: net.JoinHostPort(v, "53")},
}
} }
return r.servers return r.servers
@ -296,3 +314,15 @@ func (r *Resolver) Nameservers(qname string) []*Nameserver {
func (r *Resolver) Timeout() time.Duration { func (r *Resolver) Timeout() time.Duration {
return time.Duration(r.config.Timeout) * time.Second return time.Duration(r.config.Timeout) * time.Second
} }
func determineNet(t, server string) string {
if t != "" {
return t
}
if strings.HasPrefix(server, "https") {
return "https"
}
return "udp"
}

View File

@ -2,15 +2,15 @@ package resolver
type suffixTreeNode struct { type suffixTreeNode struct {
key string key string
value string value *Nameserver
children map[string]*suffixTreeNode children map[string]*suffixTreeNode
} }
func newSuffixTreeRoot() *suffixTreeNode { func newSuffixTreeRoot() *suffixTreeNode {
return newSuffixTree("", "") return newSuffixTree("", nil)
} }
func newSuffixTree(key string, value string) *suffixTreeNode { func newSuffixTree(key string, value *Nameserver) *suffixTreeNode {
root := &suffixTreeNode{ root := &suffixTreeNode{
key: key, key: key,
value: value, value: value,
@ -21,11 +21,11 @@ func newSuffixTree(key string, value string) *suffixTreeNode {
func (node *suffixTreeNode) ensureSubTree(key string) { func (node *suffixTreeNode) ensureSubTree(key string) {
if _, ok := node.children[key]; !ok { if _, ok := node.children[key]; !ok {
node.children[key] = newSuffixTree(key, "") node.children[key] = newSuffixTree(key, nil)
} }
} }
func (node *suffixTreeNode) insert(key string, value string) { func (node *suffixTreeNode) insert(key string, value *Nameserver) {
if c, ok := node.children[key]; ok { if c, ok := node.children[key]; ok {
c.value = value c.value = value
} else { } else {
@ -33,7 +33,7 @@ func (node *suffixTreeNode) insert(key string, value string) {
} }
} }
func (node *suffixTreeNode) sinsert(keys []string, value string) { func (node *suffixTreeNode) sinsert(keys []string, value *Nameserver) {
if len(keys) == 0 { if len(keys) == 0 {
return return
} }
@ -48,9 +48,9 @@ func (node *suffixTreeNode) sinsert(keys []string, value string) {
node.insert(key, value) node.insert(key, value)
} }
func (node *suffixTreeNode) search(keys []string) (string, bool) { func (node *suffixTreeNode) search(keys []string) (*Nameserver, bool) {
if len(keys) == 0 { if len(keys) == 0 {
return "", false return nil, false
} }
key := keys[len(keys)-1] key := keys[len(keys)-1]
@ -58,8 +58,8 @@ func (node *suffixTreeNode) search(keys []string) (string, bool) {
if nextValue, found := n.search(keys[:len(keys)-1]); found { if nextValue, found := n.search(keys[:len(keys)-1]); found {
return nextValue, found return nextValue, found
} }
return n.value, (n.value != "") return n.value, n.value != nil
} }
return "", false return nil, false
} }

View File

@ -11,43 +11,43 @@ func Test_Suffix_Tree(t *testing.T) {
root := newSuffixTreeRoot() root := newSuffixTreeRoot()
Convey("Google should not be found", t, func() { Convey("Google should not be found", t, func() {
root.insert("cn", "114.114.114.114") root.insert("cn", &Nameserver{address: "114.114.114.114"})
root.sinsert([]string{"baidu", "cn"}, "166.111.8.28") root.sinsert([]string{"baidu", "cn"}, &Nameserver{address: "166.111.8.28"})
root.sinsert([]string{"sina", "cn"}, "114.114.114.114") root.sinsert([]string{"sina", "cn"}, &Nameserver{address: "114.114.114.114"})
v, found := root.search(strings.Split("google.com", ".")) v, found := root.search(strings.Split("google.com", "."))
So(found, ShouldEqual, false) So(found, ShouldEqual, false)
v, found = root.search(strings.Split("baidu.cn", ".")) v, found = root.search(strings.Split("baidu.cn", "."))
So(found, ShouldEqual, true) So(found, ShouldEqual, true)
So(v, ShouldEqual, "166.111.8.28") So(v.address, ShouldEqual, "166.111.8.28")
}) })
Convey("Google should be found", t, func() { Convey("Google should be found", t, func() {
root.sinsert(strings.Split("com", "."), "") root.sinsert(strings.Split("com", "."), &Nameserver{address: ""})
root.sinsert(strings.Split("google.com", "."), "8.8.8.8") root.sinsert(strings.Split("google.com", "."), &Nameserver{address: "8.8.8.8"})
root.sinsert(strings.Split("twitter.com", "."), "8.8.8.8") root.sinsert(strings.Split("twitter.com", "."), &Nameserver{address: "8.8.8.8"})
root.sinsert(strings.Split("scholar.google.com", "."), "208.67.222.222") root.sinsert(strings.Split("scholar.google.com", "."), &Nameserver{address: "208.67.222.222"})
v, found := root.search(strings.Split("google.com", ".")) v, found := root.search(strings.Split("google.com", "."))
So(found, ShouldEqual, true) So(found, ShouldEqual, true)
So(v, ShouldEqual, "8.8.8.8") So(v.address, ShouldEqual, "8.8.8.8")
v, found = root.search(strings.Split("www.google.com", ".")) v, found = root.search(strings.Split("www.google.com", "."))
So(found, ShouldEqual, true) So(found, ShouldEqual, true)
So(v, ShouldEqual, "8.8.8.8") So(v.address, ShouldEqual, "8.8.8.8")
v, found = root.search(strings.Split("scholar.google.com", ".")) v, found = root.search(strings.Split("scholar.google.com", "."))
So(found, ShouldEqual, true) So(found, ShouldEqual, true)
So(v, ShouldEqual, "208.67.222.222") So(v.address, ShouldEqual, "208.67.222.222")
v, found = root.search(strings.Split("twitter.com", ".")) v, found = root.search(strings.Split("twitter.com", "."))
So(found, ShouldEqual, true) So(found, ShouldEqual, true)
So(v, ShouldEqual, "8.8.8.8") So(v.address, ShouldEqual, "8.8.8.8")
v, found = root.search(strings.Split("baidu.cn", ".")) v, found = root.search(strings.Split("baidu.cn", "."))
So(found, ShouldEqual, true) So(found, ShouldEqual, true)
So(v, ShouldEqual, "166.111.8.28") So(v.address, ShouldEqual, "166.111.8.28")
}) })
} }

View File

@ -9,23 +9,23 @@ import (
func TestHostDomainAndIP(t *testing.T) { func TestHostDomainAndIP(t *testing.T) {
Convey("Test Host File Domain and IP regex", t, func() { Convey("Test Host File Domain and IP regex", t, func() {
Convey("1.1.1.1 should be IP and not domain", func() { Convey("1.1.1.1 should be IP and not domain", func() {
So(isIP("1.1.1.1"), ShouldEqual, true) So(IsIP("1.1.1.1"), ShouldEqual, true)
So(isDomain("1.1.1.1"), ShouldEqual, false) So(IsDomain("1.1.1.1"), ShouldEqual, false)
}) })
Convey("2001:470:20::2 should be IP and not domain", func() { Convey("2001:470:20::2 should be IP and not domain", func() {
So(isIP("2001:470:20::2"), ShouldEqual, true) So(IsIP("2001:470:20::2"), ShouldEqual, true)
So(isDomain("2001:470:20::2"), ShouldEqual, false) So(IsDomain("2001:470:20::2"), ShouldEqual, false)
}) })
Convey("`host` should not be domain and not IP", func() { Convey("`host` should not be domain and not IP", func() {
So(isDomain("host"), ShouldEqual, false) So(IsDomain("host"), ShouldEqual, false)
So(isIP("host"), ShouldEqual, false) So(IsIP("host"), ShouldEqual, false)
}) })
Convey("`123.test` should be domain and not IP", func() { Convey("`123.test` should be domain and not IP", func() {
So(isDomain("123.test"), ShouldEqual, true) So(IsDomain("123.test"), ShouldEqual, true)
So(isIP("123.test"), ShouldEqual, false) So(IsIP("123.test"), ShouldEqual, false)
}) })
}) })