Improve test coverage and documentation
continuous-integration/drone/push Build is failing Details

This commit is contained in:
Tyler 2022-08-14 03:42:49 -04:00
parent 3c5656284c
commit 3f71aced93
9 changed files with 287 additions and 41 deletions

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net"
"net/http" "net/http"
"net/url" "net/url"
"runtime" "runtime"
@ -82,7 +83,13 @@ func checkRedirect(locationHeader string) (bool, error) {
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired. // checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
func checkTLS(server *Server, logFields log.Fields) (bool, error) { func checkTLS(server *Server, logFields log.Fields) (bool, error) {
conn, err := tls.Dial("tcp", server.Host+":443", nil) host, port, err := net.SplitHostPort(server.Host)
if port == "" {
port = "443"
}
conn, err := tls.Dial("tcp", host+":"+port, checkTLSConfig)
if err != nil { if err != nil {
return false, err return false, err

View File

@ -1,40 +1,88 @@
package main package main
import ( import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math/big"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"time"
) )
func genTestCerts(notBefore, notAfter time.Time) (*pem.Block, *pem.Block, error) {
// Create a Certificate Authority Cert
template := x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{CommonName: "localhost"},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
}
// Create a Private Key
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("Could not generate rsa key - %s", err)
}
// Use CA Cert to sign a CSR and create a Public Cert
csr := &key.PublicKey
cert, err := x509.CreateCertificate(rand.Reader, &template, &template, csr, key)
if err != nil {
return nil, nil, fmt.Errorf("Could not generate certificate - %s", err)
}
// Convert keys into pem.Block
c := &pem.Block{Type: "CERTIFICATE", Bytes: cert}
k := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
return c, k, nil
}
var _ = Describe("Check suite", func() { var _ = Describe("Check suite", func() {
Context("HTTP Checks", func() {
var ( var (
httpServer *httptest.Server httpServer *httptest.Server
server *Server server *Server
handler http.HandlerFunc handler http.HandlerFunc
) )
BeforeEach(func() { BeforeEach(func() {
httpServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r) handler(w, r)
})) }))
})
AfterEach(func() {
httpServer.Close()
})
setupServer := func() {
u, err := url.Parse(httpServer.URL) u, err := url.Parse(httpServer.URL)
if err != nil { if err != nil {
panic(err) panic(err)
} }
server = &Server{ server = &Server{
Host: u.Host, Host: u.Host,
Path: u.Path, Path: u.Path,
} }
}) }
AfterEach(func() {
httpServer.Close() Context("HTTP Checks", func() {
BeforeEach(func() {
httpServer.Start()
setupServer()
}) })
It("Should successfully check for connectivity", func() { It("Should successfully check for connectivity", func() {
handler = func(w http.ResponseWriter, r *http.Request) { handler = func(w http.ResponseWriter, r *http.Request) {
@ -59,6 +107,96 @@ var _ = Describe("Check suite", func() {
}) })
}) })
Context("TLS Checks", func() { Context("TLS Checks", func() {
var (
x509Cert *x509.Certificate
)
setupCerts := func(notBefore, notAfter time.Time) {
cert, key, err := genTestCerts(notBefore, notAfter)
if err != nil {
panic("Unable to generate test certs")
}
x509Cert, err = x509.ParseCertificate(cert.Bytes)
if err != nil {
panic("Unable to parse certificate from bytes: " + err.Error())
}
tlsPair, err := tls.X509KeyPair(pem.EncodeToMemory(cert), pem.EncodeToMemory(key))
if err != nil {
panic("Unable to load tls key pair: " + err.Error())
}
httpServer.TLS = &tls.Config{
Certificates: []tls.Certificate{tlsPair},
}
httpServer.StartTLS()
setupServer()
}
Context("CA Tests", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should fail due to invalid ca", func() {
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
checkTLSConfig = &tls.Config{RootCAs: pool}
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
checkTLSConfig = nil
})
})
Context("Expiration tests", func() {
AfterEach(func() {
checkTLSConfig = nil
})
It("Should fail due to not yet valid certificate", func() {
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
checkTLSConfig = &tls.Config{RootCAs: pool}
// Check TLS
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
It("Should fail due to expired certificate", func() {
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
checkTLSConfig = &tls.Config{RootCAs: pool}
// Check TLS
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
})
}) })
}) })

View File

@ -3,6 +3,7 @@ package main
import ( import (
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -13,13 +14,13 @@ import (
"sync" "sync"
) )
func reloadConfig() { func reloadConfig() error {
log.Info("Loading configuration...") log.Info("Loading configuration...")
err := viper.ReadInConfig() // Find and read the config file err := viper.ReadInConfig() // Find and read the config file
if err != nil { // Handle errors reading the config file if err != nil { // Handle errors reading the config file
log.WithError(err).Fatalln("Unable to load config file") return errors.Wrap(err, "Unable to read configuration")
} }
// db will never be reloaded. // db will never be reloaded.
@ -28,7 +29,7 @@ func reloadConfig() {
db, err = maxminddb.Open(viper.GetString("geodb")) db, err = maxminddb.Open(viper.GetString("geodb"))
if err != nil { if err != nil {
log.WithError(err).Fatalln("Unable to open database") return errors.Wrap(err, "Unable to open database")
} }
} }
@ -46,10 +47,14 @@ func reloadConfig() {
topChoices = viper.GetInt("topChoices") topChoices = viper.GetInt("topChoices")
// Reload map file // Reload map file
reloadMap() if err := reloadMap(); err != nil {
return errors.Wrap(err, "Unable to load map file")
}
// Reload server list // Reload server list
reloadServers() if err := reloadServers(); err != nil {
return errors.Wrap(err, "Unable to load servers")
}
// Create mirror map // Create mirror map
mirrors := make(map[string][]*Server) mirrors := make(map[string][]*Server)
@ -77,11 +82,16 @@ func reloadConfig() {
// Force check // Force check
go servers.Check() go servers.Check()
return nil
} }
func reloadServers() { func reloadServers() error {
var serverList []ServerConfig var serverList []ServerConfig
viper.UnmarshalKey("servers", &serverList)
if err := viper.UnmarshalKey("servers", &serverList); err != nil {
return err
}
var wg sync.WaitGroup var wg sync.WaitGroup
@ -109,7 +119,7 @@ func reloadServers() {
"error": err, "error": err,
"server": server, "server": server,
}).Warning("Server is invalid") }).Warning("Server is invalid")
return return err
} }
hosts[u.Host] = true hosts[u.Host] = true
@ -161,6 +171,8 @@ func reloadServers() {
servers = append(servers[:i], servers[i+1:]...) servers = append(servers[:i], servers[i+1:]...)
} }
return nil
} }
var metricReplacer = strings.NewReplacer(".", "_", "-", "_") var metricReplacer = strings.NewReplacer(".", "_", "-", "_")
@ -217,20 +229,22 @@ func addServer(server ServerConfig, u *url.URL) *Server {
return s return s
} }
func reloadMap() { func reloadMap() error {
mapFile := viper.GetString("dl_map") mapFile := viper.GetString("dl_map")
if mapFile == "" { if mapFile == "" {
return return nil
} }
log.WithField("file", mapFile).Info("Loading download map") log.WithField("file", mapFile).Info("Loading download map")
newMap, err := loadMap(mapFile) newMap, err := loadMapFile(mapFile)
if err != nil { if err != nil {
return return err
} }
dlMap = newMap dlMap = newMap
return nil
} }

34
http.go
View File

@ -13,11 +13,18 @@ import (
"strings" "strings"
) )
// statusHandler is a simple handler that will always return 200 OK with a body of "OK"
func statusHandler(w http.ResponseWriter, r *http.Request) { func statusHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if r.Method != http.MethodHead {
w.Write([]byte("OK")) w.Write([]byte("OK"))
} }
}
// redirectHandler is the default "not found" handler which handles redirects
// if the environment variable OVERRIDE_IP is set, it will use that ip address
// this is useful for local testing when you're on the local network
func redirectHandler(w http.ResponseWriter, r *http.Request) { func redirectHandler(w http.ResponseWriter, r *http.Request) {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr) ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
@ -41,6 +48,8 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
var server *Server var server *Server
var distance float64 var distance float64
// If the path has a prefix of region/NA, it will use specific regions instead
// of the default geographical distance
if strings.HasPrefix(r.URL.Path, "/region") { if strings.HasPrefix(r.URL.Path, "/region") {
parts := strings.Split(r.URL.Path, "/") parts := strings.Split(r.URL.Path, "/")
@ -72,6 +81,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
// If none of the above exceptions are matched, we use the geographical distance based on IP
if server == nil { if server == nil {
server, distance, err = servers.Closest(ip) server, distance, err = servers.Closest(ip)
@ -81,14 +91,19 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
// If we don't have a scheme, we'll use https by default
scheme := r.URL.Scheme scheme := r.URL.Scheme
if scheme == "" { if scheme == "" {
scheme = "https" scheme = "https"
} }
// redirectPath is a combination of server path (which can be something like /armbian)
// and the URL path.
// Example: /armbian + /some/path = /armbian/some/path
redirectPath := path.Join(server.Path, r.URL.Path) redirectPath := path.Join(server.Path, r.URL.Path)
// If we have a dlMap, we map the url to a final path instead
if dlMap != nil { if dlMap != nil {
if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists { if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists {
downloadsMapped.Inc() downloadsMapped.Inc()
@ -100,6 +115,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
redirectPath += "/" redirectPath += "/"
} }
// We need to build the final url now
u := &url.URL{ u := &url.URL{
Scheme: scheme, Scheme: scheme,
Host: server.Host, Host: server.Host,
@ -109,6 +125,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
server.Redirects.Inc() server.Redirects.Inc()
redirectsServed.Inc() redirectsServed.Inc()
// If we used geographical distance, we add an X-Geo-Distance header for debug.
if distance > 0 { if distance > 0 {
w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance))
} }
@ -117,7 +134,16 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
// reloadHandler is an http handler which lets us reload the server configuration
// It is only enabled when the reloadToken is set in the configuration
func reloadHandler(w http.ResponseWriter, r *http.Request) { func reloadHandler(w http.ResponseWriter, r *http.Request) {
expectedToken := viper.GetString("reloadToken")
if expectedToken == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
token := r.Header.Get("Authorization") token := r.Header.Get("Authorization")
if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") { if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") {
@ -127,12 +153,16 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) {
token = token[strings.Index(token, " ")+1:] token = token[strings.Index(token, " ")+1:]
if token != viper.GetString("reloadToken") { if token != expectedToken {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
reloadConfig() if err := reloadConfig(); err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte("OK")) w.Write([]byte("OK"))

10
main.go
View File

@ -104,7 +104,9 @@ func main() {
viper.SetConfigFile(*configFlag) viper.SetConfigFile(*configFlag)
} }
reloadConfig() if err := reloadConfig(); err != nil {
log.WithError(err).Fatalln("Unable to load configuration")
}
// Start check loop // Start check loop
go servers.checkLoop() go servers.checkLoop()
@ -143,6 +145,10 @@ func main() {
break break
} }
reloadConfig() err := reloadConfig()
if err != nil {
log.WithError(err).Warning("Did not reload configuration due to error")
}
} }
} }

8
map.go
View File

@ -7,7 +7,8 @@ import (
"strings" "strings"
) )
func loadMap(file string) (map[string]string, error) { // loadMapFile loads a file as a map
func loadMapFile(file string) (map[string]string, error) {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
@ -16,6 +17,11 @@ func loadMap(file string) (map[string]string, error) {
defer f.Close() defer f.Close()
return loadMap(f)
}
// loadMap loads a pipe separated file of mappings
func loadMap(f io.Reader) (map[string]string, error) {
m := make(map[string]string) m := make(map[string]string)
r := csv.NewReader(f) r := csv.NewReader(f)

16
map_test.go Normal file
View File

@ -0,0 +1,16 @@
package main
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"strings"
)
var _ = Describe("Map", func() {
It("Should successfully load the map", func() {
m, err := loadMap(strings.NewReader(`bananapi/Bullseye_current|bananapi/archive/Armbian_21.08.1_Bananapi_bullseye_current_5.10.60.img.xz|Aug 26 2021|332M`))
Expect(err).To(BeNil())
Expect(m["bananapi/Bullseye_current"]).To(Equal("bananapi/archive/Armbian_21.08.1_Bananapi_bullseye_current_5.10.60.img.xz"))
})
})

View File

@ -5,9 +5,12 @@ import (
"encoding/json" "encoding/json"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"net/http" "net/http"
"strconv"
"strings" "strings"
) )
// legacyMirrorsHandler will list the mirrors by region in the legacy format
// it is preferred to use mirrors.json, but this handler is here for build support
func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -26,6 +29,7 @@ func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(mirrorOutput) json.NewEncoder(w).Encode(mirrorOutput)
} }
// mirrorsHandler is a simple handler that will return the list of servers
func mirrorsHandler(w http.ResponseWriter, r *http.Request) { func mirrorsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(servers) json.NewEncoder(w).Encode(servers)
@ -42,10 +46,13 @@ var (
statusUnknown []byte statusUnknown []byte
) )
// mirrorStatusHandler is a fancy svg-returning handler.
// it is used to display mirror statuses on a config repo of sorts
func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
serverHost := chi.URLParam(r, "server") serverHost := chi.URLParam(r, "server")
w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8") w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8")
w.Header().Set("Cache-Control", "max-age=120")
if serverHost == "" { if serverHost == "" {
w.Write(statusUnknown) w.Write(statusUnknown)
@ -57,13 +64,31 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
server, ok := hostMap[serverHost] server, ok := hostMap[serverHost]
if !ok { if !ok {
w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown)))
w.Write(statusUnknown) w.Write(statusUnknown)
return return
} }
key := "offline"
if server.Available { if server.Available {
key = "online"
}
w.Header().Set("ETag", "\""+key+"\"")
if match := r.Header.Get("If-None-Match"); match != "" {
if strings.Trim(match, "\"") == key {
w.WriteHeader(http.StatusNotModified)
return
}
}
if server.Available {
w.Header().Set("Content-Length", strconv.Itoa(len(statusUp)))
w.Write(statusUp) w.Write(statusUp)
} else { } else {
w.Header().Set("Content-Length", strconv.Itoa(len(statusDown)))
w.Write(statusDown) w.Write(statusDown)
} }
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"crypto/tls"
"github.com/jmcvetta/randutil" "github.com/jmcvetta/randutil"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -20,12 +21,15 @@ var (
}, },
} }
checkTLSConfig *tls.Config = nil
checks = []serverCheck{ checks = []serverCheck{
checkHttp, checkHttp,
checkTLS, checkTLS,
} }
) )
// Server represents a download server
type Server struct { type Server struct {
Available bool `json:"available"` Available bool `json:"available"`
Host string `json:"host"` Host string `json:"host"`