Updates
This commit is contained in:
parent
489adb58ef
commit
a6f6c4e96d
138
hosts.go
138
hosts.go
|
@ -10,49 +10,62 @@ import (
|
|||
|
||||
"github.com/hoisie/redis"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
type Hosts struct {
|
||||
fileHosts *FileHosts
|
||||
redisHosts *RedisHosts
|
||||
providers []HostProvider
|
||||
refreshInterval time.Duration
|
||||
}
|
||||
|
||||
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
|
||||
fileHosts := &FileHosts{
|
||||
file: hs.HostsFile,
|
||||
hosts: make(map[string]string),
|
||||
type HostProvider interface {
|
||||
Get(domain string) ([]string, bool)
|
||||
Refresh()
|
||||
}
|
||||
|
||||
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
|
||||
providers := []HostProvider{
|
||||
NewFileProvider(hs.HostsFile),
|
||||
}
|
||||
|
||||
var redisHosts *RedisHosts
|
||||
if hs.RedisEnable {
|
||||
rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password}
|
||||
redisHosts = &RedisHosts{
|
||||
redis: rc,
|
||||
key: hs.RedisKey,
|
||||
hosts: make(map[string]string),
|
||||
}
|
||||
|
||||
providers = append(providers, NewRedisProvider(rc, hs.RedisKey))
|
||||
}
|
||||
|
||||
hosts := Hosts{fileHosts, redisHosts, time.Second * time.Duration(hs.RefreshInterval)}
|
||||
hosts.refresh()
|
||||
return hosts
|
||||
return Hosts{providers, time.Second * time.Duration(hs.RefreshInterval)}
|
||||
}
|
||||
|
||||
func (h *Hosts) refresh() {
|
||||
ticker := time.NewTicker(h.refreshInterval)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
// Force a refresh every refreshInterval
|
||||
for _, provider := range h.providers {
|
||||
provider.Refresh()
|
||||
}
|
||||
|
||||
<-ticker.C
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
/*
|
||||
Match local /etc/hosts file first, remote redis records second
|
||||
*/
|
||||
func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
|
||||
|
||||
var sips []string
|
||||
var ok bool
|
||||
var ip net.IP
|
||||
var ips []net.IP
|
||||
|
||||
sips, ok := h.fileHosts.Get(domain)
|
||||
if !ok {
|
||||
if h.redisHosts != nil {
|
||||
sips, ok = h.redisHosts.Get(domain)
|
||||
for _, provider := range h.providers {
|
||||
sips, ok = provider.Get(domain)
|
||||
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,32 +87,55 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
|
|||
}
|
||||
}
|
||||
|
||||
return ips, (ips != nil)
|
||||
}
|
||||
|
||||
/*
|
||||
Update hosts records from /etc/hosts file and redis per minute
|
||||
*/
|
||||
func (h *Hosts) refresh() {
|
||||
ticker := time.NewTicker(h.refreshInterval)
|
||||
go func() {
|
||||
for {
|
||||
h.fileHosts.Refresh()
|
||||
if h.redisHosts != nil {
|
||||
h.redisHosts.Refresh()
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}()
|
||||
return ips, ips != nil
|
||||
}
|
||||
|
||||
type RedisHosts struct {
|
||||
HostProvider
|
||||
|
||||
redis *redis.Client
|
||||
key string
|
||||
hosts map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRedisProvider(rc *redis.Client, key string) HostProvider {
|
||||
rh := &RedisHosts{
|
||||
redis: rc,
|
||||
key: key,
|
||||
hosts: make(map[string]string),
|
||||
}
|
||||
|
||||
// Use pubsub to listen for key update events
|
||||
go func() {
|
||||
sub := make(chan string, 2)
|
||||
sub <- "godns:update"
|
||||
sub <- "godns:update_record"
|
||||
messages := make(chan redis.Message, 0)
|
||||
go rc.Subscribe(sub, nil, nil, nil, messages)
|
||||
|
||||
for {
|
||||
msg := <- messages
|
||||
|
||||
if msg.Channel == "godns:update" {
|
||||
rh.Refresh()
|
||||
} else if msg.Channel == "godns:update_record" {
|
||||
recordName := string(msg.Message)
|
||||
|
||||
b, err := rc.Hget(key, recordName)
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rh.hosts[recordName] = string(b)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return rh
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Get(domain string) ([]string, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
@ -152,11 +188,39 @@ func (r *RedisHosts) clear() {
|
|||
}
|
||||
|
||||
type FileHosts struct {
|
||||
HostProvider
|
||||
|
||||
file string
|
||||
hosts map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewFileProvider(file string) HostProvider {
|
||||
fp := &FileHosts{
|
||||
file: file,
|
||||
hosts: make(map[string]string),
|
||||
}
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
|
||||
// Use fsnotify to notify us of host file changes
|
||||
if err == nil {
|
||||
watcher.Add(file)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
e := <- watcher.Events
|
||||
|
||||
if e.Op == fsnotify.Write {
|
||||
fp.Refresh()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return fp
|
||||
}
|
||||
|
||||
func (f *FileHosts) Get(domain string) ([]string, bool) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
|
50
resolver.go
50
resolver.go
|
@ -48,36 +48,51 @@ func NewResolver(c ResolvSettings) *Resolver {
|
|||
|
||||
if len(c.ResolvFile) > 0 {
|
||||
clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
|
||||
|
||||
if err != nil {
|
||||
logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
|
||||
logger.Error("%s", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, server := range clientConfig.Servers {
|
||||
nameserver := net.JoinHostPort(server, clientConfig.Port)
|
||||
r.servers = append(r.servers, nameserver)
|
||||
r.servers = append(r.servers, net.JoinHostPort(server, clientConfig.Port))
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.DOHServer) > 0 {
|
||||
r.servers = append([]string{c.DOHServer}, r.servers...)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Resolver) parseServerListFile(buf *os.File) {
|
||||
scanner := bufio.NewScanner(buf)
|
||||
|
||||
var line string
|
||||
var idx int
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimSpace(scanner.Text())
|
||||
|
||||
if !strings.HasPrefix(line, "server") {
|
||||
continue
|
||||
}
|
||||
|
||||
idx = strings.Index(line, "=")
|
||||
|
||||
if idx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
sli := strings.Split(line, "=")
|
||||
|
||||
if len(sli) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(sli[1])
|
||||
line = strings.TrimSpace(line[idx:])
|
||||
|
||||
tokens := strings.Split(line, "/")
|
||||
switch len(tokens) {
|
||||
|
@ -88,25 +103,31 @@ func (r *Resolver) parseServerListFile(buf *os.File) {
|
|||
if !isDomain(domain) || !isIP(ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
r.domain_server.sinsert(strings.Split(domain, "."), ip)
|
||||
case 1:
|
||||
srv_port := strings.Split(line, "#")
|
||||
|
||||
if len(srv_port) > 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := ""
|
||||
|
||||
if ip = srv_port[0]; !isIP(ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
port := "53"
|
||||
|
||||
if len(srv_port) == 2 {
|
||||
if _, err := strconv.Atoi(srv_port[1]); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
port = srv_port[1]
|
||||
}
|
||||
|
||||
r.servers = append(r.servers, net.JoinHostPort(ip, port))
|
||||
}
|
||||
}
|
||||
|
@ -135,6 +156,12 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
|
|||
WriteTimeout: r.Timeout(),
|
||||
}
|
||||
|
||||
httpC := &dns.Client{
|
||||
Net: "https",
|
||||
ReadTimeout: r.Timeout(),
|
||||
WriteTimeout: r.Timeout(),
|
||||
}
|
||||
|
||||
if net == "udp" && settings.ResolvConfig.SetEDNS0 {
|
||||
req = req.SetEdns0(65535, true)
|
||||
}
|
||||
|
@ -145,7 +172,17 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
|
|||
var wg sync.WaitGroup
|
||||
L := func(nameserver string) {
|
||||
defer wg.Done()
|
||||
r, rtt, err := c.Exchange(req, nameserver)
|
||||
|
||||
var r *dns.Msg
|
||||
var rtt time.Duration
|
||||
var err error
|
||||
|
||||
if strings.HasPrefix(nameserver, "https") {
|
||||
r, rtt, err = httpC.Exchange(req, nameserver)
|
||||
} else {
|
||||
r, rtt, err = c.Exchange(req, nameserver)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Warn("%s socket error on %s", qname, nameserver)
|
||||
logger.Warn("error:%s", err.Error())
|
||||
|
@ -215,6 +252,7 @@ func (r *Resolver) Nameservers(qname string) []string {
|
|||
for _, nameserver := range r.servers {
|
||||
ns = append(ns, nameserver)
|
||||
}
|
||||
|
||||
return ns
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ type ResolvSettings struct {
|
|||
SetEDNS0 bool
|
||||
ServerListFile string `toml:"server-list-file"`
|
||||
ResolvFile string `toml:"resolv-file"`
|
||||
DOHServer string `toml:"dns-over-https"`
|
||||
}
|
||||
|
||||
type DNSServerSettings struct {
|
||||
|
|
Loading…
Reference in New Issue