armbian-router/check.go

208 lines
4.7 KiB
Go
Raw Normal View History

package redirector
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"net"
"net/http"
"net/url"
"runtime"
"strings"
"time"
)
var (
ErrHttpsRedirect = errors.New("unexpected forced https redirect")
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
ErrCertExpired = errors.New("certificate is expired")
)
func (r *Redirector) checkHttp(scheme string) ServerCheck {
return func(server *Server, logFields log.Fields) (bool, error) {
return r.checkHttpScheme(server, scheme, logFields)
}
}
// checkHttp checks a URL for validity, and checks redirects
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
u := &url.URL{
Scheme: scheme,
Host: server.Host,
Path: server.Path,
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
if err != nil {
return false, err
}
res, err := r.config.checkClient.Do(req)
if err != nil {
return false, err
}
logFields["responseCode"] = res.StatusCode
if res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound || res.StatusCode == http.StatusNotFound {
if res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound {
location := res.Header.Get("Location")
logFields["url"] = location
switch u.Scheme {
case "http":
res, err := r.checkRedirect(u.Scheme, location)
if !res || err != nil {
// If we don't support http, we remove it from supported protocols
server.Protocols = server.Protocols.Remove("http")
} else {
// Otherwise, we verify https support
r.checkProtocol(server, "https")
}
case "https":
// We don't want to allow downgrading, so this is an error.
return r.checkRedirect(u.Scheme, location)
}
}
return true, nil
}
logFields["cause"] = fmt.Sprintf("Unexpected http status %d", res.StatusCode)
return false, nil
}
func (r *Redirector) checkProtocol(server *Server, scheme string) {
res, err := r.checkHttpScheme(server, scheme, log.Fields{})
if !res || err != nil {
return
}
if !server.Protocols.Contains(scheme) {
server.Protocols = server.Protocols.Append(scheme)
}
}
// checkRedirect parses a location header response and checks the scheme
func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
newUrl, err := url.Parse(locationHeader)
if err != nil {
return false, err
}
if newUrl.Scheme == "https" {
return false, ErrHttpsRedirect
} else if originatingScheme == "https" && newUrl.Scheme == "http" {
return false, ErrHttpRedirect
}
return true, nil
}
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) {
var host, port string
var err error
if strings.Contains(server.Host, ":") {
host, port, err = net.SplitHostPort(server.Host)
if err != nil {
return false, err
}
} else {
host = server.Host
}
log.WithFields(log.Fields{
"server": server.Host,
"host": host,
"port": port,
}).Info("Checking TLS server")
if port == "" {
port = "443"
}
conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{
RootCAs: r.config.RootCAs,
})
if err != nil {
return false, err
}
defer conn.Close()
err = conn.VerifyHostname(server.Host)
if err != nil {
return false, err
}
now := time.Now()
state := conn.ConnectionState()
peerPool := x509.NewCertPool()
for _, intermediate := range state.PeerCertificates {
if !intermediate.IsCA {
continue
}
peerPool.AddCert(intermediate)
}
opts := x509.VerifyOptions{
Roots: r.config.RootCAs,
Intermediates: peerPool,
CurrentTime: time.Now(),
}
// We want only the leaf certificate, as this will verify up the chain for us.
cert := state.PeerCertificates[0]
if _, err := cert.Verify(opts); err != nil {
logFields["peerCert"] = cert.Subject.String()
if authErr, ok := err.(x509.UnknownAuthorityError); ok {
logFields["authCert"] = authErr.Cert.Subject.String()
logFields["ca"] = authErr.Cert.Issuer
}
return false, err
}
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
logFields["peerCert"] = cert.Subject.String()
return false, err
}
for _, chain := range state.VerifiedChains {
for _, cert := range chain {
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
logFields["cert"] = cert.Subject.String()
return false, ErrCertExpired
}
}
}
// If https is valid, append it
if !server.Protocols.Contains("https") {
server.Protocols = server.Protocols.Append("https")
}
return true, nil
}