diff --git a/etc/china.conf b/etc/china.conf new file mode 100644 index 0000000..e886276 --- /dev/null +++ b/etc/china.conf @@ -0,0 +1,5 @@ +server=8.8.8.8#53 +server=127.0.0.1#5553 + +server=/baidu.com/114.114.114.114 +# refer https://github.com/felixonmars/dnsmasq-china-list diff --git a/godns.conf b/etc/godns.conf similarity index 88% rename from godns.conf rename to etc/godns.conf index 9a5e1d4..50429f2 100644 --- a/godns.conf +++ b/etc/godns.conf @@ -12,6 +12,8 @@ host = "127.0.0.1" port = 53 [resolv] +#Domain-specific nameservers configuration, formatting keep compatible with Dnsmasq +server-list-file = "./etc/china.conf" resolv-file = "/etc/resolv.conf" timeout = 5 # 5 seconds # The concurrency interval request upstream recursive server @@ -21,6 +23,7 @@ interval = 200 # 200 milliseconds setedns0 = false #Support for larger UDP DNS responses [redis] +enable = true host = "127.0.0.1" port = 6379 db = 0 diff --git a/handler.go b/handler.go index b8a26f4..1de7ee1 100644 --- a/handler.go +++ b/handler.go @@ -32,21 +32,12 @@ type GODNSHandler struct { func NewHandler() *GODNSHandler { var ( - clientConfig *dns.ClientConfig cacheConfig CacheSettings resolver *Resolver cache, negCache Cache ) - resolvConfig := settings.ResolvConfig - clientConfig, err := dns.ClientConfigFromFile(resolvConfig.ResolvFile) - if err != nil { - logger.Warn(":%s is not a valid resolv.conf file\n", resolvConfig.ResolvFile) - logger.Error(err.Error()) - panic(err) - } - clientConfig.Timeout = resolvConfig.Timeout - resolver = &Resolver{clientConfig} + resolver = NewResolver(settings.ResolvConfig) cacheConfig = settings.Cache switch cacheConfig.Backend { diff --git a/hosts.go b/hosts.go index d094f8e..360efe8 100644 --- a/hosts.go +++ b/hosts.go @@ -218,7 +218,7 @@ func (f *FileHosts) Refresh() { } ip := sli[0] - if !f.isIP(ip) { + if !isIP(ip) { continue } @@ -241,14 +241,14 @@ func (f *FileHosts) clear() { f.hosts = make(map[string]string) } -func (f *FileHosts) isDomain(domain string) bool { - if f.isIP(domain) { +func isDomain(domain string) bool { + if isIP(domain) { return false } match, _ := regexp.MatchString(`^([a-zA-Z0-9\*]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,6}$`, domain) return match } -func (f *FileHosts) isIP(ip string) bool { +func isIP(ip string) bool { return (net.ParseIP(ip) != nil) } diff --git a/hosts_test.go b/hosts_test.go index 2a795a8..715b333 100644 --- a/hosts_test.go +++ b/hosts_test.go @@ -8,26 +8,24 @@ import ( func TestHostDomainAndIP(t *testing.T) { Convey("Test Host File Domain and IP regex", t, func() { - f := &FileHosts{} - Convey("1.1.1.1 should be IP and not domain", func() { - So(f.isIP("1.1.1.1"), ShouldEqual, true) - So(f.isDomain("1.1.1.1"), ShouldEqual, false) + So(isIP("1.1.1.1"), ShouldEqual, true) + So(isDomain("1.1.1.1"), ShouldEqual, false) }) Convey("2001:470:20::2 should be IP and not domain", func() { - So(f.isIP("2001:470:20::2"), ShouldEqual, true) - So(f.isDomain("2001:470:20::2"), ShouldEqual, false) + So(isIP("2001:470:20::2"), ShouldEqual, true) + So(isDomain("2001:470:20::2"), ShouldEqual, false) }) Convey("`host` should not be domain and not IP", func() { - So(f.isDomain("host"), ShouldEqual, false) - So(f.isIP("host"), ShouldEqual, false) + So(isDomain("host"), ShouldEqual, false) + So(isIP("host"), ShouldEqual, false) }) Convey("`123.test` should be domain and not IP", func() { - So(f.isDomain("123.test"), ShouldEqual, true) - So(f.isIP("123.test"), ShouldEqual, false) + So(isDomain("123.test"), ShouldEqual, true) + So(isIP("123.test"), ShouldEqual, false) }) }) diff --git a/resolver.go b/resolver.go index e20b9d8..0789453 100644 --- a/resolver.go +++ b/resolver.go @@ -1,8 +1,10 @@ package main import ( + "bufio" "fmt" - "net" + "os" + "strconv" "strings" "sync" "time" @@ -21,7 +23,92 @@ func (e ResolvError) Error() string { } type Resolver struct { - config *dns.ClientConfig + servers []string + domain_server *suffixTreeNode + config *ResolvSettings +} + +func NewResolver(c ResolvSettings) *Resolver { + r := &Resolver{ + servers: []string{}, + domain_server: newSuffixTreeRoot(), + config: &c, + } + + if len(c.ServerListFile) > 0 { + r.ReadServerListFile(c.ServerListFile) + // Debug("%v", r.servers) + } + + if len(c.ResolvFile) > 0 { + clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) + if err != nil { + logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) + logger.Error("%s", err) + panic(err) + } + for _, server := range clientConfig.Servers { + nameserver := server + ":" + clientConfig.Port + r.servers = append(r.servers, nameserver) + } + } + + return r +} + +func (r *Resolver) ReadServerListFile(file string) { + buf, err := os.Open(file) + if err != nil { + panic("Can't open " + file) + } + scanner := bufio.NewScanner(buf) + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + + if !strings.HasPrefix(line, "server") { + continue + } + + sli := strings.Split(line, "=") + if len(sli) != 2 { + continue + } + + line = strings.TrimSpace(sli[1]) + + tokens := strings.Split(line, "/") + switch len(tokens) { + case 3: + domain := tokens[1] + ip := tokens[2] + if !isDomain(domain) || !isIP(ip) { + continue + } + r.domain_server.sinsert(strings.Split(domain, "."), ip) + case 1: + srv_port := strings.Split(line, "#") + if len(srv_port) > 2 { + continue + } + + ip := "" + if ip = srv_port[0]; !isIP(ip) { + continue + } + + port := "53" + if len(srv_port) == 2 { + if _, err := strconv.Atoi(srv_port[1]); err != nil { + continue + } + port = srv_port[1] + } + r.servers = append(r.servers, ip+":"+port) + + } + } + } // Lookup will ask each nameserver in top-to-bottom fashion, starting a new request @@ -60,7 +147,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error return } } else { - logger.Debug("%s resolv on %s (%s) ttl: %d", UnFqdn(qname), nameserver, net, rtt) + logger.Debug("%s resolv on %s (%s) ttl: %v", UnFqdn(qname), nameserver, net, rtt) } select { case res <- r: @@ -71,12 +158,14 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error ticker := time.NewTicker(time.Duration(settings.ResolvConfig.Interval) * time.Millisecond) defer ticker.Stop() // Start lookup on each nameserver top-down, in every second - for _, nameserver := range r.Nameservers() { + nameservers := r.Nameservers(qname) + for _, nameserver := range nameservers { wg.Add(1) go L(nameserver) // but exit early, if we have an answer select { case r := <-res: + // logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), nameserver, rtt) return r, nil case <-ticker.C: continue @@ -86,25 +175,32 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error wg.Wait() select { case r := <-res: + // logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), nameserver, rtt) return r, nil default: - return nil, ResolvError{qname, net, r.Nameservers()} + return nil, ResolvError{qname, net, nameservers} } - } // Namservers return the array of nameservers, with port number appended. // '#' in the name is treated as port separator, as with dnsmasq. -func (r *Resolver) Nameservers() (ns []string) { - for _, server := range r.config.Servers { - if i := strings.IndexByte(server, '#'); i > 0 { - server = net.JoinHostPort(server[:i], server[i+1:]) - } else { - server = net.JoinHostPort(server, r.config.Port) - } - ns = append(ns, server) + +func (r *Resolver) Nameservers(qname string) []string { + queryKeys := strings.Split(qname, ".") + queryKeys = queryKeys[:len(queryKeys)-1] // ignore last '.' + + ns := []string{} + if v, found := r.domain_server.search(queryKeys); found { + logger.Debug("found upstream: %v", v) + server := v + nameserver := server + ":53" + ns = append(ns, nameserver) } - return + + for _, nameserver := range r.servers { + ns = append(ns, nameserver) + } + return ns } func (r *Resolver) Timeout() time.Duration { diff --git a/settings.go b/settings.go index 2f8d22c..f2310a6 100644 --- a/settings.go +++ b/settings.go @@ -34,10 +34,11 @@ type Settings struct { } type ResolvSettings struct { - ResolvFile string `toml:"resolv-file"` - Timeout int - Interval int - SetEDNS0 bool + Timeout int + Interval int + SetEDNS0 bool + ServerListFile string `toml:"server-list-file"` + ResolvFile string `toml:"resolv-file"` } type DNSServerSettings struct { @@ -93,7 +94,7 @@ func init() { var configFile string - flag.StringVar(&configFile, "c", "godns.conf", "Look for godns toml-formatting config file in this directory") + flag.StringVar(&configFile, "c", "./etc/godns.conf", "Look for godns toml-formatting config file in this directory") flag.Parse() if _, err := toml.DecodeFile(configFile, &settings); err != nil { diff --git a/sfx_tree.go b/sfx_tree.go new file mode 100644 index 0000000..bfaca0e --- /dev/null +++ b/sfx_tree.go @@ -0,0 +1,65 @@ +package main + +type suffixTreeNode struct { + key string + value string + children map[string]*suffixTreeNode +} + +func newSuffixTreeRoot() *suffixTreeNode { + return newSuffixTree("", "") +} + +func newSuffixTree(key string, value string) *suffixTreeNode { + root := &suffixTreeNode{ + key: key, + value: value, + children: map[string]*suffixTreeNode{}, + } + return root +} + +func (node *suffixTreeNode) ensureSubTree(key string) { + if _, ok := node.children[key]; !ok { + node.children[key] = newSuffixTree(key, "") + } +} + +func (node *suffixTreeNode) insert(key string, value string) { + if c, ok := node.children[key]; ok { + c.value = value + } else { + node.children[key] = newSuffixTree(key, value) + } +} + +func (node *suffixTreeNode) sinsert(keys []string, value string) { + if len(keys) == 0 { + return + } + + key := keys[len(keys)-1] + if len(keys) > 1 { + node.ensureSubTree(key) + node.children[key].sinsert(keys[:len(keys)-1], value) + return + } + + node.insert(key, value) +} + +func (node *suffixTreeNode) search(keys []string) (string, bool) { + if len(keys) == 0 { + return "", false + } + + key := keys[len(keys)-1] + if n, ok := node.children[key]; ok { + if nextValue, found := n.search(keys[:len(keys)-1]); found { + return nextValue, found + } + return n.value, (n.value != "") + } + + return "", false +} diff --git a/sfx_tree_test.go b/sfx_tree_test.go new file mode 100644 index 0000000..47471dd --- /dev/null +++ b/sfx_tree_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "strings" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func Test_Suffix_Tree(t *testing.T) { + root := newSuffixTreeRoot() + + Convey("Google should not be found", t, func() { + root.insert("cn", "114.114.114.114") + root.sinsert([]string{"baidu", "cn"}, "166.111.8.28") + root.sinsert([]string{"sina", "cn"}, "114.114.114.114") + + v, found := root.search(strings.Split("google.com", ".")) + So(found, ShouldEqual, false) + + v, found = root.search(strings.Split("baidu.cn", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "166.111.8.28") + }) + + Convey("Google should be found", t, func() { + root.sinsert(strings.Split("com", "."), "") + root.sinsert(strings.Split("google.com", "."), "8.8.8.8") + root.sinsert(strings.Split("twitter.com", "."), "8.8.8.8") + root.sinsert(strings.Split("scholar.google.com", "."), "208.67.222.222") + + v, found := root.search(strings.Split("google.com", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "8.8.8.8") + + v, found = root.search(strings.Split("scholar.google.com", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "208.67.222.222") + + v, found = root.search(strings.Split("twitter.com", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "8.8.8.8") + + v, found = root.search(strings.Split("baidu.cn", ".")) + So(found, ShouldEqual, true) + So(v, ShouldEqual, "166.111.8.28") + }) + +}