Resolve issues with checks, forced http detection
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
7f27df70ae
commit
caa0fb43e4
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,4 +3,5 @@ userdata.csv
|
||||
dlrouter-apt.yaml
|
||||
*.yaml
|
||||
!dlrouter.yaml
|
||||
*.exe
|
||||
*.exe
|
||||
cover.out
|
9
check.go
9
check.go
@ -16,6 +16,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrHttpsRedirect = errors.New("unexpected forced https redirect")
|
||||
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
|
||||
ErrCertExpired = errors.New("certificate is expired")
|
||||
)
|
||||
|
||||
@ -28,7 +29,7 @@ func (r *Redirector) checkHttp(scheme string) ServerCheck {
|
||||
// checkHttp checks a URL for validity, and checks redirects
|
||||
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: scheme,
|
||||
Host: server.Host,
|
||||
Path: server.Path,
|
||||
}
|
||||
@ -41,7 +42,7 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
|
||||
return false, err
|
||||
}
|
||||
|
||||
res, err := checkClient.Do(req)
|
||||
res, err := r.config.checkClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -102,8 +103,8 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
|
||||
|
||||
if newUrl.Scheme == "https" {
|
||||
return false, ErrHttpsRedirect
|
||||
} else if originatingScheme == "https" && newUrl.Scheme == "https" {
|
||||
return false, ErrHttpsRedirect
|
||||
} else if originatingScheme == "https" && newUrl.Scheme == "http" {
|
||||
return false, ErrHttpRedirect
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
@ -64,9 +64,8 @@ var _ = Describe("Check suite", func() {
|
||||
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r)
|
||||
}))
|
||||
r = New(&Config{
|
||||
RootCAs: x509.NewCertPool(),
|
||||
})
|
||||
r = New(&Config{})
|
||||
r.config.SetRootCAs(x509.NewCertPool())
|
||||
})
|
||||
AfterEach(func() {
|
||||
httpServer.Close()
|
||||
@ -98,17 +97,6 @@ var _ = Describe("Check suite", func() {
|
||||
Expect(res).To(BeTrue())
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
It("Should return an error when redirected to https", func() {
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1))
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
}
|
||||
|
||||
res, err := r.checkHttpScheme(server, "http", log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).To(Equal(ErrHttpsRedirect))
|
||||
})
|
||||
})
|
||||
Context("TLS Checks", func() {
|
||||
var (
|
||||
@ -137,26 +125,48 @@ var _ = Describe("Check suite", func() {
|
||||
Certificates: []tls.Certificate{tlsPair},
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
|
||||
pool.AddCert(x509Cert)
|
||||
|
||||
r.config.SetRootCAs(pool)
|
||||
|
||||
httpServer.StartTLS()
|
||||
setupServer()
|
||||
}
|
||||
Context("HTTPS Checks", func() {
|
||||
BeforeEach(func() {
|
||||
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
||||
})
|
||||
It("Should return an error when redirected to http from https", func() {
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1))
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
}
|
||||
|
||||
logFields := log.Fields{}
|
||||
|
||||
res, err := r.checkHttpScheme(server, "https", logFields)
|
||||
|
||||
Expect(logFields["url"]).ToNot(BeEmpty())
|
||||
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
|
||||
Expect(err).To(Equal(ErrHttpRedirect))
|
||||
Expect(res).To(BeFalse())
|
||||
})
|
||||
})
|
||||
Context("CA Tests", func() {
|
||||
BeforeEach(func() {
|
||||
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
||||
})
|
||||
It("Should fail due to invalid ca", func() {
|
||||
r.config.SetRootCAs(x509.NewCertPool())
|
||||
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).ToNot(BeNil())
|
||||
})
|
||||
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
|
||||
pool := x509.NewCertPool()
|
||||
|
||||
pool.AddCert(x509Cert)
|
||||
|
||||
r.config.RootCAs = pool
|
||||
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
@ -167,13 +177,6 @@ var _ = Describe("Check suite", func() {
|
||||
It("Should fail due to not yet valid certificate", func() {
|
||||
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
|
||||
|
||||
// Trust our certs
|
||||
pool := x509.NewCertPool()
|
||||
|
||||
pool.AddCert(x509Cert)
|
||||
|
||||
r.config.RootCAs = pool
|
||||
|
||||
// Check TLS
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
|
||||
@ -183,13 +186,6 @@ var _ = Describe("Check suite", func() {
|
||||
It("Should fail due to expired certificate", func() {
|
||||
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
|
||||
|
||||
// Trust our certs
|
||||
pool := x509.NewCertPool()
|
||||
|
||||
pool.AddCert(x509Cert)
|
||||
|
||||
r.config.RootCAs = pool
|
||||
|
||||
// Check TLS
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
|
||||
|
24
config.go
24
config.go
@ -1,6 +1,7 @@
|
||||
package redirector
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
@ -9,9 +10,11 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@ -24,6 +27,27 @@ type Config struct {
|
||||
ServerList []ServerConfig `mapstructure:"servers"`
|
||||
ReloadFunc func()
|
||||
RootCAs *x509.CertPool
|
||||
checkClient *http.Client
|
||||
}
|
||||
|
||||
// SetRootCAs sets the root ca files, and creates the http client for checks
|
||||
// This **MUST** be called before r.checkClient is used.
|
||||
func (c *Config) SetRootCAs(cas *x509.CertPool) {
|
||||
c.RootCAs = cas
|
||||
|
||||
t := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: cas,
|
||||
},
|
||||
}
|
||||
|
||||
c.checkClient = &http.Client{
|
||||
Transport: t,
|
||||
Timeout: 20 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type ProtocolList []string
|
||||
|
@ -35,6 +35,7 @@ type Redirector struct {
|
||||
topChoices int
|
||||
serverCache *lru.Cache
|
||||
checks []ServerCheck
|
||||
checkClient *http.Client
|
||||
}
|
||||
|
||||
type LocationLookup struct {
|
||||
|
10
servers.go
10
servers.go
@ -6,21 +6,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
checkClient = &http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Server represents a download server
|
||||
type Server struct {
|
||||
Available bool `json:"available"`
|
||||
|
Loading…
Reference in New Issue
Block a user