diff --git a/check.go b/check.go index f809b6b..afdb0fd 100644 --- a/check.go +++ b/check.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" log "github.com/sirupsen/logrus" + "net" "net/http" "net/url" "runtime" @@ -82,7 +83,13 @@ func checkRedirect(locationHeader string) (bool, error) { // checkTLS checks tls certificates from a host, ensures they're valid, and not expired. func checkTLS(server *Server, logFields log.Fields) (bool, error) { - conn, err := tls.Dial("tcp", server.Host+":443", nil) + host, port, err := net.SplitHostPort(server.Host) + + if port == "" { + port = "443" + } + + conn, err := tls.Dial("tcp", host+":"+port, checkTLSConfig) if err != nil { return false, err diff --git a/check_test.go b/check_test.go index 2418498..7b4c758 100644 --- a/check_test.go +++ b/check_test.go @@ -1,40 +1,88 @@ package main import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "math/big" + "net" "net/http" "net/http/httptest" "net/url" "strings" + "time" ) +func genTestCerts(notBefore, notAfter time.Time) (*pem.Block, *pem.Block, error) { + // Create a Certificate Authority Cert + template := x509.Certificate{ + SerialNumber: big.NewInt(0), + Subject: pkix.Name{CommonName: "localhost"}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: notBefore, + NotAfter: notAfter, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + } + + // Create a Private Key + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, fmt.Errorf("Could not generate rsa key - %s", err) + } + + // Use CA Cert to sign a CSR and create a Public Cert + csr := &key.PublicKey + cert, err := x509.CreateCertificate(rand.Reader, &template, &template, csr, key) + if err != nil { + return nil, nil, fmt.Errorf("Could not generate certificate - %s", err) + } + + // Convert keys into pem.Block + c := &pem.Block{Type: "CERTIFICATE", Bytes: cert} + k := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)} + return c, k, nil +} + var _ = Describe("Check suite", func() { + var ( + httpServer *httptest.Server + server *Server + handler http.HandlerFunc + ) + BeforeEach(func() { + httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler(w, r) + })) + }) + AfterEach(func() { + httpServer.Close() + }) + setupServer := func() { + u, err := url.Parse(httpServer.URL) + + if err != nil { + panic(err) + } + server = &Server{ + Host: u.Host, + Path: u.Path, + } + } + Context("HTTP Checks", func() { - var ( - httpServer *httptest.Server - server *Server - handler http.HandlerFunc - ) BeforeEach(func() { - httpServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handler(w, r) - })) - - u, err := url.Parse(httpServer.URL) - - if err != nil { - panic(err) - } - - server = &Server{ - Host: u.Host, - Path: u.Path, - } - }) - AfterEach(func() { - httpServer.Close() + httpServer.Start() + setupServer() }) It("Should successfully check for connectivity", func() { handler = func(w http.ResponseWriter, r *http.Request) { @@ -59,6 +107,96 @@ var _ = Describe("Check suite", func() { }) }) Context("TLS Checks", func() { + var ( + x509Cert *x509.Certificate + ) + setupCerts := func(notBefore, notAfter time.Time) { + cert, key, err := genTestCerts(notBefore, notAfter) + if err != nil { + panic("Unable to generate test certs") + } + + x509Cert, err = x509.ParseCertificate(cert.Bytes) + + if err != nil { + panic("Unable to parse certificate from bytes: " + err.Error()) + } + + tlsPair, err := tls.X509KeyPair(pem.EncodeToMemory(cert), pem.EncodeToMemory(key)) + + if err != nil { + panic("Unable to load tls key pair: " + err.Error()) + } + + httpServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{tlsPair}, + } + + httpServer.StartTLS() + setupServer() + } + Context("CA Tests", func() { + BeforeEach(func() { + setupCerts(time.Now(), time.Now().Add(24*time.Hour)) + }) + It("Should fail due to invalid ca", func() { + res, err := checkTLS(server, log.Fields{}) + + Expect(res).To(BeFalse()) + Expect(err).ToNot(BeNil()) + }) + It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() { + pool := x509.NewCertPool() + + pool.AddCert(x509Cert) + + checkTLSConfig = &tls.Config{RootCAs: pool} + + res, err := checkTLS(server, log.Fields{}) + + Expect(res).To(BeFalse()) + Expect(err).ToNot(BeNil()) + + checkTLSConfig = nil + }) + }) + Context("Expiration tests", func() { + AfterEach(func() { + checkTLSConfig = nil + }) + It("Should fail due to not yet valid certificate", func() { + setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour)) + + // Trust our certs + pool := x509.NewCertPool() + + pool.AddCert(x509Cert) + + checkTLSConfig = &tls.Config{RootCAs: pool} + + // Check TLS + res, err := checkTLS(server, log.Fields{}) + + Expect(res).To(BeFalse()) + Expect(err).ToNot(BeNil()) + }) + It("Should fail due to expired certificate", func() { + setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour)) + + // Trust our certs + pool := x509.NewCertPool() + + pool.AddCert(x509Cert) + + checkTLSConfig = &tls.Config{RootCAs: pool} + + // Check TLS + res, err := checkTLS(server, log.Fields{}) + + Expect(res).To(BeFalse()) + Expect(err).ToNot(BeNil()) + }) + }) }) }) diff --git a/config.go b/config.go index ab7171b..216eaf9 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package main import ( lru "github.com/hashicorp/golang-lru" "github.com/oschwald/maxminddb-golang" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" log "github.com/sirupsen/logrus" @@ -13,13 +14,13 @@ import ( "sync" ) -func reloadConfig() { +func reloadConfig() error { log.Info("Loading configuration...") err := viper.ReadInConfig() // Find and read the config file if err != nil { // Handle errors reading the config file - log.WithError(err).Fatalln("Unable to load config file") + return errors.Wrap(err, "Unable to read configuration") } // db will never be reloaded. @@ -28,7 +29,7 @@ func reloadConfig() { db, err = maxminddb.Open(viper.GetString("geodb")) if err != nil { - log.WithError(err).Fatalln("Unable to open database") + return errors.Wrap(err, "Unable to open database") } } @@ -46,10 +47,14 @@ func reloadConfig() { topChoices = viper.GetInt("topChoices") // Reload map file - reloadMap() + if err := reloadMap(); err != nil { + return errors.Wrap(err, "Unable to load map file") + } // Reload server list - reloadServers() + if err := reloadServers(); err != nil { + return errors.Wrap(err, "Unable to load servers") + } // Create mirror map mirrors := make(map[string][]*Server) @@ -77,11 +82,16 @@ func reloadConfig() { // Force check go servers.Check() + + return nil } -func reloadServers() { +func reloadServers() error { var serverList []ServerConfig - viper.UnmarshalKey("servers", &serverList) + + if err := viper.UnmarshalKey("servers", &serverList); err != nil { + return err + } var wg sync.WaitGroup @@ -109,7 +119,7 @@ func reloadServers() { "error": err, "server": server, }).Warning("Server is invalid") - return + return err } hosts[u.Host] = true @@ -161,6 +171,8 @@ func reloadServers() { servers = append(servers[:i], servers[i+1:]...) } + + return nil } var metricReplacer = strings.NewReplacer(".", "_", "-", "_") @@ -217,20 +229,22 @@ func addServer(server ServerConfig, u *url.URL) *Server { return s } -func reloadMap() { +func reloadMap() error { mapFile := viper.GetString("dl_map") if mapFile == "" { - return + return nil } log.WithField("file", mapFile).Info("Loading download map") - newMap, err := loadMap(mapFile) + newMap, err := loadMapFile(mapFile) if err != nil { - return + return err } dlMap = newMap + + return nil } diff --git a/http.go b/http.go index fd47234..ec8f511 100644 --- a/http.go +++ b/http.go @@ -13,11 +13,18 @@ import ( "strings" ) +// statusHandler is a simple handler that will always return 200 OK with a body of "OK" func statusHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) + + if r.Method != http.MethodHead { + w.Write([]byte("OK")) + } } +// redirectHandler is the default "not found" handler which handles redirects +// if the environment variable OVERRIDE_IP is set, it will use that ip address +// this is useful for local testing when you're on the local network func redirectHandler(w http.ResponseWriter, r *http.Request) { ipStr, _, err := net.SplitHostPort(r.RemoteAddr) @@ -41,6 +48,8 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { var server *Server var distance float64 + // If the path has a prefix of region/NA, it will use specific regions instead + // of the default geographical distance if strings.HasPrefix(r.URL.Path, "/region") { parts := strings.Split(r.URL.Path, "/") @@ -72,6 +81,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { } } + // If none of the above exceptions are matched, we use the geographical distance based on IP if server == nil { server, distance, err = servers.Closest(ip) @@ -81,14 +91,19 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { } } + // If we don't have a scheme, we'll use https by default scheme := r.URL.Scheme if scheme == "" { scheme = "https" } + // redirectPath is a combination of server path (which can be something like /armbian) + // and the URL path. + // Example: /armbian + /some/path = /armbian/some/path redirectPath := path.Join(server.Path, r.URL.Path) + // If we have a dlMap, we map the url to a final path instead if dlMap != nil { if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists { downloadsMapped.Inc() @@ -100,6 +115,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { redirectPath += "/" } + // We need to build the final url now u := &url.URL{ Scheme: scheme, Host: server.Host, @@ -109,6 +125,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { server.Redirects.Inc() redirectsServed.Inc() + // If we used geographical distance, we add an X-Geo-Distance header for debug. if distance > 0 { w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) } @@ -117,7 +134,16 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusFound) } +// reloadHandler is an http handler which lets us reload the server configuration +// It is only enabled when the reloadToken is set in the configuration func reloadHandler(w http.ResponseWriter, r *http.Request) { + expectedToken := viper.GetString("reloadToken") + + if expectedToken == "" { + w.WriteHeader(http.StatusUnauthorized) + return + } + token := r.Header.Get("Authorization") if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") { @@ -127,12 +153,16 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) { token = token[strings.Index(token, " ")+1:] - if token != viper.GetString("reloadToken") { + if token != expectedToken { w.WriteHeader(http.StatusUnauthorized) return } - reloadConfig() + if err := reloadConfig(); err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) diff --git a/main.go b/main.go index 9900b1c..9115c55 100644 --- a/main.go +++ b/main.go @@ -104,7 +104,9 @@ func main() { viper.SetConfigFile(*configFlag) } - reloadConfig() + if err := reloadConfig(); err != nil { + log.WithError(err).Fatalln("Unable to load configuration") + } // Start check loop go servers.checkLoop() @@ -143,6 +145,10 @@ func main() { break } - reloadConfig() + err := reloadConfig() + + if err != nil { + log.WithError(err).Warning("Did not reload configuration due to error") + } } } diff --git a/map.go b/map.go index f6019f9..265862d 100644 --- a/map.go +++ b/map.go @@ -7,7 +7,8 @@ import ( "strings" ) -func loadMap(file string) (map[string]string, error) { +// loadMapFile loads a file as a map +func loadMapFile(file string) (map[string]string, error) { f, err := os.Open(file) if err != nil { @@ -16,6 +17,11 @@ func loadMap(file string) (map[string]string, error) { defer f.Close() + return loadMap(f) +} + +// loadMap loads a pipe separated file of mappings +func loadMap(f io.Reader) (map[string]string, error) { m := make(map[string]string) r := csv.NewReader(f) diff --git a/map_test.go b/map_test.go new file mode 100644 index 0000000..c9fa3ca --- /dev/null +++ b/map_test.go @@ -0,0 +1,16 @@ +package main + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "strings" +) + +var _ = Describe("Map", func() { + It("Should successfully load the map", func() { + m, err := loadMap(strings.NewReader(`bananapi/Bullseye_current|bananapi/archive/Armbian_21.08.1_Bananapi_bullseye_current_5.10.60.img.xz|Aug 26 2021|332M`)) + + Expect(err).To(BeNil()) + Expect(m["bananapi/Bullseye_current"]).To(Equal("bananapi/archive/Armbian_21.08.1_Bananapi_bullseye_current_5.10.60.img.xz")) + }) +}) diff --git a/mirrors.go b/mirrors.go index ed92411..269da2b 100644 --- a/mirrors.go +++ b/mirrors.go @@ -5,9 +5,12 @@ import ( "encoding/json" "github.com/go-chi/chi/v5" "net/http" + "strconv" "strings" ) +// legacyMirrorsHandler will list the mirrors by region in the legacy format +// it is preferred to use mirrors.json, but this handler is here for build support func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -26,6 +29,7 @@ func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(mirrorOutput) } +// mirrorsHandler is a simple handler that will return the list of servers func mirrorsHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(servers) @@ -42,10 +46,13 @@ var ( statusUnknown []byte ) +// mirrorStatusHandler is a fancy svg-returning handler. +// it is used to display mirror statuses on a config repo of sorts func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { serverHost := chi.URLParam(r, "server") w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8") + w.Header().Set("Cache-Control", "max-age=120") if serverHost == "" { w.Write(statusUnknown) @@ -57,13 +64,31 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { server, ok := hostMap[serverHost] if !ok { + w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown))) w.Write(statusUnknown) return } + key := "offline" + if server.Available { + key = "online" + } + + w.Header().Set("ETag", "\""+key+"\"") + + if match := r.Header.Get("If-None-Match"); match != "" { + if strings.Trim(match, "\"") == key { + w.WriteHeader(http.StatusNotModified) + return + } + } + + if server.Available { + w.Header().Set("Content-Length", strconv.Itoa(len(statusUp))) w.Write(statusUp) } else { + w.Header().Set("Content-Length", strconv.Itoa(len(statusDown))) w.Write(statusDown) } } diff --git a/servers.go b/servers.go index 5b2c1ed..425936c 100644 --- a/servers.go +++ b/servers.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "github.com/jmcvetta/randutil" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -20,12 +21,15 @@ var ( }, } + checkTLSConfig *tls.Config = nil + checks = []serverCheck{ checkHttp, checkTLS, } ) +// Server represents a download server type Server struct { Available bool `json:"available"` Host string `json:"host"`