Add bolt provider, rewrite hosts, start of api, start of update via nsupdate

This commit is contained in:
Tyler
2021-04-14 23:42:24 -04:00
parent 3383c5e4f9
commit f38586dcb0
18 changed files with 684 additions and 539 deletions

View File

@ -1,6 +1,7 @@
package hosts
import (
"github.com/miekg/dns"
"time"
)
@ -8,8 +9,18 @@ var (
zeroDuration = time.Duration(0)
)
type Host struct {
Type uint16 `json:"type"`
TTL time.Duration `json:"ttl"`
Values []string `json:"values"`
}
func (h *Host) TypeString() string {
return dns.TypeToString[h.Type]
}
type Hosts interface {
Get(queryType uint16, domain string) ([]string, time.Duration, bool)
Get(queryType uint16, domain string) (*Host, error)
}
type ProviderList struct {
@ -17,33 +28,26 @@ type ProviderList struct {
}
type Provider interface {
Get(queryType uint16, domain string) ([]string, time.Duration, bool)
Set(t, domain, value string) (bool, error)
Get(queryType uint16, domain string) (*Host, error)
Set(domain string, host *Host) error
}
func NewHosts(providers []Provider) Hosts {
return &ProviderList{providers}
}
/*
Match local /etc/hosts file first, remote redis records second
*/
func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
var vals []string
var ok bool
var ttl time.Duration
// Get Matches values to providers, loping each in order
func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
var host *Host
var err error
for _, provider := range h.providers {
vals, ttl, ok = provider.Get(queryType, domain)
host, err = provider.Get(queryType, domain)
if ok {
if host != nil {
break
}
}
if vals == nil {
return nil, zeroDuration, false
}
return vals, ttl, true
return host, err
}

124
hosts/hosts_bolt.go Normal file
View File

@ -0,0 +1,124 @@
package hosts
import (
"encoding/json"
"errors"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
"strings"
)
const (
recordBucket = "records"
)
type BoltHosts struct {
Provider
db *bolt.DB
}
func NewBoltProvider(file string) Provider {
db, err := bolt.Open(file, 0600, &bolt.Options{})
if err != nil {
log.WithError(err).Fatalln("Unable to open database")
}
err = db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(recordBucket))
if err != nil {
return err
}
return nil
})
return &BoltHosts{
db: db,
}
}
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking bolt provider for %s : %s", queryType, domain)
domain = strings.ToLower(domain)
var err error
key := domain + "_" + dns.TypeToString[queryType]
var v []byte
err = b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("records"))
v = b.Get([]byte(key))
if string(v) == "" {
return errors.New( "Record not found, key: " + key)
}
v = b.Get([]byte("*." + key))
if string(v) == "" {
return errors.New( "Record not found, key: " + key)
}
return nil
})
if err != nil {
return nil, err
}
var h []Host
if err = json.Unmarshal(v, &h); err != nil {
return nil, err
}
for _, host := range h {
if host.Type == queryType {
return &host, nil
}
}
return nil, errRecordNotFound
}
func (b *BoltHosts) Set(domain string, host *Host) error {
err := b.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(recordBucket))
hosts := []*Host{host}
existing := b.Get([]byte(domain))
if existing != nil {
err := json.Unmarshal(existing, &hosts)
if err != nil {
return err
}
hosts = append(hosts, host)
}
hostBytes, err := json.Marshal(hosts)
if err != nil {
return err
}
err = b.Put([]byte(domain), hostBytes)
if err != nil {
return err
}
return nil
})
return err
}

View File

@ -6,7 +6,7 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
"github.com/ryanuber/go-glob"
"meow.tf/joker/godns/log"
log "github.com/sirupsen/logrus"
"meow.tf/joker/godns/utils"
"os"
"regexp"
@ -15,11 +15,17 @@ import (
"time"
)
var (
errUnsupportedType = errors.New("unsupported type")
errRecordNotFound = errors.New("record not found")
errUnsupportedOperation = errors.New("unsupported operation")
)
type FileHosts struct {
Provider
file string
hosts map[string]string
hosts map[string]Host
mu sync.RWMutex
ttl time.Duration
}
@ -27,7 +33,7 @@ type FileHosts struct {
func NewFileProvider(file string, ttl time.Duration) Provider {
fp := &FileHosts{
file: file,
hosts: make(map[string]string),
hosts: make(map[string]Host),
ttl: ttl,
}
@ -51,41 +57,41 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
return fp
}
func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking file provider for %s : %s", queryType, domain)
// Does not support CNAME/TXT/etc
if queryType != dns.TypeA && queryType != dns.TypeAAAA {
return nil, zeroDuration, false
return nil, errUnsupportedType
}
f.mu.RLock()
defer f.mu.RUnlock()
domain = strings.ToLower(domain)
if ip, ok := f.hosts[domain]; ok {
return strings.Split(ip, ","), f.ttl, true
if host, ok := f.hosts[domain]; ok {
return &host, nil
}
if idx := strings.Index(domain, "."); idx != -1 {
wildcard := "*." + domain[strings.Index(domain, ".") + 1:]
if ip, ok := f.hosts[wildcard]; ok {
return strings.Split(ip, ","), f.ttl, true
if host, ok := f.hosts[wildcard]; ok {
return &host, nil
}
}
for host, ip := range f.hosts {
if glob.Glob(host, domain) {
return strings.Split(ip, ","), f.ttl, true
for hostname, host := range f.hosts {
if glob.Glob(hostname, domain) {
return &host, nil
}
}
return nil, time.Duration(0), false
return nil, errRecordNotFound
}
func (f *FileHosts) Set(t, domain, value string) (bool, error) {
return false, errors.New("file provider does not support setting values")
func (f *FileHosts) Set(domain string, host *Host) error {
return errUnsupportedOperation
}
var (
@ -112,7 +118,7 @@ func (f *FileHosts) Refresh() {
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "#") || line == "" {
if line == "" || line[0] == '#' {
continue
}
@ -140,12 +146,12 @@ func (f *FileHosts) Refresh() {
continue
}
f.hosts[strings.ToLower(domain)] = ip
f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}}
}
}
log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts))
}
func (f *FileHosts) clear() {
f.hosts = make(map[string]string)
f.hosts = make(map[string]Host)
}

View File

@ -1,13 +1,10 @@
package hosts
import (
"encoding/json"
"github.com/go-redis/redis/v7"
"github.com/miekg/dns"
"github.com/ryanuber/go-glob"
"meow.tf/joker/godns/log"
log "github.com/sirupsen/logrus"
"strings"
"sync"
"time"
)
type RedisHosts struct {
@ -15,79 +12,75 @@ type RedisHosts struct {
redis *redis.Client
key string
hosts map[string]string
mu sync.RWMutex
ttl time.Duration
}
func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider {
func NewRedisProvider(rc *redis.Client, key string) Provider {
rh := &RedisHosts{
redis: rc,
key: key,
hosts: make(map[string]string),
ttl: ttl,
}
// Force an initial refresh
rh.Refresh()
return rh
}
func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking redis provider for %s", domain)
// Don't support queries other than A/AAAA for now
if queryType != dns.TypeA || queryType != dns.TypeAAAA {
return nil, zeroDuration, false
}
r.mu.RLock()
defer r.mu.RUnlock()
domain = strings.ToLower(domain)
if ip, ok := r.hosts[domain]; ok {
return strings.Split(ip, ","), r.ttl, true
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
var h []Host
if err = json.Unmarshal([]byte(res), &h); err != nil {
return nil, err
}
for _, host := range h {
if host.Type == queryType {
return &host, nil
}
}
}
if idx := strings.Index(domain, "."); idx != -1 {
wildcard := "*." + domain[strings.Index(domain, ".")+1:]
if ip, ok := r.hosts[wildcard]; ok {
return strings.Split(ip, ","), r.ttl, true
if res, err := r.redis.HGet(r.key, wildcard).Result(); res != "" && err == nil {
var h []Host
if err = json.Unmarshal([]byte(res), &h); err != nil {
return nil, err
}
for _, host := range h {
if host.Type == queryType {
return &host, nil
}
}
}
}
for host, ip := range r.hosts {
if glob.Glob(host, domain) {
return strings.Split(ip, ","), r.ttl, true
return nil, errRecordNotFound
}
func (r *RedisHosts) Set(domain string, host *Host) error {
hosts := []*Host{host}
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
if err = json.Unmarshal([]byte(res), &hosts); err != nil {
return err
}
hosts = append(hosts, host)
}
return nil, time.Duration(0), false
}
b, err := json.Marshal(hosts)
func (r *RedisHosts) Set(t, domain, ip string) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()
}
func (r *RedisHosts) Refresh() {
r.mu.Lock()
defer r.mu.Unlock()
r.clear()
var err error
r.hosts, err = r.redis.HGetAll(r.key).Result()
if err != nil {
log.Warn("Update hosts records from redis failed %s", err)
} else {
log.Debug("Update hosts records from redis")
return err
}
}
func (r *RedisHosts) clear() {
r.hosts = make(map[string]string)
}
_, err = r.redis.HSet(r.key, strings.ToLower(domain), b).Result()
return err
}