Updates
This commit is contained in:
parent
489adb58ef
commit
a6f6c4e96d
136
hosts.go
136
hosts.go
|
@ -10,49 +10,62 @@ import (
|
||||||
|
|
||||||
"github.com/hoisie/redis"
|
"github.com/hoisie/redis"
|
||||||
"golang.org/x/net/publicsuffix"
|
"golang.org/x/net/publicsuffix"
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Hosts struct {
|
type Hosts struct {
|
||||||
fileHosts *FileHosts
|
providers []HostProvider
|
||||||
redisHosts *RedisHosts
|
|
||||||
refreshInterval time.Duration
|
refreshInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HostProvider interface {
|
||||||
|
Get(domain string) ([]string, bool)
|
||||||
|
Refresh()
|
||||||
|
}
|
||||||
|
|
||||||
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
|
func NewHosts(hs HostsSettings, rs RedisSettings) Hosts {
|
||||||
fileHosts := &FileHosts{
|
providers := []HostProvider{
|
||||||
file: hs.HostsFile,
|
NewFileProvider(hs.HostsFile),
|
||||||
hosts: make(map[string]string),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var redisHosts *RedisHosts
|
|
||||||
if hs.RedisEnable {
|
if hs.RedisEnable {
|
||||||
rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password}
|
rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password}
|
||||||
redisHosts = &RedisHosts{
|
|
||||||
redis: rc,
|
providers = append(providers, NewRedisProvider(rc, hs.RedisKey))
|
||||||
key: hs.RedisKey,
|
|
||||||
hosts: make(map[string]string),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts := Hosts{fileHosts, redisHosts, time.Second * time.Duration(hs.RefreshInterval)}
|
return Hosts{providers, time.Second * time.Duration(hs.RefreshInterval)}
|
||||||
hosts.refresh()
|
}
|
||||||
return hosts
|
|
||||||
|
|
||||||
|
func (h *Hosts) refresh() {
|
||||||
|
ticker := time.NewTicker(h.refreshInterval)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
// Force a refresh every refreshInterval
|
||||||
|
for _, provider := range h.providers {
|
||||||
|
provider.Refresh()
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ticker.C
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Match local /etc/hosts file first, remote redis records second
|
Match local /etc/hosts file first, remote redis records second
|
||||||
*/
|
*/
|
||||||
func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
|
func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
|
||||||
|
|
||||||
var sips []string
|
var sips []string
|
||||||
|
var ok bool
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
var ips []net.IP
|
var ips []net.IP
|
||||||
|
|
||||||
sips, ok := h.fileHosts.Get(domain)
|
for _, provider := range h.providers {
|
||||||
if !ok {
|
sips, ok = provider.Get(domain)
|
||||||
if h.redisHosts != nil {
|
|
||||||
sips, ok = h.redisHosts.Get(domain)
|
if ok {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,32 +87,55 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, (ips != nil)
|
return ips, ips != nil
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Update hosts records from /etc/hosts file and redis per minute
|
|
||||||
*/
|
|
||||||
func (h *Hosts) refresh() {
|
|
||||||
ticker := time.NewTicker(h.refreshInterval)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
h.fileHosts.Refresh()
|
|
||||||
if h.redisHosts != nil {
|
|
||||||
h.redisHosts.Refresh()
|
|
||||||
}
|
|
||||||
<-ticker.C
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RedisHosts struct {
|
type RedisHosts struct {
|
||||||
|
HostProvider
|
||||||
|
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
key string
|
key string
|
||||||
hosts map[string]string
|
hosts map[string]string
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewRedisProvider(rc *redis.Client, key string) HostProvider {
|
||||||
|
rh := &RedisHosts{
|
||||||
|
redis: rc,
|
||||||
|
key: key,
|
||||||
|
hosts: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pubsub to listen for key update events
|
||||||
|
go func() {
|
||||||
|
sub := make(chan string, 2)
|
||||||
|
sub <- "godns:update"
|
||||||
|
sub <- "godns:update_record"
|
||||||
|
messages := make(chan redis.Message, 0)
|
||||||
|
go rc.Subscribe(sub, nil, nil, nil, messages)
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg := <- messages
|
||||||
|
|
||||||
|
if msg.Channel == "godns:update" {
|
||||||
|
rh.Refresh()
|
||||||
|
} else if msg.Channel == "godns:update_record" {
|
||||||
|
recordName := string(msg.Message)
|
||||||
|
|
||||||
|
b, err := rc.Hget(key, recordName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rh.hosts[recordName] = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return rh
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RedisHosts) Get(domain string) ([]string, bool) {
|
func (r *RedisHosts) Get(domain string) ([]string, bool) {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
|
@ -152,11 +188,39 @@ func (r *RedisHosts) clear() {
|
||||||
}
|
}
|
||||||
|
|
||||||
type FileHosts struct {
|
type FileHosts struct {
|
||||||
|
HostProvider
|
||||||
|
|
||||||
file string
|
file string
|
||||||
hosts map[string]string
|
hosts map[string]string
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewFileProvider(file string) HostProvider {
|
||||||
|
fp := &FileHosts{
|
||||||
|
file: file,
|
||||||
|
hosts: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
|
||||||
|
// Use fsnotify to notify us of host file changes
|
||||||
|
if err == nil {
|
||||||
|
watcher.Add(file)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
e := <- watcher.Events
|
||||||
|
|
||||||
|
if e.Op == fsnotify.Write {
|
||||||
|
fp.Refresh()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
|
||||||
func (f *FileHosts) Get(domain string) ([]string, bool) {
|
func (f *FileHosts) Get(domain string) ([]string, bool) {
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
50
resolver.go
50
resolver.go
|
@ -48,36 +48,51 @@ func NewResolver(c ResolvSettings) *Resolver {
|
||||||
|
|
||||||
if len(c.ResolvFile) > 0 {
|
if len(c.ResolvFile) > 0 {
|
||||||
clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
|
clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
|
logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile)
|
||||||
logger.Error("%s", err)
|
logger.Error("%s", err)
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, server := range clientConfig.Servers {
|
for _, server := range clientConfig.Servers {
|
||||||
nameserver := net.JoinHostPort(server, clientConfig.Port)
|
r.servers = append(r.servers, net.JoinHostPort(server, clientConfig.Port))
|
||||||
r.servers = append(r.servers, nameserver)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(c.DOHServer) > 0 {
|
||||||
|
r.servers = append([]string{c.DOHServer}, r.servers...)
|
||||||
|
}
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resolver) parseServerListFile(buf *os.File) {
|
func (r *Resolver) parseServerListFile(buf *os.File) {
|
||||||
scanner := bufio.NewScanner(buf)
|
scanner := bufio.NewScanner(buf)
|
||||||
|
|
||||||
|
var line string
|
||||||
|
var idx int
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line = strings.TrimSpace(scanner.Text())
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
|
|
||||||
if !strings.HasPrefix(line, "server") {
|
if !strings.HasPrefix(line, "server") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
idx = strings.Index(line, "=")
|
||||||
|
|
||||||
|
if idx == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
sli := strings.Split(line, "=")
|
sli := strings.Split(line, "=")
|
||||||
|
|
||||||
if len(sli) != 2 {
|
if len(sli) != 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
line = strings.TrimSpace(sli[1])
|
line = strings.TrimSpace(line[idx:])
|
||||||
|
|
||||||
tokens := strings.Split(line, "/")
|
tokens := strings.Split(line, "/")
|
||||||
switch len(tokens) {
|
switch len(tokens) {
|
||||||
|
@ -88,25 +103,31 @@ func (r *Resolver) parseServerListFile(buf *os.File) {
|
||||||
if !isDomain(domain) || !isIP(ip) {
|
if !isDomain(domain) || !isIP(ip) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r.domain_server.sinsert(strings.Split(domain, "."), ip)
|
r.domain_server.sinsert(strings.Split(domain, "."), ip)
|
||||||
case 1:
|
case 1:
|
||||||
srv_port := strings.Split(line, "#")
|
srv_port := strings.Split(line, "#")
|
||||||
|
|
||||||
if len(srv_port) > 2 {
|
if len(srv_port) > 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := ""
|
ip := ""
|
||||||
|
|
||||||
if ip = srv_port[0]; !isIP(ip) {
|
if ip = srv_port[0]; !isIP(ip) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
port := "53"
|
port := "53"
|
||||||
|
|
||||||
if len(srv_port) == 2 {
|
if len(srv_port) == 2 {
|
||||||
if _, err := strconv.Atoi(srv_port[1]); err != nil {
|
if _, err := strconv.Atoi(srv_port[1]); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
port = srv_port[1]
|
port = srv_port[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
r.servers = append(r.servers, net.JoinHostPort(ip, port))
|
r.servers = append(r.servers, net.JoinHostPort(ip, port))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -135,6 +156,12 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
|
||||||
WriteTimeout: r.Timeout(),
|
WriteTimeout: r.Timeout(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
httpC := &dns.Client{
|
||||||
|
Net: "https",
|
||||||
|
ReadTimeout: r.Timeout(),
|
||||||
|
WriteTimeout: r.Timeout(),
|
||||||
|
}
|
||||||
|
|
||||||
if net == "udp" && settings.ResolvConfig.SetEDNS0 {
|
if net == "udp" && settings.ResolvConfig.SetEDNS0 {
|
||||||
req = req.SetEdns0(65535, true)
|
req = req.SetEdns0(65535, true)
|
||||||
}
|
}
|
||||||
|
@ -145,7 +172,17 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
L := func(nameserver string) {
|
L := func(nameserver string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
r, rtt, err := c.Exchange(req, nameserver)
|
|
||||||
|
var r *dns.Msg
|
||||||
|
var rtt time.Duration
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if strings.HasPrefix(nameserver, "https") {
|
||||||
|
r, rtt, err = httpC.Exchange(req, nameserver)
|
||||||
|
} else {
|
||||||
|
r, rtt, err = c.Exchange(req, nameserver)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("%s socket error on %s", qname, nameserver)
|
logger.Warn("%s socket error on %s", qname, nameserver)
|
||||||
logger.Warn("error:%s", err.Error())
|
logger.Warn("error:%s", err.Error())
|
||||||
|
@ -215,6 +252,7 @@ func (r *Resolver) Nameservers(qname string) []string {
|
||||||
for _, nameserver := range r.servers {
|
for _, nameserver := range r.servers {
|
||||||
ns = append(ns, nameserver)
|
ns = append(ns, nameserver)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ns
|
return ns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ type ResolvSettings struct {
|
||||||
SetEDNS0 bool
|
SetEDNS0 bool
|
||||||
ServerListFile string `toml:"server-list-file"`
|
ServerListFile string `toml:"server-list-file"`
|
||||||
ResolvFile string `toml:"resolv-file"`
|
ResolvFile string `toml:"resolv-file"`
|
||||||
|
DOHServer string `toml:"dns-over-https"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSServerSettings struct {
|
type DNSServerSettings struct {
|
||||||
|
|
Loading…
Reference in New Issue