Redo settings a bit, move a lot of init logic to main.go

This commit is contained in:
Tyler
2020-01-25 13:48:26 -05:00
parent 991ae3ecb5
commit f726a5d5ae
12 changed files with 152 additions and 394 deletions

View File

@ -1,12 +1,8 @@
package hosts
import (
"meow.tf/joker/godns/log"
"meow.tf/joker/godns/settings"
"net"
"time"
"github.com/hoisie/redis"
)
const (
@ -15,70 +11,38 @@ const (
_IP6Query = 6
)
var (
zeroDuration = time.Duration(0)
)
type Hosts interface {
Get(domain string, family int) ([]net.IP, bool)
TTL() uint32
Get(domain string, family int) ([]net.IP, time.Duration, bool)
}
type ProviderList struct {
settings settings.HostsSettings
providers []HostProvider
refreshInterval time.Duration
providers []Provider
}
type HostProvider interface {
Get(domain string) ([]string, bool)
Refresh()
type Provider interface {
Get(domain string) ([]string, time.Duration, bool)
}
func NewHosts(hs settings.HostsSettings, rs settings.RedisSettings) Hosts {
providers := []HostProvider{
NewFileProvider(hs.HostsFile),
}
if hs.RedisEnable {
log.Info("Redis is enabled: %s", rs.Addr())
rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password}
providers = append(providers, NewRedisProvider(rc, hs.RedisKey))
}
h := &ProviderList{hs, providers, time.Second * time.Duration(hs.RefreshInterval)}
if h.refreshInterval > 0 {
h.refresh()
}
return h
}
func (h *ProviderList) refresh() {
ticker := time.NewTicker(h.refreshInterval)
go func() {
for {
// Force a refresh every refreshInterval
for _, provider := range h.providers {
provider.Refresh()
}
<-ticker.C
}
}()
func NewHosts(providers []Provider) Hosts {
return &ProviderList{providers}
}
/*
Match local /etc/hosts file first, remote redis records second
*/
func (h *ProviderList) Get(domain string, family int) ([]net.IP, bool) {
func (h *ProviderList) Get(domain string, family int) ([]net.IP, time.Duration, bool) {
var sips []string
var ok bool
var ip net.IP
var ips []net.IP
var ttl time.Duration
for _, provider := range h.providers {
sips, ok = provider.Get(domain)
sips, ttl, ok = provider.Get(domain)
if ok {
break
@ -86,7 +50,7 @@ func (h *ProviderList) Get(domain string, family int) ([]net.IP, bool) {
}
if sips == nil {
return nil, false
return nil, zeroDuration, false
}
for _, sip := range sips {
@ -103,9 +67,5 @@ func (h *ProviderList) Get(domain string, family int) ([]net.IP, bool) {
}
}
return ips, ips != nil
}
func (h *ProviderList) TTL() uint32 {
return h.settings.TTL
return ips, ttl, ips != nil
}

View File

@ -10,20 +10,23 @@ import (
"regexp"
"strings"
"sync"
"time"
)
type FileHosts struct {
HostProvider
Provider
file string
hosts map[string]string
mu sync.RWMutex
ttl time.Duration
}
func NewFileProvider(file string) HostProvider {
func NewFileProvider(file string, ttl time.Duration) Provider {
fp := &FileHosts{
file: file,
hosts: make(map[string]string),
ttl: ttl,
}
watcher, err := fsnotify.NewWatcher()
@ -46,7 +49,7 @@ func NewFileProvider(file string) HostProvider {
return fp
}
func (f *FileHosts) Get(domain string) ([]string, bool) {
func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) {
log.Debug("Checking file provider for %s", domain)
f.mu.RLock()
@ -54,24 +57,24 @@ func (f *FileHosts) Get(domain string) ([]string, bool) {
domain = strings.ToLower(domain)
if ip, ok := f.hosts[domain]; ok {
return strings.Split(ip, ","), true
return strings.Split(ip, ","), f.ttl, true
}
if idx := strings.Index(domain, "."); idx != -1 {
wildcard := "*." + domain[strings.Index(domain, ".") + 1:]
if ip, ok := f.hosts[wildcard]; ok {
return strings.Split(ip, ","), true
return strings.Split(ip, ","), f.ttl, true
}
}
for host, ip := range f.hosts {
if glob.Glob(host, domain) {
return strings.Split(ip, ","), true
return strings.Split(ip, ","), f.ttl, true
}
}
return nil, false
return nil, time.Duration(0), false
}
var (

View File

@ -1,83 +1,39 @@
package hosts
import (
"github.com/hoisie/redis"
"github.com/go-redis/redis/v7"
"github.com/ryanuber/go-glob"
"meow.tf/joker/godns/log"
"strings"
"sync"
"time"
)
type RedisHosts struct {
HostProvider
Provider
redis *redis.Client
key string
hosts map[string]string
mu sync.RWMutex
ttl time.Duration
}
func NewRedisProvider(rc *redis.Client, key string) HostProvider {
func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider {
rh := &RedisHosts{
redis: rc,
key: key,
hosts: make(map[string]string),
ttl: ttl,
}
// Force an initial refresh
rh.Refresh()
// Use pubsub to listen for key update events
go func() {
keyspaceEvent := "__keyspace@0__:" + key
sub := make(chan string, 3)
sub <- keyspaceEvent
sub <- "godns:update"
sub <- "godns:update_record"
messages := make(chan redis.Message, 0)
go rc.Subscribe(sub, nil, nil, nil, messages)
for {
msg := <-messages
if msg.Channel == "godns:update" {
log.Debug("Refreshing redis records due to update")
rh.Refresh()
} else if msg.Channel == "godns:update_record" {
recordName := string(msg.Message)
b, err := rc.Hget(key, recordName)
if err != nil {
log.Warn("Record %s does not exist, but was updated", recordName)
continue
}
log.Debug("Record %s was updated to %s", recordName, string(b))
rh.mu.Lock()
rh.hosts[recordName] = string(b)
rh.mu.Unlock()
} else if msg.Channel == "godns:remove_record" {
log.Debug("Record %s was removed", msg.Message)
recordName := string(msg.Message)
rh.mu.Lock()
delete(rh.hosts, recordName)
rh.mu.Unlock()
} else if msg.Channel == keyspaceEvent {
log.Debug("Refreshing redis records due to update")
rh.Refresh()
}
}
}()
return rh
}
func (r *RedisHosts) Get(domain string) ([]string, bool) {
func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) {
log.Debug("Checking redis provider for %s", domain)
r.mu.RLock()
@ -86,37 +42,39 @@ func (r *RedisHosts) Get(domain string) ([]string, bool) {
domain = strings.ToLower(domain)
if ip, ok := r.hosts[domain]; ok {
return strings.Split(ip, ","), true
return strings.Split(ip, ","), r.ttl, true
}
if idx := strings.Index(domain, "."); idx != -1 {
wildcard := "*." + domain[strings.Index(domain, ".")+1:]
if ip, ok := r.hosts[wildcard]; ok {
return strings.Split(ip, ","), true
return strings.Split(ip, ","), r.ttl, true
}
}
for host, ip := range r.hosts {
if glob.Glob(host, domain) {
return strings.Split(ip, ","), true
return strings.Split(ip, ","), r.ttl, true
}
}
return nil, false
return nil, time.Duration(0), false
}
func (r *RedisHosts) Set(domain, ip string) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.redis.Hset(r.key, strings.ToLower(domain), []byte(ip))
return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()
}
func (r *RedisHosts) Refresh() {
r.mu.Lock()
defer r.mu.Unlock()
r.clear()
err := r.redis.Hgetall(r.key, r.hosts)
var err error
r.hosts, err = r.redis.HGetAll(r.key).Result()
if err != nil {
log.Warn("Update hosts records from redis failed %s", err)
} else {