From a6f6c4e96d5cfd4c2c62cf31056c5cecb3293dbb Mon Sep 17 00:00:00 2001 From: Tyler Date: Sat, 30 Jun 2018 23:08:29 -0400 Subject: [PATCH] Updates --- hosts.go | 136 ++++++++++++++++++++++++++++++++++++++-------------- resolver.go | 50 ++++++++++++++++--- settings.go | 1 + 3 files changed, 145 insertions(+), 42 deletions(-) diff --git a/hosts.go b/hosts.go index c234ef7..6645cbc 100644 --- a/hosts.go +++ b/hosts.go @@ -10,49 +10,62 @@ import ( "github.com/hoisie/redis" "golang.org/x/net/publicsuffix" + "github.com/fsnotify/fsnotify" ) type Hosts struct { - fileHosts *FileHosts - redisHosts *RedisHosts + providers []HostProvider refreshInterval time.Duration } +type HostProvider interface { + Get(domain string) ([]string, bool) + Refresh() +} + func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { - fileHosts := &FileHosts{ - file: hs.HostsFile, - hosts: make(map[string]string), + providers := []HostProvider{ + NewFileProvider(hs.HostsFile), } - var redisHosts *RedisHosts if hs.RedisEnable { rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password} - redisHosts = &RedisHosts{ - redis: rc, - key: hs.RedisKey, - hosts: make(map[string]string), - } + + providers = append(providers, NewRedisProvider(rc, hs.RedisKey)) } - hosts := Hosts{fileHosts, redisHosts, time.Second * time.Duration(hs.RefreshInterval)} - hosts.refresh() - return hosts + return Hosts{providers, time.Second * time.Duration(hs.RefreshInterval)} +} +func (h *Hosts) refresh() { + ticker := time.NewTicker(h.refreshInterval) + + go func() { + for { + // Force a refresh every refreshInterval + for _, provider := range h.providers { + provider.Refresh() + } + + <-ticker.C + } + }() } /* Match local /etc/hosts file first, remote redis records second */ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) { - var sips []string + var ok bool var ip net.IP var ips []net.IP - sips, ok := h.fileHosts.Get(domain) - if !ok { - if h.redisHosts != nil { - sips, ok = h.redisHosts.Get(domain) + for _, provider := range h.providers { + sips, ok = provider.Get(domain) + + if ok { + break } } @@ -74,32 +87,55 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) { } } - return ips, (ips != nil) -} - -/* -Update hosts records from /etc/hosts file and redis per minute -*/ -func (h *Hosts) refresh() { - ticker := time.NewTicker(h.refreshInterval) - go func() { - for { - h.fileHosts.Refresh() - if h.redisHosts != nil { - h.redisHosts.Refresh() - } - <-ticker.C - } - }() + return ips, ips != nil } type RedisHosts struct { + HostProvider + redis *redis.Client key string hosts map[string]string mu sync.RWMutex } +func NewRedisProvider(rc *redis.Client, key string) HostProvider { + rh := &RedisHosts{ + redis: rc, + key: key, + hosts: make(map[string]string), + } + + // Use pubsub to listen for key update events + go func() { + sub := make(chan string, 2) + sub <- "godns:update" + sub <- "godns:update_record" + messages := make(chan redis.Message, 0) + go rc.Subscribe(sub, nil, nil, nil, messages) + + for { + msg := <- messages + + if msg.Channel == "godns:update" { + rh.Refresh() + } else if msg.Channel == "godns:update_record" { + recordName := string(msg.Message) + + b, err := rc.Hget(key, recordName) + + if err != nil { + continue + } + + rh.hosts[recordName] = string(b) + } + } + }() + + return rh +} + func (r *RedisHosts) Get(domain string) ([]string, bool) { r.mu.RLock() defer r.mu.RUnlock() @@ -152,11 +188,39 @@ func (r *RedisHosts) clear() { } type FileHosts struct { + HostProvider + file string hosts map[string]string mu sync.RWMutex } +func NewFileProvider(file string) HostProvider { + fp := &FileHosts{ + file: file, + hosts: make(map[string]string), + } + + watcher, err := fsnotify.NewWatcher() + + // Use fsnotify to notify us of host file changes + if err == nil { + watcher.Add(file) + + go func() { + for { + e := <- watcher.Events + + if e.Op == fsnotify.Write { + fp.Refresh() + } + } + }() + } + + return fp +} + func (f *FileHosts) Get(domain string) ([]string, bool) { f.mu.RLock() defer f.mu.RUnlock() diff --git a/resolver.go b/resolver.go index 1afe8b8..6aa7627 100644 --- a/resolver.go +++ b/resolver.go @@ -48,36 +48,51 @@ func NewResolver(c ResolvSettings) *Resolver { if len(c.ResolvFile) > 0 { clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) + if err != nil { logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) logger.Error("%s", err) panic(err) } + for _, server := range clientConfig.Servers { - nameserver := net.JoinHostPort(server, clientConfig.Port) - r.servers = append(r.servers, nameserver) + r.servers = append(r.servers, net.JoinHostPort(server, clientConfig.Port)) } } + if len(c.DOHServer) > 0 { + r.servers = append([]string{c.DOHServer}, r.servers...) + } + return r } func (r *Resolver) parseServerListFile(buf *os.File) { scanner := bufio.NewScanner(buf) + + var line string + var idx int + for scanner.Scan() { - line := scanner.Text() - line = strings.TrimSpace(line) + line = strings.TrimSpace(scanner.Text()) if !strings.HasPrefix(line, "server") { continue } + idx = strings.Index(line, "=") + + if idx == -1 { + continue + } + sli := strings.Split(line, "=") + if len(sli) != 2 { continue } - line = strings.TrimSpace(sli[1]) + line = strings.TrimSpace(line[idx:]) tokens := strings.Split(line, "/") switch len(tokens) { @@ -88,25 +103,31 @@ func (r *Resolver) parseServerListFile(buf *os.File) { if !isDomain(domain) || !isIP(ip) { continue } + r.domain_server.sinsert(strings.Split(domain, "."), ip) case 1: srv_port := strings.Split(line, "#") + if len(srv_port) > 2 { continue } ip := "" + if ip = srv_port[0]; !isIP(ip) { continue } port := "53" + if len(srv_port) == 2 { if _, err := strconv.Atoi(srv_port[1]); err != nil { continue } + port = srv_port[1] } + r.servers = append(r.servers, net.JoinHostPort(ip, port)) } } @@ -135,6 +156,12 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error WriteTimeout: r.Timeout(), } + httpC := &dns.Client{ + Net: "https", + ReadTimeout: r.Timeout(), + WriteTimeout: r.Timeout(), + } + if net == "udp" && settings.ResolvConfig.SetEDNS0 { req = req.SetEdns0(65535, true) } @@ -145,7 +172,17 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error var wg sync.WaitGroup L := func(nameserver string) { defer wg.Done() - r, rtt, err := c.Exchange(req, nameserver) + + var r *dns.Msg + var rtt time.Duration + var err error + + if strings.HasPrefix(nameserver, "https") { + r, rtt, err = httpC.Exchange(req, nameserver) + } else { + r, rtt, err = c.Exchange(req, nameserver) + } + if err != nil { logger.Warn("%s socket error on %s", qname, nameserver) logger.Warn("error:%s", err.Error()) @@ -215,6 +252,7 @@ func (r *Resolver) Nameservers(qname string) []string { for _, nameserver := range r.servers { ns = append(ns, nameserver) } + return ns } diff --git a/settings.go b/settings.go index f2310a6..dc4a2e5 100644 --- a/settings.go +++ b/settings.go @@ -39,6 +39,7 @@ type ResolvSettings struct { SetEDNS0 bool ServerListFile string `toml:"server-list-file"` ResolvFile string `toml:"resolv-file"` + DOHServer string `toml:"dns-over-https"` } type DNSServerSettings struct {