Add bolt provider, rewrite hosts, start of api, start of update via nsupdate
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -8,8 +9,18 @@ var (
|
||||
zeroDuration = time.Duration(0)
|
||||
)
|
||||
|
||||
type Host struct {
|
||||
Type uint16 `json:"type"`
|
||||
TTL time.Duration `json:"ttl"`
|
||||
Values []string `json:"values"`
|
||||
}
|
||||
|
||||
func (h *Host) TypeString() string {
|
||||
return dns.TypeToString[h.Type]
|
||||
}
|
||||
|
||||
type Hosts interface {
|
||||
Get(queryType uint16, domain string) ([]string, time.Duration, bool)
|
||||
Get(queryType uint16, domain string) (*Host, error)
|
||||
}
|
||||
|
||||
type ProviderList struct {
|
||||
@ -17,33 +28,26 @@ type ProviderList struct {
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
Get(queryType uint16, domain string) ([]string, time.Duration, bool)
|
||||
Set(t, domain, value string) (bool, error)
|
||||
Get(queryType uint16, domain string) (*Host, error)
|
||||
Set(domain string, host *Host) error
|
||||
}
|
||||
|
||||
func NewHosts(providers []Provider) Hosts {
|
||||
return &ProviderList{providers}
|
||||
}
|
||||
|
||||
/*
|
||||
Match local /etc/hosts file first, remote redis records second
|
||||
*/
|
||||
func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
|
||||
var vals []string
|
||||
var ok bool
|
||||
var ttl time.Duration
|
||||
// Get Matches values to providers, loping each in order
|
||||
func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
|
||||
var host *Host
|
||||
var err error
|
||||
|
||||
for _, provider := range h.providers {
|
||||
vals, ttl, ok = provider.Get(queryType, domain)
|
||||
host, err = provider.Get(queryType, domain)
|
||||
|
||||
if ok {
|
||||
if host != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if vals == nil {
|
||||
return nil, zeroDuration, false
|
||||
}
|
||||
|
||||
return vals, ttl, true
|
||||
return host, err
|
||||
}
|
124
hosts/hosts_bolt.go
Normal file
124
hosts/hosts_bolt.go
Normal file
@ -0,0 +1,124 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
recordBucket = "records"
|
||||
)
|
||||
|
||||
type BoltHosts struct {
|
||||
Provider
|
||||
|
||||
db *bolt.DB
|
||||
}
|
||||
|
||||
func NewBoltProvider(file string) Provider {
|
||||
db, err := bolt.Open(file, 0600, &bolt.Options{})
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Fatalln("Unable to open database")
|
||||
}
|
||||
|
||||
err = db.Update(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(recordBucket))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &BoltHosts{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
log.Debug("Checking bolt provider for %s : %s", queryType, domain)
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
var err error
|
||||
|
||||
key := domain + "_" + dns.TypeToString[queryType]
|
||||
var v []byte
|
||||
|
||||
err = b.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("records"))
|
||||
|
||||
v = b.Get([]byte(key))
|
||||
|
||||
if string(v) == "" {
|
||||
return errors.New( "Record not found, key: " + key)
|
||||
}
|
||||
|
||||
v = b.Get([]byte("*." + key))
|
||||
|
||||
if string(v) == "" {
|
||||
return errors.New( "Record not found, key: " + key)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var h []Host
|
||||
|
||||
if err = json.Unmarshal(v, &h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, host := range h {
|
||||
if host.Type == queryType {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
func (b *BoltHosts) Set(domain string, host *Host) error {
|
||||
err := b.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte(recordBucket))
|
||||
|
||||
hosts := []*Host{host}
|
||||
|
||||
existing := b.Get([]byte(domain))
|
||||
|
||||
if existing != nil {
|
||||
err := json.Unmarshal(existing, &hosts)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
|
||||
hostBytes, err := json.Marshal(hosts)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.Put([]byte(domain), hostBytes)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
@ -6,7 +6,7 @@ import (
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ryanuber/go-glob"
|
||||
"meow.tf/joker/godns/log"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"meow.tf/joker/godns/utils"
|
||||
"os"
|
||||
"regexp"
|
||||
@ -15,11 +15,17 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnsupportedType = errors.New("unsupported type")
|
||||
errRecordNotFound = errors.New("record not found")
|
||||
errUnsupportedOperation = errors.New("unsupported operation")
|
||||
)
|
||||
|
||||
type FileHosts struct {
|
||||
Provider
|
||||
|
||||
file string
|
||||
hosts map[string]string
|
||||
hosts map[string]Host
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
@ -27,7 +33,7 @@ type FileHosts struct {
|
||||
func NewFileProvider(file string, ttl time.Duration) Provider {
|
||||
fp := &FileHosts{
|
||||
file: file,
|
||||
hosts: make(map[string]string),
|
||||
hosts: make(map[string]Host),
|
||||
ttl: ttl,
|
||||
}
|
||||
|
||||
@ -51,41 +57,41 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
|
||||
return fp
|
||||
}
|
||||
|
||||
func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
|
||||
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
log.Debug("Checking file provider for %s : %s", queryType, domain)
|
||||
|
||||
// Does not support CNAME/TXT/etc
|
||||
if queryType != dns.TypeA && queryType != dns.TypeAAAA {
|
||||
return nil, zeroDuration, false
|
||||
return nil, errUnsupportedType
|
||||
}
|
||||
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
if ip, ok := f.hosts[domain]; ok {
|
||||
return strings.Split(ip, ","), f.ttl, true
|
||||
if host, ok := f.hosts[domain]; ok {
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
if idx := strings.Index(domain, "."); idx != -1 {
|
||||
wildcard := "*." + domain[strings.Index(domain, ".") + 1:]
|
||||
|
||||
if ip, ok := f.hosts[wildcard]; ok {
|
||||
return strings.Split(ip, ","), f.ttl, true
|
||||
if host, ok := f.hosts[wildcard]; ok {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
|
||||
for host, ip := range f.hosts {
|
||||
if glob.Glob(host, domain) {
|
||||
return strings.Split(ip, ","), f.ttl, true
|
||||
for hostname, host := range f.hosts {
|
||||
if glob.Glob(hostname, domain) {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, time.Duration(0), false
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
func (f *FileHosts) Set(t, domain, value string) (bool, error) {
|
||||
return false, errors.New("file provider does not support setting values")
|
||||
func (f *FileHosts) Set(domain string, host *Host) error {
|
||||
return errUnsupportedOperation
|
||||
}
|
||||
|
||||
var (
|
||||
@ -112,7 +118,7 @@ func (f *FileHosts) Refresh() {
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.HasPrefix(line, "#") || line == "" {
|
||||
if line == "" || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -140,12 +146,12 @@ func (f *FileHosts) Refresh() {
|
||||
continue
|
||||
}
|
||||
|
||||
f.hosts[strings.ToLower(domain)] = ip
|
||||
f.hosts[strings.ToLower(domain)] = Host{Values: []string{ip}}
|
||||
}
|
||||
}
|
||||
log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts))
|
||||
}
|
||||
|
||||
func (f *FileHosts) clear() {
|
||||
f.hosts = make(map[string]string)
|
||||
f.hosts = make(map[string]Host)
|
||||
}
|
||||
|
@ -1,13 +1,10 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/go-redis/redis/v7"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ryanuber/go-glob"
|
||||
"meow.tf/joker/godns/log"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RedisHosts struct {
|
||||
@ -15,79 +12,75 @@ type RedisHosts struct {
|
||||
|
||||
redis *redis.Client
|
||||
key string
|
||||
hosts map[string]string
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider {
|
||||
func NewRedisProvider(rc *redis.Client, key string) Provider {
|
||||
rh := &RedisHosts{
|
||||
redis: rc,
|
||||
key: key,
|
||||
hosts: make(map[string]string),
|
||||
ttl: ttl,
|
||||
}
|
||||
|
||||
// Force an initial refresh
|
||||
rh.Refresh()
|
||||
|
||||
return rh
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
|
||||
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
log.Debug("Checking redis provider for %s", domain)
|
||||
|
||||
// Don't support queries other than A/AAAA for now
|
||||
if queryType != dns.TypeA || queryType != dns.TypeAAAA {
|
||||
return nil, zeroDuration, false
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
if ip, ok := r.hosts[domain]; ok {
|
||||
return strings.Split(ip, ","), r.ttl, true
|
||||
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
|
||||
var h []Host
|
||||
|
||||
if err = json.Unmarshal([]byte(res), &h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, host := range h {
|
||||
if host.Type == queryType {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idx := strings.Index(domain, "."); idx != -1 {
|
||||
wildcard := "*." + domain[strings.Index(domain, ".")+1:]
|
||||
|
||||
if ip, ok := r.hosts[wildcard]; ok {
|
||||
return strings.Split(ip, ","), r.ttl, true
|
||||
if res, err := r.redis.HGet(r.key, wildcard).Result(); res != "" && err == nil {
|
||||
var h []Host
|
||||
|
||||
if err = json.Unmarshal([]byte(res), &h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, host := range h {
|
||||
if host.Type == queryType {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for host, ip := range r.hosts {
|
||||
if glob.Glob(host, domain) {
|
||||
return strings.Split(ip, ","), r.ttl, true
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Set(domain string, host *Host) error {
|
||||
hosts := []*Host{host}
|
||||
|
||||
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
|
||||
if err = json.Unmarshal([]byte(res), &hosts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
|
||||
return nil, time.Duration(0), false
|
||||
}
|
||||
b, err := json.Marshal(hosts)
|
||||
|
||||
func (r *RedisHosts) Set(t, domain, ip string) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Refresh() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.clear()
|
||||
|
||||
var err error
|
||||
r.hosts, err = r.redis.HGetAll(r.key).Result()
|
||||
if err != nil {
|
||||
log.Warn("Update hosts records from redis failed %s", err)
|
||||
} else {
|
||||
log.Debug("Update hosts records from redis")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisHosts) clear() {
|
||||
r.hosts = make(map[string]string)
|
||||
}
|
||||
_, err = r.redis.HSet(r.key, strings.ToLower(domain), b).Result()
|
||||
|
||||
return err
|
||||
}
|
Reference in New Issue
Block a user