diff --git a/check.go b/check.go new file mode 100644 index 0000000..f809b6b --- /dev/null +++ b/check.go @@ -0,0 +1,127 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "net/http" + "net/url" + "runtime" + "time" +) + +var ( + ErrHttpsRedirect = errors.New("unexpected forced https redirect") + ErrCertExpired = errors.New("certificate is expired") +) + +// checkHttp checks a URL for validity, and checks redirects +func checkHttp(server *Server, logFields log.Fields) (bool, error) { + u := &url.URL{ + Scheme: "http", + Host: server.Host, + Path: server.Path, + } + + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + + req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")") + + if err != nil { + return false, err + } + + res, err := checkClient.Do(req) + + if err != nil { + return false, err + } + + logFields["responseCode"] = res.StatusCode + + if res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound || res.StatusCode == http.StatusNotFound { + if res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound { + location := res.Header.Get("Location") + + logFields["url"] = location + + // Check that we don't redirect to https from a http url + if u.Scheme == "http" { + res, err := checkRedirect(location) + + if !res || err != nil { + return res, err + } + } + } + + return true, nil + } + + logFields["cause"] = fmt.Sprintf("Unexpected http status %d", res.StatusCode) + + return false, nil +} + +// checkRedirect parses a location header response and checks the scheme +func checkRedirect(locationHeader string) (bool, error) { + newUrl, err := url.Parse(locationHeader) + + if err != nil { + return false, err + } + + if newUrl.Scheme == "https" { + return false, ErrHttpsRedirect + } + + return true, nil +} + +// checkTLS checks tls certificates from a host, ensures they're valid, and not expired. +func checkTLS(server *Server, logFields log.Fields) (bool, error) { + conn, err := tls.Dial("tcp", server.Host+":443", nil) + + if err != nil { + return false, err + } + + defer conn.Close() + + err = conn.VerifyHostname(server.Host) + + if err != nil { + return false, err + } + + now := time.Now() + + state := conn.ConnectionState() + + opts := x509.VerifyOptions{ + CurrentTime: time.Now(), + } + + for _, cert := range state.PeerCertificates { + if _, err := cert.Verify(opts); err != nil { + logFields["peerCert"] = cert.Subject.String() + return false, err + } + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return false, err + } + } + + for _, chain := range state.VerifiedChains { + for _, cert := range chain { + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + logFields["cert"] = cert.Subject.String() + return false, ErrCertExpired + } + } + } + + return true, nil +} diff --git a/config.go b/config.go index 1b17844..ab7171b 100644 --- a/config.go +++ b/config.go @@ -39,6 +39,9 @@ func reloadConfig() { serverCache.Resize(viper.GetInt("cacheSize")) } + // Purge the cache to ensure we don't have any invalid servers in it + serverCache.Purge() + // Set top choice count topChoices = viper.GetInt("topChoices") diff --git a/servers.go b/servers.go index 26aecaa..5b2c1ed 100644 --- a/servers.go +++ b/servers.go @@ -7,10 +7,7 @@ import ( "math" "net" "net/http" - "net/url" - "runtime" "sort" - "strings" "sync" "time" ) @@ -18,6 +15,14 @@ import ( var ( checkClient = &http.Client{ Timeout: 20 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + checks = []serverCheck{ + checkHttp, + checkTLS, } ) @@ -33,87 +38,45 @@ type Server struct { LastChange time.Time `json:"lastChange"` } +type serverCheck func(server *Server, logFields log.Fields) (bool, error) + +// checkStatus runs all status checks against a server func (server *Server) checkStatus() { - req, err := http.NewRequest(http.MethodGet, "http://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil) - - req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")") - - if err != nil { - // This should never happen. - log.WithFields(log.Fields{ - "server": server.Host, - "error": err, - }).Warning("Invalid request! This should not happen, please check config.") - return + logFields := log.Fields{ + "host": server.Host, } - res, err := checkClient.Do(req) + var res bool + var err error - if err != nil { + for _, check := range checks { + res, err = check(server, logFields) + + if err != nil { + logFields["error"] = err + } + + if !res { + break + } + } + + if !res { if server.Available { - log.WithFields(log.Fields{ - "server": server.Host, - "error": err, - }).Info("Server went offline") + log.WithFields(logFields).Info("Server went offline") server.Available = false server.LastChange = time.Now() } else { - log.WithFields(log.Fields{ - "server": server.Host, - "error": err, - }).Debug("Server is still offline") + log.WithFields(logFields).Debug("Server is still offline") } + return - } - - responseFields := log.Fields{ - "server": server.Host, - "responseCode": res.StatusCode, - } - - if res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound || res.StatusCode == http.StatusNotFound { - if res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound { - location := res.Header.Get("Location") - - responseFields["url"] = location - - log.WithFields(responseFields).Debug("Server responded with redirect") - - newUrl, err := url.Parse(location) - - if err != nil { - if server.Available { - log.WithFields(responseFields).Warning("Server returned invalid url") - server.Available = false - server.LastChange = time.Now() - } - return - } - - if newUrl.Scheme == "https" { - if server.Available { - responseFields["url"] = location - log.WithFields(responseFields).Warning("Server returned https url for http request") - server.Available = false - server.LastChange = time.Now() - } - return - } - } - + } else { if !server.Available { server.Available = true server.LastChange = time.Now() - log.WithFields(responseFields).Info("Server is online") - } - } else { - log.WithFields(responseFields).Debug("Server status not known") - - if server.Available { - log.WithFields(responseFields).Info("Server went offline") - server.Available = false - server.LastChange = time.Now() + log.WithFields(logFields).Info("Server is online") } } } @@ -223,6 +186,13 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { dist := choice.Item.(ComputedDistance) + if !dist.Server.Available { + // Choose a new server and refresh cache + serverCache.Remove(ip.String()) + + return s.Closest(ip) + } + return dist.Server, dist.Distance, nil } @@ -233,7 +203,7 @@ func hsin(theta float64) float64 { // Distance function returns the distance (in meters) between two points of // a given longitude and latitude relatively accurately (using a spherical -// approximation of the Earth) through the Haversin Distance Formula for +// approximation of the Earth) through the Haversine Distance Formula for // great arc distance on a sphere with accuracy for small distances // // point coordinates are supplied in degrees and converted into rad. in the func