Changes to allow host "Set" to be standard, providers being able to use other query types, cache all responses, etc.
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
f726a5d5ae
commit
3383c5e4f9
|
@ -29,7 +29,7 @@ maxcount = 0 #If set zero. The Sum of cache itmes will be unlimit.
|
|||
|
||||
# Redis cache backend config
|
||||
[cache.redis]
|
||||
host = "192.168.1.71"
|
||||
host = "127.0.0.1"
|
||||
port = 6379
|
||||
db = 0
|
||||
password =""
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
#Toml config file
|
||||
title = "GODNS"
|
||||
Version = "0.2.3"
|
||||
Author = "kenshin, tystuyfzand"
|
||||
|
||||
debug = false
|
||||
|
||||
[server]
|
||||
host = ""
|
||||
port = 53
|
||||
|
||||
[resolv]
|
||||
# Domain-specific nameservers configuration, formatting keep compatible with Dnsmasq
|
||||
# Semicolon separate multiple files.
|
||||
resolv-file = "/etc/resolv.conf"
|
||||
timeout = 5 # 5 seconds
|
||||
# The concurrency interval request upstream recursive server
|
||||
# Match the PR15, https://github.com/kenshinx/godns/pull/15
|
||||
interval = 200 # 200 milliseconds
|
||||
# When defined, this is preferred over regular DNS. This requires a resolver to be active besides this, only for the initial lookup.
|
||||
# A hosts file entry will suffice as well.
|
||||
# dns-over-https = "https://cloudflare-dns.com/dns-query"
|
||||
setedns0 = false #Support for larger UDP DNS responses
|
||||
|
||||
[redis]
|
||||
enable = true
|
||||
host = "127.0.0.1"
|
||||
port = 6379
|
||||
db = 0
|
||||
password =""
|
||||
|
||||
[memcache]
|
||||
servers = ["127.0.0.1:11211"]
|
||||
|
||||
[log]
|
||||
stdout = true
|
||||
file = "./godns.log"
|
||||
level = "INFO" #DEBUG | INFO |NOTICE | WARN | ERROR
|
||||
|
||||
[cache]
|
||||
# backend option [memory|memcache|redis]
|
||||
backend = "memory"
|
||||
expire = 600 # 10 minutes
|
||||
maxcount = 0 #If set zero. The Sum of cache items will be unlimit.
|
||||
|
||||
[hosts]
|
||||
#If set false, will not query hosts file and redis hosts record
|
||||
enable = true
|
||||
host-file = "/etc/hosts"
|
||||
redis-enable = false
|
||||
redis-key = "godns:hosts"
|
||||
ttl = 600
|
||||
# Refresh interval can be high since we have automatic updating via push and fsnotify
|
||||
refresh-interval = 300
|
||||
|
||||
|
114
handler.go
114
handler.go
|
@ -13,12 +13,6 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
notIPQuery = 0
|
||||
_IP4Query = 4
|
||||
_IP6Query = 6
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
resolver *resolver.Resolver
|
||||
cache, negCache cache.Cache
|
||||
|
@ -29,46 +23,48 @@ func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hos
|
|||
return &Handler{r, resolverCache, negCache, h}
|
||||
}
|
||||
|
||||
func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) {
|
||||
func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: dns.TypeToString[q.Qtype], Class: dns.ClassToString[q.Qclass]}
|
||||
question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: q.Qtype, Class: dns.ClassToString[q.Qclass]}
|
||||
|
||||
var remote net.IP
|
||||
if Net == "tcp" {
|
||||
remote = w.RemoteAddr().(*net.TCPAddr).IP
|
||||
} else {
|
||||
remote = w.RemoteAddr().(*net.UDPAddr).IP
|
||||
|
||||
switch t := w.RemoteAddr().(type) {
|
||||
case *net.TCPAddr:
|
||||
remote = t.IP
|
||||
case *net.UDPAddr:
|
||||
remote = t.IP
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("%s lookup %s", remote, question.String())
|
||||
|
||||
IPQuery := h.isIPQuery(q)
|
||||
|
||||
// Query hosts
|
||||
if h.hosts != nil && IPQuery > 0 {
|
||||
if ips, ttl, ok := h.hosts.Get(question.Name, IPQuery); ok {
|
||||
if h.hosts != nil {
|
||||
if vals, ttl, ok := h.hosts.Get(question.Type, question.Name); ok {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
|
||||
switch IPQuery {
|
||||
case _IP4Query:
|
||||
hdr := dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(ttl / time.Second),
|
||||
hdr := dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: question.Type,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(ttl / time.Second),
|
||||
}
|
||||
|
||||
switch question.Type {
|
||||
case dns.TypeA:
|
||||
for _, val := range vals {
|
||||
m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: net.ParseIP(val).To4()})
|
||||
}
|
||||
for _, ip := range ips {
|
||||
m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: ip})
|
||||
case dns.TypeAAAA:
|
||||
for _, val := range vals {
|
||||
m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(val).To16()})
|
||||
}
|
||||
case _IP6Query:
|
||||
hdr := dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(ttl / time.Second),
|
||||
}
|
||||
for _, ip := range ips {
|
||||
m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
|
||||
case dns.TypeCNAME:
|
||||
for _, val := range vals {
|
||||
m.Answer = append(m.Answer, &dns.CNAME{Hdr: hdr, Target: val})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,29 +76,28 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
// Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN'
|
||||
key := KeyGen(question)
|
||||
if IPQuery > 0 {
|
||||
mesg, err := h.cache.Get(key)
|
||||
if err != nil {
|
||||
if mesg, err = h.negCache.Get(key); err != nil {
|
||||
log.Debug("%s didn't hit cache", question.String())
|
||||
} else {
|
||||
log.Debug("%s hit negative cache", question.String())
|
||||
dns.HandleFailed(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
mesg, err := h.cache.Get(key)
|
||||
|
||||
if err != nil {
|
||||
if mesg, err = h.negCache.Get(key); err != nil {
|
||||
log.Debug("%s didn't hit cache", question.String())
|
||||
} else {
|
||||
log.Debug("%s hit cache", question.String())
|
||||
// we need this copy against concurrent modification of Id
|
||||
msg := *mesg
|
||||
msg.Id = req.Id
|
||||
w.WriteMsg(&msg)
|
||||
log.Debug("%s hit negative cache", question.String())
|
||||
dns.HandleFailed(w, req)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Debug("%s hit cache", question.String())
|
||||
// we need this copy against concurrent modification of Id
|
||||
msg := *mesg
|
||||
msg.Id = req.Id
|
||||
w.WriteMsg(&msg)
|
||||
return
|
||||
}
|
||||
|
||||
mesg, err := h.resolver.Lookup(Net, req)
|
||||
mesg, err = h.resolver.Lookup(network, req)
|
||||
|
||||
if err != nil {
|
||||
log.Warn("Resolve query error %s", err)
|
||||
|
@ -117,7 +112,7 @@ func (h *Handler) do(Net string, w dns.ResponseWriter, req *dns.Msg) {
|
|||
|
||||
w.WriteMsg(mesg)
|
||||
|
||||
if IPQuery > 0 && len(mesg.Answer) > 0 {
|
||||
if len(mesg.Answer) > 0 {
|
||||
err = h.cache.Set(key, mesg)
|
||||
if err != nil {
|
||||
log.Warn("Set %s cache failed: %s", question.String(), err.Error())
|
||||
|
@ -132,21 +127,6 @@ func (h *Handler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handler) isIPQuery(q dns.Question) int {
|
||||
if q.Qclass != dns.ClassINET {
|
||||
return notIPQuery
|
||||
}
|
||||
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
return _IP4Query
|
||||
case dns.TypeAAAA:
|
||||
return _IP6Query
|
||||
default:
|
||||
return notIPQuery
|
||||
}
|
||||
}
|
||||
|
||||
func KeyGen(q resolver.Question) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(q.String()))
|
||||
|
|
|
@ -1,22 +1,15 @@
|
|||
package hosts
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
notIPQuery = 0
|
||||
_IP4Query = 4
|
||||
_IP6Query = 6
|
||||
)
|
||||
|
||||
var (
|
||||
zeroDuration = time.Duration(0)
|
||||
)
|
||||
|
||||
type Hosts interface {
|
||||
Get(domain string, family int) ([]net.IP, time.Duration, bool)
|
||||
Get(queryType uint16, domain string) ([]string, time.Duration, bool)
|
||||
}
|
||||
|
||||
type ProviderList struct {
|
||||
|
@ -24,7 +17,8 @@ type ProviderList struct {
|
|||
}
|
||||
|
||||
type Provider interface {
|
||||
Get(domain string) ([]string, time.Duration, bool)
|
||||
Get(queryType uint16, domain string) ([]string, time.Duration, bool)
|
||||
Set(t, domain, value string) (bool, error)
|
||||
}
|
||||
|
||||
func NewHosts(providers []Provider) Hosts {
|
||||
|
@ -34,38 +28,22 @@ func NewHosts(providers []Provider) Hosts {
|
|||
/*
|
||||
Match local /etc/hosts file first, remote redis records second
|
||||
*/
|
||||
func (h *ProviderList) Get(domain string, family int) ([]net.IP, time.Duration, bool) {
|
||||
var sips []string
|
||||
func (h *ProviderList) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
|
||||
var vals []string
|
||||
var ok bool
|
||||
var ip net.IP
|
||||
var ips []net.IP
|
||||
var ttl time.Duration
|
||||
|
||||
for _, provider := range h.providers {
|
||||
sips, ttl, ok = provider.Get(domain)
|
||||
vals, ttl, ok = provider.Get(queryType, domain)
|
||||
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sips == nil {
|
||||
if vals == nil {
|
||||
return nil, zeroDuration, false
|
||||
}
|
||||
|
||||
for _, sip := range sips {
|
||||
switch family {
|
||||
case _IP4Query:
|
||||
ip = net.ParseIP(sip).To4()
|
||||
case _IP6Query:
|
||||
ip = net.ParseIP(sip).To16()
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if ip != nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return ips, ttl, ips != nil
|
||||
return vals, ttl, true
|
||||
}
|
|
@ -2,7 +2,9 @@ package hosts
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ryanuber/go-glob"
|
||||
"meow.tf/joker/godns/log"
|
||||
"meow.tf/joker/godns/utils"
|
||||
|
@ -49,8 +51,13 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
|
|||
return fp
|
||||
}
|
||||
|
||||
func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) {
|
||||
log.Debug("Checking file provider for %s", domain)
|
||||
func (f *FileHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
|
||||
log.Debug("Checking file provider for %s : %s", queryType, domain)
|
||||
|
||||
// Does not support CNAME/TXT/etc
|
||||
if queryType != dns.TypeA && queryType != dns.TypeAAAA {
|
||||
return nil, zeroDuration, false
|
||||
}
|
||||
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
@ -77,6 +84,10 @@ func (f *FileHosts) Get(domain string) ([]string, time.Duration, bool) {
|
|||
return nil, time.Duration(0), false
|
||||
}
|
||||
|
||||
func (f *FileHosts) Set(t, domain, value string) (bool, error) {
|
||||
return false, errors.New("file provider does not support setting values")
|
||||
}
|
||||
|
||||
var (
|
||||
hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$")
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ package hosts
|
|||
|
||||
import (
|
||||
"github.com/go-redis/redis/v7"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ryanuber/go-glob"
|
||||
"meow.tf/joker/godns/log"
|
||||
"strings"
|
||||
|
@ -33,9 +34,14 @@ func NewRedisProvider(rc *redis.Client, key string, ttl time.Duration) Provider
|
|||
return rh
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) {
|
||||
func (r *RedisHosts) Get(queryType uint16, domain string) ([]string, time.Duration, bool) {
|
||||
log.Debug("Checking redis provider for %s", domain)
|
||||
|
||||
// Don't support queries other than A/AAAA for now
|
||||
if queryType != dns.TypeA || queryType != dns.TypeAAAA {
|
||||
return nil, zeroDuration, false
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
|
@ -62,7 +68,7 @@ func (r *RedisHosts) Get(domain string) ([]string, time.Duration, bool) {
|
|||
return nil, time.Duration(0), false
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Set(domain, ip string) (bool, error) {
|
||||
func (r *RedisHosts) Set(t, domain, ip string) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.redis.HSet(r.key, strings.ToLower(domain), []byte(ip)).Result()
|
||||
|
|
2
main.go
2
main.go
|
@ -48,7 +48,7 @@ func main() {
|
|||
}
|
||||
|
||||
var resolverCache, negCache cache.Cache
|
||||
r := resolver.NewResolver(settings.ResolvSettings{
|
||||
r := resolver.NewResolver(resolver.Settings{
|
||||
Timeout: viper.GetInt("resolv.timeout"),
|
||||
Interval: viper.GetInt("resolv.interval"),
|
||||
SetEDNS0: viper.GetBool("resolv.edns0"),
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package resolver
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
type Question struct {
|
||||
Name string
|
||||
Type string
|
||||
Type uint16
|
||||
Class string
|
||||
}
|
||||
|
||||
func (q *Question) String() string {
|
||||
return q.Name + " " + q.Class + " " + q.Type
|
||||
return q.Name + " " + q.Class + " " + dns.TypeToString[q.Type]
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"meow.tf/joker/godns/log"
|
||||
"meow.tf/joker/godns/settings"
|
||||
"meow.tf/joker/godns/utils"
|
||||
"net"
|
||||
"os"
|
||||
|
@ -37,13 +36,13 @@ type RResp struct {
|
|||
type Resolver struct {
|
||||
servers []string
|
||||
domain_server *suffixTreeNode
|
||||
config *settings.ResolvSettings
|
||||
config *Settings
|
||||
|
||||
clients map[string]*dns.Client
|
||||
clientLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewResolver(c settings.ResolvSettings) *Resolver {
|
||||
func NewResolver(c Settings) *Resolver {
|
||||
r := &Resolver{
|
||||
servers: []string{},
|
||||
domain_server: newSuffixTreeRoot(),
|
||||
|
@ -109,26 +108,26 @@ func (r *Resolver) parseServerListFile(buf *os.File) {
|
|||
|
||||
r.domain_server.sinsert(strings.Split(domain, "."), ip)
|
||||
case 1:
|
||||
srv_port := strings.Split(line, "#")
|
||||
srvPort := strings.Split(line, "#")
|
||||
|
||||
if len(srv_port) > 2 {
|
||||
if len(srvPort) > 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := ""
|
||||
|
||||
if ip = srv_port[0]; !utils.IsIP(ip) {
|
||||
if ip = srvPort[0]; !utils.IsIP(ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
port := "53"
|
||||
|
||||
if len(srv_port) == 2 {
|
||||
if _, err := strconv.Atoi(srv_port[1]); err != nil {
|
||||
if len(srvPort) == 2 {
|
||||
if _, err := strconv.Atoi(srvPort[1]); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
port = srv_port[1]
|
||||
port = srvPort[1]
|
||||
}
|
||||
|
||||
r.servers = append(r.servers, net.JoinHostPort(ip, port))
|
||||
|
@ -222,8 +221,14 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error
|
|||
}
|
||||
|
||||
func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) {
|
||||
key := net
|
||||
|
||||
if net == "tcp-tls" {
|
||||
key = net + ":" + nameserver
|
||||
}
|
||||
|
||||
r.clientLock.RLock()
|
||||
client, exists := r.clients[net]
|
||||
client, exists := r.clients[key]
|
||||
r.clientLock.RUnlock()
|
||||
|
||||
if exists {
|
||||
|
@ -249,7 +254,7 @@ func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) {
|
|||
}
|
||||
|
||||
r.clientLock.Lock()
|
||||
r.clients[net] = client
|
||||
r.clients[key] = client
|
||||
r.clientLock.Lock()
|
||||
|
||||
return client, nil
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
package resolver
|
||||
|
||||
type Settings struct {
|
||||
Timeout int `toml:"timeout" env:"RESOLV_TIMEOUT"`
|
||||
Interval int `toml:"interval" env:"RESOLV_INTERVAL"`
|
||||
SetEDNS0 bool `toml:"setedns0" env:"RESOLV_EDNS0"`
|
||||
ServerListFile []string `toml:"server-list-file" env:"SERVER_LIST_FILE"`
|
||||
ResolvFile string `toml:"resolv-file" env:"RESOLV_FILE"`
|
||||
}
|
Loading…
Reference in New Issue