Changes to allow host "Set" to be standard, providers being able to use other query types, cache all responses, etc.
This commit is contained in:
		@ -1,22 +1,15 @@
 | 
			
		||||
package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	notIPQuery = 0
 | 
			
		||||
	_IP4Query  = 4
 | 
			
		||||
	_IP6Query  = 6
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	zeroDuration = time.Duration(0)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Hosts interface {
 | 
			
		||||
	Get(domain string, family int) ([]net.IP, time.Duration, bool)
 | 
			
		||||
	Get(queryType uint16, domain string) ([]string, time.Duration, bool)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ProviderList struct {
 | 
			
		||||
@ -24,7 +17,8 @@ type ProviderList struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Provider interface {
 | 
			
		||||
	Get(domain string) ([]string, time.Duration, bool)
 | 
			
		||||
	Get(queryType uint16, domain string) ([]string, time.Duration, bool)
 | 
			
		||||
	Set(t, domain, value string) (bool, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHosts(providers []Provider) Hosts {
 | 
			
		||||
@ -34,38 +28,22 @@ func NewHosts(providers []Provider) Hosts {
 | 
			
		||||
/*
 | 
			
		||||
Match local /etc/hosts file first, remote redis records second
 | 
			
		||||
*/
 | 
			
		||||
func (h *ProviderList) Get(domain string, family int) ([]net.IP, time.Duration, bool) {
 | 
			
		||||
	var sips []string
 | 
			
		||||
func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	var vals []string
 | 
			
		||||
	var ok bool
 | 
			
		||||
	var ip net.IP
 | 
			
		||||
	var ips []net.IP
 | 
			
		||||
	var ttl time.Duration
 | 
			
		||||
 | 
			
		||||
	for _, provider := range h.providers {
 | 
			
		||||
		sips, ttl, ok = provider.Get(domain)
 | 
			
		||||
		vals, ttl, ok = provider.Get(queryType, domain)
 | 
			
		||||
 | 
			
		||||
		if ok {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sips == nil {
 | 
			
		||||
	if vals == nil {
 | 
			
		||||
		return nil, zeroDuration, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, sip := range sips {
 | 
			
		||||
		switch family {
 | 
			
		||||
		case _IP4Query:
 | 
			
		||||
			ip = net.ParseIP(sip).To4()
 | 
			
		||||
		case _IP6Query:
 | 
			
		||||
			ip = net.ParseIP(sip).To16()
 | 
			
		||||
		default:
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if ip != nil {
 | 
			
		||||
			ips = append(ips, ip)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ips, ttl, ips != nil
 | 
			
		||||
	return vals, ttl, true
 | 
			
		||||
}
 | 
			
		||||
@ -2,7 +2,9 @@ package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/fsnotify/fsnotify"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/ryanuber/go-glob"
 | 
			
		||||
	"meow.tf/joker/godns/log"
 | 
			
		||||
	"meow.tf/joker/godns/utils"
 | 
			
		||||
@ -49,8 +51,13 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
 | 
			
		||||
	return fp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	log.Debug("Checking file provider for %s", domain)
 | 
			
		||||
func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	log.Debug("Checking file provider for %s : %s", queryType, domain)
 | 
			
		||||
 | 
			
		||||
	// Does not support CNAME/TXT/etc
 | 
			
		||||
	if queryType != dns.TypeA && queryType != dns.TypeAAAA {
 | 
			
		||||
		return nil, zeroDuration, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	f.mu.RLock()
 | 
			
		||||
	defer f.mu.RUnlock()
 | 
			
		||||
@ -77,6 +84,10 @@ func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	return nil, time.Duration(0), false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) Set(t, domain, value string) (bool, error) {
 | 
			
		||||
	return false, errors.New("file provider does not support setting values")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/go-redis/redis/v7"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/ryanuber/go-glob"
 | 
			
		||||
	"meow.tf/joker/godns/log"
 | 
			
		||||
	"strings"
 | 
			
		||||
@ -33,9 +34,14 @@ func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider
 | 
			
		||||
	return rh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	log.Debug("Checking redis provider for %s", domain)
 | 
			
		||||
 | 
			
		||||
	// Don't support queries other than A/AAAA for now
 | 
			
		||||
	if queryType != dns.TypeA || queryType != dns.TypeAAAA {
 | 
			
		||||
		return nil, zeroDuration, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.mu.RLock()
 | 
			
		||||
	defer r.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
@ -62,7 +68,7 @@ func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	return nil, time.Duration(0), false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Set(domain, ip string) (bool, error) {
 | 
			
		||||
func (r *RedisHosts) Set(t, domain, ip string) (bool, error) {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
	return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user