diff --git a/.gitignore b/.gitignore index 9e4a423..22b7cfc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ userdata.csv dlrouter-apt.yaml *.yaml !dlrouter.yaml -*.exe \ No newline at end of file +*.exe +cover.out \ No newline at end of file diff --git a/check.go b/check.go index 93552d4..a9e587d 100644 --- a/check.go +++ b/check.go @@ -16,6 +16,7 @@ import ( var ( ErrHttpsRedirect = errors.New("unexpected forced https redirect") + ErrHttpRedirect = errors.New("unexpected redirect to insecure url") ErrCertExpired = errors.New("certificate is expired") ) @@ -28,7 +29,7 @@ func (r *Redirector) checkHttp(scheme string) ServerCheck { // checkHttp checks a URL for validity, and checks redirects func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) { u := &url.URL{ - Scheme: "http", + Scheme: scheme, Host: server.Host, Path: server.Path, } @@ -41,7 +42,7 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo return false, err } - res, err := checkClient.Do(req) + res, err := r.config.checkClient.Do(req) if err != nil { return false, err @@ -102,8 +103,8 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo if newUrl.Scheme == "https" { return false, ErrHttpsRedirect - } else if originatingScheme == "https" && newUrl.Scheme == "https" { - return false, ErrHttpsRedirect + } else if originatingScheme == "https" && newUrl.Scheme == "http" { + return false, ErrHttpRedirect } return true, nil diff --git a/check_test.go b/check_test.go index 5be3df9..2a5bed0 100644 --- a/check_test.go +++ b/check_test.go @@ -64,9 +64,8 @@ var _ = Describe("Check suite", func() { httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler(w, r) })) - r = New(&Config{ - RootCAs: x509.NewCertPool(), - }) + r = New(&Config{}) + r.config.SetRootCAs(x509.NewCertPool()) }) AfterEach(func() { httpServer.Close() @@ -98,17 +97,6 @@ var _ = Describe("Check suite", func() { Expect(res).To(BeTrue()) Expect(err).To(BeNil()) }) - It("Should return an error when redirected to https", func() { - handler = func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1)) - w.WriteHeader(http.StatusMovedPermanently) - } - - res, err := r.checkHttpScheme(server, "http", log.Fields{}) - - Expect(res).To(BeFalse()) - Expect(err).To(Equal(ErrHttpsRedirect)) - }) }) Context("TLS Checks", func() { var ( @@ -137,26 +125,48 @@ var _ = Describe("Check suite", func() { Certificates: []tls.Certificate{tlsPair}, } + pool := x509.NewCertPool() + + pool.AddCert(x509Cert) + + r.config.SetRootCAs(pool) + httpServer.StartTLS() setupServer() } + Context("HTTPS Checks", func() { + BeforeEach(func() { + setupCerts(time.Now(), time.Now().Add(24*time.Hour)) + }) + It("Should return an error when redirected to http from https", func() { + handler = func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1)) + w.WriteHeader(http.StatusMovedPermanently) + } + + logFields := log.Fields{} + + res, err := r.checkHttpScheme(server, "https", logFields) + + Expect(logFields["url"]).ToNot(BeEmpty()) + Expect(logFields["url"]).ToNot(Equal(httpServer.URL)) + Expect(err).To(Equal(ErrHttpRedirect)) + Expect(res).To(BeFalse()) + }) + }) Context("CA Tests", func() { BeforeEach(func() { setupCerts(time.Now(), time.Now().Add(24*time.Hour)) }) It("Should fail due to invalid ca", func() { + r.config.SetRootCAs(x509.NewCertPool()) + res, err := r.checkTLS(server, log.Fields{}) Expect(res).To(BeFalse()) Expect(err).ToNot(BeNil()) }) It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() { - pool := x509.NewCertPool() - - pool.AddCert(x509Cert) - - r.config.RootCAs = pool - res, err := r.checkTLS(server, log.Fields{}) Expect(res).To(BeFalse()) @@ -167,13 +177,6 @@ var _ = Describe("Check suite", func() { It("Should fail due to not yet valid certificate", func() { setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour)) - // Trust our certs - pool := x509.NewCertPool() - - pool.AddCert(x509Cert) - - r.config.RootCAs = pool - // Check TLS res, err := r.checkTLS(server, log.Fields{}) @@ -183,13 +186,6 @@ var _ = Describe("Check suite", func() { It("Should fail due to expired certificate", func() { setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour)) - // Trust our certs - pool := x509.NewCertPool() - - pool.AddCert(x509Cert) - - r.config.RootCAs = pool - // Check TLS res, err := r.checkTLS(server, log.Fields{}) diff --git a/config.go b/config.go index e105437..c23bbbc 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package redirector import ( + "crypto/tls" "crypto/x509" lru "github.com/hashicorp/golang-lru" "github.com/oschwald/maxminddb-golang" @@ -9,9 +10,11 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" log "github.com/sirupsen/logrus" "net" + "net/http" "net/url" "strings" "sync" + "time" ) type Config struct { @@ -24,6 +27,27 @@ type Config struct { ServerList []ServerConfig `mapstructure:"servers"` ReloadFunc func() RootCAs *x509.CertPool + checkClient *http.Client +} + +// SetRootCAs sets the root ca files, and creates the http client for checks +// This **MUST** be called before r.checkClient is used. +func (c *Config) SetRootCAs(cas *x509.CertPool) { + c.RootCAs = cas + + t := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: cas, + }, + } + + c.checkClient = &http.Client{ + Transport: t, + Timeout: 20 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } } type ProtocolList []string diff --git a/redirector.go b/redirector.go index 269bf86..76e7e7a 100644 --- a/redirector.go +++ b/redirector.go @@ -35,6 +35,7 @@ type Redirector struct { topChoices int serverCache *lru.Cache checks []ServerCheck + checkClient *http.Client } type LocationLookup struct { diff --git a/servers.go b/servers.go index 7961ef1..edf3762 100644 --- a/servers.go +++ b/servers.go @@ -6,21 +6,11 @@ import ( log "github.com/sirupsen/logrus" "math" "net" - "net/http" "sort" "sync" "time" ) -var ( - checkClient = &http.Client{ - Timeout: 20 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } -) - // Server represents a download server type Server struct { Available bool `json:"available"`