Add bolt provider, rewrite hosts, start of api, start of update via nsupdate
This commit is contained in:
		@ -1,6 +1,7 @@
 | 
			
		||||
package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -8,8 +9,18 @@ var (
 | 
			
		||||
	zeroDuration = time.Duration(0)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Host struct {
 | 
			
		||||
	Type uint16 `json:"type"`
 | 
			
		||||
	TTL time.Duration `json:"ttl"`
 | 
			
		||||
	Values []string `json:"values"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Host) TypeString() string {
 | 
			
		||||
	return dns.TypeToString[h.Type]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Hosts interface {
 | 
			
		||||
	Get(queryType uint16, domain string) ([]string, time.Duration, bool)
 | 
			
		||||
	Get(queryType uint16, domain string) (*Host, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ProviderList struct {
 | 
			
		||||
@ -17,33 +28,26 @@ type ProviderList struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Provider interface {
 | 
			
		||||
	Get(queryType uint16, domain string) ([]string, time.Duration, bool)
 | 
			
		||||
	Set(t, domain, value string) (bool, error)
 | 
			
		||||
	Get(queryType uint16, domain string) (*Host, error)
 | 
			
		||||
	Set(domain string, host *Host) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHosts(providers []Provider) Hosts {
 | 
			
		||||
	return &ProviderList{providers}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
Match local /etc/hosts file first, remote redis records second
 | 
			
		||||
*/
 | 
			
		||||
func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
	var vals []string
 | 
			
		||||
	var ok bool
 | 
			
		||||
	var ttl time.Duration
 | 
			
		||||
// Get Matches values to providers, loping each in order
 | 
			
		||||
func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	var host *Host
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	for _, provider := range h.providers {
 | 
			
		||||
		vals, ttl, ok = provider.Get(queryType, domain)
 | 
			
		||||
		host, err = provider.Get(queryType, domain)
 | 
			
		||||
 | 
			
		||||
		if ok {
 | 
			
		||||
		if host != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if vals == nil {
 | 
			
		||||
		return nil, zeroDuration, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return vals, ttl, true
 | 
			
		||||
	return host, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										124
									
								
								hosts/hosts_bolt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								hosts/hosts_bolt.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,124 @@
 | 
			
		||||
package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	bolt "go.etcd.io/bbolt"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	recordBucket = "records"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type BoltHosts struct {
 | 
			
		||||
	Provider
 | 
			
		||||
 | 
			
		||||
	db *bolt.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewBoltProvider(file string) Provider {
 | 
			
		||||
	db, err := bolt.Open(file, 0600, &bolt.Options{})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.WithError(err).Fatalln("Unable to open database")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = db.Update(func(tx *bolt.Tx) error {
 | 
			
		||||
		_, err := tx.CreateBucketIfNotExists([]byte(recordBucket))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return &BoltHosts{
 | 
			
		||||
		db: db,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	log.Debug("Checking bolt provider for %s : %s", queryType, domain)
 | 
			
		||||
 | 
			
		||||
	domain = strings.ToLower(domain)
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	key := domain + "_" + dns.TypeToString[queryType]
 | 
			
		||||
	var v []byte
 | 
			
		||||
 | 
			
		||||
	err = b.db.View(func(tx *bolt.Tx) error {
 | 
			
		||||
		b := tx.Bucket([]byte("records"))
 | 
			
		||||
 | 
			
		||||
		v = b.Get([]byte(key))
 | 
			
		||||
 | 
			
		||||
		if string(v) == "" {
 | 
			
		||||
			return errors.New( "Record not found, key:  " + key)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		v = b.Get([]byte("*." + key))
 | 
			
		||||
 | 
			
		||||
		if string(v) == "" {
 | 
			
		||||
			return errors.New( "Record not found, key:  " + key)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var h []Host
 | 
			
		||||
 | 
			
		||||
	if err = json.Unmarshal(v, &h); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, host := range h {
 | 
			
		||||
		if host.Type == queryType {
 | 
			
		||||
			return &host, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, errRecordNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *BoltHosts) Set(domain string, host *Host) error {
 | 
			
		||||
	err := b.db.Update(func(tx *bolt.Tx) error {
 | 
			
		||||
		b := tx.Bucket([]byte(recordBucket))
 | 
			
		||||
 | 
			
		||||
		hosts := []*Host{host}
 | 
			
		||||
 | 
			
		||||
		existing := b.Get([]byte(domain))
 | 
			
		||||
 | 
			
		||||
		if existing != nil {
 | 
			
		||||
			err := json.Unmarshal(existing, &hosts)
 | 
			
		||||
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			hosts = append(hosts, host)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		hostBytes, err := json.Marshal(hosts)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = b.Put([]byte(domain), hostBytes)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"github.com/fsnotify/fsnotify"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/ryanuber/go-glob"
 | 
			
		||||
	"meow.tf/joker/godns/log"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"meow.tf/joker/godns/utils"
 | 
			
		||||
	"os"
 | 
			
		||||
	"regexp"
 | 
			
		||||
@ -15,11 +15,17 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	errUnsupportedType = errors.New("unsupported type")
 | 
			
		||||
	errRecordNotFound = errors.New("record not found")
 | 
			
		||||
	errUnsupportedOperation = errors.New("unsupported operation")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type FileHosts struct {
 | 
			
		||||
	Provider
 | 
			
		||||
 | 
			
		||||
	file  string
 | 
			
		||||
	hosts map[string]string
 | 
			
		||||
	hosts map[string]Host
 | 
			
		||||
	mu    sync.RWMutex
 | 
			
		||||
	ttl time.Duration
 | 
			
		||||
}
 | 
			
		||||
@ -27,7 +33,7 @@ type FileHosts struct {
 | 
			
		||||
func NewFileProvider(file string, ttl time.Duration) Provider {
 | 
			
		||||
	fp := &FileHosts{
 | 
			
		||||
		file: file,
 | 
			
		||||
		hosts: make(map[string]string),
 | 
			
		||||
		hosts: make(map[string]Host),
 | 
			
		||||
		ttl: ttl,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -51,41 +57,41 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
 | 
			
		||||
	return fp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	log.Debug("Checking file provider for %s : %s", queryType, domain)
 | 
			
		||||
 | 
			
		||||
	// Does not support CNAME/TXT/etc
 | 
			
		||||
	if queryType != dns.TypeA && queryType != dns.TypeAAAA {
 | 
			
		||||
		return nil, zeroDuration, false
 | 
			
		||||
		return nil, errUnsupportedType
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	f.mu.RLock()
 | 
			
		||||
	defer f.mu.RUnlock()
 | 
			
		||||
	domain = strings.ToLower(domain)
 | 
			
		||||
 | 
			
		||||
	if ip, ok := f.hosts[domain]; ok {
 | 
			
		||||
		return strings.Split(ip, ","), f.ttl, true
 | 
			
		||||
	if host, ok := f.hosts[domain]; ok {
 | 
			
		||||
		return &host, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if idx := strings.Index(domain, "."); idx != -1 {
 | 
			
		||||
		wildcard := "*." + domain[strings.Index(domain, ".") + 1:]
 | 
			
		||||
 | 
			
		||||
		if ip, ok := f.hosts[wildcard]; ok {
 | 
			
		||||
			return strings.Split(ip, ","), f.ttl, true
 | 
			
		||||
		if host, ok := f.hosts[wildcard]; ok {
 | 
			
		||||
			return &host, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for host, ip := range f.hosts {
 | 
			
		||||
		if glob.Glob(host, domain) {
 | 
			
		||||
			return strings.Split(ip, ","), f.ttl, true
 | 
			
		||||
	for hostname, host := range f.hosts {
 | 
			
		||||
		if glob.Glob(hostname, domain) {
 | 
			
		||||
			return &host, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, time.Duration(0), false
 | 
			
		||||
	return nil, errRecordNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) Set(t, domain, value string) (bool, error) {
 | 
			
		||||
	return false, errors.New("file provider does not support setting values")
 | 
			
		||||
func (f *FileHosts) Set(domain string, host *Host) error {
 | 
			
		||||
	return errUnsupportedOperation
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
@ -112,7 +118,7 @@ func (f *FileHosts) Refresh() {
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
 | 
			
		||||
		if strings.HasPrefix(line, "#") || line == "" {
 | 
			
		||||
		if line == "" || line[0] == '#' {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -140,12 +146,12 @@ func (f *FileHosts) Refresh() {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			f.hosts[strings.ToLower(domain)] = ip
 | 
			
		||||
			f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) clear() {
 | 
			
		||||
	f.hosts = make(map[string]string)
 | 
			
		||||
	f.hosts = make(map[string]Host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,13 +1,10 @@
 | 
			
		||||
package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/go-redis/redis/v7"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/ryanuber/go-glob"
 | 
			
		||||
	"meow.tf/joker/godns/log"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RedisHosts struct {
 | 
			
		||||
@ -15,79 +12,75 @@ type RedisHosts struct {
 | 
			
		||||
 | 
			
		||||
	redis *redis.Client
 | 
			
		||||
	key   string
 | 
			
		||||
	hosts map[string]string
 | 
			
		||||
	mu    sync.RWMutex
 | 
			
		||||
	ttl time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider {
 | 
			
		||||
func NewRedisProvider(rc *redis.Client, key string) Provider {
 | 
			
		||||
	rh := &RedisHosts{
 | 
			
		||||
		redis: rc,
 | 
			
		||||
		key:   key,
 | 
			
		||||
		hosts: make(map[string]string),
 | 
			
		||||
		ttl: ttl,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Force an initial refresh
 | 
			
		||||
	rh.Refresh()
 | 
			
		||||
 | 
			
		||||
	return rh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
 | 
			
		||||
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	log.Debug("Checking redis provider for %s", domain)
 | 
			
		||||
 | 
			
		||||
	// Don't support queries other than A/AAAA for now
 | 
			
		||||
	if queryType != dns.TypeA || queryType != dns.TypeAAAA {
 | 
			
		||||
		return nil, zeroDuration, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.mu.RLock()
 | 
			
		||||
	defer r.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
	domain = strings.ToLower(domain)
 | 
			
		||||
 | 
			
		||||
	if ip, ok := r.hosts[domain]; ok {
 | 
			
		||||
		return strings.Split(ip, ","), r.ttl, true
 | 
			
		||||
	if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
 | 
			
		||||
		var h []Host
 | 
			
		||||
 | 
			
		||||
		if err = json.Unmarshal([]byte(res), &h); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, host := range h {
 | 
			
		||||
			if host.Type == queryType {
 | 
			
		||||
				return &host, nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if idx := strings.Index(domain, "."); idx != -1 {
 | 
			
		||||
		wildcard := "*." + domain[strings.Index(domain, ".")+1:]
 | 
			
		||||
 | 
			
		||||
		if ip, ok := r.hosts[wildcard]; ok {
 | 
			
		||||
			return strings.Split(ip, ","), r.ttl, true
 | 
			
		||||
		if res, err := r.redis.HGet(r.key, wildcard).Result(); res != "" && err == nil {
 | 
			
		||||
			var h []Host
 | 
			
		||||
 | 
			
		||||
			if err = json.Unmarshal([]byte(res), &h); err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, host := range h {
 | 
			
		||||
				if host.Type == queryType {
 | 
			
		||||
					return &host, nil
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for host, ip := range r.hosts {
 | 
			
		||||
		if glob.Glob(host, domain) {
 | 
			
		||||
			return strings.Split(ip, ","), r.ttl, true
 | 
			
		||||
	return nil, errRecordNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Set(domain string, host *Host) error {
 | 
			
		||||
	hosts := []*Host{host}
 | 
			
		||||
 | 
			
		||||
	if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
 | 
			
		||||
		if err = json.Unmarshal([]byte(res), &hosts); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		hosts = append(hosts, host)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, time.Duration(0), false
 | 
			
		||||
}
 | 
			
		||||
	b, err := json.Marshal(hosts)
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Set(t, domain, ip string) (bool, error) {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
	return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Refresh() {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
	r.clear()
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	r.hosts, err = r.redis.HGetAll(r.key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Warn("Update hosts records from redis failed %s", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Debug("Update hosts records from redis")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) clear() {
 | 
			
		||||
	r.hosts = make(map[string]string)
 | 
			
		||||
}
 | 
			
		||||
	_, err = r.redis.HSet(r.key, strings.ToLower(domain), b).Result()
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user