Resolve issues with checks, forced http detection
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Tyler 2022-08-15 02:37:51 -04:00
parent 7f27df70ae
commit caa0fb43e4
6 changed files with 62 additions and 49 deletions

3
.gitignore vendored
View File

@ -3,4 +3,5 @@ userdata.csv
dlrouter-apt.yaml
*.yaml
!dlrouter.yaml
*.exe
*.exe
cover.out

View File

@ -16,6 +16,7 @@ import (
var (
ErrHttpsRedirect = errors.New("unexpected forced https redirect")
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
ErrCertExpired = errors.New("certificate is expired")
)
@ -28,7 +29,7 @@ func (r *Redirector) checkHttp(scheme string) ServerCheck {
// checkHttp checks a URL for validity, and checks redirects
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
u := &url.URL{
Scheme: "http",
Scheme: scheme,
Host: server.Host,
Path: server.Path,
}
@ -41,7 +42,7 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
return false, err
}
res, err := checkClient.Do(req)
res, err := r.config.checkClient.Do(req)
if err != nil {
return false, err
@ -102,8 +103,8 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
if newUrl.Scheme == "https" {
return false, ErrHttpsRedirect
} else if originatingScheme == "https" && newUrl.Scheme == "https" {
return false, ErrHttpsRedirect
} else if originatingScheme == "https" && newUrl.Scheme == "http" {
return false, ErrHttpRedirect
}
return true, nil

View File

@ -64,9 +64,8 @@ var _ = Describe("Check suite", func() {
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r)
}))
r = New(&Config{
RootCAs: x509.NewCertPool(),
})
r = New(&Config{})
r.config.SetRootCAs(x509.NewCertPool())
})
AfterEach(func() {
httpServer.Close()
@ -98,17 +97,6 @@ var _ = Describe("Check suite", func() {
Expect(res).To(BeTrue())
Expect(err).To(BeNil())
})
It("Should return an error when redirected to https", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1))
w.WriteHeader(http.StatusMovedPermanently)
}
res, err := r.checkHttpScheme(server, "http", log.Fields{})
Expect(res).To(BeFalse())
Expect(err).To(Equal(ErrHttpsRedirect))
})
})
Context("TLS Checks", func() {
var (
@ -137,26 +125,48 @@ var _ = Describe("Check suite", func() {
Certificates: []tls.Certificate{tlsPair},
}
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.SetRootCAs(pool)
httpServer.StartTLS()
setupServer()
}
Context("HTTPS Checks", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should return an error when redirected to http from https", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1))
w.WriteHeader(http.StatusMovedPermanently)
}
logFields := log.Fields{}
res, err := r.checkHttpScheme(server, "https", logFields)
Expect(logFields["url"]).ToNot(BeEmpty())
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
Expect(err).To(Equal(ErrHttpRedirect))
Expect(res).To(BeFalse())
})
})
Context("CA Tests", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should fail due to invalid ca", func() {
r.config.SetRootCAs(x509.NewCertPool())
res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.RootCAs = pool
res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
@ -167,13 +177,6 @@ var _ = Describe("Check suite", func() {
It("Should fail due to not yet valid certificate", func() {
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.RootCAs = pool
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
@ -183,13 +186,6 @@ var _ = Describe("Check suite", func() {
It("Should fail due to expired certificate", func() {
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.RootCAs = pool
// Check TLS
res, err := r.checkTLS(server, log.Fields{})

View File

@ -1,6 +1,7 @@
package redirector
import (
"crypto/tls"
"crypto/x509"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang"
@ -9,9 +10,11 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
type Config struct {
@ -24,6 +27,27 @@ type Config struct {
ServerList []ServerConfig `mapstructure:"servers"`
ReloadFunc func()
RootCAs *x509.CertPool
checkClient *http.Client
}
// SetRootCAs sets the root ca files, and creates the http client for checks
// This **MUST** be called before r.checkClient is used.
func (c *Config) SetRootCAs(cas *x509.CertPool) {
c.RootCAs = cas
t := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: cas,
},
}
c.checkClient = &http.Client{
Transport: t,
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
type ProtocolList []string

View File

@ -35,6 +35,7 @@ type Redirector struct {
topChoices int
serverCache *lru.Cache
checks []ServerCheck
checkClient *http.Client
}
type LocationLookup struct {

View File

@ -6,21 +6,11 @@ import (
log "github.com/sirupsen/logrus"
"math"
"net"
"net/http"
"sort"
"sync"
"time"
)
var (
checkClient = &http.Client{
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
)
// Server represents a download server
type Server struct {
Available bool `json:"available"`