This commit is contained in:
Tyler 2018-06-30 23:08:29 -04:00
parent 489adb58ef
commit a6f6c4e96d
3 changed files with 145 additions and 42 deletions

138
hosts.go
View File

@ -10,49 +10,62 @@ import (
"github.com/hoisie/redis"
"golang.org/x/net/publicsuffix"
"github.com/fsnotify/fsnotify"
)
type Hosts struct {
fileHosts *FileHosts
redisHosts *RedisHosts
providers []HostProvider
refreshInterval time.Duration
}
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
fileHosts := &FileHosts{
file: hs.HostsFile,
hosts: make(map[string]string),
type HostProvider interface {
Get(domain string) ([]string, bool)
Refresh()
}
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
providers := []HostProvider{
NewFileProvider(hs.HostsFile),
}
var redisHosts *RedisHosts
if hs.RedisEnable {
rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password}
redisHosts = &RedisHosts{
redis: rc,
key: hs.RedisKey,
hosts: make(map[string]string),
}
providers = append(providers, NewRedisProvider(rc, hs.RedisKey))
}
hosts := Hosts{fileHosts, redisHosts, time.Second * time.Duration(hs.RefreshInterval)}
hosts.refresh()
return hosts
return Hosts{providers, time.Second * time.Duration(hs.RefreshInterval)}
}
func (h *Hosts) refresh() {
ticker := time.NewTicker(h.refreshInterval)
go func() {
for {
// Force a refresh every refreshInterval
for _, provider := range h.providers {
provider.Refresh()
}
<-ticker.C
}
}()
}
/*
Match local /etc/hosts file first, remote redis records second
*/
func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
var sips []string
var ok bool
var ip net.IP
var ips []net.IP
sips, ok := h.fileHosts.Get(domain)
if !ok {
if h.redisHosts != nil {
sips, ok = h.redisHosts.Get(domain)
for _, provider := range h.providers {
sips, ok = provider.Get(domain)
if ok {
break
}
}
@ -74,32 +87,55 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
}
}
return ips, (ips != nil)
}
/*
Update hosts records from /etc/hosts file and redis per minute
*/
func (h *Hosts) refresh() {
ticker := time.NewTicker(h.refreshInterval)
go func() {
for {
h.fileHosts.Refresh()
if h.redisHosts != nil {
h.redisHosts.Refresh()
}
<-ticker.C
}
}()
return ips, ips != nil
}
type RedisHosts struct {
HostProvider
redis *redis.Client
key string
hosts map[string]string
mu sync.RWMutex
}
func NewRedisProvider(rc *redis.Client, key string) HostProvider {
rh := &RedisHosts{
redis: rc,
key: key,
hosts: make(map[string]string),
}
// Use pubsub to listen for key update events
go func() {
sub := make(chan string, 2)
sub <- "godns:update"
sub <- "godns:update_record"
messages := make(chan redis.Message, 0)
go rc.Subscribe(sub, nil, nil, nil, messages)
for {
msg := <- messages
if msg.Channel == "godns:update" {
rh.Refresh()
} else if msg.Channel == "godns:update_record" {
recordName := string(msg.Message)
b, err := rc.Hget(key, recordName)
if err != nil {
continue
}
rh.hosts[recordName] = string(b)
}
}
}()
return rh
}
func (r *RedisHosts) Get(domain string) ([]string, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
@ -152,11 +188,39 @@ func (r *RedisHosts) clear() {
}
type FileHosts struct {
HostProvider
file string
hosts map[string]string
mu sync.RWMutex
}
func NewFileProvider(file string) HostProvider {
fp := &FileHosts{
file: file,
hosts: make(map[string]string),
}
watcher, err := fsnotify.NewWatcher()
// Use fsnotify to notify us of host file changes
if err == nil {
watcher.Add(file)
go func() {
for {
e := <- watcher.Events
if e.Op == fsnotify.Write {
fp.Refresh()
}
}
}()
}
return fp
}
func (f *FileHosts) Get(domain string) ([]string, bool) {
f.mu.RLock()
defer f.mu.RUnlock()

View File

@ -48,36 +48,51 @@ func NewResolver(c ResolvSettings) *Resolver {
if len(c.ResolvFile) > 0 {
clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
if err != nil {
logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
logger.Error("%s", err)
panic(err)
}
for _, server := range clientConfig.Servers {
nameserver := net.JoinHostPort(server, clientConfig.Port)
r.servers = append(r.servers, nameserver)
r.servers = append(r.servers, net.JoinHostPort(server, clientConfig.Port))
}
}
if len(c.DOHServer) > 0 {
r.servers = append([]string{c.DOHServer}, r.servers...)
}
return r
}
func (r *Resolver) parseServerListFile(buf *os.File) {
scanner := bufio.NewScanner(buf)
var line string
var idx int
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
line = strings.TrimSpace(scanner.Text())
if !strings.HasPrefix(line, "server") {
continue
}
idx = strings.Index(line, "=")
if idx == -1 {
continue
}
sli := strings.Split(line, "=")
if len(sli) != 2 {
continue
}
line = strings.TrimSpace(sli[1])
line = strings.TrimSpace(line[idx:])
tokens := strings.Split(line, "/")
switch len(tokens) {
@ -88,25 +103,31 @@ func (r *Resolver) parseServerListFile(buf *os.File) {
if !isDomain(domain) || !isIP(ip) {
continue
}
r.domain_server.sinsert(strings.Split(domain, "."), ip)
case 1:
srv_port := strings.Split(line, "#")
if len(srv_port) > 2 {
continue
}
ip := ""
if ip = srv_port[0]; !isIP(ip) {
continue
}
port := "53"
if len(srv_port) == 2 {
if _, err := strconv.Atoi(srv_port[1]); err != nil {
continue
}
port = srv_port[1]
}
r.servers = append(r.servers, net.JoinHostPort(ip, port))
}
}
@ -135,6 +156,12 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
WriteTimeout: r.Timeout(),
}
httpC := &dns.Client{
Net: "https",
ReadTimeout: r.Timeout(),
WriteTimeout: r.Timeout(),
}
if net == "udp" && settings.ResolvConfig.SetEDNS0 {
req = req.SetEdns0(65535, true)
}
@ -145,7 +172,17 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
var wg sync.WaitGroup
L := func(nameserver string) {
defer wg.Done()
r, rtt, err := c.Exchange(req, nameserver)
var r *dns.Msg
var rtt time.Duration
var err error
if strings.HasPrefix(nameserver, "https") {
r, rtt, err = httpC.Exchange(req, nameserver)
} else {
r, rtt, err = c.Exchange(req, nameserver)
}
if err != nil {
logger.Warn("%s socket error on %s", qname, nameserver)
logger.Warn("error:%s", err.Error())
@ -215,6 +252,7 @@ func (r *Resolver) Nameservers(qname string) []string {
for _, nameserver := range r.servers {
ns = append(ns, nameserver)
}
return ns
}

View File

@ -39,6 +39,7 @@ type ResolvSettings struct {
SetEDNS0 bool
ServerListFile string `toml:"server-list-file"`
ResolvFile string `toml:"resolv-file"`
DOHServer string `toml:"dns-over-https"`
}
type DNSServerSettings struct {