diff --git a/godns.conf b/godns.conf index 7ce4b6e..5234a62 100644 --- a/godns.conf +++ b/godns.conf @@ -14,6 +14,7 @@ port = 53 [resolv] resolv-file = "/etc/resolv.conf" timeout = 5 # 5 seconds +domain-server-file = "/etc/godns.d/servers" [redis] host = "127.0.0.1" diff --git a/handler.go b/handler.go index 91467c3..ad6a268 100644 --- a/handler.go +++ b/handler.go @@ -33,21 +33,12 @@ type GODNSHandler struct { func NewHandler() *GODNSHandler { var ( - clientConfig *dns.ClientConfig - cacheConfig CacheSettings - resolver *Resolver - cache Cache + cacheConfig CacheSettings + resolver *Resolver + cache Cache ) - resolvConfig := settings.ResolvConfig - clientConfig, err := dns.ClientConfigFromFile(resolvConfig.ResolvFile) - if err != nil { - logger.Printf(":%s is not a valid resolv.conf file\n", resolvConfig.ResolvFile) - logger.Println(err) - panic(err) - } - clientConfig.Timeout = resolvConfig.Timeout - resolver = &Resolver{clientConfig} + resolver = NewResolver(settings.ResolvConfig) cacheConfig = settings.Cache switch cacheConfig.Backend { diff --git a/hosts.go b/hosts.go index 17d6445..7390e76 100644 --- a/hosts.go +++ b/hosts.go @@ -116,7 +116,7 @@ func (f *FileHosts) GetAll() map[string]string { domain := sli[len(sli)-1] ip := sli[0] - if !f.isDomain(domain) || !f.isIP(ip) { + if !isDomain(domain) || !isIP(ip) { continue } @@ -125,14 +125,14 @@ func (f *FileHosts) GetAll() map[string]string { return hosts } -func (f *FileHosts) isDomain(domain string) bool { - if f.isIP(domain) { +func isDomain(domain string) bool { + if isIP(domain) { return false } match, _ := regexp.MatchString("^[a-zA-Z0-9][a-zA-Z0-9-]", domain) return match } -func (f *FileHosts) isIP(ip string) bool { +func isIP(ip string) bool { return (net.ParseIP(ip) != nil) } diff --git a/hosts_test.go b/hosts_test.go index af70747..4499e8a 100644 --- a/hosts_test.go +++ b/hosts_test.go @@ -8,26 +8,24 @@ import ( func TestHostDomainAndIP(t *testing.T) { Convey("Test Host File Domain and IP regex", t, func() { - f := &FileHosts{} - Convey("1.1.1.1 should be IP and not domain", func() { - So(f.isIP("1.1.1.1"), ShouldEqual, true) - So(f.isDomain("1.1.1.1"), ShouldEqual, false) + So(isIP("1.1.1.1"), ShouldEqual, true) + So(isDomain("1.1.1.1"), ShouldEqual, false) }) Convey("2001:470:20::2 should be IP and not domain", func() { - So(f.isIP("2001:470:20::2"), ShouldEqual, true) - So(f.isDomain("2001:470:20::2"), ShouldEqual, false) + So(isIP("2001:470:20::2"), ShouldEqual, true) + So(isDomain("2001:470:20::2"), ShouldEqual, false) }) Convey("`host` should be domain and not IP", func() { - So(f.isDomain("host"), ShouldEqual, true) - So(f.isIP("host"), ShouldEqual, false) + So(isDomain("host"), ShouldEqual, true) + So(isIP("host"), ShouldEqual, false) }) Convey("`123.test` should be domain and not IP", func() { - So(f.isDomain("123.test"), ShouldEqual, true) - So(f.isIP("123.test"), ShouldEqual, false) + So(isDomain("123.test"), ShouldEqual, true) + So(isIP("123.test"), ShouldEqual, false) }) }) diff --git a/resolver.go b/resolver.go index 11756c1..438ed4c 100644 --- a/resolver.go +++ b/resolver.go @@ -1,10 +1,13 @@ package main import ( + "bufio" "fmt" - "github.com/miekg/dns" + "os" "strings" "time" + + "github.com/miekg/dns" ) type ResolvError struct { @@ -18,7 +21,62 @@ func (e ResolvError) Error() string { } type Resolver struct { - config *dns.ClientConfig + config *dns.ClientConfig + domain_server *suffixTreeNode +} + +func NewResolver(c ResolvSettings) *Resolver { + var clientConfig *dns.ClientConfig + clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) + if err != nil { + logger.Printf(":%s is not a valid resolv.conf file\n", c.ResolvFile) + logger.Println(err) + panic(err) + } + clientConfig.Timeout = c.Timeout + + domain_server := newSuffixTreeRoot() + r := &Resolver{clientConfig, domain_server} + + if len(c.DomainServerFile) > 0 { + r.ReadDomainServerFile(c.DomainServerFile) + } + return r +} + +func (r *Resolver) ReadDomainServerFile(file string) { + buf, err := os.Open(file) + if err != nil { + panic("Can't open " + file) + } + scanner := bufio.NewScanner(buf) + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + + if !strings.HasPrefix(line, "server") { + continue + } + + sli := strings.Split(line, "=") + if len(sli) != 2 { + continue + } + + line = strings.TrimSpace(sli[1]) + + tokens := strings.Split(line, "/") + if len(tokens) != 3 { + continue + } + domain := tokens[1] + ip := tokens[2] + if !isDomain(domain) || !isIP(ip) { + continue + } + r.domain_server.sinsert(strings.Split(domain, "."), ip) + } + } func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error) { @@ -29,8 +87,8 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error } qname := req.Question[0].Name - - for _, nameserver := range r.Nameservers() { + nameservers := r.Nameservers(qname) + for _, nameserver := range nameservers { r, rtt, err := c.Exchange(req, nameserver) if err != nil { Debug("%s socket error on %s", qname, nameserver) @@ -41,19 +99,30 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error Debug("%s failed to get an valid answer on %s", qname, nameserver) continue } - Debug("%s resolv on %s ttl: %d", UnFqdn(qname), nameserver, rtt) + Debug("%s resolv on %s rtt: %v", UnFqdn(qname), nameserver, rtt) return r, nil } - return nil, ResolvError{qname, r.Nameservers()} - + return nil, ResolvError{qname, nameservers} } -func (r *Resolver) Nameservers() (ns []string) { +func (r *Resolver) Nameservers(qname string) []string { + + queryKeys := strings.Split(qname, ".") + queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.' + + ns := []string{} + if v, found := r.domain_server.search(queryKeys); found { + Debug("found upstream: %v", v) + server := v + nameserver := server + ":53" + ns = append(ns, nameserver) + } + for _, server := range r.config.Servers { nameserver := server + ":" + r.config.Port ns = append(ns, nameserver) } - return + return ns } func (r *Resolver) Timeout() time.Duration { diff --git a/settings.go b/settings.go index c729b8d..0660532 100644 --- a/settings.go +++ b/settings.go @@ -3,9 +3,10 @@ package main import ( "flag" "fmt" - "github.com/BurntSushi/toml" "os" "strconv" + + "github.com/BurntSushi/toml" ) var ( @@ -24,8 +25,9 @@ type Settings struct { } type ResolvSettings struct { - ResolvFile string `toml:"resolv-file"` - Timeout int + ResolvFile string `toml:"resolv-file"` + DomainServerFile string `toml:"domain-server-file"` + Timeout int } type DNSServerSettings struct { diff --git a/sfx_tree.go b/sfx_tree.go new file mode 100644 index 0000000..bfaca0e --- /dev/null +++ b/sfx_tree.go @@ -0,0 +1,65 @@ +package main + +type suffixTreeNode struct { + key string + value string + children map[string]*suffixTreeNode +} + +func newSuffixTreeRoot() *suffixTreeNode { + return newSuffixTree("", "") +} + +func newSuffixTree(key string, value string) *suffixTreeNode { + root := &suffixTreeNode{ + key: key, + value: value, + children: map[string]*suffixTreeNode{}, + } + return root +} + +func (node *suffixTreeNode) ensureSubTree(key string) { + if _, ok := node.children[key]; !ok { + node.children[key] = newSuffixTree(key, "") + } +} + +func (node *suffixTreeNode) insert(key string, value string) { + if c, ok := node.children[key]; ok { + c.value = value + } else { + node.children[key] = newSuffixTree(key, value) + } +} + +func (node *suffixTreeNode) sinsert(keys []string, value string) { + if len(keys) == 0 { + return + } + + key := keys[len(keys)-1] + if len(keys) > 1 { + node.ensureSubTree(key) + node.children[key].sinsert(keys[:len(keys)-1], value) + return + } + + node.insert(key, value) +} + +func (node *suffixTreeNode) search(keys []string) (string, bool) { + if len(keys) == 0 { + return "", false + } + + key := keys[len(keys)-1] + if n, ok := node.children[key]; ok { + if nextValue, found := n.search(keys[:len(keys)-1]); found { + return nextValue, found + } + return n.value, (n.value != "") + } + + return "", false +} diff --git a/sfx_tree_test.go b/sfx_tree_test.go new file mode 100644 index 0000000..47471dd --- /dev/null +++ b/sfx_tree_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "strings" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func Test_Suffix_Tree(t *testing.T) { + root := newSuffixTreeRoot() + + Convey("Google should not be found", t, func() { + root.insert("cn", "114.114.114.114") + root.sinsert([]string{"baidu", "cn"}, "166.111.8.28") + root.sinsert([]string{"sina", "cn"}, "114.114.114.114") + + v, found := root.search(strings.Split("google.com", ".")) + So(found, ShouldEqual, false) + + v, found = root.search(strings.Split("baidu.cn", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "166.111.8.28") + }) + + Convey("Google should be found", t, func() { + root.sinsert(strings.Split("com", "."), "") + root.sinsert(strings.Split("google.com", "."), "8.8.8.8") + root.sinsert(strings.Split("twitter.com", "."), "8.8.8.8") + root.sinsert(strings.Split("scholar.google.com", "."), "208.67.222.222") + + v, found := root.search(strings.Split("google.com", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "8.8.8.8") + + v, found = root.search(strings.Split("scholar.google.com", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "208.67.222.222") + + v, found = root.search(strings.Split("twitter.com", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "8.8.8.8") + + v, found = root.search(strings.Split("baidu.cn", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "166.111.8.28") + }) + +}