This commit is contained in:
Tyler 2018-06-30 23:08:29 -04:00
parent 489adb58ef
commit a6f6c4e96d
3 changed files with 145 additions and 42 deletions

136
hosts.go
View File

@ -10,49 +10,62 @@ import (
"github.com/hoisie/redis" "github.com/hoisie/redis"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
"github.com/fsnotify/fsnotify"
) )
type Hosts struct { type Hosts struct {
fileHosts *FileHosts providers []HostProvider
redisHosts *RedisHosts
refreshInterval time.Duration refreshInterval time.Duration
} }
type HostProvider interface {
Get(domain string) ([]string, bool)
Refresh()
}
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
fileHosts := &FileHosts{ providers := []HostProvider{
file: hs.HostsFile, NewFileProvider(hs.HostsFile),
hosts: make(map[string]string),
} }
var redisHosts *RedisHosts
if hs.RedisEnable { if hs.RedisEnable {
rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password} rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password}
redisHosts = &RedisHosts{
redis: rc, providers = append(providers, NewRedisProvider(rc, hs.RedisKey))
key: hs.RedisKey,
hosts: make(map[string]string),
}
} }
hosts := Hosts{fileHosts, redisHosts, time.Second * time.Duration(hs.RefreshInterval)} return Hosts{providers, time.Second * time.Duration(hs.RefreshInterval)}
hosts.refresh() }
return hosts
func (h *Hosts) refresh() {
ticker := time.NewTicker(h.refreshInterval)
go func() {
for {
// Force a refresh every refreshInterval
for _, provider := range h.providers {
provider.Refresh()
}
<-ticker.C
}
}()
} }
/* /*
Match local /etc/hosts file first, remote redis records second Match local /etc/hosts file first, remote redis records second
*/ */
func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) { func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
var sips []string var sips []string
var ok bool
var ip net.IP var ip net.IP
var ips []net.IP var ips []net.IP
sips, ok := h.fileHosts.Get(domain) for _, provider := range h.providers {
if !ok { sips, ok = provider.Get(domain)
if h.redisHosts != nil {
sips, ok = h.redisHosts.Get(domain) if ok {
break
} }
} }
@ -74,32 +87,55 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
} }
} }
return ips, (ips != nil) return ips, ips != nil
}
/*
Update hosts records from /etc/hosts file and redis per minute
*/
func (h *Hosts) refresh() {
ticker := time.NewTicker(h.refreshInterval)
go func() {
for {
h.fileHosts.Refresh()
if h.redisHosts != nil {
h.redisHosts.Refresh()
}
<-ticker.C
}
}()
} }
type RedisHosts struct { type RedisHosts struct {
HostProvider
redis *redis.Client redis *redis.Client
key string key string
hosts map[string]string hosts map[string]string
mu sync.RWMutex mu sync.RWMutex
} }
func NewRedisProvider(rc *redis.Client, key string) HostProvider {
rh := &RedisHosts{
redis: rc,
key: key,
hosts: make(map[string]string),
}
// Use pubsub to listen for key update events
go func() {
sub := make(chan string, 2)
sub <- "godns:update"
sub <- "godns:update_record"
messages := make(chan redis.Message, 0)
go rc.Subscribe(sub, nil, nil, nil, messages)
for {
msg := <- messages
if msg.Channel == "godns:update" {
rh.Refresh()
} else if msg.Channel == "godns:update_record" {
recordName := string(msg.Message)
b, err := rc.Hget(key, recordName)
if err != nil {
continue
}
rh.hosts[recordName] = string(b)
}
}
}()
return rh
}
func (r *RedisHosts) Get(domain string) ([]string, bool) { func (r *RedisHosts) Get(domain string) ([]string, bool) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
@ -152,11 +188,39 @@ func (r *RedisHosts) clear() {
} }
type FileHosts struct { type FileHosts struct {
HostProvider
file string file string
hosts map[string]string hosts map[string]string
mu sync.RWMutex mu sync.RWMutex
} }
func NewFileProvider(file string) HostProvider {
fp := &FileHosts{
file: file,
hosts: make(map[string]string),
}
watcher, err := fsnotify.NewWatcher()
// Use fsnotify to notify us of host file changes
if err == nil {
watcher.Add(file)
go func() {
for {
e := <- watcher.Events
if e.Op == fsnotify.Write {
fp.Refresh()
}
}
}()
}
return fp
}
func (f *FileHosts) Get(domain string) ([]string, bool) { func (f *FileHosts) Get(domain string) ([]string, bool) {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()

View File

@ -48,36 +48,51 @@ func NewResolver(c ResolvSettings) *Resolver {
if len(c.ResolvFile) > 0 { if len(c.ResolvFile) > 0 {
clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
if err != nil { if err != nil {
logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
logger.Error("%s", err) logger.Error("%s", err)
panic(err) panic(err)
} }
for _, server := range clientConfig.Servers { for _, server := range clientConfig.Servers {
nameserver := net.JoinHostPort(server, clientConfig.Port) r.servers = append(r.servers, net.JoinHostPort(server, clientConfig.Port))
r.servers = append(r.servers, nameserver)
} }
} }
if len(c.DOHServer) > 0 {
r.servers = append([]string{c.DOHServer}, r.servers...)
}
return r return r
} }
func (r *Resolver) parseServerListFile(buf *os.File) { func (r *Resolver) parseServerListFile(buf *os.File) {
scanner := bufio.NewScanner(buf) scanner := bufio.NewScanner(buf)
var line string
var idx int
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line = strings.TrimSpace(scanner.Text())
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "server") { if !strings.HasPrefix(line, "server") {
continue continue
} }
idx = strings.Index(line, "=")
if idx == -1 {
continue
}
sli := strings.Split(line, "=") sli := strings.Split(line, "=")
if len(sli) != 2 { if len(sli) != 2 {
continue continue
} }
line = strings.TrimSpace(sli[1]) line = strings.TrimSpace(line[idx:])
tokens := strings.Split(line, "/") tokens := strings.Split(line, "/")
switch len(tokens) { switch len(tokens) {
@ -88,25 +103,31 @@ func (r *Resolver) parseServerListFile(buf *os.File) {
if !isDomain(domain) || !isIP(ip) { if !isDomain(domain) || !isIP(ip) {
continue continue
} }
r.domain_server.sinsert(strings.Split(domain, "."), ip) r.domain_server.sinsert(strings.Split(domain, "."), ip)
case 1: case 1:
srv_port := strings.Split(line, "#") srv_port := strings.Split(line, "#")
if len(srv_port) > 2 { if len(srv_port) > 2 {
continue continue
} }
ip := "" ip := ""
if ip = srv_port[0]; !isIP(ip) { if ip = srv_port[0]; !isIP(ip) {
continue continue
} }
port := "53" port := "53"
if len(srv_port) == 2 { if len(srv_port) == 2 {
if _, err := strconv.Atoi(srv_port[1]); err != nil { if _, err := strconv.Atoi(srv_port[1]); err != nil {
continue continue
} }
port = srv_port[1] port = srv_port[1]
} }
r.servers = append(r.servers, net.JoinHostPort(ip, port)) r.servers = append(r.servers, net.JoinHostPort(ip, port))
} }
} }
@ -135,6 +156,12 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
WriteTimeout: r.Timeout(), WriteTimeout: r.Timeout(),
} }
httpC := &dns.Client{
Net: "https",
ReadTimeout: r.Timeout(),
WriteTimeout: r.Timeout(),
}
if net == "udp" && settings.ResolvConfig.SetEDNS0 { if net == "udp" && settings.ResolvConfig.SetEDNS0 {
req = req.SetEdns0(65535, true) req = req.SetEdns0(65535, true)
} }
@ -145,7 +172,17 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
var wg sync.WaitGroup var wg sync.WaitGroup
L := func(nameserver string) { L := func(nameserver string) {
defer wg.Done() defer wg.Done()
r, rtt, err := c.Exchange(req, nameserver)
var r *dns.Msg
var rtt time.Duration
var err error
if strings.HasPrefix(nameserver, "https") {
r, rtt, err = httpC.Exchange(req, nameserver)
} else {
r, rtt, err = c.Exchange(req, nameserver)
}
if err != nil { if err != nil {
logger.Warn("%s socket error on %s", qname, nameserver) logger.Warn("%s socket error on %s", qname, nameserver)
logger.Warn("error:%s", err.Error()) logger.Warn("error:%s", err.Error())
@ -215,6 +252,7 @@ func (r *Resolver) Nameservers(qname string) []string {
for _, nameserver := range r.servers { for _, nameserver := range r.servers {
ns = append(ns, nameserver) ns = append(ns, nameserver)
} }
return ns return ns
} }

View File

@ -39,6 +39,7 @@ type ResolvSettings struct {
SetEDNS0 bool SetEDNS0 bool
ServerListFile string `toml:"server-list-file"` ServerListFile string `toml:"server-list-file"`
ResolvFile string `toml:"resolv-file"` ResolvFile string `toml:"resolv-file"`
DOHServer string `toml:"dns-over-https"`
} }
type DNSServerSettings struct { type DNSServerSettings struct {