godns/main.go
Tyler d8079551c9
Some checks failed
continuous-integration/drone/push Build is failing
Add testing, cleanup, rework suffix tree to use nameservers. Parse nameservers from yaml.
2021-04-15 01:04:58 -04:00

198 lines
4.2 KiB
Go

package main
import (
"flag"
"github.com/go-redis/redis/v7"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"io"
"meow.tf/joker/godns/api"
"meow.tf/joker/godns/cache"
"meow.tf/joker/godns/hosts"
"meow.tf/joker/godns/resolver"
"meow.tf/joker/godns/settings"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"time"
)
const (
Version = "0.3.0"
)
var (
cfgFile string
)
func init() {
flag.StringVar(&cfgFile, "config", "/etc/godns/godns.conf", "")
}
func main() {
initLogger()
viper.SetConfigFile(cfgFile)
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err == nil {
log.WithField("file", viper.ConfigFileUsed()).Info("Using configuration from file")
}
server := &Server{
host: viper.GetString("server.host"),
networks: viper.GetStringSlice("server.nets"),
rTimeout: 5 * time.Second,
wTimeout: 5 * time.Second,
}
var resolverSettings resolver.Settings
/*
resolver.Settings{
Timeout: viper.GetInt("resolv.timeout"),
Interval: viper.GetInt("resolv.interval"),
SetEDNS0: viper.GetBool("resolv.edns0"),
ServerListFile: viper.GetStringSlice("resolv.server-list"),
ResolvFile: viper.GetString("resolv.file"),
}
*/
viper.UnmarshalKey("resolv", &resolverSettings)
var resolverCache, negCache cache.Cache
r := resolver.NewResolver(resolverSettings)
cacheDuration := viper.GetDuration("cache.expire")
backend := viper.GetString("cache.backend")
var redisConfig settings.RedisSettings
viper.UnmarshalKey("cache.redis", &redisConfig)
switch backend {
case "memory":
cacheMaxCount := viper.GetInt("cache.memory.maxCount")
negCache = cache.NewMemoryCache(cacheDuration/2, cacheMaxCount)
resolverCache = cache.NewMemoryCache(cacheDuration, cacheMaxCount)
case "memcached":
servers := viper.GetStringSlice("cache.memcached.servers")
resolverCache = cache.NewMemcachedCache(servers, int32(cacheDuration.Seconds()))
negCache = cache.NewMemcachedCache(servers, int32(cacheDuration.Seconds()/2))
case "redis":
resolverCache = cache.NewRedisCache(redisConfig, cacheDuration)
negCache = cache.NewRedisCache(redisConfig, cacheDuration/2)
default:
log.WithField("backend", backend).Fatalln("Invalid cache backend")
}
providers := make([]hosts.Provider, 0)
if viper.GetBool("hosts.file.enable") {
providers = append(providers, hosts.NewFileProvider(viper.GetString("hosts.file.file"), viper.GetDuration("hosts.file.ttl")))
}
if viper.GetBool("hosts.bolt.enable") {
providers = append(providers, hosts.NewBoltProvider(viper.GetString("hosts.bolt.file")))
}
if viper.GetBool("hosts.redis.enable") {
rc := redis.NewClient(&redis.Options{Addr: redisConfig.Addr(), DB: redisConfig.DB, Password: redisConfig.Password})
providers = append(providers, hosts.NewRedisProvider(rc, viper.GetString("hosts.redis.key")))
}
h := hosts.NewHosts(providers)
a := api.New()
hosts.EnableAPI(h, a.Router())
if viper.GetBool("api.enabled") {
go func() {
err := a.Start()
if err != nil {
log.WithError(err).Fatalln("Unable to bind API")
}
}()
}
handler := NewHandler(r, resolverCache, negCache, h)
server.Run(handler)
log.Infof("joker dns %s (%s)", Version, runtime.Version())
if viper.GetBool("debug") {
go profileCPU()
go profileMEM()
}
sig := make(chan os.Signal)
signal.Notify(sig, os.Interrupt)
<- sig
log.Info("signal received, stopping")
}
func profileCPU() {
f, err := os.Create("godns.cprof")
if err != nil {
log.WithError(err).Error("Unable to profile cpu due to error")
return
}
pprof.StartCPUProfile(f)
time.AfterFunc(6*time.Minute, func() {
pprof.StopCPUProfile()
f.Close()
})
}
func profileMEM() {
f, err := os.Create("godns.mprof")
if err != nil {
log.WithError(err).Error("Unable to profile memory due to error")
return
}
time.AfterFunc(5*time.Minute, func() {
pprof.WriteHeapProfile(f)
f.Close()
})
}
func initLogger() {
if viper.GetBool("log.stdout") {
// log.SetLogger("console", nil)
}
if file := viper.GetString("log.file"); file != "" {
f, err := os.Create(file)
if err != nil {
return
}
log.SetOutput(io.MultiWriter(f, log.StandardLogger().Out))
}
level, err := log.ParseLevel(viper.GetString("log.level"))
if err != nil {
return
}
log.SetLevel(level)
}