diff --git a/cache.go b/cache/cache.go similarity index 74% rename from cache.go rename to cache/cache.go index ebaabc9..f580f42 100644 --- a/cache.go +++ b/cache/cache.go @@ -1,10 +1,8 @@ -package main +package cache import ( "fmt" - "time" - "crypto/md5" "github.com/miekg/dns" ) @@ -39,23 +37,11 @@ func (e SerializerError) Error() string { return fmt.Sprintf("Serializer error: got %v", e.err) } -type Mesg struct { - Msg *dns.Msg - Expire time.Time -} - type Cache interface { Get(key string) (Msg *dns.Msg, err error) Set(key string, Msg *dns.Msg) error Exists(key string) bool Remove(key string) error Full() bool -} - -func KeyGen(q Question) string { - h := md5.New() - h.Write([]byte(q.String())) - x := h.Sum(nil) - key := fmt.Sprintf("%x", x) - return key + Purge() error } \ No newline at end of file diff --git a/cache_memcached.go b/cache/cache_memcached.go similarity index 93% rename from cache_memcached.go rename to cache/cache_memcached.go index a30f829..5b7b55a 100644 --- a/cache_memcached.go +++ b/cache/cache_memcached.go @@ -1,4 +1,4 @@ -package main +package cache import ( "github.com/bradfitz/gomemcache/memcache" @@ -64,3 +64,7 @@ func (m *MemcachedCache) Full() bool { // memcache is never full (LRU) return false } + +func (m *MemcachedCache) Purge() error { + return m.backend.DeleteAll() +} diff --git a/cache_memory.go b/cache/cache_memory.go similarity index 64% rename from cache_memory.go rename to cache/cache_memory.go index 5e57f21..d4b839a 100644 --- a/cache_memory.go +++ b/cache/cache_memory.go @@ -1,4 +1,4 @@ -package main +package cache import ( "github.com/miekg/dns" @@ -6,18 +6,35 @@ import ( "time" ) +func NewMemoryCache(expire time.Duration, maxCount int) *MemoryCache { + return &MemoryCache{ + backend: make(map[string]Mesg), + Expire: expire, + Maxcount: maxCount, + } +} + +type Mesg struct { + Msg *dns.Msg + Expire time.Time +} + type MemoryCache struct { Cache - Backend map[string]Mesg + backend map[string]Mesg Expire time.Duration Maxcount int mu sync.RWMutex } +func (c *MemoryCache) initialize() { + c.backend = make(map[string]Mesg) +} + func (c *MemoryCache) Get(key string) (*dns.Msg, error) { c.mu.RLock() - mesg, ok := c.Backend[key] + mesg, ok := c.backend[key] c.mu.RUnlock() if !ok { return nil, KeyNotFound{key} @@ -40,21 +57,21 @@ func (c *MemoryCache) Set(key string, msg *dns.Msg) error { expire := time.Now().Add(c.Expire) mesg := Mesg{msg, expire} c.mu.Lock() - c.Backend[key] = mesg + c.backend[key] = mesg c.mu.Unlock() return nil } func (c *MemoryCache) Remove(key string) error { c.mu.Lock() - delete(c.Backend, key) + delete(c.backend, key) c.mu.Unlock() return nil } func (c *MemoryCache) Exists(key string) bool { c.mu.RLock() - _, ok := c.Backend[key] + _, ok := c.backend[key] c.mu.RUnlock() return ok } @@ -62,7 +79,7 @@ func (c *MemoryCache) Exists(key string) bool { func (c *MemoryCache) Length() int { c.mu.RLock() defer c.mu.RUnlock() - return len(c.Backend) + return len(c.backend) } func (c *MemoryCache) Full() bool { @@ -72,3 +89,12 @@ func (c *MemoryCache) Full() bool { } return c.Length() >= c.Maxcount } + +func (c *MemoryCache) Purge() error { + c.mu.Lock() + defer c.mu.Unlock() + + c.initialize() + + return nil +} \ No newline at end of file diff --git a/cache_redis.go b/cache/cache_redis.go similarity index 53% rename from cache_redis.go rename to cache/cache_redis.go index eb42267..b3eb57a 100644 --- a/cache_redis.go +++ b/cache/cache_redis.go @@ -1,12 +1,14 @@ -package main +package cache import ( - "github.com/hoisie/redis" + "github.com/go-redis/redis/v7" "github.com/miekg/dns" + "meow.tf/joker/godns/settings" + "time" ) -func NewRedisCache(c RedisSettings, expire int32) *RedisCache { - rc := &redis.Client{Addr: c.Addr(), Db: c.DB, Password: c.Password} +func NewRedisCache(c settings.RedisSettings, expire int32) *RedisCache { + rc := redis.NewClient(&redis.Options{Addr: c.Addr(), DB: c.DB, Password: c.Password}) return &RedisCache{ backend: rc, @@ -25,6 +27,8 @@ func (m *RedisCache) Set(key string, msg *dns.Msg) error { var val []byte var err error + key = "cache:" + key + // handle cases for negacache where it sets nil values if msg == nil { val = []byte("nil") @@ -34,16 +38,21 @@ func (m *RedisCache) Set(key string, msg *dns.Msg) error { if err != nil { err = SerializerError{err} } - return m.backend.Setex(key, int64(m.expire), val) + return m.backend.Set(key, val, time.Duration(m.expire) * time.Second).Err() } func (m *RedisCache) Get(key string) (*dns.Msg, error) { var msg dns.Msg - item, err := m.backend.Get(key) + var err error + key = "cache:" + key + + item, err := m.backend.Get(key).Bytes() + if err != nil { err = KeyNotFound{key} return &msg, err } + err = msg.Unpack(item) if err != nil { err = SerializerError{err} @@ -52,19 +61,40 @@ func (m *RedisCache) Get(key string) (*dns.Msg, error) { } func (m *RedisCache) Exists(key string) bool { - exists, err := m.backend.Exists(key) + res, err := m.backend.Exists(key).Result() + if err != nil { return true } - return exists + + return res == 1 } func (m *RedisCache) Remove(key string) error { - _, err := m.backend.Del(key) - return err + return m.backend.Del(key).Err() } func (m *RedisCache) Full() bool { // redis is never full (LRU) return false } + +func (m *RedisCache) Purge() error { + iter := m.backend.Scan(0, "cache:*", 0).Iterator() + + if iter.Err() != nil { + return iter.Err() + } + + var err error + + for iter.Next() { + err = m.backend.Del(iter.Val()).Err() + + if err != nil { + return err + } + } + + return nil +} \ No newline at end of file diff --git a/etc/godns.conf b/etc/godns.conf index 41ce675..d75abea 100644 --- a/etc/godns.conf +++ b/etc/godns.conf @@ -9,7 +9,7 @@ Debug = false [server] host = "0.0.0.0" -port = 53 +nets = ["tcp:53", "udp:53"] [resolv] # Domain-specific nameservers configuration, formatting keep compatible with Dnsmasq diff --git a/go.mod b/go.mod index 583846b..748461d 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module godns +module meow.tf/joker/godns go 1.12 @@ -7,18 +7,19 @@ require ( github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/caarlos0/env v3.5.0+incompatible github.com/fsnotify/fsnotify v1.4.7 + github.com/go-redis/redis/v7 v7.0.0-beta.5 github.com/gopherjs/gopherjs v0.0.0-20190915194858-d3ddacdb130f // indirect github.com/hoisie/redis v0.0.0-20160730154456-b5c6e81454e0 - github.com/miekg/dns v1.1.18 + github.com/kr/pretty v0.2.0 // indirect + github.com/miekg/dns v1.1.27 + github.com/mitchellh/go-homedir v1.1.0 github.com/ryanuber/go-glob v1.0.0 github.com/smartystreets/assertions v1.0.1 // indirect - github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337 - github.com/stretchr/objx v0.2.0 // indirect + github.com/smartystreets/goconvey v1.6.4 + github.com/spf13/viper v1.6.2 github.com/stretchr/testify v1.4.0 // indirect - golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392 // indirect - golang.org/x/net v0.0.0-20190926025831-c00fd9afed17 // indirect + golang.org/x/crypto v0.0.0-20200117160349-530e935923ad // indirect + golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa // indirect golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e // indirect - golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe // indirect - golang.org/x/tools v0.0.0-20190925230517-ea99b82c7b93 // indirect - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect + golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9 // indirect ) diff --git a/handler.go b/handler.go index cf97b4c..7f72983 100644 --- a/handler.go +++ b/handler.go @@ -1,6 +1,14 @@ package main import ( + "crypto/md5" + "fmt" + "meow.tf/joker/godns/cache" + "meow.tf/joker/godns/hosts" + "meow.tf/joker/godns/log" + "meow.tf/joker/godns/resolver" + "meow.tf/joker/godns/settings" + "meow.tf/joker/godns/utils" "net" "time" @@ -13,75 +21,56 @@ const ( _IP6Query = 6 ) -type Question struct { - qname string - qtype string - qclass string -} - -func (q *Question) String() string { - return q.qname + " " + q.qclass + " " + q.qtype -} - type GODNSHandler struct { - resolver *Resolver - cache, negCache Cache - hosts Hosts + resolver *resolver.Resolver + cache, negCache cache.Cache + hosts hosts.Hosts } func NewHandler() *GODNSHandler { var ( - cacheConfig CacheSettings - resolver *Resolver - cache, negCache Cache + cacheConfig settings.CacheSettings + r *resolver.Resolver + resolverCache, negCache cache.Cache ) - resolver = NewResolver(settings.ResolvConfig) + r = resolver.NewResolver(settings.Resolver()) - cacheConfig = settings.Cache + cacheConfig = settings.Cache() switch cacheConfig.Backend { case "memory": - cache = &MemoryCache{ - Backend: make(map[string]Mesg, cacheConfig.Maxcount), - Expire: time.Duration(cacheConfig.Expire) * time.Second, - Maxcount: cacheConfig.Maxcount, - } - negCache = &MemoryCache{ - Backend: make(map[string]Mesg), - Expire: time.Duration(cacheConfig.Expire) * time.Second / 2, - Maxcount: cacheConfig.Maxcount, - } + cacheDuration := time.Duration(cacheConfig.Expire) * time.Second + + negCache = cache.NewMemoryCache(cacheDuration/2, cacheConfig.Maxcount) + resolverCache = cache.NewMemoryCache(time.Duration(cacheConfig.Expire)*time.Second, cacheConfig.Maxcount) case "memcache": - cache = NewMemcachedCache( - settings.Memcache.Servers, + resolverCache = cache.NewMemcachedCache( + settings.Memcache().Servers, int32(cacheConfig.Expire)) - negCache = NewMemcachedCache( - settings.Memcache.Servers, + negCache = cache.NewMemcachedCache( + settings.Memcache().Servers, int32(cacheConfig.Expire/2)) case "redis": - cache = NewRedisCache( - settings.Redis, + resolverCache = cache.NewRedisCache( + settings.Redis(), int32(cacheConfig.Expire)) - negCache = NewRedisCache( - settings.Redis, + negCache = cache.NewRedisCache( + settings.Redis(), int32(cacheConfig.Expire/2)) default: - logger.Error("Invalid cache backend %s", cacheConfig.Backend) + log.Error("Invalid cache backend %s", cacheConfig.Backend) panic("Invalid cache backend") } - var hosts Hosts - if settings.Hosts.Enable { - hosts = NewHosts(settings.Hosts, settings.Redis) - } + h := hosts.NewHosts(settings.Hosts(), settings.Redis()) - return &GODNSHandler{resolver, cache, negCache, hosts} + return &GODNSHandler{r, resolverCache, negCache, h} } func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] - Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]} + question := resolver.Question{Name: utils.UnFqdn(q.Name), Type: dns.TypeToString[q.Qtype], Class: dns.ClassToString[q.Qclass]} var remote net.IP if Net == "tcp" { @@ -89,63 +78,61 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { } else { remote = w.RemoteAddr().(*net.UDPAddr).IP } - logger.Info("%s lookup %s", remote, Q.String()) + log.Info("%s lookup %s", remote, question.String()) IPQuery := h.isIPQuery(q) // Query hosts - if settings.Hosts.Enable && IPQuery > 0 { - if ips, ok := h.hosts.Get(Q.qname, IPQuery); ok { + if h.hosts != nil && IPQuery > 0 { + if ips, ok := h.hosts.Get(question.Name, IPQuery); ok { m := new(dns.Msg) m.SetReply(req) switch IPQuery { case _IP4Query: - rr_header := dns.RR_Header{ + hdr := dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, - Ttl: settings.Hosts.TTL, + Ttl: h.hosts.TTL(), } for _, ip := range ips { - a := &dns.A{rr_header, ip} - m.Answer = append(m.Answer, a) + m.Answer = append(m.Answer, &dns.A{Hdr: hdr, A: ip}) } case _IP6Query: - rr_header := dns.RR_Header{ + hdr := dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, - Ttl: settings.Hosts.TTL, + Ttl: h.hosts.TTL(), } for _, ip := range ips { - aaaa := &dns.AAAA{rr_header, ip} - m.Answer = append(m.Answer, aaaa) + m.Answer = append(m.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip}) } } w.WriteMsg(m) - logger.Debug("%s found in hosts file", Q.qname) + log.Debug("%s found in hosts file", question.Name) return } else { - logger.Debug("%s didn't found in hosts file", Q.qname) + log.Debug("%s didn't found in hosts file", question.Name) } } // Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN' - key := KeyGen(Q) + key := KeyGen(question) if IPQuery > 0 { mesg, err := h.cache.Get(key) if err != nil { if mesg, err = h.negCache.Get(key); err != nil { - logger.Debug("%s didn't hit cache", Q.String()) + log.Debug("%s didn't hit cache", question.String()) } else { - logger.Debug("%s hit negative cache", Q.String()) + log.Debug("%s hit negative cache", question.String()) dns.HandleFailed(w, req) return } } else { - logger.Debug("%s hit cache", Q.String()) + log.Debug("%s hit cache", question.String()) // we need this copy against concurrent modification of Id msg := *mesg msg.Id = req.Id @@ -157,12 +144,12 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { mesg, err := h.resolver.Lookup(Net, req) if err != nil { - logger.Warn("Resolve query error %s", err) + log.Warn("Resolve query error %s", err) dns.HandleFailed(w, req) // cache the failure, too! if err = h.negCache.Set(key, nil); err != nil { - logger.Warn("Set %s negative cache failed: %v", Q.String(), err) + log.Warn("Set %s negative cache failed: %v", question.String(), err) } return } @@ -172,18 +159,16 @@ func (h *GODNSHandler) do(Net string, w dns.ResponseWriter, req *dns.Msg) { if IPQuery > 0 && len(mesg.Answer) > 0 { err = h.cache.Set(key, mesg) if err != nil { - logger.Warn("Set %s cache failed: %s", Q.String(), err.Error()) + log.Warn("Set %s cache failed: %s", question.String(), err.Error()) } - logger.Debug("Insert %s into cache", Q.String()) + log.Debug("Insert %s into cache", question.String()) } } -func (h *GODNSHandler) DoTCP(w dns.ResponseWriter, req *dns.Msg) { - h.do("tcp", w, req) -} - -func (h *GODNSHandler) DoUDP(w dns.ResponseWriter, req *dns.Msg) { - h.do("udp", w, req) +func (h *GODNSHandler) Bind(net string) func(w dns.ResponseWriter, req *dns.Msg) { + return func(w dns.ResponseWriter, req *dns.Msg) { + h.do(net, w, req) + } } func (h *GODNSHandler) isIPQuery(q dns.Question) int { @@ -201,9 +186,10 @@ func (h *GODNSHandler) isIPQuery(q dns.Question) int { } } -func UnFqdn(s string) string { - if dns.IsFqdn(s) { - return s[:len(s)-1] - } - return s +func KeyGen(q resolver.Question) string { + h := md5.New() + h.Write([]byte(q.String())) + x := h.Sum(nil) + key := fmt.Sprintf("%x", x) + return key } diff --git a/hosts.go b/hosts/hosts.go similarity index 65% rename from hosts.go rename to hosts/hosts.go index 657cda7..748232d 100644 --- a/hosts.go +++ b/hosts/hosts.go @@ -1,13 +1,27 @@ -package main +package hosts import ( + "meow.tf/joker/godns/log" + "meow.tf/joker/godns/settings" "net" "time" "github.com/hoisie/redis" ) -type Hosts struct { +const ( + notIPQuery = 0 + _IP4Query = 4 + _IP6Query = 6 +) + +type Hosts interface { + Get(domain string, family int) ([]net.IP, bool) + TTL() uint32 +} + +type ProviderList struct { + settings settings.HostsSettings providers []HostProvider refreshInterval time.Duration } @@ -17,20 +31,20 @@ type HostProvider interface { Refresh() } -func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { +func NewHosts(hs settings.HostsSettings, rs settings.RedisSettings) Hosts { providers := []HostProvider{ NewFileProvider(hs.HostsFile), } if hs.RedisEnable { - logger.Info("Redis is enabled: %s", rs.Addr()) + log.Info("Redis is enabled: %s", rs.Addr()) rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password} providers = append(providers, NewRedisProvider(rc, hs.RedisKey)) } - h := Hosts{providers, time.Second * time.Duration(hs.RefreshInterval)} + h := &ProviderList{hs, providers, time.Second * time.Duration(hs.RefreshInterval)} if h.refreshInterval > 0 { h.refresh() @@ -39,7 +53,7 @@ func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { return h } -func (h *Hosts) refresh() { +func (h *ProviderList) refresh() { ticker := time.NewTicker(h.refreshInterval) go func() { @@ -57,7 +71,7 @@ func (h *Hosts) refresh() { /* Match local /etc/hosts file first, remote redis records second */ -func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) { +func (h *ProviderList) Get(domain string, family int) ([]net.IP, bool) { var sips []string var ok bool var ip net.IP @@ -91,3 +105,7 @@ func (h *Hosts) Get(domain string, family int) ([]net.IP, bool) { return ips, ips != nil } + +func (h *ProviderList) TTL() uint32 { + return h.settings.TTL +} \ No newline at end of file diff --git a/hosts_file.go b/hosts/hosts_file.go similarity index 90% rename from hosts_file.go rename to hosts/hosts_file.go index b1b1af8..65109b9 100644 --- a/hosts_file.go +++ b/hosts/hosts_file.go @@ -1,9 +1,10 @@ -package main +package hosts import ( "bufio" "github.com/fsnotify/fsnotify" "github.com/ryanuber/go-glob" + "meow.tf/joker/godns/log" "os" "regexp" "strings" @@ -45,7 +46,7 @@ func NewFileProvider(file string) HostProvider { } func (f *FileHosts) Get(domain string) ([]string, bool) { - logger.Debug("Checking file provider for %s", domain) + log.Debug("Checking file provider for %s", domain) f.mu.RLock() defer f.mu.RUnlock() @@ -80,7 +81,7 @@ func (f *FileHosts) Refresh() { buf, err := os.Open(f.file) if err != nil { - logger.Warn("Update hosts records from file failed %s", err) + log.Warn("Update hosts records from file failed %s", err) return } @@ -127,7 +128,7 @@ func (f *FileHosts) Refresh() { f.hosts[strings.ToLower(domain)] = ip } } - logger.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts)) + log.Debug("update hosts records from %s, total %d records.", f.file, len(f.hosts)) } func (f *FileHosts) clear() { diff --git a/hosts_redis.go b/hosts/hosts_redis.go similarity index 81% rename from hosts_redis.go rename to hosts/hosts_redis.go index 760aa26..be86a45 100644 --- a/hosts_redis.go +++ b/hosts/hosts_redis.go @@ -1,8 +1,9 @@ -package main +package hosts import ( "github.com/hoisie/redis" "github.com/ryanuber/go-glob" + "meow.tf/joker/godns/log" "strings" "sync" ) @@ -41,7 +42,7 @@ func NewRedisProvider(rc *redis.Client, key string) HostProvider { msg := <-messages if msg.Channel == "godns:update" { - logger.Debug("Refreshing redis records due to update") + log.Debug("Refreshing redis records due to update") rh.Refresh() } else if msg.Channel == "godns:update_record" { recordName := string(msg.Message) @@ -49,17 +50,17 @@ func NewRedisProvider(rc *redis.Client, key string) HostProvider { b, err := rc.Hget(key, recordName) if err != nil { - logger.Warn("Record %s does not exist, but was updated", recordName) + log.Warn("Record %s does not exist, but was updated", recordName) continue } - logger.Debug("Record %s was updated to %s", recordName, string(b)) + log.Debug("Record %s was updated to %s", recordName, string(b)) rh.mu.Lock() rh.hosts[recordName] = string(b) rh.mu.Unlock() } else if msg.Channel == "godns:remove_record" { - logger.Debug("Record %s was removed", msg.Message) + log.Debug("Record %s was removed", msg.Message) recordName := string(msg.Message) @@ -67,7 +68,7 @@ func NewRedisProvider(rc *redis.Client, key string) HostProvider { delete(rh.hosts, recordName) rh.mu.Unlock() } else if msg.Channel == keyspaceEvent { - logger.Debug("Refreshing redis records due to update") + log.Debug("Refreshing redis records due to update") rh.Refresh() } } @@ -77,7 +78,7 @@ func NewRedisProvider(rc *redis.Client, key string) HostProvider { } func (r *RedisHosts) Get(domain string) ([]string, bool) { - logger.Debug("Checking redis provider for %s", domain) + log.Debug("Checking redis provider for %s", domain) r.mu.RLock() defer r.mu.RUnlock() @@ -117,9 +118,9 @@ func (r *RedisHosts) Refresh() { r.clear() err := r.redis.Hgetall(r.key, r.hosts) if err != nil { - logger.Warn("Update hosts records from redis failed %s", err) + log.Warn("Update hosts records from redis failed %s", err) } else { - logger.Debug("Update hosts records from redis") + log.Debug("Update hosts records from redis") } } diff --git a/log.go b/log/log.go similarity index 98% rename from log.go rename to log/log.go index a355bf9..4a8cd0d 100644 --- a/log.go +++ b/log/log.go @@ -1,4 +1,4 @@ -package main +package log import ( "fmt" @@ -16,6 +16,10 @@ const ( LevelError ) +var ( + logger *GoDNSLogger +) + type logMesg struct { Level int Mesg string diff --git a/log_test.go b/log/log_test.go similarity index 98% rename from log_test.go rename to log/log_test.go index ca7eb26..79662c1 100644 --- a/log_test.go +++ b/log/log_test.go @@ -1,4 +1,4 @@ -package main +package log import ( "bufio" diff --git a/log/logger.go b/log/logger.go new file mode 100644 index 0000000..21eae12 --- /dev/null +++ b/log/logger.go @@ -0,0 +1,33 @@ +package log + +func init() { + logger = NewLogger() +} + +func Debug(format string, v ...interface{}) { + logger.Debug(format, v...) +} + +func Info(format string, v ...interface{}) { + logger.Info(format, v...) +} + +func Notice(format string, v ...interface{}) { + logger.Notice(format, v...) +} + +func Warn(format string, v ...interface{}) { + logger.Warn(format, v...) +} + +func Error(format string, v ...interface{}) { + logger.Error(format, v...) +} + +func SetLogger(handlerType string, config map[string]interface{}) { + logger.SetLogger(handlerType, config) +} + +func SetLevel(level int) { + logger.SetLevel(level) +} \ No newline at end of file diff --git a/main.go b/main.go index 2fbef68..a64e63c 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,11 @@ package main import ( + "flag" + "fmt" + "github.com/spf13/viper" + "meow.tf/joker/godns/log" + "meow.tf/joker/godns/settings" "os" "os/signal" "runtime" @@ -8,26 +13,43 @@ import ( "time" ) -var ( - logger *GoDNSLogger +const ( + Version = "0.3.0" ) -func main() { +var ( + cfgFile string +) +func init() { + flag.StringVar(&cfgFile, "config", "/etc/godns/godns.conf", "") +} + +func main() { initLogger() + viper.SetConfigFile(cfgFile) + + viper.AutomaticEnv() + + if err := viper.ReadInConfig(); err == nil { + fmt.Println("Using config file:", viper.ConfigFileUsed()) + } + + serverSettings := settings.Server() + server := &Server{ - host: settings.Server.Host, - port: settings.Server.Port, + host: serverSettings.Host, + port: serverSettings.Port, rTimeout: 5 * time.Second, wTimeout: 5 * time.Second, } server.Run() - logger.Info("godns %s start", settings.Version) + log.Info("godns %s (%s) start", Version, runtime.Version()) - if settings.Debug { + if settings.Debug() { go profileCPU() go profileMEM() } @@ -35,21 +57,15 @@ func main() { sig := make(chan os.Signal) signal.Notify(sig, os.Interrupt) -forever: - for { - select { - case <-sig: - logger.Info("signal received, stopping") - break forever - } - } + <- sig + log.Info("signal received, stopping") } func profileCPU() { f, err := os.Create("godns.cprof") if err != nil { - logger.Error("%s", err) + log.Error("%s", err) return } @@ -63,8 +79,9 @@ func profileCPU() { func profileMEM() { f, err := os.Create("godns.mprof") + if err != nil { - logger.Error("%s", err) + log.Error("%s", err) return } @@ -76,18 +93,17 @@ func profileMEM() { } func initLogger() { - logger = NewLogger() + logSettings := settings.Log() - if settings.Log.Stdout { - logger.SetLogger("console", nil) + if viper.GetBool("log.stdout") { + log.SetLogger("console", nil) } - if settings.Log.File != "" { - config := map[string]interface{}{"file": settings.Log.File} - logger.SetLogger("file", config) + if file := viper.GetString("log.file"); file != "" { + log.SetLogger("file", map[string]interface{}{"file": file}) } - logger.SetLevel(settings.Log.LogLevel()) + log.SetLevel(logSettings.LogLevel()) } func init() { diff --git a/resolver/nameserver.go b/resolver/nameserver.go new file mode 100644 index 0000000..3b5d73a --- /dev/null +++ b/resolver/nameserver.go @@ -0,0 +1,5 @@ +package resolver + +type Nameserver struct { + address string +} diff --git a/resolver/question.go b/resolver/question.go new file mode 100644 index 0000000..ecab487 --- /dev/null +++ b/resolver/question.go @@ -0,0 +1,11 @@ +package resolver + +type Question struct { + Name string + Type string + Class string +} + +func (q *Question) String() string { + return q.Name + " " + q.Class + " " + q.Type +} diff --git a/resolver.go b/resolver/resolver.go similarity index 71% rename from resolver.go rename to resolver/resolver.go index 39931e0..5c7c0c3 100644 --- a/resolver.go +++ b/resolver/resolver.go @@ -1,9 +1,12 @@ -package main +package resolver import ( "bufio" + "errors" "fmt" - "log" + "meow.tf/joker/godns/log" + "meow.tf/joker/godns/settings" + "meow.tf/joker/godns/utils" "net" "os" "strconv" @@ -11,9 +14,8 @@ import ( "sync" "time" - "github.com/miekg/dns" - "errors" "crypto/tls" + "github.com/miekg/dns" ) type ResolvError struct { @@ -35,14 +37,13 @@ type RResp struct { type Resolver struct { servers []string domain_server *suffixTreeNode - config *ResolvSettings + config *settings.ResolvSettings - tcpClient *dns.Client - udpClient *dns.Client - httpsClient *dns.Client + clients map[string]*dns.Client + clientLock sync.RWMutex } -func NewResolver(c ResolvSettings) *Resolver { +func NewResolver(c settings.ResolvSettings) *Resolver { r := &Resolver{ servers: []string{}, domain_server: newSuffixTreeRoot(), @@ -51,16 +52,14 @@ func NewResolver(c ResolvSettings) *Resolver { if len(c.ServerListFile) > 0 { r.ReadServerListFile(c.ServerListFile) - - log.Println("Read servers", strings.Join(r.servers, ", ")) } if len(c.ResolvFile) > 0 { clientConfig, err := dns.ClientConfigFromFile(c.ResolvFile) if err != nil { - logger.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) - logger.Error("%s", err) + log.Error(":%s is not a valid resolv.conf file\n", c.ResolvFile) + log.Error("%s", err) panic(err) } @@ -73,26 +72,6 @@ func NewResolver(c ResolvSettings) *Resolver { r.servers = append([]string{c.DOHServer}, r.servers...) } - timeout := r.Timeout() - - r.udpClient = &dns.Client{ - Net: "udp", - ReadTimeout: timeout, - WriteTimeout: timeout, - } - - r.tcpClient = &dns.Client{ - Net: "tcp", - ReadTimeout: timeout, - WriteTimeout: timeout, - } - - r.httpsClient = &dns.Client{ - Net: "https", - ReadTimeout: timeout, - WriteTimeout: timeout, - } - return r } @@ -128,7 +107,7 @@ func (r *Resolver) parseServerListFile(buf *os.File) { domain := tokens[1] ip := tokens[2] - if !isDomain(domain) || !isIP(ip) { + if !utils.IsDomain(domain) || !utils.IsIP(ip) { continue } @@ -142,7 +121,7 @@ func (r *Resolver) parseServerListFile(buf *os.File) { ip := "" - if ip = srv_port[0]; !isIP(ip) { + if ip = srv_port[0]; !utils.IsIP(ip) { continue } @@ -169,8 +148,8 @@ func (r *Resolver) ReadServerListFile(path string) { if err != nil { panic("Can't open " + file) } - defer buf.Close() r.parseServerListFile(buf) + buf.Close() } } @@ -178,7 +157,7 @@ func (r *Resolver) ReadServerListFile(path string) { // in every second, and return as early as possbile (have an answer). // It returns an error if no request has succeeded. func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error) { - if net == "udp" && settings.ResolvConfig.SetEDNS0 { + if net == "udp" && r.config.SetEDNS0 { req = req.SetEdns0(65535, true) } @@ -192,15 +171,15 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error c, err := resolver.resolverFor(net, nameserver) if err != nil { - logger.Warn("error:%s", err.Error()) + log.Warn("error:%s", err.Error()) return } r, rtt, err := c.Exchange(req, nameserver) if err != nil { - logger.Warn("%s socket error on %s", qname, nameserver) - logger.Warn("error:%s", err.Error()) + log.Warn("%s socket error on %s", qname, nameserver) + log.Warn("error:%s", err.Error()) return } // If SERVFAIL happen, should return immediately and try another upstream resolver. @@ -208,7 +187,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error // that it has been verified no such domain existas and ask other resolvers // would make no sense. See more about #20 if r != nil && r.Rcode != dns.RcodeSuccess { - logger.Warn("%s failed to get an valid answer on %s", qname, nameserver) + log.Warn("%s failed to get an valid answer on %s", qname, nameserver) if r.Rcode == dns.RcodeServerFailure { return } @@ -220,7 +199,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error } } - ticker := time.NewTicker(time.Duration(settings.ResolvConfig.Interval) * time.Millisecond) + ticker := time.NewTicker(time.Duration(r.config.Interval) * time.Millisecond) defer ticker.Stop() // Start lookup on each nameserver top-down, in every second nameservers := r.Nameservers(qname) @@ -230,7 +209,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error // but exit early, if we have an answer select { case re := <-res: - logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), re.nameserver, re.rtt) + log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver, re.rtt) return re.msg, nil case <-ticker.C: continue @@ -240,7 +219,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error wg.Wait() select { case re := <-res: - logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), re.nameserver, re.rtt) + log.Debug("%s resolv on %s rtt: %v", utils.UnFqdn(qname), re.nameserver, re.rtt) return re.msg, nil default: return nil, ResolvError{qname, net, nameservers} @@ -248,25 +227,37 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error } func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) { - if strings.HasPrefix(nameserver, "https") { - return r.httpsClient, nil - } else if strings.HasSuffix(nameserver, ":853") { - // TODO We need to set the server name so we can confirm the TLS connection. This may require a rewrite of storing nameservers. - return &dns.Client{ - Net: "tcp-tls", - ReadTimeout: r.Timeout(), - WriteTimeout: r.Timeout(), - TLSConfig: &tls.Config{ - ServerName: "", - }, - }, nil - } else if net == "udp" { - return r.udpClient, nil - } else if net == "tcp" { - return r.tcpClient, nil + r.clientLock.RLock() + client, exists := r.clients[net] + r.clientLock.RUnlock() + + if exists { + return client, nil } - return nil, errors.New("no client for nameserver") + if net != "tcp" && net != "tcp-tls" && net != "https" && net != "udp" { + return nil, errors.New("unknown network type") + } + + timeout := r.Timeout() + + client = &dns.Client{ + Net: net, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + if strings.HasSuffix(nameserver, ":853") { + client.TLSConfig = &tls.Config{ + + } + } + + r.clientLock.Lock() + r.clients[net] = client + r.clientLock.Lock() + + return client, nil } // Namservers return the array of nameservers, with port number appended. @@ -278,7 +269,7 @@ func (r *Resolver) Nameservers(qname string) []string { ns := []string{} if v, found := r.domain_server.search(queryKeys); found { - logger.Debug("%s be found in domain server list, upstream: %v", qname, v) + log.Debug("%s found in domain server list, upstream: %v", qname, v) ns = append(ns, net.JoinHostPort(v, "53")) //Ensure query the specific upstream nameserver in async Lookup() function. diff --git a/sfx_tree.go b/resolver/sfx_tree.go similarity index 98% rename from sfx_tree.go rename to resolver/sfx_tree.go index bfaca0e..41e4cf8 100644 --- a/sfx_tree.go +++ b/resolver/sfx_tree.go @@ -1,4 +1,4 @@ -package main +package resolver type suffixTreeNode struct { key string diff --git a/sfx_tree_test.go b/resolver/sfx_tree_test.go similarity index 98% rename from sfx_tree_test.go rename to resolver/sfx_tree_test.go index b2f0ae1..9ff26b9 100644 --- a/sfx_tree_test.go +++ b/resolver/sfx_tree_test.go @@ -1,4 +1,4 @@ -package main +package resolver import ( "strings" diff --git a/server.go b/server.go index 020ee48..a1109e4 100644 --- a/server.go +++ b/server.go @@ -1,8 +1,9 @@ package main import ( - "net" - "strconv" + "github.com/spf13/viper" + "meow.tf/joker/godns/log" + "strings" "time" "github.com/miekg/dns" @@ -15,45 +16,49 @@ type Server struct { wTimeout time.Duration } -func (s *Server) Addr() string { - return net.JoinHostPort(s.host, strconv.Itoa(s.port)) -} - func (s *Server) Run() { handler := NewHandler() - tcpHandler := dns.NewServeMux() - tcpHandler.HandleFunc(".", handler.DoTCP) + nets := viper.GetStringSlice("networks") - udpHandler := dns.NewServeMux() - udpHandler.HandleFunc(".", handler.DoUDP) + var addr string + var split []string - tcpServer := &dns.Server{ - Addr: s.Addr(), - Net: "tcp", - Handler: tcpHandler, - ReadTimeout: s.rTimeout, - WriteTimeout: s.wTimeout, + // Defaults: tcp, udp + for _, net := range nets { + split = strings.Split(net, ":") + + net = split[0] + + addr = s.host + + if len(split) == 1 { + addr += ":53" + } else { + addr += ":" + split[1] + } + + h := dns.NewServeMux() + h.HandleFunc(".", handler.Bind(net)) + + server := &dns.Server{ + Addr: addr, + Net: net, + Handler: h, + ReadTimeout: s.rTimeout, + WriteTimeout: s.wTimeout, + } + + go s.start(server) } - - udpServer := &dns.Server{ - Addr: s.Addr(), - Net: "udp", - Handler: udpHandler, - UDPSize: 65535, - ReadTimeout: s.rTimeout, - WriteTimeout: s.wTimeout, - } - - go s.start(udpServer) - go s.start(tcpServer) } func (s *Server) start(ds *dns.Server) { - logger.Info("Start %s listener on %s", ds.Net, s.Addr()) + log.Info("Start %s listener on %s", ds.Net, ds.Addr) err := ds.ListenAndServe() + if err != nil { - logger.Error("Start %s listener on %s failed:%s", ds.Net, s.Addr(), err.Error()) + log.Error("Start %s listener on %s failed:%s", ds.Net, ds.Addr, err.Error()) } } diff --git a/settings.go b/settings/settings.go similarity index 82% rename from settings.go rename to settings/settings.go index 96ca08c..bbd0201 100644 --- a/settings.go +++ b/settings/settings.go @@ -1,8 +1,9 @@ -package main +package settings import ( "flag" "fmt" + "meow.tf/joker/godns/log" "os" "strconv" @@ -15,11 +16,20 @@ var ( ) var LogLevelMap = map[string]int{ - "DEBUG": LevelDebug, - "INFO": LevelInfo, - "NOTICE": LevelNotice, - "WARN": LevelWarn, - "ERROR": LevelError, + "DEBUG": log.LevelDebug, + "INFO": log.LevelInfo, + "NOTICE": log.LevelNotice, + "WARN": log.LevelWarn, + "ERROR": log.LevelError, +} + +type HostsSettings struct { + Enable bool `toml:"enable" env:"HOSTS_ENABLE"` + HostsFile string `toml:"host-file" env:"HOSTS_FILE"` + RedisEnable bool `toml:"redis-enable" env:"REDIS_HOSTS_ENABLE"` + RedisKey string `toml:"redis-key" env:"REDIS_HOSTS_KEY"` + TTL uint32 `toml:"ttl" env:"HOSTS_TTL"` + RefreshInterval uint32 `toml:"refresh-interval" env:"HOSTS_REFRESH_INTERVAL"` } type Settings struct { @@ -83,15 +93,6 @@ type CacheSettings struct { Maxcount int `toml:"maxcount" env:"CACHE_MAX_COUNT"` } -type HostsSettings struct { - Enable bool `toml:"enable" env:"HOSTS_ENABLE"` - HostsFile string `toml:"host-file" env:"HOSTS_FILE"` - RedisEnable bool `toml:"redis-enable" env:"REDIS_HOSTS_ENABLE"` - RedisKey string `toml:"redis-key" env:"REDIS_HOSTS_KEY"` - TTL uint32 `toml:"ttl" env:"HOSTS_TTL"` - RefreshInterval uint32 `toml:"refresh-interval" env:"HOSTS_REFRESH_INTERVAL"` -} - func init() { var configFile string @@ -110,4 +111,36 @@ func init() { env.Parse(&settings.Log) env.Parse(&settings.Cache) env.Parse(&settings.Hosts) +} + +func Resolver() ResolvSettings { + return settings.ResolvConfig +} + +func Cache() CacheSettings { + return settings.Cache +} + +func Server() DNSServerSettings { + return settings.Server +} + +func Hosts() HostsSettings { + return settings.Hosts +} + +func Redis() RedisSettings { + return settings.Redis +} + +func Memcache() MemcacheSettings { + return settings.Memcache +} + +func Debug() bool { + return settings.Debug +} + +func Log() LogSettings { + return settings.Log } \ No newline at end of file diff --git a/utils.go b/utils.go deleted file mode 100644 index 9a19434..0000000 --- a/utils.go +++ /dev/null @@ -1,18 +0,0 @@ -package main - -import ( - "net" - "regexp" -) - -func isDomain(domain string) bool { - if isIP(domain) { - return false - } - match, _ := regexp.MatchString(`^([a-zA-Z0-9\*]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,6}$`, domain) - return match -} - -func isIP(ip string) bool { - return net.ParseIP(ip) != nil -} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..5a1c128 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,26 @@ +package utils + +import ( + "github.com/miekg/dns" + "net" + "regexp" +) + +func IsDomain(domain string) bool { + if IsIP(domain) { + return false + } + match, _ := regexp.MatchString(`^([a-zA-Z0-9\*]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,6}$`, domain) + return match +} + +func IsIP(ip string) bool { + return net.ParseIP(ip) != nil +} + +func UnFqdn(s string) string { + if dns.IsFqdn(s) { + return s[:len(s)-1] + } + return s +} \ No newline at end of file diff --git a/utils_test.go b/utils/utils_test.go similarity index 98% rename from utils_test.go rename to utils/utils_test.go index 715b333..8a9dc5a 100644 --- a/utils_test.go +++ b/utils/utils_test.go @@ -1,4 +1,4 @@ -package main +package utils import ( "testing"