Resolve issues with checks, forced http detection
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Tyler 2022-08-15 02:37:51 -04:00
parent 7f27df70ae
commit caa0fb43e4
6 changed files with 62 additions and 49 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ dlrouter-apt.yaml
*.yaml *.yaml
!dlrouter.yaml !dlrouter.yaml
*.exe *.exe
cover.out

View File

@ -16,6 +16,7 @@ import (
var ( var (
ErrHttpsRedirect = errors.New("unexpected forced https redirect") ErrHttpsRedirect = errors.New("unexpected forced https redirect")
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
ErrCertExpired = errors.New("certificate is expired") ErrCertExpired = errors.New("certificate is expired")
) )
@ -28,7 +29,7 @@ func (r *Redirector) checkHttp(scheme string) ServerCheck {
// checkHttp checks a URL for validity, and checks redirects // checkHttp checks a URL for validity, and checks redirects
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) { func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
u := &url.URL{ u := &url.URL{
Scheme: "http", Scheme: scheme,
Host: server.Host, Host: server.Host,
Path: server.Path, Path: server.Path,
} }
@ -41,7 +42,7 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
return false, err return false, err
} }
res, err := checkClient.Do(req) res, err := r.config.checkClient.Do(req)
if err != nil { if err != nil {
return false, err return false, err
@ -102,8 +103,8 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
if newUrl.Scheme == "https" { if newUrl.Scheme == "https" {
return false, ErrHttpsRedirect return false, ErrHttpsRedirect
} else if originatingScheme == "https" && newUrl.Scheme == "https" { } else if originatingScheme == "https" && newUrl.Scheme == "http" {
return false, ErrHttpsRedirect return false, ErrHttpRedirect
} }
return true, nil return true, nil

View File

@ -64,9 +64,8 @@ var _ = Describe("Check suite", func() {
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r) handler(w, r)
})) }))
r = New(&Config{ r = New(&Config{})
RootCAs: x509.NewCertPool(), r.config.SetRootCAs(x509.NewCertPool())
})
}) })
AfterEach(func() { AfterEach(func() {
httpServer.Close() httpServer.Close()
@ -98,17 +97,6 @@ var _ = Describe("Check suite", func() {
Expect(res).To(BeTrue()) Expect(res).To(BeTrue())
Expect(err).To(BeNil()) Expect(err).To(BeNil())
}) })
It("Should return an error when redirected to https", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1))
w.WriteHeader(http.StatusMovedPermanently)
}
res, err := r.checkHttpScheme(server, "http", log.Fields{})
Expect(res).To(BeFalse())
Expect(err).To(Equal(ErrHttpsRedirect))
})
}) })
Context("TLS Checks", func() { Context("TLS Checks", func() {
var ( var (
@ -137,26 +125,48 @@ var _ = Describe("Check suite", func() {
Certificates: []tls.Certificate{tlsPair}, Certificates: []tls.Certificate{tlsPair},
} }
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.SetRootCAs(pool)
httpServer.StartTLS() httpServer.StartTLS()
setupServer() setupServer()
} }
Context("HTTPS Checks", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should return an error when redirected to http from https", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1))
w.WriteHeader(http.StatusMovedPermanently)
}
logFields := log.Fields{}
res, err := r.checkHttpScheme(server, "https", logFields)
Expect(logFields["url"]).ToNot(BeEmpty())
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
Expect(err).To(Equal(ErrHttpRedirect))
Expect(res).To(BeFalse())
})
})
Context("CA Tests", func() { Context("CA Tests", func() {
BeforeEach(func() { BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour)) setupCerts(time.Now(), time.Now().Add(24*time.Hour))
}) })
It("Should fail due to invalid ca", func() { It("Should fail due to invalid ca", func() {
r.config.SetRootCAs(x509.NewCertPool())
res, err := r.checkTLS(server, log.Fields{}) res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse()) Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil()) Expect(err).ToNot(BeNil())
}) })
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() { It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.RootCAs = pool
res, err := r.checkTLS(server, log.Fields{}) res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse()) Expect(res).To(BeFalse())
@ -167,13 +177,6 @@ var _ = Describe("Check suite", func() {
It("Should fail due to not yet valid certificate", func() { It("Should fail due to not yet valid certificate", func() {
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour)) setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.RootCAs = pool
// Check TLS // Check TLS
res, err := r.checkTLS(server, log.Fields{}) res, err := r.checkTLS(server, log.Fields{})
@ -183,13 +186,6 @@ var _ = Describe("Check suite", func() {
It("Should fail due to expired certificate", func() { It("Should fail due to expired certificate", func() {
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour)) setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.RootCAs = pool
// Check TLS // Check TLS
res, err := r.checkTLS(server, log.Fields{}) res, err := r.checkTLS(server, log.Fields{})

View File

@ -1,6 +1,7 @@
package redirector package redirector
import ( import (
"crypto/tls"
"crypto/x509" "crypto/x509"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
@ -9,9 +10,11 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
) )
type Config struct { type Config struct {
@ -24,6 +27,27 @@ type Config struct {
ServerList []ServerConfig `mapstructure:"servers"` ServerList []ServerConfig `mapstructure:"servers"`
ReloadFunc func() ReloadFunc func()
RootCAs *x509.CertPool RootCAs *x509.CertPool
checkClient *http.Client
}
// SetRootCAs sets the root ca files, and creates the http client for checks
// This **MUST** be called before r.checkClient is used.
func (c *Config) SetRootCAs(cas *x509.CertPool) {
c.RootCAs = cas
t := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: cas,
},
}
c.checkClient = &http.Client{
Transport: t,
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
} }
type ProtocolList []string type ProtocolList []string

View File

@ -35,6 +35,7 @@ type Redirector struct {
topChoices int topChoices int
serverCache *lru.Cache serverCache *lru.Cache
checks []ServerCheck checks []ServerCheck
checkClient *http.Client
} }
type LocationLookup struct { type LocationLookup struct {

View File

@ -6,21 +6,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math" "math"
"net" "net"
"net/http"
"sort" "sort"
"sync" "sync"
"time" "time"
) )
var (
checkClient = &http.Client{
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
)
// Server represents a download server // Server represents a download server
type Server struct { type Server struct {
Available bool `json:"available"` Available bool `json:"available"`