Resolve issues with checks, forced http detection
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
7f27df70ae
commit
caa0fb43e4
|
@ -4,3 +4,4 @@ dlrouter-apt.yaml
|
||||||
*.yaml
|
*.yaml
|
||||||
!dlrouter.yaml
|
!dlrouter.yaml
|
||||||
*.exe
|
*.exe
|
||||||
|
cover.out
|
9
check.go
9
check.go
|
@ -16,6 +16,7 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrHttpsRedirect = errors.New("unexpected forced https redirect")
|
ErrHttpsRedirect = errors.New("unexpected forced https redirect")
|
||||||
|
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
|
||||||
ErrCertExpired = errors.New("certificate is expired")
|
ErrCertExpired = errors.New("certificate is expired")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ func (r *Redirector) checkHttp(scheme string) ServerCheck {
|
||||||
// checkHttp checks a URL for validity, and checks redirects
|
// checkHttp checks a URL for validity, and checks redirects
|
||||||
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
|
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: "http",
|
Scheme: scheme,
|
||||||
Host: server.Host,
|
Host: server.Host,
|
||||||
Path: server.Path,
|
Path: server.Path,
|
||||||
}
|
}
|
||||||
|
@ -41,7 +42,7 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := checkClient.Do(req)
|
res, err := r.config.checkClient.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -102,8 +103,8 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
|
||||||
|
|
||||||
if newUrl.Scheme == "https" {
|
if newUrl.Scheme == "https" {
|
||||||
return false, ErrHttpsRedirect
|
return false, ErrHttpsRedirect
|
||||||
} else if originatingScheme == "https" && newUrl.Scheme == "https" {
|
} else if originatingScheme == "https" && newUrl.Scheme == "http" {
|
||||||
return false, ErrHttpsRedirect
|
return false, ErrHttpRedirect
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|
|
@ -64,9 +64,8 @@ var _ = Describe("Check suite", func() {
|
||||||
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
}))
|
}))
|
||||||
r = New(&Config{
|
r = New(&Config{})
|
||||||
RootCAs: x509.NewCertPool(),
|
r.config.SetRootCAs(x509.NewCertPool())
|
||||||
})
|
|
||||||
})
|
})
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
httpServer.Close()
|
httpServer.Close()
|
||||||
|
@ -98,17 +97,6 @@ var _ = Describe("Check suite", func() {
|
||||||
Expect(res).To(BeTrue())
|
Expect(res).To(BeTrue())
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
})
|
})
|
||||||
It("Should return an error when redirected to https", func() {
|
|
||||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1))
|
|
||||||
w.WriteHeader(http.StatusMovedPermanently)
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := r.checkHttpScheme(server, "http", log.Fields{})
|
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
|
||||||
Expect(err).To(Equal(ErrHttpsRedirect))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
Context("TLS Checks", func() {
|
Context("TLS Checks", func() {
|
||||||
var (
|
var (
|
||||||
|
@ -137,26 +125,48 @@ var _ = Describe("Check suite", func() {
|
||||||
Certificates: []tls.Certificate{tlsPair},
|
Certificates: []tls.Certificate{tlsPair},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
|
||||||
|
pool.AddCert(x509Cert)
|
||||||
|
|
||||||
|
r.config.SetRootCAs(pool)
|
||||||
|
|
||||||
httpServer.StartTLS()
|
httpServer.StartTLS()
|
||||||
setupServer()
|
setupServer()
|
||||||
}
|
}
|
||||||
|
Context("HTTPS Checks", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
||||||
|
})
|
||||||
|
It("Should return an error when redirected to http from https", func() {
|
||||||
|
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1))
|
||||||
|
w.WriteHeader(http.StatusMovedPermanently)
|
||||||
|
}
|
||||||
|
|
||||||
|
logFields := log.Fields{}
|
||||||
|
|
||||||
|
res, err := r.checkHttpScheme(server, "https", logFields)
|
||||||
|
|
||||||
|
Expect(logFields["url"]).ToNot(BeEmpty())
|
||||||
|
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
|
||||||
|
Expect(err).To(Equal(ErrHttpRedirect))
|
||||||
|
Expect(res).To(BeFalse())
|
||||||
|
})
|
||||||
|
})
|
||||||
Context("CA Tests", func() {
|
Context("CA Tests", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
||||||
})
|
})
|
||||||
It("Should fail due to invalid ca", func() {
|
It("Should fail due to invalid ca", func() {
|
||||||
|
r.config.SetRootCAs(x509.NewCertPool())
|
||||||
|
|
||||||
res, err := r.checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
Expect(err).ToNot(BeNil())
|
Expect(err).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
|
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
|
||||||
pool := x509.NewCertPool()
|
|
||||||
|
|
||||||
pool.AddCert(x509Cert)
|
|
||||||
|
|
||||||
r.config.RootCAs = pool
|
|
||||||
|
|
||||||
res, err := r.checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
|
@ -167,13 +177,6 @@ var _ = Describe("Check suite", func() {
|
||||||
It("Should fail due to not yet valid certificate", func() {
|
It("Should fail due to not yet valid certificate", func() {
|
||||||
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
|
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
|
||||||
|
|
||||||
// Trust our certs
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
|
|
||||||
pool.AddCert(x509Cert)
|
|
||||||
|
|
||||||
r.config.RootCAs = pool
|
|
||||||
|
|
||||||
// Check TLS
|
// Check TLS
|
||||||
res, err := r.checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
|
@ -183,13 +186,6 @@ var _ = Describe("Check suite", func() {
|
||||||
It("Should fail due to expired certificate", func() {
|
It("Should fail due to expired certificate", func() {
|
||||||
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
|
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
|
||||||
|
|
||||||
// Trust our certs
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
|
|
||||||
pool.AddCert(x509Cert)
|
|
||||||
|
|
||||||
r.config.RootCAs = pool
|
|
||||||
|
|
||||||
// Check TLS
|
// Check TLS
|
||||||
res, err := r.checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
|
|
24
config.go
24
config.go
|
@ -1,6 +1,7 @@
|
||||||
package redirector
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
lru "github.com/hashicorp/golang-lru"
|
lru "github.com/hashicorp/golang-lru"
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
@ -9,9 +10,11 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -24,6 +27,27 @@ type Config struct {
|
||||||
ServerList []ServerConfig `mapstructure:"servers"`
|
ServerList []ServerConfig `mapstructure:"servers"`
|
||||||
ReloadFunc func()
|
ReloadFunc func()
|
||||||
RootCAs *x509.CertPool
|
RootCAs *x509.CertPool
|
||||||
|
checkClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRootCAs sets the root ca files, and creates the http client for checks
|
||||||
|
// This **MUST** be called before r.checkClient is used.
|
||||||
|
func (c *Config) SetRootCAs(cas *x509.CertPool) {
|
||||||
|
c.RootCAs = cas
|
||||||
|
|
||||||
|
t := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: cas,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.checkClient = &http.Client{
|
||||||
|
Transport: t,
|
||||||
|
Timeout: 20 * time.Second,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProtocolList []string
|
type ProtocolList []string
|
||||||
|
|
|
@ -35,6 +35,7 @@ type Redirector struct {
|
||||||
topChoices int
|
topChoices int
|
||||||
serverCache *lru.Cache
|
serverCache *lru.Cache
|
||||||
checks []ServerCheck
|
checks []ServerCheck
|
||||||
|
checkClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocationLookup struct {
|
type LocationLookup struct {
|
||||||
|
|
10
servers.go
10
servers.go
|
@ -6,21 +6,11 @@ import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
checkClient = &http.Client{
|
|
||||||
Timeout: 20 * time.Second,
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server represents a download server
|
// Server represents a download server
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Available bool `json:"available"`
|
Available bool `json:"available"`
|
||||||
|
|
Loading…
Reference in New Issue