From e7236b13dee24cd63aa65302f75921f502ac029c Mon Sep 17 00:00:00 2001 From: Tyler Date: Mon, 15 Aug 2022 02:16:22 -0400 Subject: [PATCH] Massive refactoring, struct cleanup, supporting more features Features: - Protocol lists (http, https), managed by http responses - Working TLS Checks - Root certificate parsing for TLS checks - Moving configuration into a Config struct, no more direct viper access --- armbianmirror_suite_test.go | 2 +- check.go | 110 ++++++++++++--- check_test.go | 29 ++-- cmd/main.go | 112 +++++++++++++++ config.go | 157 ++++++++++++++-------- dlrouter.yaml | 4 + go.mod | 5 +- go.sum | 5 +- http.go | 67 +++++---- main.go | 154 --------------------- map.go | 2 +- map_test.go | 2 +- middleware.go => middleware/middleware.go | 2 +- mirrors.go | 20 +-- redirector.go | 130 ++++++++++++++++++ servers.go | 48 +++---- util.go | 4 +- util/certificates.go | 46 +++++++ 18 files changed, 574 insertions(+), 325 deletions(-) create mode 100644 cmd/main.go delete mode 100644 main.go rename middleware.go => middleware/middleware.go (98%) create mode 100644 redirector.go create mode 100644 util/certificates.go diff --git a/armbianmirror_suite_test.go b/armbianmirror_suite_test.go index 04b1730..07476b7 100644 --- a/armbianmirror_suite_test.go +++ b/armbianmirror_suite_test.go @@ -1,4 +1,4 @@ -package main +package redirector import ( "testing" diff --git a/check.go b/check.go index afdb0fd..93552d4 100644 --- a/check.go +++ b/check.go @@ -1,4 +1,4 @@ -package main +package redirector import ( "crypto/tls" @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "runtime" + "strings" "time" ) @@ -18,8 +19,14 @@ var ( ErrCertExpired = errors.New("certificate is expired") ) +func (r *Redirector) checkHttp(scheme string) ServerCheck { + return func(server *Server, logFields log.Fields) (bool, error) { + return r.checkHttpScheme(server, scheme, logFields) + } +} + // checkHttp checks a URL for validity, and checks redirects -func checkHttp(server *Server, logFields log.Fields) (bool, error) { +func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) { u := &url.URL{ Scheme: "http", Host: server.Host, @@ -48,13 +55,20 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) { logFields["url"] = location - // Check that we don't redirect to https from a http url - if u.Scheme == "http" { - res, err := checkRedirect(location) + switch u.Scheme { + case "http": + res, err := r.checkRedirect(u.Scheme, location) if !res || err != nil { - return res, err + // If we don't support http, we remove it from supported protocols + server.Protocols = server.Protocols.Remove("http") + } else { + // Otherwise, we verify https support + r.checkProtocol(server, "https") } + case "https": + // We don't want to allow downgrading, so this is an error. + return r.checkRedirect(u.Scheme, location) } } @@ -66,8 +80,20 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) { return false, nil } +func (r *Redirector) checkProtocol(server *Server, scheme string) { + res, err := r.checkHttpScheme(server, scheme, log.Fields{}) + + if !res || err != nil { + return + } + + if !server.Protocols.Contains(scheme) { + server.Protocols = server.Protocols.Append(scheme) + } +} + // checkRedirect parses a location header response and checks the scheme -func checkRedirect(locationHeader string) (bool, error) { +func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) { newUrl, err := url.Parse(locationHeader) if err != nil { @@ -76,20 +102,41 @@ func checkRedirect(locationHeader string) (bool, error) { if newUrl.Scheme == "https" { return false, ErrHttpsRedirect + } else if originatingScheme == "https" && newUrl.Scheme == "https" { + return false, ErrHttpsRedirect } return true, nil } // checkTLS checks tls certificates from a host, ensures they're valid, and not expired. -func checkTLS(server *Server, logFields log.Fields) (bool, error) { - host, port, err := net.SplitHostPort(server.Host) +func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) { + var host, port string + var err error + + if strings.Contains(server.Host, ":") { + host, port, err = net.SplitHostPort(server.Host) + + if err != nil { + return false, err + } + } else { + host = server.Host + } + + log.WithFields(log.Fields{ + "server": server.Host, + "host": host, + "port": port, + }).Info("Checking TLS server") if port == "" { port = "443" } - conn, err := tls.Dial("tcp", host+":"+port, checkTLSConfig) + conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{ + RootCAs: r.config.RootCAs, + }) if err != nil { return false, err @@ -107,18 +154,38 @@ func checkTLS(server *Server, logFields log.Fields) (bool, error) { state := conn.ConnectionState() - opts := x509.VerifyOptions{ - CurrentTime: time.Now(), + peerPool := x509.NewCertPool() + + for _, intermediate := range state.PeerCertificates { + if !intermediate.IsCA { + continue + } + + peerPool.AddCert(intermediate) } - for _, cert := range state.PeerCertificates { - if _, err := cert.Verify(opts); err != nil { - logFields["peerCert"] = cert.Subject.String() - return false, err - } - if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { - return false, err + opts := x509.VerifyOptions{ + Roots: r.config.RootCAs, + Intermediates: peerPool, + CurrentTime: time.Now(), + } + + // We want only the leaf certificate, as this will verify up the chain for us. + cert := state.PeerCertificates[0] + + if _, err := cert.Verify(opts); err != nil { + logFields["peerCert"] = cert.Subject.String() + + if authErr, ok := err.(x509.UnknownAuthorityError); ok { + logFields["authCert"] = authErr.Cert.Subject.String() + logFields["ca"] = authErr.Cert.Issuer } + return false, err + } + + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + logFields["peerCert"] = cert.Subject.String() + return false, err } for _, chain := range state.VerifiedChains { @@ -130,5 +197,10 @@ func checkTLS(server *Server, logFields log.Fields) (bool, error) { } } + // If https is valid, append it + if !server.Protocols.Contains("https") { + server.Protocols = server.Protocols.Append("https") + } + return true, nil } diff --git a/check_test.go b/check_test.go index 7b4c758..5be3df9 100644 --- a/check_test.go +++ b/check_test.go @@ -1,4 +1,4 @@ -package main +package redirector import ( "crypto/rand" @@ -58,11 +58,15 @@ var _ = Describe("Check suite", func() { httpServer *httptest.Server server *Server handler http.HandlerFunc + r *Redirector ) BeforeEach(func() { httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler(w, r) })) + r = New(&Config{ + RootCAs: x509.NewCertPool(), + }) }) AfterEach(func() { httpServer.Close() @@ -89,7 +93,7 @@ var _ = Describe("Check suite", func() { w.WriteHeader(http.StatusOK) } - res, err := checkHttp(server, log.Fields{}) + res, err := r.checkHttpScheme(server, "http", log.Fields{}) Expect(res).To(BeTrue()) Expect(err).To(BeNil()) @@ -100,7 +104,7 @@ var _ = Describe("Check suite", func() { w.WriteHeader(http.StatusMovedPermanently) } - res, err := checkHttp(server, log.Fields{}) + res, err := r.checkHttpScheme(server, "http", log.Fields{}) Expect(res).To(BeFalse()) Expect(err).To(Equal(ErrHttpsRedirect)) @@ -141,7 +145,7 @@ var _ = Describe("Check suite", func() { setupCerts(time.Now(), time.Now().Add(24*time.Hour)) }) It("Should fail due to invalid ca", func() { - res, err := checkTLS(server, log.Fields{}) + res, err := r.checkTLS(server, log.Fields{}) Expect(res).To(BeFalse()) Expect(err).ToNot(BeNil()) @@ -151,20 +155,15 @@ var _ = Describe("Check suite", func() { pool.AddCert(x509Cert) - checkTLSConfig = &tls.Config{RootCAs: pool} + r.config.RootCAs = pool - res, err := checkTLS(server, log.Fields{}) + res, err := r.checkTLS(server, log.Fields{}) Expect(res).To(BeFalse()) Expect(err).ToNot(BeNil()) - - checkTLSConfig = nil }) }) Context("Expiration tests", func() { - AfterEach(func() { - checkTLSConfig = nil - }) It("Should fail due to not yet valid certificate", func() { setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour)) @@ -173,10 +172,10 @@ var _ = Describe("Check suite", func() { pool.AddCert(x509Cert) - checkTLSConfig = &tls.Config{RootCAs: pool} + r.config.RootCAs = pool // Check TLS - res, err := checkTLS(server, log.Fields{}) + res, err := r.checkTLS(server, log.Fields{}) Expect(res).To(BeFalse()) Expect(err).ToNot(BeNil()) @@ -189,10 +188,10 @@ var _ = Describe("Check suite", func() { pool.AddCert(x509Cert) - checkTLSConfig = &tls.Config{RootCAs: pool} + r.config.RootCAs = pool // Check TLS - res, err := checkTLS(server, log.Fields{}) + res, err := r.checkTLS(server, log.Fields{}) Expect(res).To(BeFalse()) Expect(err).ToNot(BeNil()) diff --git a/cmd/main.go b/cmd/main.go new file mode 100644 index 0000000..e371aa2 --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,112 @@ +package main + +import ( + "flag" + "github.com/armbian/redirector" + "github.com/armbian/redirector/util" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "os" + "os/signal" + "syscall" +) + +var ( + configFlag = flag.String("config", "", "configuration file path") + flagDebug = flag.Bool("debug", false, "Enable debug logging") +) + +func main() { + flag.Parse() + + if *flagDebug { + log.SetLevel(log.DebugLevel) + } + + viper.SetDefault("bind", ":8080") + viper.SetDefault("cacheSize", 1024) + viper.SetDefault("topChoices", 3) + viper.SetDefault("reloadKey", redirector.RandomSequence(32)) + + viper.SetConfigName("dlrouter") // name of config file (without extension) + viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name + viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in + viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths + viper.AddConfigPath(".") // optionally look for config in the working directory + + if *configFlag != "" { + viper.SetConfigFile(*configFlag) + } + + config := &redirector.Config{} + + loadConfig := func(fatal bool) { + log.Info("Reading configuration") + + // Bind reload to reading in the viper config, then deserializing + if err := viper.ReadInConfig(); err != nil { + log.WithError(err).Error("Unable to unmarshal configuration") + + if fatal { + os.Exit(1) + } + } + + log.Info("Unmarshalling configuration") + + if err := viper.Unmarshal(config); err != nil { + log.WithError(err).Error("Unable to unmarshal configuration") + + if fatal { + os.Exit(1) + } + } + + log.Info("Updating root certificates") + + certs, err := util.LoadCACerts() + + if err != nil { + log.WithError(err).Error("Unable to load certificates") + + if fatal { + os.Exit(1) + } + } + + config.RootCAs = certs + } + + config.ReloadFunc = func() { + loadConfig(false) + } + + loadConfig(true) + + redir := redirector.New(config) + + // Because we have a bind address, we can start it without the return value. + redir.Start() + + log.Info("Ready") + + c := make(chan os.Signal) + + signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP) + + for { + sig := <-c + + if sig != syscall.SIGHUP { + break + } + + loadConfig(false) + + err := redir.ReloadConfig() + + if err != nil { + log.WithError(err).Warning("Did not reload configuration due to error") + } + } +} diff --git a/config.go b/config.go index 216eaf9..e105437 100644 --- a/config.go +++ b/config.go @@ -1,109 +1,149 @@ -package main +package redirector import ( + "crypto/x509" lru "github.com/hashicorp/golang-lru" "github.com/oschwald/maxminddb-golang" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" log "github.com/sirupsen/logrus" - "github.com/spf13/viper" "net" "net/url" "strings" "sync" ) -func reloadConfig() error { - log.Info("Loading configuration...") +type Config struct { + BindAddress string `mapstructure:"bind"` + GeoDBPath string `mapstructure:"geodb"` + MapFile string `mapstructure:"dl_map"` + CacheSize int `mapstructure:"cacheSize"` + TopChoices int `mapstructure:"topChoices"` + ReloadToken string `mapstructure:"reloadToken"` + ServerList []ServerConfig `mapstructure:"servers"` + ReloadFunc func() + RootCAs *x509.CertPool +} - err := viper.ReadInConfig() // Find and read the config file +type ProtocolList []string - if err != nil { // Handle errors reading the config file - return errors.Wrap(err, "Unable to read configuration") - } - - // db will never be reloaded. - if db == nil { - // Load maxmind database - db, err = maxminddb.Open(viper.GetString("geodb")) - - if err != nil { - return errors.Wrap(err, "Unable to open database") +func (p ProtocolList) Contains(value string) bool { + for _, val := range p { + if value == val { + return true } } + return false +} + +func (p ProtocolList) Append(value string) ProtocolList { + return append(p, value) +} + +func (p ProtocolList) Remove(value string) ProtocolList { + index := -1 + + for i, val := range p { + if value == val { + index = i + break + } + } + + if index == -1 { + return p + } + + p[index] = p[len(p)-1] + return p[:len(p)-1] +} + +func (r *Redirector) ReloadConfig() error { + log.Info("Loading configuration...") + + var err error + + // Load maxmind database + if r.db != nil { + err = r.db.Close() + + if err != nil { + return errors.Wrap(err, "Unable to close database") + } + } + + // db can be hot-reloaded if the file changed + r.db, err = maxminddb.Open(r.config.GeoDBPath) + + if err != nil { + return errors.Wrap(err, "Unable to open database") + } + // Refresh server cache if size changed - if serverCache == nil { - serverCache, err = lru.New(viper.GetInt("cacheSize")) + if r.serverCache == nil { + r.serverCache, err = lru.New(r.config.CacheSize) } else { - serverCache.Resize(viper.GetInt("cacheSize")) + r.serverCache.Resize(r.config.CacheSize) } // Purge the cache to ensure we don't have any invalid servers in it - serverCache.Purge() - - // Set top choice count - topChoices = viper.GetInt("topChoices") + r.serverCache.Purge() // Reload map file - if err := reloadMap(); err != nil { + if err := r.reloadMap(); err != nil { return errors.Wrap(err, "Unable to load map file") } // Reload server list - if err := reloadServers(); err != nil { + if err := r.reloadServers(); err != nil { return errors.Wrap(err, "Unable to load servers") } // Create mirror map mirrors := make(map[string][]*Server) - for _, server := range servers { + for _, server := range r.servers { mirrors[server.Continent] = append(mirrors[server.Continent], server) } mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...) - regionMap = mirrors + r.regionMap = mirrors hosts := make(map[string]*Server) - for _, server := range servers { + for _, server := range r.servers { hosts[server.Host] = server } - hostMap = hosts + r.hostMap = hosts // Check top choices size - if topChoices > len(servers) { - topChoices = len(servers) + if r.config.TopChoices > len(r.servers) { + r.config.TopChoices = len(r.servers) } // Force check - go servers.Check() + go r.servers.Check(r.checks) return nil } -func reloadServers() error { - var serverList []ServerConfig - - if err := viper.UnmarshalKey("servers", &serverList); err != nil { - return err - } - +func (r *Redirector) reloadServers() error { + log.WithField("count", len(r.config.ServerList)).Info("Loading servers") var wg sync.WaitGroup existing := make(map[string]int) - for i, server := range servers { + for i, server := range r.servers { existing[server.Host] = i } hosts := make(map[string]bool) - for _, server := range serverList { + for _, server := range r.config.ServerList { wg.Add(1) var prefix string @@ -133,19 +173,19 @@ func reloadServers() error { go func(i int, server ServerConfig, u *url.URL) { defer wg.Done() - s := addServer(server, u) + s := r.addServer(server, u) if _, ok := existing[u.Host]; ok { - s.Redirects = servers[i].Redirects + s.Redirects = r.servers[i].Redirects - servers[i] = s + r.servers[i] = s } else { s.Redirects = promauto.NewCounter(prometheus.CounterOpts{ Name: "armbian_router_redirects_" + metricReplacer.Replace(u.Host), Help: "The number of redirects for server " + u.Host, }) - servers = append(servers, s) + r.servers = append(r.servers, s) log.WithFields(log.Fields{ "server": u.Host, @@ -160,16 +200,16 @@ func reloadServers() error { wg.Wait() // Remove servers that no longer exist in the config - for i := len(servers) - 1; i >= 0; i-- { - if _, exists := hosts[servers[i].Host]; exists { + for i := len(r.servers) - 1; i >= 0; i-- { + if _, exists := hosts[r.servers[i].Host]; exists { continue } log.WithFields(log.Fields{ - "server": servers[i].Host, + "server": r.servers[i].Host, }).Info("Removed server") - servers = append(servers[:i], servers[i+1:]...) + r.servers = append(r.servers[:i], r.servers[i+1:]...) } return nil @@ -179,7 +219,7 @@ var metricReplacer = strings.NewReplacer(".", "_", "-", "_") // addServer takes ServerConfig and constructs a server. // This will create duplicate servers, but it will overwrite existing ones when changed. -func addServer(server ServerConfig, u *url.URL) *Server { +func (r *Redirector) addServer(server ServerConfig, u *url.URL) *Server { s := &Server{ Available: true, Host: u.Host, @@ -188,6 +228,15 @@ func addServer(server ServerConfig, u *url.URL) *Server { Longitude: server.Longitude, Continent: server.Continent, Weight: server.Weight, + Protocols: ProtocolList{"http", "https"}, + } + + if len(server.Protocols) > 0 { + for _, proto := range server.Protocols { + if !s.Protocols.Contains(proto) { + s.Protocols = s.Protocols.Append(proto) + } + } } // Defaults to 10 to allow servers to be set lower for lower priority @@ -206,7 +255,7 @@ func addServer(server ServerConfig, u *url.URL) *Server { } var city City - err = db.Lookup(ips[0], &city) + err = r.db.Lookup(ips[0], &city) if err != nil { log.WithFields(log.Fields{ @@ -229,8 +278,8 @@ func addServer(server ServerConfig, u *url.URL) *Server { return s } -func reloadMap() error { - mapFile := viper.GetString("dl_map") +func (r *Redirector) reloadMap() error { + mapFile := r.config.MapFile if mapFile == "" { return nil @@ -244,7 +293,7 @@ func reloadMap() error { return err } - dlMap = newMap + r.dlMap = newMap return nil } diff --git a/dlrouter.yaml b/dlrouter.yaml index c559ce9..41b4303 100644 --- a/dlrouter.yaml +++ b/dlrouter.yaml @@ -34,6 +34,10 @@ servers: - server: mirrors.bfsu.edu.cn/armbian/ - server: mirrors.dotsrc.org/armbian-apt/ weight: 15 + protocols: + - http + - https + - rsync - server: mirrors.netix.net/armbian/apt/ - server: mirrors.nju.edu.cn/armbian/ - server: mirrors.sustech.edu.cn/armbian/ diff --git a/go.mod b/go.mod index 08766c5..6c36b3a 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,11 @@ -module meow.tf/armbian-router +module github.com/armbian/redirector -go 1.17 +go 1.19 require ( github.com/chi-middleware/logrus-logger v0.2.0 github.com/go-chi/chi/v5 v5.0.7 + github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d github.com/hashicorp/golang-lru v0.5.4 github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff github.com/onsi/ginkgo/v2 v2.1.4 diff --git a/go.sum b/go.sum index 1cb0b01..0f6f734 100644 --- a/go.sum +++ b/go.sum @@ -128,7 +128,6 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -198,7 +197,6 @@ github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= @@ -210,6 +208,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0= github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d h1:Kp5G1kHMb2fAD9OiqWDXro4qLB8bQ2NusoorYya4Lbo= +github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g= github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0= github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -680,7 +680,6 @@ golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/http.go b/http.go index ec8f511..9a2b144 100644 --- a/http.go +++ b/http.go @@ -1,10 +1,9 @@ -package main +package redirector import ( "encoding/json" "fmt" "github.com/jmcvetta/randutil" - "github.com/spf13/viper" "net" "net/http" "net/url" @@ -14,10 +13,10 @@ import ( ) // statusHandler is a simple handler that will always return 200 OK with a body of "OK" -func statusHandler(w http.ResponseWriter, r *http.Request) { +func (r *Redirector) statusHandler(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) - if r.Method != http.MethodHead { + if req.Method != http.MethodHead { w.Write([]byte("OK")) } } @@ -25,8 +24,8 @@ func statusHandler(w http.ResponseWriter, r *http.Request) { // redirectHandler is the default "not found" handler which handles redirects // if the environment variable OVERRIDE_IP is set, it will use that ip address // this is useful for local testing when you're on the local network -func redirectHandler(w http.ResponseWriter, r *http.Request) { - ipStr, _, err := net.SplitHostPort(r.RemoteAddr) +func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) { + ipStr, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -50,11 +49,11 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { // If the path has a prefix of region/NA, it will use specific regions instead // of the default geographical distance - if strings.HasPrefix(r.URL.Path, "/region") { - parts := strings.Split(r.URL.Path, "/") + if strings.HasPrefix(req.URL.Path, "/region") { + parts := strings.Split(req.URL.Path, "/") // region = parts[2] - if mirrors, ok := regionMap[parts[2]]; ok { + if mirrors, ok := r.regionMap[parts[2]]; ok { choices := make([]randutil.Choice, len(mirrors)) for i, item := range mirrors { @@ -77,13 +76,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { server = choice.Item.(*Server) - r.URL.Path = strings.Join(parts[3:], "/") + req.URL.Path = strings.Join(parts[3:], "/") } } + // If we don't have a scheme, we'll use http by default + scheme := req.URL.Scheme + + if scheme == "" { + scheme = "http" + } + // If none of the above exceptions are matched, we use the geographical distance based on IP if server == nil { - server, distance, err = servers.Closest(ip) + server, distance, err = r.servers.Closest(r, scheme, ip) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -91,27 +97,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { } } - // If we don't have a scheme, we'll use https by default - scheme := r.URL.Scheme - - if scheme == "" { - scheme = "https" - } - // redirectPath is a combination of server path (which can be something like /armbian) // and the URL path. // Example: /armbian + /some/path = /armbian/some/path - redirectPath := path.Join(server.Path, r.URL.Path) + redirectPath := path.Join(server.Path, req.URL.Path) // If we have a dlMap, we map the url to a final path instead - if dlMap != nil { - if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists { + if r.dlMap != nil { + if newPath, exists := r.dlMap[strings.TrimLeft(req.URL.Path, "/")]; exists { downloadsMapped.Inc() redirectPath = path.Join(server.Path, newPath) } } - if strings.HasSuffix(r.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") { + if strings.HasSuffix(req.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") { redirectPath += "/" } @@ -136,15 +135,13 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { // reloadHandler is an http handler which lets us reload the server configuration // It is only enabled when the reloadToken is set in the configuration -func reloadHandler(w http.ResponseWriter, r *http.Request) { - expectedToken := viper.GetString("reloadToken") - - if expectedToken == "" { +func (r *Redirector) reloadHandler(w http.ResponseWriter, req *http.Request) { + if r.config.ReloadToken == "" { w.WriteHeader(http.StatusUnauthorized) return } - token := r.Header.Get("Authorization") + token := req.Header.Get("Authorization") if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") { w.WriteHeader(http.StatusUnauthorized) @@ -153,12 +150,12 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) { token = token[strings.Index(token, " ")+1:] - if token != expectedToken { + if token != r.config.ReloadToken { w.WriteHeader(http.StatusUnauthorized) return } - if err := reloadConfig(); err != nil { + if err := r.ReloadConfig(); err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error())) return @@ -168,19 +165,19 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) } -func dlMapHandler(w http.ResponseWriter, r *http.Request) { - if dlMap == nil { +func (r *Redirector) dlMapHandler(w http.ResponseWriter, req *http.Request) { + if r.dlMap == nil { w.WriteHeader(http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(dlMap) + json.NewEncoder(w).Encode(r.dlMap) } -func geoIPHandler(w http.ResponseWriter, r *http.Request) { - ipStr, _, err := net.SplitHostPort(r.RemoteAddr) +func (r *Redirector) geoIPHandler(w http.ResponseWriter, req *http.Request) { + ipStr, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -190,7 +187,7 @@ func geoIPHandler(w http.ResponseWriter, r *http.Request) { ip := net.ParseIP(ipStr) var city City - err = db.Lookup(ip, &city) + err = r.db.Lookup(ip, &city) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/main.go b/main.go deleted file mode 100644 index 9115c55..0000000 --- a/main.go +++ /dev/null @@ -1,154 +0,0 @@ -package main - -import ( - "flag" - "github.com/chi-middleware/logrus-logger" - "github.com/go-chi/chi/v5" - lru "github.com/hashicorp/golang-lru" - "github.com/oschwald/maxminddb-golang" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/prometheus/client_golang/prometheus/promhttp" - log "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "net/http" - "os" - "os/signal" - "syscall" -) - -var ( - db *maxminddb.Reader - servers ServerList - regionMap map[string][]*Server - hostMap map[string]*Server - dlMap map[string]string - topChoices int - - redirectsServed = promauto.NewCounter(prometheus.CounterOpts{ - Name: "armbian_router_redirects", - Help: "The total number of processed redirects", - }) - - downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{ - Name: "armbian_router_download_maps", - Help: "The total number of mapped download paths", - }) - - serverCache *lru.Cache -) - -type LocationLookup struct { - Location struct { - Latitude float64 `maxminddb:"latitude"` - Longitude float64 `maxminddb:"longitude"` - } `maxminddb:"location"` -} - -// City represents a MaxmindDB city -type City struct { - Continent struct { - Code string `maxminddb:"code" json:"code"` - GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"` - Names map[string]string `maxminddb:"names" json:"names"` - } `maxminddb:"continent" json:"continent"` - Country struct { - GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"` - IsoCode string `maxminddb:"iso_code" json:"iso_code"` - Names map[string]string `maxminddb:"names" json:"names"` - } `maxminddb:"country" json:"country"` - Location struct { - AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'` - Latitude float64 `maxminddb:"latitude" json:"latitude"` - Longitude float64 `maxminddb:"longitude" json:"longitude"` - } `maxminddb:"location"` - RegisteredCountry struct { - GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"` - IsoCode string `maxminddb:"iso_code" json:"iso_code"` - Names map[string]string `maxminddb:"names" json:"names"` - } `maxminddb:"registered_country" json:"registered_country"` -} - -type ServerConfig struct { - Server string `mapstructure:"server" yaml:"server"` - Latitude float64 `mapstructure:"latitude" yaml:"latitude"` - Longitude float64 `mapstructure:"longitude" yaml:"longitude"` - Continent string `mapstructure:"continent"` - Weight int `mapstructure:"weight" yaml:"weight"` -} - -var ( - configFlag = flag.String("config", "", "configuration file path") - flagDebug = flag.Bool("debug", false, "Enable debug logging") -) - -func main() { - flag.Parse() - - if *flagDebug { - log.SetLevel(log.DebugLevel) - } - - viper.SetDefault("bind", ":8080") - viper.SetDefault("cacheSize", 1024) - viper.SetDefault("topChoices", 3) - viper.SetDefault("reloadKey", randSeq(32)) - - viper.SetConfigName("dlrouter") // name of config file (without extension) - viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name - viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in - viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths - viper.AddConfigPath(".") // optionally look for config in the working directory - - if *configFlag != "" { - viper.SetConfigFile(*configFlag) - } - - if err := reloadConfig(); err != nil { - log.WithError(err).Fatalln("Unable to load configuration") - } - - // Start check loop - go servers.checkLoop() - - log.Info("Starting") - - r := chi.NewRouter() - - r.Use(RealIPMiddleware) - r.Use(logger.Logger("router", log.StandardLogger())) - - r.Head("/status", statusHandler) - r.Get("/status", statusHandler) - r.Get("/mirrors", legacyMirrorsHandler) - r.Get("/mirrors/{server}.svg", mirrorStatusHandler) - r.Get("/mirrors.json", mirrorsHandler) - r.Post("/reload", reloadHandler) - r.Get("/dl_map", dlMapHandler) - r.Get("/geoip", geoIPHandler) - r.Get("/metrics", promhttp.Handler().ServeHTTP) - - r.NotFound(redirectHandler) - - go http.ListenAndServe(viper.GetString("bind"), r) - - log.Info("Ready") - - c := make(chan os.Signal) - - signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP) - - for { - sig := <-c - - if sig != syscall.SIGHUP { - break - } - - err := reloadConfig() - - if err != nil { - log.WithError(err).Warning("Did not reload configuration due to error") - } - } -} diff --git a/map.go b/map.go index 265862d..d372447 100644 --- a/map.go +++ b/map.go @@ -1,4 +1,4 @@ -package main +package redirector import ( "encoding/csv" diff --git a/map_test.go b/map_test.go index c9fa3ca..e149f35 100644 --- a/map_test.go +++ b/map_test.go @@ -1,4 +1,4 @@ -package main +package redirector import ( . "github.com/onsi/ginkgo/v2" diff --git a/middleware.go b/middleware/middleware.go similarity index 98% rename from middleware.go rename to middleware/middleware.go index 3d08b00..22f4610 100644 --- a/middleware.go +++ b/middleware/middleware.go @@ -1,4 +1,4 @@ -package main +package middleware import ( "net" diff --git a/mirrors.go b/mirrors.go index 269da2b..462166a 100644 --- a/mirrors.go +++ b/mirrors.go @@ -1,4 +1,4 @@ -package main +package redirector import ( _ "embed" @@ -11,16 +11,16 @@ import ( // legacyMirrorsHandler will list the mirrors by region in the legacy format // it is preferred to use mirrors.json, but this handler is here for build support -func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { +func (r *Redirector) legacyMirrorsHandler(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Type", "application/json") mirrorOutput := make(map[string][]string) - for region, mirrors := range regionMap { + for region, mirrors := range r.regionMap { list := make([]string, len(mirrors)) for i, mirror := range mirrors { - list[i] = r.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/") + list[i] = req.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/") } mirrorOutput[region] = list @@ -30,9 +30,9 @@ func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { } // mirrorsHandler is a simple handler that will return the list of servers -func mirrorsHandler(w http.ResponseWriter, r *http.Request) { +func (r *Redirector) mirrorsHandler(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(servers) + json.NewEncoder(w).Encode(r.servers) } var ( @@ -48,8 +48,8 @@ var ( // mirrorStatusHandler is a fancy svg-returning handler. // it is used to display mirror statuses on a config repo of sorts -func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { - serverHost := chi.URLParam(r, "server") +func (r *Redirector) mirrorStatusHandler(w http.ResponseWriter, req *http.Request) { + serverHost := chi.URLParam(req, "server") w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8") w.Header().Set("Cache-Control", "max-age=120") @@ -61,7 +61,7 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { serverHost = strings.Replace(serverHost, "_", ".", -1) - server, ok := hostMap[serverHost] + server, ok := r.hostMap[serverHost] if !ok { w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown))) @@ -77,7 +77,7 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("ETag", "\""+key+"\"") - if match := r.Header.Get("If-None-Match"); match != "" { + if match := req.Header.Get("If-None-Match"); match != "" { if strings.Trim(match, "\"") == key { w.WriteHeader(http.StatusNotModified) return diff --git a/redirector.go b/redirector.go new file mode 100644 index 0000000..269bf86 --- /dev/null +++ b/redirector.go @@ -0,0 +1,130 @@ +package redirector + +import ( + "github.com/armbian/redirector/middleware" + "github.com/chi-middleware/logrus-logger" + "github.com/go-chi/chi/v5" + lru "github.com/hashicorp/golang-lru" + "github.com/oschwald/maxminddb-golang" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" + log "github.com/sirupsen/logrus" + "net/http" +) + +var ( + redirectsServed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "armbian_router_redirects", + Help: "The total number of processed redirects", + }) + + downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{ + Name: "armbian_router_download_maps", + Help: "The total number of mapped download paths", + }) +) + +type Redirector struct { + config *Config + db *maxminddb.Reader + servers ServerList + regionMap map[string][]*Server + hostMap map[string]*Server + dlMap map[string]string + topChoices int + serverCache *lru.Cache + checks []ServerCheck +} + +type LocationLookup struct { + Location struct { + Latitude float64 `maxminddb:"latitude"` + Longitude float64 `maxminddb:"longitude"` + } `maxminddb:"location"` +} + +// City represents a MaxmindDB city +type City struct { + Continent struct { + Code string `maxminddb:"code" json:"code"` + GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"` + Names map[string]string `maxminddb:"names" json:"names"` + } `maxminddb:"continent" json:"continent"` + Country struct { + GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"` + IsoCode string `maxminddb:"iso_code" json:"iso_code"` + Names map[string]string `maxminddb:"names" json:"names"` + } `maxminddb:"country" json:"country"` + Location struct { + AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'` + Latitude float64 `maxminddb:"latitude" json:"latitude"` + Longitude float64 `maxminddb:"longitude" json:"longitude"` + } `maxminddb:"location"` + RegisteredCountry struct { + GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"` + IsoCode string `maxminddb:"iso_code" json:"iso_code"` + Names map[string]string `maxminddb:"names" json:"names"` + } `maxminddb:"registered_country" json:"registered_country"` +} + +type ServerConfig struct { + Server string `mapstructure:"server" yaml:"server"` + Latitude float64 `mapstructure:"latitude" yaml:"latitude"` + Longitude float64 `mapstructure:"longitude" yaml:"longitude"` + Continent string `mapstructure:"continent"` + Weight int `mapstructure:"weight" yaml:"weight"` + Protocols []string `mapstructure:"protocols" yaml:"protocols"` +} + +// New creates a new instance of Redirector +func New(config *Config) *Redirector { + r := &Redirector{ + config: config, + } + + r.checks = []ServerCheck{ + r.checkHttp("http"), + r.checkTLS, + } + + return r +} + +func (r *Redirector) Start() http.Handler { + if err := r.ReloadConfig(); err != nil { + log.WithError(err).Fatalln("Unable to load configuration") + } + + log.Info("Starting check loop") + + // Start check loop + go r.servers.checkLoop(r.checks) + + log.Info("Setting up routes") + + router := chi.NewRouter() + + router.Use(middleware.RealIPMiddleware) + router.Use(logger.Logger("router", log.StandardLogger())) + + router.Head("/status", r.statusHandler) + router.Get("/status", r.statusHandler) + router.Get("/mirrors", r.legacyMirrorsHandler) + router.Get("/mirrors/{server}.svg", r.mirrorStatusHandler) + router.Get("/mirrors.json", r.mirrorsHandler) + router.Post("/reload", r.reloadHandler) + router.Get("/dl_map", r.dlMapHandler) + router.Get("/geoip", r.geoIPHandler) + router.Get("/metrics", promhttp.Handler().ServeHTTP) + + router.NotFound(r.redirectHandler) + + if r.config.BindAddress != "" { + log.WithField("bind", r.config.BindAddress).Info("Binding to address") + + go http.ListenAndServe(r.config.BindAddress, router) + } + + return router +} diff --git a/servers.go b/servers.go index 425936c..7961ef1 100644 --- a/servers.go +++ b/servers.go @@ -1,7 +1,6 @@ -package main +package redirector import ( - "crypto/tls" "github.com/jmcvetta/randutil" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -20,13 +19,6 @@ var ( return http.ErrUseLastResponse }, } - - checkTLSConfig *tls.Config = nil - - checks = []serverCheck{ - checkHttp, - checkTLS, - } ) // Server represents a download server @@ -38,14 +30,15 @@ type Server struct { Longitude float64 `json:"longitude"` Weight int `json:"weight"` Continent string `json:"continent"` + Protocols ProtocolList `json:"protocols"` Redirects prometheus.Counter `json:"-"` LastChange time.Time `json:"lastChange"` } -type serverCheck func(server *Server, logFields log.Fields) (bool, error) +type ServerCheck func(server *Server, logFields log.Fields) (bool, error) // checkStatus runs all status checks against a server -func (server *Server) checkStatus() { +func (server *Server) checkStatus(checks []ServerCheck) { logFields := log.Fields{ "host": server.Host, } @@ -87,19 +80,19 @@ func (server *Server) checkStatus() { type ServerList []*Server -func (s ServerList) checkLoop() { +func (s ServerList) checkLoop(checks []ServerCheck) { t := time.NewTicker(60 * time.Second) for { <-t.C - s.Check() + s.Check(checks) } } // Check will request the index from all servers // If a server does not respond in 10 seconds, it is considered offline. // This will wait until all checks are complete. -func (s ServerList) Check() { +func (s ServerList) Check(checks []ServerCheck) { var wg sync.WaitGroup for _, server := range s { @@ -108,7 +101,7 @@ func (s ServerList) Check() { go func(server *Server) { defer wg.Done() - server.checkStatus() + server.checkStatus(checks) }(server) } @@ -127,12 +120,12 @@ type DistanceList []ComputedDistance // Closest will use GeoIP on the IP provided and find the closest servers. // When we have a list of x servers closest, we can choose a random or weighted one. // Return values are the closest server, the distance, and if an error occurred. -func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { - choiceInterface, exists := serverCache.Get(ip.String()) +func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, float64, error) { + choiceInterface, exists := r.serverCache.Get(scheme + "_" + ip.String()) if !exists { var city LocationLookup - err := db.Lookup(ip, &city) + err := r.db.Lookup(ip, &city) if err != nil { return nil, -1, err @@ -141,7 +134,7 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { c := make(DistanceList, len(s)) for i, server := range s { - if !server.Available { + if !server.Available || !server.Protocols.Contains(scheme) { continue } @@ -158,9 +151,9 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { return c[i].Distance < c[j].Distance }) - choiceCount := topChoices + choiceCount := r.config.TopChoices - if len(c) < topChoices { + if len(c) < r.config.TopChoices { choiceCount = len(c) } @@ -179,7 +172,7 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { choiceInterface = choices - serverCache.Add(ip.String(), choiceInterface) + r.serverCache.Add(scheme+"_"+ip.String(), choiceInterface) } choice, err := randutil.WeightedChoice(choiceInterface.([]randutil.Choice)) @@ -192,9 +185,9 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { if !dist.Server.Available { // Choose a new server and refresh cache - serverCache.Remove(ip.String()) + r.serverCache.Remove(scheme + "_" + ip.String()) - return s.Closest(ip) + return s.Closest(r, scheme, ip) } return dist.Server, dist.Distance, nil @@ -206,9 +199,10 @@ func hsin(theta float64) float64 { } // Distance function returns the distance (in meters) between two points of -// a given longitude and latitude relatively accurately (using a spherical -// approximation of the Earth) through the Haversine Distance Formula for -// great arc distance on a sphere with accuracy for small distances +// +// a given longitude and latitude relatively accurately (using a spherical +// approximation of the Earth) through the Haversine Distance Formula for +// great arc distance on a sphere with accuracy for small distances // // point coordinates are supplied in degrees and converted into rad. in the func // diff --git a/util.go b/util.go index a79ece5..d52af81 100644 --- a/util.go +++ b/util.go @@ -1,10 +1,10 @@ -package main +package redirector import "math/rand" var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") -func randSeq(n int) string { +func RandomSequence(n int) string { b := make([]rune, n) for i := range b { b[i] = letters[rand.Intn(len(letters))] diff --git a/util/certificates.go b/util/certificates.go new file mode 100644 index 0000000..a1c76fd --- /dev/null +++ b/util/certificates.go @@ -0,0 +1,46 @@ +package util + +import ( + "crypto/x509" + "github.com/gwatts/rootcerts/certparse" + log "github.com/sirupsen/logrus" + "net/http" +) + +const ( + defaultDownloadURL = "https://github.com/mozilla/gecko-dev/blob/master/security/nss/lib/ckfw/builtins/certdata.txt?raw=true" +) + +func LoadCACerts() (*x509.CertPool, error) { + res, err := http.Get(defaultDownloadURL) + + if err != nil { + return nil, err + } + + defer res.Body.Close() + + certs, err := certparse.ReadTrustedCerts(res.Body) + + if err != nil { + return nil, err + } + + pool := x509.NewCertPool() + + var count int + + for _, cert := range certs { + if cert.Trust&certparse.ServerTrustedDelegator == 0 { + continue + } + + count++ + + pool.AddCert(cert.Cert) + } + + log.WithField("certs", count).Info("Loaded root cas") + + return pool, nil +}