Changes to allow host "Set" to be standard, providers being able to use other query types, cache all responses, etc.
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Tyler 2020-02-07 22:38:22 -05:00
parent f726a5d5ae
commit 3383c5e4f9
10 changed files with 107 additions and 172 deletions

View File

@ -29,7 +29,7 @@ maxcount = 0 #If set zero. The Sum of cache itmes will be unlimit.
# Redis cache backend config # Redis cache backend config
[cache.redis] [cache.redis]
host = "192.168.1.71" host = "127.0.0.1"
port = 6379 port = 6379
db = 0 db = 0
password ="" password =""

View File

@ -1,56 +0,0 @@
#Toml config file
title = "GODNS"
Version = "0.2.3"
Author = "kenshin, tystuyfzand"
debug = false
[server]
host = ""
port = 53
[resolv]
# Domain-specific nameservers configuration, formatting keep compatible with Dnsmasq
# Semicolon separate multiple files.
resolv-file = "/etc/resolv.conf"
timeout = 5 # 5 seconds
# The concurrency interval request upstream recursive server
# Match the PR15, https://github.com/kenshinx/godns/pull/15
interval = 200 # 200 milliseconds
# When defined, this is preferred over regular DNS. This requires a resolver to be active besides this, only for the initial lookup.
# A hosts file entry will suffice as well.
# dns-over-https = "https://cloudflare-dns.com/dns-query"
setedns0 = false #Support for larger UDP DNS responses
[redis]
enable = true
host = "127.0.0.1"
port = 6379
db = 0
password =""
[memcache]
servers = ["127.0.0.1:11211"]
[log]
stdout = true
file = "./godns.log"
level = "INFO" #DEBUG | INFO |NOTICE | WARN | ERROR
[cache]
# backend option [memory|memcache|redis]
backend = "memory"
expire = 600 # 10 minutes
maxcount = 0 #If set zero. The Sum of cache items will be unlimit.
[hosts]
#If set false, will not query hosts file and redis hosts record
enable = true
host-file = "/etc/hosts"
redis-enable = false
redis-key = "godns:hosts"
ttl = 600
# Refresh interval can be high since we have automatic updating via push and fsnotify
refresh-interval = 300

View File

@ -13,12 +13,6 @@ import (
"time" "time"
) )
const (
notIPQuery = 0
_IP4Query = 4
_IP6Query = 6
)
type Handler struct { type Handler struct {
resolver *resolver.Resolver resolver *resolver.Resolver
cache, negCache cache.Cache cache, negCache cache.Cache
@ -29,46 +23,48 @@ func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hos
return &Handler{r, resolverCache, negCache, h} return &Handler{r, resolverCache, negCache, h}
} }
func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0] q := req.Question[0]
question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: dns.TypeToString[q.Qtype], Class: dns.ClassToString[q.Qclass]} question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]}
var remote net.IP var remote net.IP
if Net == "tcp" {
remote = w.RemoteAddr().(*net.TCPAddr).IP switch t := w.RemoteAddr().(type) {
} else { case *net.TCPAddr:
remote = w.RemoteAddr().(*net.UDPAddr).IP remote = t.IP
case *net.UDPAddr:
remote = t.IP
default:
return
} }
log.Info("%s lookup %s", remote, question.String()) log.Info("%s lookup %s", remote, question.String())
IPQuery := h.isIPQuery(q)
// Query hosts // Query hosts
if h.hosts != nil && IPQuery > 0 { if h.hosts != nil {
if ips, ttl, ok := h.hosts.Get(question.Name, IPQuery); ok { if vals, ttl, ok := h.hosts.Get(question.Type, question.Name); ok {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(req) m.SetReply(req)
switch IPQuery {
case _IP4Query:
hdr := dns.RR_Header{ hdr := dns.RR_Header{
Name: q.Name, Name: q.Name,
Rrtype: dns.TypeA, Rrtype: question.Type,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second), Ttl: uint32(ttl / time.Second),
} }
for _, ip := range ips {
m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: ip}) switch question.Type {
case dns.TypeA:
for _, val := range vals {
m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()})
} }
case _IP6Query: case dns.TypeAAAA:
hdr := dns.RR_Header{ for _, val := range vals {
Name: q.Name, m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()})
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second),
} }
for _, ip := range ips { case dns.TypeCNAME:
m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip}) for _, val := range vals {
m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val})
} }
} }
@ -80,10 +76,10 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) {
} }
} }
// Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN'
key := KeyGen(question) key := KeyGen(question)
if IPQuery > 0 {
mesg, err := h.cache.Get(key) mesg, err := h.cache.Get(key)
if err != nil { if err != nil {
if mesg, err = h.negCache.Get(key); err != nil { if mesg, err = h.negCache.Get(key); err != nil {
log.Debug("%s didn't hit cache", question.String()) log.Debug("%s didn't hit cache", question.String())
@ -100,9 +96,8 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) {
w.WriteMsg(&msg) w.WriteMsg(&msg)
return return
} }
}
mesg, err := h.resolver.Lookup(Net, req) mesg, err = h.resolver.Lookup(network, req)
if err != nil { if err != nil {
log.Warn("Resolve query error %s", err) log.Warn("Resolve query error %s", err)
@ -117,7 +112,7 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) {
w.WriteMsg(mesg) w.WriteMsg(mesg)
if IPQuery > 0 && len(mesg.Answer) > 0 { if len(mesg.Answer) > 0 {
err = h.cache.Set(key, mesg) err = h.cache.Set(key, mesg)
if err != nil { if err != nil {
log.Warn("Set %s cache failed: %s", question.String(), err.Error()) log.Warn("Set %s cache failed: %s", question.String(), err.Error())
@ -132,21 +127,6 @@ func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) {
} }
} }
func (h *Handler) isIPQuery(q dns.Question) int {
if q.Qclass != dns.ClassINET {
return notIPQuery
}
switch q.Qtype {
case dns.TypeA:
return _IP4Query
case dns.TypeAAAA:
return _IP6Query
default:
return notIPQuery
}
}
func KeyGen(q resolver.Question) string { func KeyGen(q resolver.Question) string {
h := md5.New() h := md5.New()
h.Write([]byte(q.String())) h.Write([]byte(q.String()))

View File

@ -1,22 +1,15 @@
package hosts package hosts
import ( import (
"net"
"time" "time"
) )
const (
notIPQuery = 0
_IP4Query = 4
_IP6Query = 6
)
var ( var (
zeroDuration = time.Duration(0) zeroDuration = time.Duration(0)
) )
type Hosts interface { type Hosts interface {
Get(domain string, family int) ([]net.IP, time.Duration, bool) Get(queryType uint16, domain string) ([]string, time.Duration, bool)
} }
type ProviderList struct { type ProviderList struct {
@ -24,7 +17,8 @@ type ProviderList struct {
} }
type Provider interface { type Provider interface {
Get(domain string) ([]string, time.Duration, bool) Get(queryType uint16, domain string) ([]string, time.Duration, bool)
Set(t, domain, value string) (bool, error)
} }
func NewHosts(providers []Provider) Hosts { func NewHosts(providers []Provider) Hosts {
@ -34,38 +28,22 @@ func NewHosts(providers []Provider) Hosts {
/* /*
Match local /etc/hosts file first, remote redis records second Match local /etc/hosts file first, remote redis records second
*/ */
func (h *ProviderList) Get(domain string, family int) ([]net.IP, time.Duration, bool) { func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
var sips []string var vals []string
var ok bool var ok bool
var ip net.IP
var ips []net.IP
var ttl time.Duration var ttl time.Duration
for _, provider := range h.providers { for _, provider := range h.providers {
sips, ttl, ok = provider.Get(domain) vals, ttl, ok = provider.Get(queryType, domain)
if ok { if ok {
break break
} }
} }
if sips == nil { if vals == nil {
return nil, zeroDuration, false return nil, zeroDuration, false
} }
for _, sip := range sips { return vals, ttl, true
switch family {
case _IP4Query:
ip = net.ParseIP(sip).To4()
case _IP6Query:
ip = net.ParseIP(sip).To16()
default:
continue
}
if ip != nil {
ips = append(ips, ip)
}
}
return ips, ttl, ips != nil
} }

View File

@ -2,7 +2,9 @@ package hosts
import ( import (
"bufio" "bufio"
"errors"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
"github.com/ryanuber/go-glob" "github.com/ryanuber/go-glob"
"meow.tf/joker/godns/log" "meow.tf/joker/godns/log"
"meow.tf/joker/godns/utils" "meow.tf/joker/godns/utils"
@ -49,8 +51,13 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
return fp return fp
} }
func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) { func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
log.Debug("Checking file provider for %s", domain) log.Debug("Checking file provider for %s : %s", queryType, domain)
// Does not support CNAME/TXT/etc
if queryType != dns.TypeA && queryType != dns.TypeAAAA {
return nil, zeroDuration, false
}
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
@ -77,6 +84,10 @@ func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) {
return nil, time.Duration(0), false return nil, time.Duration(0), false
} }
func (f *FileHosts) Set(t, domain, value string) (bool, error) {
return false, errors.New("file provider does not support setting values")
}
var ( var (
hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$") hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$")
) )

View File

@ -2,6 +2,7 @@ package hosts
import ( import (
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v7"
"github.com/miekg/dns"
"github.com/ryanuber/go-glob" "github.com/ryanuber/go-glob"
"meow.tf/joker/godns/log" "meow.tf/joker/godns/log"
"strings" "strings"
@ -33,9 +34,14 @@ func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider
return rh return rh
} }
func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) { func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
log.Debug("Checking redis provider for %s", domain) log.Debug("Checking redis provider for %s", domain)
// Don't support queries other than A/AAAA for now
if queryType != dns.TypeA || queryType != dns.TypeAAAA {
return nil, zeroDuration, false
}
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
@ -62,7 +68,7 @@ func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) {
return nil, time.Duration(0), false return nil, time.Duration(0), false
} }
func (r *RedisHosts) Set(domain, ip string) (bool, error) { func (r *RedisHosts) Set(t, domain, ip string) (bool, error) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result() return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()

View File

@ -48,7 +48,7 @@ func main() {
} }
var resolverCache, negCache cache.Cache var resolverCache, negCache cache.Cache
r := resolver.NewResolver(settings.ResolvSettings{ r := resolver.NewResolver(resolver.Settings{
Timeout: viper.GetInt("resolv.timeout"), Timeout: viper.GetInt("resolv.timeout"),
Interval: viper.GetInt("resolv.interval"), Interval: viper.GetInt("resolv.interval"),
SetEDNS0: viper.GetBool("resolv.edns0"), SetEDNS0: viper.GetBool("resolv.edns0"),

View File

@ -1,11 +1,13 @@
package resolver package resolver
import "github.com/miekg/dns"
type Question struct { type Question struct {
Name string Name string
Type string Type uint16
Class string Class string
} }
func (q *Question) String() string { func (q *Question) String() string {
return q.Name + " " + q.Class + " " + q.Type return q.Name + " " + q.Class + " " + dns.TypeToString[q.Type]
} }

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"meow.tf/joker/godns/log" "meow.tf/joker/godns/log"
"meow.tf/joker/godns/settings"
"meow.tf/joker/godns/utils" "meow.tf/joker/godns/utils"
"net" "net"
"os" "os"
@ -37,13 +36,13 @@ type RResp struct {
type Resolver struct { type Resolver struct {
servers []string servers []string
domain_server *suffixTreeNode domain_server *suffixTreeNode
config *settings.ResolvSettings config *Settings
clients map[string]*dns.Client clients map[string]*dns.Client
clientLock sync.RWMutex clientLock sync.RWMutex
} }
func NewResolver(c settings.ResolvSettings) *Resolver { func NewResolver(c Settings) *Resolver {
r := &Resolver{ r := &Resolver{
servers: []string{}, servers: []string{},
domain_server: newSuffixTreeRoot(), domain_server: newSuffixTreeRoot(),
@ -109,26 +108,26 @@ func (r *Resolver) parseServerListFile(buf *os.File) {
r.domain_server.sinsert(strings.Split(domain, "."), ip) r.domain_server.sinsert(strings.Split(domain, "."), ip)
case 1: case 1:
srv_port := strings.Split(line, "#") srvPort := strings.Split(line, "#")
if len(srv_port) > 2 { if len(srvPort) > 2 {
continue continue
} }
ip := "" ip := ""
if ip = srv_port[0]; !utils.IsIP(ip) { if ip = srvPort[0]; !utils.IsIP(ip) {
continue continue
} }
port := "53" port := "53"
if len(srv_port) == 2 { if len(srvPort) == 2 {
if _, err := strconv.Atoi(srv_port[1]); err != nil { if _, err := strconv.Atoi(srvPort[1]); err != nil {
continue continue
} }
port = srv_port[1] port = srvPort[1]
} }
r.servers = append(r.servers, net.JoinHostPort(ip, port)) r.servers = append(r.servers, net.JoinHostPort(ip, port))
@ -222,8 +221,14 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
} }
func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) { func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) {
key := net
if net == "tcp-tls" {
key = net + ":" + nameserver
}
r.clientLock.RLock() r.clientLock.RLock()
client, exists := r.clients[net] client, exists := r.clients[key]
r.clientLock.RUnlock() r.clientLock.RUnlock()
if exists { if exists {
@ -249,7 +254,7 @@ func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) {
} }
r.clientLock.Lock() r.clientLock.Lock()
r.clients[net] = client r.clients[key] = client
r.clientLock.Lock() r.clientLock.Lock()
return client, nil return client, nil

9
resolver/settings.go Normal file
View File

@ -0,0 +1,9 @@
package resolver
type Settings struct {
Timeout int `toml:"timeout" env:"RESOLV_TIMEOUT"`
Interval int `toml:"interval" env:"RESOLV_INTERVAL"`
SetEDNS0 bool `toml:"setedns0" env:"RESOLV_EDNS0"`
ServerListFile []string `toml:"server-list-file" env:"SERVER_LIST_FILE"`
ResolvFile string `toml:"resolv-file" env:"RESOLV_FILE"`
}