diff --git a/hosts.go b/hosts.go index 6645cbc..db0d121 100644 --- a/hosts.go +++ b/hosts.go @@ -1,16 +1,10 @@ package main import ( - "bufio" "net" - "os" - "strings" - "sync" "time" "github.com/hoisie/redis" - "golang.org/x/net/publicsuffix" - "github.com/fsnotify/fsnotify" ) type Hosts struct { @@ -89,217 +83,3 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) { return ips, ips != nil } - -type RedisHosts struct { - HostProvider - - redis *redis.Client - key string - hosts map[string]string - mu sync.RWMutex -} - -func NewRedisProvider(rc *redis.Client, key string) HostProvider { - rh := &RedisHosts{ - redis: rc, - key: key, - hosts: make(map[string]string), - } - - // Use pubsub to listen for key update events - go func() { - sub := make(chan string, 2) - sub <- "godns:update" - sub <- "godns:update_record" - messages := make(chan redis.Message, 0) - go rc.Subscribe(sub, nil, nil, nil, messages) - - for { - msg := <- messages - - if msg.Channel == "godns:update" { - rh.Refresh() - } else if msg.Channel == "godns:update_record" { - recordName := string(msg.Message) - - b, err := rc.Hget(key, recordName) - - if err != nil { - continue - } - - rh.hosts[recordName] = string(b) - } - } - }() - - return rh -} - -func (r *RedisHosts) Get(domain string) ([]string, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - - domain = strings.ToLower(domain) - ip, ok := r.hosts[domain] - if ok { - return strings.Split(ip, ","), true - } - - sld, err := publicsuffix.EffectiveTLDPlusOne(domain) - if err != nil { - return nil, false - } - - for host, ip := range r.hosts { - if strings.HasPrefix(host, "*.") { - old, err := publicsuffix.EffectiveTLDPlusOne(host) - if err != nil { - continue - } - if sld == old { - return strings.Split(ip, ","), true - } - } - } - return nil, false -} - -func (r *RedisHosts) Set(domain, ip string) (bool, error) { - r.mu.Lock() - defer r.mu.Unlock() - return r.redis.Hset(r.key, strings.ToLower(domain), []byte(ip)) -} - -func (r *RedisHosts) Refresh() { - r.mu.Lock() - defer r.mu.Unlock() - r.clear() - err := r.redis.Hgetall(r.key, r.hosts) - if err != nil { - logger.Warn("Update hosts records from redis failed %s", err) - } else { - logger.Debug("Update hosts records from redis") - } -} - -func (r *RedisHosts) clear() { - r.hosts = make(map[string]string) -} - -type FileHosts struct { - HostProvider - - file string - hosts map[string]string - mu sync.RWMutex -} - -func NewFileProvider(file string) HostProvider { - fp := &FileHosts{ - file: file, - hosts: make(map[string]string), - } - - watcher, err := fsnotify.NewWatcher() - - // Use fsnotify to notify us of host file changes - if err == nil { - watcher.Add(file) - - go func() { - for { - e := <- watcher.Events - - if e.Op == fsnotify.Write { - fp.Refresh() - } - } - }() - } - - return fp -} - -func (f *FileHosts) Get(domain string) ([]string, bool) { - f.mu.RLock() - defer f.mu.RUnlock() - domain = strings.ToLower(domain) - ip, ok := f.hosts[domain] - if ok { - return []string{ip}, true - } - - sld, err := publicsuffix.EffectiveTLDPlusOne(domain) - if err != nil { - return nil, false - } - - for host, ip := range f.hosts { - if strings.HasPrefix(host, "*.") { - old, err := publicsuffix.EffectiveTLDPlusOne(host) - if err != nil { - continue - } - if sld == old { - return []string{ip}, true - } - } - } - - return nil, false -} - -func (f *FileHosts) Refresh() { - buf, err := os.Open(f.file) - if err != nil { - logger.Warn("Update hosts records from file failed %s", err) - return - } - defer buf.Close() - - f.mu.Lock() - defer f.mu.Unlock() - - f.clear() - - scanner := bufio.NewScanner(buf) - for scanner.Scan() { - - line := scanner.Text() - line = strings.TrimSpace(line) - line = strings.Replace(line, "\t", " ", -1) - - if strings.HasPrefix(line, "#") || line == "" { - continue - } - - sli := strings.Split(line, " ") - - if len(sli) < 2 { - continue - } - - ip := sli[0] - if !isIP(ip) { - continue - } - - // Would have multiple columns of domain in line. - // Such as "127.0.0.1 localhost localhost.domain" on linux. - // The domains may not strict standard, like "local" so don't check with f.isDomain(domain). - for i := 1; i <= len(sli)-1; i++ { - domain := strings.TrimSpace(sli[i]) - if domain == "" { - continue - } - - f.hosts[strings.ToLower(domain)] = ip - } - } - logger.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts)) -} - -func (f *FileHosts) clear() { - f.hosts = make(map[string]string) -} diff --git a/hosts_file.go b/hosts_file.go new file mode 100644 index 0000000..985bd42 --- /dev/null +++ b/hosts_file.go @@ -0,0 +1,136 @@ +package main + +import ( + "sync" + "github.com/fsnotify/fsnotify" + "strings" + "golang.org/x/net/publicsuffix" + "os" + "bufio" + "regexp" +) + +type FileHosts struct { + HostProvider + + file string + hosts map[string]string + mu sync.RWMutex +} + +func NewFileProvider(file string) HostProvider { + fp := &FileHosts{ + file: file, + hosts: make(map[string]string), + } + + watcher, err := fsnotify.NewWatcher() + + // Use fsnotify to notify us of host file changes + if err == nil { + watcher.Add(file) + + go func() { + for { + e := <- watcher.Events + + if e.Op == fsnotify.Write { + fp.Refresh() + } + } + }() + } + + return fp +} + +func (f *FileHosts) Get(domain string) ([]string, bool) { + f.mu.RLock() + defer f.mu.RUnlock() + domain = strings.ToLower(domain) + ip, ok := f.hosts[domain] + if ok { + return []string{ip}, true + } + + sld, err := publicsuffix.EffectiveTLDPlusOne(domain) + if err != nil { + return nil, false + } + + for host, ip := range f.hosts { + if strings.HasPrefix(host, "*.") { + old, err := publicsuffix.EffectiveTLDPlusOne(host) + if err != nil { + continue + } + if sld == old { + return []string{ip}, true + } + } + } + + return nil, false +} + +var ( + hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$") +) + +func (f *FileHosts) Refresh() { + buf, err := os.Open(f.file) + + if err != nil { + logger.Warn("Update hosts records from file failed %s", err) + return + } + + defer buf.Close() + + f.mu.Lock() + defer f.mu.Unlock() + + f.clear() + + scanner := bufio.NewScanner(buf) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.HasPrefix(line, "#") || line == "" { + continue + } + + m := hostRegexp.FindStringSubmatch(line) + + if m == nil { + continue + } + + ip := m[1] + + if !isIP(ip) { + continue + } + + domains := strings.Split(m[2], " ") + + // Would have multiple columns of domain in line. + // Such as "127.0.0.1 localhost localhost.domain" on linux. + // The domains may not strict standard, like "local" so don't check with f.isDomain(domain). + for _, domain := range domains { + domain = strings.TrimSpace(domain) + + if domain == "" { + continue + } + + f.hosts[strings.ToLower(domain)] = ip + } + } + logger.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts)) +} + +func (f *FileHosts) clear() { + f.hosts = make(map[string]string) +} diff --git a/hosts_redis.go b/hosts_redis.go new file mode 100644 index 0000000..79c57f2 --- /dev/null +++ b/hosts_redis.go @@ -0,0 +1,110 @@ +package main + +import ( + "github.com/hoisie/redis" + "sync" + "strings" + "golang.org/x/net/publicsuffix" +) + +type RedisHosts struct { + HostProvider + + redis *redis.Client + key string + hosts map[string]string + mu sync.RWMutex +} + +func NewRedisProvider(rc *redis.Client, key string) HostProvider { + rh := &RedisHosts{ + redis: rc, + key: key, + hosts: make(map[string]string), + } + + // Use pubsub to listen for key update events + go func() { + keyspaceEvent := "__keyspace@0__:" + key + + sub := make(chan string, 2) + sub <- keyspaceEvent + sub <- "godns:update" + sub <- "godns:update_record" + messages := make(chan redis.Message, 0) + go rc.Subscribe(sub, nil, nil, nil, messages) + + for { + msg := <- messages + + if msg.Channel == "godns:update" { + rh.Refresh() + } else if msg.Channel == "godns:update_record" { + recordName := string(msg.Message) + + b, err := rc.Hget(key, recordName) + + if err != nil { + continue + } + + rh.hosts[recordName] = string(b) + } else if msg.Channel == keyspaceEvent { + rh.Refresh() + } + } + }() + + return rh +} + +func (r *RedisHosts) Get(domain string) ([]string, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + domain = strings.ToLower(domain) + ip, ok := r.hosts[domain] + if ok { + return strings.Split(ip, ","), true + } + + sld, err := publicsuffix.EffectiveTLDPlusOne(domain) + if err != nil { + return nil, false + } + + for host, ip := range r.hosts { + if strings.HasPrefix(host, "*.") { + old, err := publicsuffix.EffectiveTLDPlusOne(host) + if err != nil { + continue + } + if sld == old { + return strings.Split(ip, ","), true + } + } + } + return nil, false +} + +func (r *RedisHosts) Set(domain, ip string) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.redis.Hset(r.key, strings.ToLower(domain), []byte(ip)) +} + +func (r *RedisHosts) Refresh() { + r.mu.Lock() + defer r.mu.Unlock() + r.clear() + err := r.redis.Hgetall(r.key, r.hosts) + if err != nil { + logger.Warn("Update hosts records from redis failed %s", err) + } else { + logger.Debug("Update hosts records from redis") + } +} + +func (r *RedisHosts) clear() { + r.hosts = make(map[string]string) +}