16 Commits

Author SHA1 Message Date
8a4d02c6e2 Cleanup hosts map
All checks were successful
continuous-integration/drone/push Build is passing
2022-08-18 23:43:58 -04:00
d22b87da10 Fix config loading issue when servers don't resolve 2022-08-18 23:42:54 -04:00
4d7c836810 Update README to include check information and code quality
All checks were successful
continuous-integration/drone/push Build is passing
2022-08-15 02:46:25 -04:00
43205337c1 Update README
All checks were successful
continuous-integration/drone/push Build is passing
2022-08-15 02:41:31 -04:00
caa0fb43e4 Resolve issues with checks, forced http detection
All checks were successful
continuous-integration/drone/push Build is passing
2022-08-15 02:37:51 -04:00
7f27df70ae Remove dependency on go.mod/go.sum for ginkgo running
Some checks failed
continuous-integration/drone/push Build is failing
2022-08-15 02:17:12 -04:00
e7236b13de Massive refactoring, struct cleanup, supporting more features
Some checks failed
continuous-integration/drone/push Build is failing
Features:
- Protocol lists (http, https), managed by http responses
- Working TLS Checks
- Root certificate parsing for TLS checks
- Moving configuration into a Config struct, no more direct viper access
2022-08-15 02:16:22 -04:00
3e7782e5ec Merge pull request 'Initial testing and improvements of code' (#1) from feature/testing into master
All checks were successful
continuous-integration/drone/push Build is passing
Reviewed-on: #1
2022-08-14 08:03:49 +00:00
2f71e97f2e Improve readme with proper information and configuration
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/pr Build is passing
2022-08-14 04:02:10 -04:00
5ff4aa9fae go.sum magically removed ginkgo
All checks were successful
continuous-integration/drone/push Build is passing
2022-08-14 03:50:11 -04:00
c4bd02485c Fix tests and add coverage
Some checks failed
continuous-integration/drone/push Build is failing
2022-08-14 03:49:28 -04:00
3f71aced93 Improve test coverage and documentation
Some checks failed
continuous-integration/drone/push Build is failing
2022-08-14 03:42:49 -04:00
3c5656284c Fix extra character on cgo environment variable
All checks were successful
continuous-integration/drone/push Build is passing
2022-08-06 16:25:05 -04:00
08da75d309 Disable cgo 2022-08-06 16:24:45 -04:00
8ea77adee2 Update go.sum
Some checks failed
continuous-integration/drone/push Build is failing
2022-08-06 16:23:50 -04:00
91b99572c2 Ensure ginkgo is installed
Some checks failed
continuous-integration/drone/push Build is failing
2022-08-06 16:22:51 -04:00
21 changed files with 1000 additions and 347 deletions

View File

@ -10,7 +10,11 @@ steps:
path: /build path: /build
commands: commands:
- go mod download - go mod download
- ginkgo . - go install -mod=mod github.com/onsi/ginkgo/v2/ginkgo
- ginkgo --randomize-all --p --cover --coverprofile=cover.out .
- go tool cover -func=cover.out
environment:
CGO_ENABLED: '0'
- name: build - name: build
image: tystuyfzand/goc:latest image: tystuyfzand/goc:latest
volumes: volumes:

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ dlrouter-apt.yaml
*.yaml *.yaml
!dlrouter.yaml !dlrouter.yaml
*.exe *.exe
cover.out

133
README.md
View File

@ -5,7 +5,8 @@ This repository contains a redirect service for Armbian downloads, apt, etc.
It uses multiple current technologies and best practices, including: It uses multiple current technologies and best practices, including:
- Go 1.17/1.18 - Go 1.19
- Ginkgo v2 and Gomega testing framework
- GeoIP + Distance routing - GeoIP + Distance routing
- Server weighting, pooling (top x servers are served instead of a single one) - Server weighting, pooling (top x servers are served instead of a single one)
- Health checks (HTTP, TLS) - Health checks (HTTP, TLS)
@ -13,6 +14,132 @@ It uses multiple current technologies and best practices, including:
Code Quality Code Quality
------------ ------------
The code quality isn't the greatest/top tier. All code lives in the "main" package and should be moved at some point. The code quality isn't the greatest/top tier. Work is being done towards cleaning it up and standardizing it, writing tests, etc.
Regardless, it is meant to be simple and easy to understand. All contributions are welcome, see the `check_test.go` file for example tests.
Checks
------
The supported checks are HTTP and TLS.
### HTTP
Verifies server accessibility via HTTP. If the server returns a forced redirect to an `https://` url, it is considered to be https-only.
If the server responds on the `https` url with a forced `http` redirect, it will be marked down due to misconfiguration. Requests should never downgrade.
### TLS
Certificate checking to ensure no servers are used which have invalid/expired certificates. This check is written to use the Mozilla ca certificate list, loaded on start/config load, to verify roots.
OS certificate trusts WERE being used to do this, however some issues with the date validation (which could be user error) caused the move to the ca bundle, which could be considered more usable.
Note: This downloads from github every startup/reload. This should be a reliable process, as long as Mozilla doesn't deprecate their repo. Their HG URL is super slow.
Configuration
-------------
### Modes
#### Redirect
Standard redirect functionality
#### Download Mapping
Uses the `dl_map` configuration variable to enable mapping of paths to new paths.
Think symlinks, but in a generated file.
### Mirrors
Mirror targets with trailing slash are placed in the yaml configuration file.
### Example YAML
```yaml
# GeoIP Database Path
geodb: GeoLite2-City.mmdb
# Comment out to disable
dl_map: userdata.csv
# LRU Cache Size (in items)
cacheSize: 1024
# Server definition
# Weights are just like nginx, where if it's > 1 it'll be chosen x out of x + total times
# By default, the top 3 servers are used for choosing the best.
# server = full url or host+path
# weight = int
# optional: latitude, longitude (float)
# optional: protocols (list/array)
servers:
- server: armbian.12z.eu/apt/
- server: armbian.chi.auroradev.org/apt/
weight: 15
latitude: 41.8879
longitude: -88.1995
# Example of a server with additional protocols (rsync)
# Useful for defining servers which could be used for rsync sources
- server: mirrors.dotsrc.org/armbian-apt/
weight: 15
protocols:
- rsync
````
## API
`/status`
Meant for a simple health check (nginx/etc can 502 or similar if down)
`/reload`
Flushes cache and reloads configuration and mapping. Requires reloadToken to be set in the configuration, and a matching token provided in `Authorization: Bearer TOKEN`
`/mirrors`
Shows all mirrors in the legacy (by region) format
`/mirrors.json`
Shows all mirrors in the new JSON format. Example:
```json
[
{
"available":true,
"host":"imola.armbian.com",
"path":"/apt/",
"latitude":46.0503,
"longitude":14.5046,
"weight":10,
"continent":"EU",
"lastChange":"2022-08-12T06:52:35.029565986Z"
}
]
```
`/mirrors/{server}.svg`
Magic SVG path to show badges based on server status, for use in dynamic mirror lists.
`/dl_map`
Shows json-encoded download mappings
`/geoip`
Shows GeoIP information for the requester
`/region/REGIONCODE/PATH`
Using this magic path will redirect to the desired region:
* NA - North America
* EU - Europe
* AS - Asia
`/metrics`
Prometheus metrics endpoint. Metrics aren't considered private, thus are exposed to the public.

View File

@ -1,4 +1,4 @@
package main package redirector
import ( import (
"testing" "testing"

106
check.go
View File

@ -1,4 +1,4 @@
package main package redirector
import ( import (
"crypto/tls" "crypto/tls"
@ -6,21 +6,30 @@ import (
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net"
"net/http" "net/http"
"net/url" "net/url"
"runtime" "runtime"
"strings"
"time" "time"
) )
var ( var (
ErrHttpsRedirect = errors.New("unexpected forced https redirect") ErrHttpsRedirect = errors.New("unexpected forced https redirect")
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
ErrCertExpired = errors.New("certificate is expired") ErrCertExpired = errors.New("certificate is expired")
) )
func (r *Redirector) checkHttp(scheme string) ServerCheck {
return func(server *Server, logFields log.Fields) (bool, error) {
return r.checkHttpScheme(server, scheme, logFields)
}
}
// checkHttp checks a URL for validity, and checks redirects // checkHttp checks a URL for validity, and checks redirects
func checkHttp(server *Server, logFields log.Fields) (bool, error) { func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
u := &url.URL{ u := &url.URL{
Scheme: "http", Scheme: scheme,
Host: server.Host, Host: server.Host,
Path: server.Path, Path: server.Path,
} }
@ -33,7 +42,7 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) {
return false, err return false, err
} }
res, err := checkClient.Do(req) res, err := r.config.checkClient.Do(req)
if err != nil { if err != nil {
return false, err return false, err
@ -47,13 +56,20 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) {
logFields["url"] = location logFields["url"] = location
// Check that we don't redirect to https from a http url switch u.Scheme {
if u.Scheme == "http" { case "http":
res, err := checkRedirect(location) res, err := r.checkRedirect(u.Scheme, location)
if !res || err != nil { if !res || err != nil {
return res, err // If we don't support http, we remove it from supported protocols
server.Protocols = server.Protocols.Remove("http")
} else {
// Otherwise, we verify https support
r.checkProtocol(server, "https")
} }
case "https":
// We don't want to allow downgrading, so this is an error.
return r.checkRedirect(u.Scheme, location)
} }
} }
@ -65,8 +81,20 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) {
return false, nil return false, nil
} }
func (r *Redirector) checkProtocol(server *Server, scheme string) {
res, err := r.checkHttpScheme(server, scheme, log.Fields{})
if !res || err != nil {
return
}
if !server.Protocols.Contains(scheme) {
server.Protocols = server.Protocols.Append(scheme)
}
}
// checkRedirect parses a location header response and checks the scheme // checkRedirect parses a location header response and checks the scheme
func checkRedirect(locationHeader string) (bool, error) { func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
newUrl, err := url.Parse(locationHeader) newUrl, err := url.Parse(locationHeader)
if err != nil { if err != nil {
@ -75,14 +103,41 @@ func checkRedirect(locationHeader string) (bool, error) {
if newUrl.Scheme == "https" { if newUrl.Scheme == "https" {
return false, ErrHttpsRedirect return false, ErrHttpsRedirect
} else if originatingScheme == "https" && newUrl.Scheme == "http" {
return false, ErrHttpRedirect
} }
return true, nil return true, nil
} }
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired. // checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
func checkTLS(server *Server, logFields log.Fields) (bool, error) { func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) {
conn, err := tls.Dial("tcp", server.Host+":443", nil) var host, port string
var err error
if strings.Contains(server.Host, ":") {
host, port, err = net.SplitHostPort(server.Host)
if err != nil {
return false, err
}
} else {
host = server.Host
}
log.WithFields(log.Fields{
"server": server.Host,
"host": host,
"port": port,
}).Info("Checking TLS server")
if port == "" {
port = "443"
}
conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{
RootCAs: r.config.RootCAs,
})
if err != nil { if err != nil {
return false, err return false, err
@ -100,19 +155,39 @@ func checkTLS(server *Server, logFields log.Fields) (bool, error) {
state := conn.ConnectionState() state := conn.ConnectionState()
peerPool := x509.NewCertPool()
for _, intermediate := range state.PeerCertificates {
if !intermediate.IsCA {
continue
}
peerPool.AddCert(intermediate)
}
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Roots: r.config.RootCAs,
Intermediates: peerPool,
CurrentTime: time.Now(), CurrentTime: time.Now(),
} }
for _, cert := range state.PeerCertificates { // We want only the leaf certificate, as this will verify up the chain for us.
cert := state.PeerCertificates[0]
if _, err := cert.Verify(opts); err != nil { if _, err := cert.Verify(opts); err != nil {
logFields["peerCert"] = cert.Subject.String() logFields["peerCert"] = cert.Subject.String()
if authErr, ok := err.(x509.UnknownAuthorityError); ok {
logFields["authCert"] = authErr.Cert.Subject.String()
logFields["ca"] = authErr.Cert.Issuer
}
return false, err return false, err
} }
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
logFields["peerCert"] = cert.Subject.String()
return false, err return false, err
} }
}
for _, chain := range state.VerifiedChains { for _, chain := range state.VerifiedChains {
for _, cert := range chain { for _, cert := range chain {
@ -123,5 +198,10 @@ func checkTLS(server *Server, logFields log.Fields) (bool, error) {
} }
} }
// If https is valid, append it
if !server.Protocols.Contains("https") {
server.Protocols = server.Protocols.Append("https")
}
return true, nil return true, nil
} }

View File

@ -1,64 +1,197 @@
package main package redirector
import ( import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math/big"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"time"
) )
func genTestCerts(notBefore, notAfter time.Time) (*pem.Block, *pem.Block, error) {
// Create a Certificate Authority Cert
template := x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{CommonName: "localhost"},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
}
// Create a Private Key
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("Could not generate rsa key - %s", err)
}
// Use CA Cert to sign a CSR and create a Public Cert
csr := &key.PublicKey
cert, err := x509.CreateCertificate(rand.Reader, &template, &template, csr, key)
if err != nil {
return nil, nil, fmt.Errorf("Could not generate certificate - %s", err)
}
// Convert keys into pem.Block
c := &pem.Block{Type: "CERTIFICATE", Bytes: cert}
k := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
return c, k, nil
}
var _ = Describe("Check suite", func() { var _ = Describe("Check suite", func() {
Context("HTTP Checks", func() {
var ( var (
httpServer *httptest.Server httpServer *httptest.Server
server *Server server *Server
handler http.HandlerFunc handler http.HandlerFunc
r *Redirector
) )
BeforeEach(func() { BeforeEach(func() {
httpServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r) handler(w, r)
})) }))
r = New(&Config{})
r.config.SetRootCAs(x509.NewCertPool())
})
AfterEach(func() {
httpServer.Close()
})
setupServer := func() {
u, err := url.Parse(httpServer.URL) u, err := url.Parse(httpServer.URL)
if err != nil { if err != nil {
panic(err) panic(err)
} }
server = &Server{ server = &Server{
Host: u.Host, Host: u.Host,
Path: u.Path, Path: u.Path,
} }
}) }
AfterEach(func() {
httpServer.Close() Context("HTTP Checks", func() {
BeforeEach(func() {
httpServer.Start()
setupServer()
}) })
It("Should successfully check for connectivity", func() { It("Should successfully check for connectivity", func() {
handler = func(w http.ResponseWriter, r *http.Request) { handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
res, err := checkHttp(server, log.Fields{}) res, err := r.checkHttpScheme(server, "http", log.Fields{})
Expect(res).To(BeTrue()) Expect(res).To(BeTrue())
Expect(err).To(BeNil()) Expect(err).To(BeNil())
}) })
It("Should return an error when redirected to https", func() { })
Context("TLS Checks", func() {
var (
x509Cert *x509.Certificate
)
setupCerts := func(notBefore, notAfter time.Time) {
cert, key, err := genTestCerts(notBefore, notAfter)
if err != nil {
panic("Unable to generate test certs")
}
x509Cert, err = x509.ParseCertificate(cert.Bytes)
if err != nil {
panic("Unable to parse certificate from bytes: " + err.Error())
}
tlsPair, err := tls.X509KeyPair(pem.EncodeToMemory(cert), pem.EncodeToMemory(key))
if err != nil {
panic("Unable to load tls key pair: " + err.Error())
}
httpServer.TLS = &tls.Config{
Certificates: []tls.Certificate{tlsPair},
}
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.SetRootCAs(pool)
httpServer.StartTLS()
setupServer()
}
Context("HTTPS Checks", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should return an error when redirected to http from https", func() {
handler = func(w http.ResponseWriter, r *http.Request) { handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1)) w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1))
w.WriteHeader(http.StatusMovedPermanently) w.WriteHeader(http.StatusMovedPermanently)
} }
res, err := checkHttp(server, log.Fields{}) logFields := log.Fields{}
res, err := r.checkHttpScheme(server, "https", logFields)
Expect(logFields["url"]).ToNot(BeEmpty())
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
Expect(err).To(Equal(ErrHttpRedirect))
Expect(res).To(BeFalse())
})
})
Context("CA Tests", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should fail due to invalid ca", func() {
r.config.SetRootCAs(x509.NewCertPool())
res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse()) Expect(res).To(BeFalse())
Expect(err).To(Equal(ErrHttpsRedirect)) Expect(err).ToNot(BeNil())
}) })
}) It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
Context("TLS Checks", func() { res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
})
Context("Expiration tests", func() {
It("Should fail due to not yet valid certificate", func() {
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
It("Should fail due to expired certificate", func() {
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
})
}) })
}) })

112
cmd/main.go Normal file
View File

@ -0,0 +1,112 @@
package main
import (
"flag"
"github.com/armbian/redirector"
"github.com/armbian/redirector/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"os"
"os/signal"
"syscall"
)
var (
configFlag = flag.String("config", "", "configuration file path")
flagDebug = flag.Bool("debug", false, "Enable debug logging")
)
func main() {
flag.Parse()
if *flagDebug {
log.SetLevel(log.DebugLevel)
}
viper.SetDefault("bind", ":8080")
viper.SetDefault("cacheSize", 1024)
viper.SetDefault("topChoices", 3)
viper.SetDefault("reloadKey", redirector.RandomSequence(32))
viper.SetConfigName("dlrouter") // name of config file (without extension)
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
viper.AddConfigPath(".") // optionally look for config in the working directory
if *configFlag != "" {
viper.SetConfigFile(*configFlag)
}
config := &redirector.Config{}
loadConfig := func(fatal bool) {
log.Info("Reading configuration")
// Bind reload to reading in the viper config, then deserializing
if err := viper.ReadInConfig(); err != nil {
log.WithError(err).Error("Unable to unmarshal configuration")
if fatal {
os.Exit(1)
}
}
log.Info("Unmarshalling configuration")
if err := viper.Unmarshal(config); err != nil {
log.WithError(err).Error("Unable to unmarshal configuration")
if fatal {
os.Exit(1)
}
}
log.Info("Updating root certificates")
certs, err := util.LoadCACerts()
if err != nil {
log.WithError(err).Error("Unable to load certificates")
if fatal {
os.Exit(1)
}
}
config.RootCAs = certs
}
config.ReloadFunc = func() {
loadConfig(false)
}
loadConfig(true)
redir := redirector.New(config)
// Because we have a bind address, we can start it without the return value.
redir.Start()
log.Info("Ready")
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP)
for {
sig := <-c
if sig != syscall.SIGHUP {
break
}
loadConfig(false)
err := redir.ReloadConfig()
if err != nil {
log.WithError(err).Warning("Did not reload configuration due to error")
}
}
}

208
config.go
View File

@ -1,99 +1,175 @@
package main package redirector
import ( import (
"crypto/tls"
"crypto/x509"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
) )
func reloadConfig() { type Config struct {
BindAddress string `mapstructure:"bind"`
GeoDBPath string `mapstructure:"geodb"`
MapFile string `mapstructure:"dl_map"`
CacheSize int `mapstructure:"cacheSize"`
TopChoices int `mapstructure:"topChoices"`
ReloadToken string `mapstructure:"reloadToken"`
ServerList []ServerConfig `mapstructure:"servers"`
ReloadFunc func()
RootCAs *x509.CertPool
checkClient *http.Client
}
// SetRootCAs sets the root ca files, and creates the http client for checks
// This **MUST** be called before r.checkClient is used.
func (c *Config) SetRootCAs(cas *x509.CertPool) {
c.RootCAs = cas
t := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: cas,
},
}
c.checkClient = &http.Client{
Transport: t,
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
type ProtocolList []string
func (p ProtocolList) Contains(value string) bool {
for _, val := range p {
if value == val {
return true
}
}
return false
}
func (p ProtocolList) Append(value string) ProtocolList {
return append(p, value)
}
func (p ProtocolList) Remove(value string) ProtocolList {
index := -1
for i, val := range p {
if value == val {
index = i
break
}
}
if index == -1 {
return p
}
p[index] = p[len(p)-1]
return p[:len(p)-1]
}
func (r *Redirector) ReloadConfig() error {
log.Info("Loading configuration...") log.Info("Loading configuration...")
err := viper.ReadInConfig() // Find and read the config file var err error
if err != nil { // Handle errors reading the config file
log.WithError(err).Fatalln("Unable to load config file")
}
// db will never be reloaded.
if db == nil {
// Load maxmind database // Load maxmind database
db, err = maxminddb.Open(viper.GetString("geodb")) if r.db != nil {
err = r.db.Close()
if err != nil { if err != nil {
log.WithError(err).Fatalln("Unable to open database") return errors.Wrap(err, "Unable to close database")
} }
} }
// db can be hot-reloaded if the file changed
r.db, err = maxminddb.Open(r.config.GeoDBPath)
if err != nil {
return errors.Wrap(err, "Unable to open database")
}
// Refresh server cache if size changed // Refresh server cache if size changed
if serverCache == nil { if r.serverCache == nil {
serverCache, err = lru.New(viper.GetInt("cacheSize")) r.serverCache, err = lru.New(r.config.CacheSize)
} else { } else {
serverCache.Resize(viper.GetInt("cacheSize")) r.serverCache.Resize(r.config.CacheSize)
} }
// Purge the cache to ensure we don't have any invalid servers in it // Purge the cache to ensure we don't have any invalid servers in it
serverCache.Purge() r.serverCache.Purge()
// Set top choice count
topChoices = viper.GetInt("topChoices")
// Reload map file // Reload map file
reloadMap() if err := r.reloadMap(); err != nil {
return errors.Wrap(err, "Unable to load map file")
}
// Reload server list // Reload server list
reloadServers() if err := r.reloadServers(); err != nil {
return errors.Wrap(err, "Unable to load servers")
}
// Create mirror map // Create mirror map
mirrors := make(map[string][]*Server) mirrors := make(map[string][]*Server)
for _, server := range servers { for _, server := range r.servers {
mirrors[server.Continent] = append(mirrors[server.Continent], server) mirrors[server.Continent] = append(mirrors[server.Continent], server)
} }
mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...) mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...)
regionMap = mirrors r.regionMap = mirrors
hosts := make(map[string]*Server) hosts := make(map[string]*Server)
for _, server := range servers { for _, server := range r.servers {
hosts[server.Host] = server hosts[server.Host] = server
} }
hostMap = hosts r.hostMap = hosts
// Check top choices size // Check top choices size
if topChoices > len(servers) { if r.config.TopChoices > len(r.servers) {
topChoices = len(servers) r.config.TopChoices = len(r.servers)
} }
// Force check // Force check
go servers.Check() go r.servers.Check(r.checks)
return nil
} }
func reloadServers() { func (r *Redirector) reloadServers() error {
var serverList []ServerConfig log.WithField("count", len(r.config.ServerList)).Info("Loading servers")
viper.UnmarshalKey("servers", &serverList)
var wg sync.WaitGroup var wg sync.WaitGroup
existing := make(map[string]int) existing := make(map[string]int)
for i, server := range servers { for i, server := range r.servers {
existing[server.Host] = i existing[server.Host] = i
} }
hosts := make(map[string]bool) hosts := make(map[string]bool)
for _, server := range serverList { var hostsLock sync.Mutex
for _, server := range r.config.ServerList {
wg.Add(1) wg.Add(1)
var prefix string var prefix string
@ -109,11 +185,9 @@ func reloadServers() {
"error": err, "error": err,
"server": server, "server": server,
}).Warning("Server is invalid") }).Warning("Server is invalid")
return return err
} }
hosts[u.Host] = true
i := -1 i := -1
if v, exists := existing[u.Host]; exists { if v, exists := existing[u.Host]; exists {
@ -123,19 +197,28 @@ func reloadServers() {
go func(i int, server ServerConfig, u *url.URL) { go func(i int, server ServerConfig, u *url.URL) {
defer wg.Done() defer wg.Done()
s := addServer(server, u) s, err := r.addServer(server, u)
if err != nil {
log.WithError(err).Warning("Unable to add server")
return
}
hostsLock.Lock()
hosts[u.Host] = true
hostsLock.Unlock()
if _, ok := existing[u.Host]; ok { if _, ok := existing[u.Host]; ok {
s.Redirects = servers[i].Redirects s.Redirects = r.servers[i].Redirects
servers[i] = s r.servers[i] = s
} else { } else {
s.Redirects = promauto.NewCounter(prometheus.CounterOpts{ s.Redirects = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_redirects_" + metricReplacer.Replace(u.Host), Name: "armbian_router_redirects_" + metricReplacer.Replace(u.Host),
Help: "The number of redirects for server " + u.Host, Help: "The number of redirects for server " + u.Host,
}) })
servers = append(servers, s) r.servers = append(r.servers, s)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"server": u.Host, "server": u.Host,
@ -150,24 +233,26 @@ func reloadServers() {
wg.Wait() wg.Wait()
// Remove servers that no longer exist in the config // Remove servers that no longer exist in the config
for i := len(servers) - 1; i >= 0; i-- { for i := len(r.servers) - 1; i >= 0; i-- {
if _, exists := hosts[servers[i].Host]; exists { if _, exists := hosts[r.servers[i].Host]; exists {
continue continue
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"server": servers[i].Host, "server": r.servers[i].Host,
}).Info("Removed server") }).Info("Removed server")
servers = append(servers[:i], servers[i+1:]...) r.servers = append(r.servers[:i], r.servers[i+1:]...)
} }
return nil
} }
var metricReplacer = strings.NewReplacer(".", "_", "-", "_") var metricReplacer = strings.NewReplacer(".", "_", "-", "_")
// addServer takes ServerConfig and constructs a server. // addServer takes ServerConfig and constructs a server.
// This will create duplicate servers, but it will overwrite existing ones when changed. // This will create duplicate servers, but it will overwrite existing ones when changed.
func addServer(server ServerConfig, u *url.URL) *Server { func (r *Redirector) addServer(server ServerConfig, u *url.URL) (*Server, error) {
s := &Server{ s := &Server{
Available: true, Available: true,
Host: u.Host, Host: u.Host,
@ -176,6 +261,15 @@ func addServer(server ServerConfig, u *url.URL) *Server {
Longitude: server.Longitude, Longitude: server.Longitude,
Continent: server.Continent, Continent: server.Continent,
Weight: server.Weight, Weight: server.Weight,
Protocols: ProtocolList{"http", "https"},
}
if len(server.Protocols) > 0 {
for _, proto := range server.Protocols {
if !s.Protocols.Contains(proto) {
s.Protocols = s.Protocols.Append(proto)
}
}
} }
// Defaults to 10 to allow servers to be set lower for lower priority // Defaults to 10 to allow servers to be set lower for lower priority
@ -190,11 +284,11 @@ func addServer(server ServerConfig, u *url.URL) *Server {
"error": err, "error": err,
"server": s.Host, "server": s.Host,
}).Warning("Could not resolve address") }).Warning("Could not resolve address")
return nil return nil, err
} }
var city City var city City
err = db.Lookup(ips[0], &city) err = r.db.Lookup(ips[0], &city)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -202,7 +296,7 @@ func addServer(server ServerConfig, u *url.URL) *Server {
"server": s.Host, "server": s.Host,
"ip": ips[0], "ip": ips[0],
}).Warning("Could not geolocate address") }).Warning("Could not geolocate address")
return nil return nil, err
} }
if s.Continent == "" { if s.Continent == "" {
@ -214,23 +308,25 @@ func addServer(server ServerConfig, u *url.URL) *Server {
s.Longitude = city.Location.Longitude s.Longitude = city.Location.Longitude
} }
return s return s, nil
} }
func reloadMap() { func (r *Redirector) reloadMap() error {
mapFile := viper.GetString("dl_map") mapFile := r.config.MapFile
if mapFile == "" { if mapFile == "" {
return return nil
} }
log.WithField("file", mapFile).Info("Loading download map") log.WithField("file", mapFile).Info("Loading download map")
newMap, err := loadMap(mapFile) newMap, err := loadMapFile(mapFile)
if err != nil { if err != nil {
return return err
} }
dlMap = newMap r.dlMap = newMap
return nil
} }

View File

@ -34,6 +34,10 @@ servers:
- server: mirrors.bfsu.edu.cn/armbian/ - server: mirrors.bfsu.edu.cn/armbian/
- server: mirrors.dotsrc.org/armbian-apt/ - server: mirrors.dotsrc.org/armbian-apt/
weight: 15 weight: 15
protocols:
- http
- https
- rsync
- server: mirrors.netix.net/armbian/apt/ - server: mirrors.netix.net/armbian/apt/
- server: mirrors.nju.edu.cn/armbian/ - server: mirrors.nju.edu.cn/armbian/
- server: mirrors.sustech.edu.cn/armbian/ - server: mirrors.sustech.edu.cn/armbian/

6
go.mod
View File

@ -1,15 +1,17 @@
module meow.tf/armbian-router module github.com/armbian/redirector
go 1.17 go 1.19
require ( require (
github.com/chi-middleware/logrus-logger v0.2.0 github.com/chi-middleware/logrus-logger v0.2.0
github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/chi/v5 v5.0.7
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff
github.com/onsi/ginkgo/v2 v2.1.4 github.com/onsi/ginkgo/v2 v2.1.4
github.com/onsi/gomega v1.20.0 github.com/onsi/gomega v1.20.0
github.com/oschwald/maxminddb-golang v1.8.0 github.com/oschwald/maxminddb-golang v1.8.0
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/spf13/viper v1.10.1 github.com/spf13/viper v1.10.1

3
go.sum
View File

@ -208,6 +208,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0= github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM= github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d h1:Kp5G1kHMb2fAD9OiqWDXro4qLB8bQ2NusoorYya4Lbo=
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g=
github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0= github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0=
github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms= github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@ -322,6 +324,7 @@ github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhEC
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

87
http.go
View File

@ -1,10 +1,9 @@
package main package redirector
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/jmcvetta/randutil" "github.com/jmcvetta/randutil"
"github.com/spf13/viper"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -13,13 +12,20 @@ import (
"strings" "strings"
) )
func statusHandler(w http.ResponseWriter, r *http.Request) { // statusHandler is a simple handler that will always return 200 OK with a body of "OK"
func (r *Redirector) statusHandler(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if req.Method != http.MethodHead {
w.Write([]byte("OK")) w.Write([]byte("OK"))
} }
}
func redirectHandler(w http.ResponseWriter, r *http.Request) { // redirectHandler is the default "not found" handler which handles redirects
ipStr, _, err := net.SplitHostPort(r.RemoteAddr) // if the environment variable OVERRIDE_IP is set, it will use that ip address
// this is useful for local testing when you're on the local network
func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
ipStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -41,11 +47,13 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
var server *Server var server *Server
var distance float64 var distance float64
if strings.HasPrefix(r.URL.Path, "/region") { // If the path has a prefix of region/NA, it will use specific regions instead
parts := strings.Split(r.URL.Path, "/") // of the default geographical distance
if strings.HasPrefix(req.URL.Path, "/region") {
parts := strings.Split(req.URL.Path, "/")
// region = parts[2] // region = parts[2]
if mirrors, ok := regionMap[parts[2]]; ok { if mirrors, ok := r.regionMap[parts[2]]; ok {
choices := make([]randutil.Choice, len(mirrors)) choices := make([]randutil.Choice, len(mirrors))
for i, item := range mirrors { for i, item := range mirrors {
@ -68,12 +76,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
server = choice.Item.(*Server) server = choice.Item.(*Server)
r.URL.Path = strings.Join(parts[3:], "/") req.URL.Path = strings.Join(parts[3:], "/")
} }
} }
// If we don't have a scheme, we'll use http by default
scheme := req.URL.Scheme
if scheme == "" {
scheme = "http"
}
// If none of the above exceptions are matched, we use the geographical distance based on IP
if server == nil { if server == nil {
server, distance, err = servers.Closest(ip) server, distance, err = r.servers.Closest(r, scheme, ip)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -81,25 +97,24 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
scheme := r.URL.Scheme // redirectPath is a combination of server path (which can be something like /armbian)
// and the URL path.
// Example: /armbian + /some/path = /armbian/some/path
redirectPath := path.Join(server.Path, req.URL.Path)
if scheme == "" { // If we have a dlMap, we map the url to a final path instead
scheme = "https" if r.dlMap != nil {
} if newPath, exists := r.dlMap[strings.TrimLeft(req.URL.Path, "/")]; exists {
redirectPath := path.Join(server.Path, r.URL.Path)
if dlMap != nil {
if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists {
downloadsMapped.Inc() downloadsMapped.Inc()
redirectPath = path.Join(server.Path, newPath) redirectPath = path.Join(server.Path, newPath)
} }
} }
if strings.HasSuffix(r.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") { if strings.HasSuffix(req.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") {
redirectPath += "/" redirectPath += "/"
} }
// We need to build the final url now
u := &url.URL{ u := &url.URL{
Scheme: scheme, Scheme: scheme,
Host: server.Host, Host: server.Host,
@ -109,6 +124,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
server.Redirects.Inc() server.Redirects.Inc()
redirectsServed.Inc() redirectsServed.Inc()
// If we used geographical distance, we add an X-Geo-Distance header for debug.
if distance > 0 { if distance > 0 {
w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance))
} }
@ -117,8 +133,15 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
func reloadHandler(w http.ResponseWriter, r *http.Request) { // reloadHandler is an http handler which lets us reload the server configuration
token := r.Header.Get("Authorization") // It is only enabled when the reloadToken is set in the configuration
func (r *Redirector) reloadHandler(w http.ResponseWriter, req *http.Request) {
if r.config.ReloadToken == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
token := req.Header.Get("Authorization")
if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") { if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
@ -127,30 +150,34 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) {
token = token[strings.Index(token, " ")+1:] token = token[strings.Index(token, " ")+1:]
if token != viper.GetString("reloadToken") { if token != r.config.ReloadToken {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
reloadConfig() if err := r.ReloadConfig(); err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte("OK")) w.Write([]byte("OK"))
} }
func dlMapHandler(w http.ResponseWriter, r *http.Request) { func (r *Redirector) dlMapHandler(w http.ResponseWriter, req *http.Request) {
if dlMap == nil { if r.dlMap == nil {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(dlMap) json.NewEncoder(w).Encode(r.dlMap)
} }
func geoIPHandler(w http.ResponseWriter, r *http.Request) { func (r *Redirector) geoIPHandler(w http.ResponseWriter, req *http.Request) {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr) ipStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -160,7 +187,7 @@ func geoIPHandler(w http.ResponseWriter, r *http.Request) {
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
var city City var city City
err = db.Lookup(ip, &city) err = r.db.Lookup(ip, &city)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

148
main.go
View File

@ -1,148 +0,0 @@
package main
import (
"flag"
"github.com/chi-middleware/logrus-logger"
"github.com/go-chi/chi/v5"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"net/http"
"os"
"os/signal"
"syscall"
)
var (
db *maxminddb.Reader
servers ServerList
regionMap map[string][]*Server
hostMap map[string]*Server
dlMap map[string]string
topChoices int
redirectsServed = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_redirects",
Help: "The total number of processed redirects",
})
downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_download_maps",
Help: "The total number of mapped download paths",
})
serverCache *lru.Cache
)
type LocationLookup struct {
Location struct {
Latitude float64 `maxminddb:"latitude"`
Longitude float64 `maxminddb:"longitude"`
} `maxminddb:"location"`
}
// City represents a MaxmindDB city
type City struct {
Continent struct {
Code string `maxminddb:"code" json:"code"`
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"continent" json:"continent"`
Country struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"country" json:"country"`
Location struct {
AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'`
Latitude float64 `maxminddb:"latitude" json:"latitude"`
Longitude float64 `maxminddb:"longitude" json:"longitude"`
} `maxminddb:"location"`
RegisteredCountry struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"registered_country" json:"registered_country"`
}
type ServerConfig struct {
Server string `mapstructure:"server" yaml:"server"`
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
Continent string `mapstructure:"continent"`
Weight int `mapstructure:"weight" yaml:"weight"`
}
var (
configFlag = flag.String("config", "", "configuration file path")
flagDebug = flag.Bool("debug", false, "Enable debug logging")
)
func main() {
flag.Parse()
if *flagDebug {
log.SetLevel(log.DebugLevel)
}
viper.SetDefault("bind", ":8080")
viper.SetDefault("cacheSize", 1024)
viper.SetDefault("topChoices", 3)
viper.SetDefault("reloadKey", randSeq(32))
viper.SetConfigName("dlrouter") // name of config file (without extension)
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
viper.AddConfigPath(".") // optionally look for config in the working directory
if *configFlag != "" {
viper.SetConfigFile(*configFlag)
}
reloadConfig()
// Start check loop
go servers.checkLoop()
log.Info("Starting")
r := chi.NewRouter()
r.Use(RealIPMiddleware)
r.Use(logger.Logger("router", log.StandardLogger()))
r.Head("/status", statusHandler)
r.Get("/status", statusHandler)
r.Get("/mirrors", legacyMirrorsHandler)
r.Get("/mirrors/{server}.svg", mirrorStatusHandler)
r.Get("/mirrors.json", mirrorsHandler)
r.Post("/reload", reloadHandler)
r.Get("/dl_map", dlMapHandler)
r.Get("/geoip", geoIPHandler)
r.Get("/metrics", promhttp.Handler().ServeHTTP)
r.NotFound(redirectHandler)
go http.ListenAndServe(viper.GetString("bind"), r)
log.Info("Ready")
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP)
for {
sig := <-c
if sig != syscall.SIGHUP {
break
}
reloadConfig()
}
}

10
map.go
View File

@ -1,4 +1,4 @@
package main package redirector
import ( import (
"encoding/csv" "encoding/csv"
@ -7,7 +7,8 @@ import (
"strings" "strings"
) )
func loadMap(file string) (map[string]string, error) { // loadMapFile loads a file as a map
func loadMapFile(file string) (map[string]string, error) {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
@ -16,6 +17,11 @@ func loadMap(file string) (map[string]string, error) {
defer f.Close() defer f.Close()
return loadMap(f)
}
// loadMap loads a pipe separated file of mappings
func loadMap(f io.Reader) (map[string]string, error) {
m := make(map[string]string) m := make(map[string]string)
r := csv.NewReader(f) r := csv.NewReader(f)

16
map_test.go Normal file
View File

@ -0,0 +1,16 @@
package redirector
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"strings"
)
var _ = Describe("Map", func() {
It("Should successfully load the map", func() {
m, err := loadMap(strings.NewReader(`bananapi/Bullseye_current|bananapi/archive/Armbian_21.08.1_Bananapi_bullseye_current_5.10.60.img.xz|Aug 26 2021|332M`))
Expect(err).To(BeNil())
Expect(m["bananapi/Bullseye_current"]).To(Equal("bananapi/archive/Armbian_21.08.1_Bananapi_bullseye_current_5.10.60.img.xz"))
})
})

View File

@ -1,4 +1,4 @@
package main package middleware
import ( import (
"net" "net"

View File

@ -1,23 +1,26 @@
package main package redirector
import ( import (
_ "embed" _ "embed"
"encoding/json" "encoding/json"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"net/http" "net/http"
"strconv"
"strings" "strings"
) )
func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { // legacyMirrorsHandler will list the mirrors by region in the legacy format
// it is preferred to use mirrors.json, but this handler is here for build support
func (r *Redirector) legacyMirrorsHandler(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
mirrorOutput := make(map[string][]string) mirrorOutput := make(map[string][]string)
for region, mirrors := range regionMap { for region, mirrors := range r.regionMap {
list := make([]string, len(mirrors)) list := make([]string, len(mirrors))
for i, mirror := range mirrors { for i, mirror := range mirrors {
list[i] = r.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/") list[i] = req.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/")
} }
mirrorOutput[region] = list mirrorOutput[region] = list
@ -26,9 +29,10 @@ func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(mirrorOutput) json.NewEncoder(w).Encode(mirrorOutput)
} }
func mirrorsHandler(w http.ResponseWriter, r *http.Request) { // mirrorsHandler is a simple handler that will return the list of servers
func (r *Redirector) mirrorsHandler(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(servers) json.NewEncoder(w).Encode(r.servers)
} }
var ( var (
@ -42,10 +46,13 @@ var (
statusUnknown []byte statusUnknown []byte
) )
func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) { // mirrorStatusHandler is a fancy svg-returning handler.
serverHost := chi.URLParam(r, "server") // it is used to display mirror statuses on a config repo of sorts
func (r *Redirector) mirrorStatusHandler(w http.ResponseWriter, req *http.Request) {
serverHost := chi.URLParam(req, "server")
w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8") w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8")
w.Header().Set("Cache-Control", "max-age=120")
if serverHost == "" { if serverHost == "" {
w.Write(statusUnknown) w.Write(statusUnknown)
@ -54,16 +61,34 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
serverHost = strings.Replace(serverHost, "_", ".", -1) serverHost = strings.Replace(serverHost, "_", ".", -1)
server, ok := hostMap[serverHost] server, ok := r.hostMap[serverHost]
if !ok { if !ok {
w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown)))
w.Write(statusUnknown) w.Write(statusUnknown)
return return
} }
key := "offline"
if server.Available { if server.Available {
key = "online"
}
w.Header().Set("ETag", "\""+key+"\"")
if match := req.Header.Get("If-None-Match"); match != "" {
if strings.Trim(match, "\"") == key {
w.WriteHeader(http.StatusNotModified)
return
}
}
if server.Available {
w.Header().Set("Content-Length", strconv.Itoa(len(statusUp)))
w.Write(statusUp) w.Write(statusUp)
} else { } else {
w.Header().Set("Content-Length", strconv.Itoa(len(statusDown)))
w.Write(statusDown) w.Write(statusDown)
} }
} }

131
redirector.go Normal file
View File

@ -0,0 +1,131 @@
package redirector
import (
"github.com/armbian/redirector/middleware"
"github.com/chi-middleware/logrus-logger"
"github.com/go-chi/chi/v5"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"net/http"
)
var (
redirectsServed = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_redirects",
Help: "The total number of processed redirects",
})
downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_download_maps",
Help: "The total number of mapped download paths",
})
)
type Redirector struct {
config *Config
db *maxminddb.Reader
servers ServerList
regionMap map[string][]*Server
hostMap map[string]*Server
dlMap map[string]string
topChoices int
serverCache *lru.Cache
checks []ServerCheck
checkClient *http.Client
}
type LocationLookup struct {
Location struct {
Latitude float64 `maxminddb:"latitude"`
Longitude float64 `maxminddb:"longitude"`
} `maxminddb:"location"`
}
// City represents a MaxmindDB city
type City struct {
Continent struct {
Code string `maxminddb:"code" json:"code"`
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"continent" json:"continent"`
Country struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"country" json:"country"`
Location struct {
AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'`
Latitude float64 `maxminddb:"latitude" json:"latitude"`
Longitude float64 `maxminddb:"longitude" json:"longitude"`
} `maxminddb:"location"`
RegisteredCountry struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"registered_country" json:"registered_country"`
}
type ServerConfig struct {
Server string `mapstructure:"server" yaml:"server"`
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
Continent string `mapstructure:"continent"`
Weight int `mapstructure:"weight" yaml:"weight"`
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
}
// New creates a new instance of Redirector
func New(config *Config) *Redirector {
r := &Redirector{
config: config,
}
r.checks = []ServerCheck{
r.checkHttp("http"),
r.checkTLS,
}
return r
}
func (r *Redirector) Start() http.Handler {
if err := r.ReloadConfig(); err != nil {
log.WithError(err).Fatalln("Unable to load configuration")
}
log.Info("Starting check loop")
// Start check loop
go r.servers.checkLoop(r.checks)
log.Info("Setting up routes")
router := chi.NewRouter()
router.Use(middleware.RealIPMiddleware)
router.Use(logger.Logger("router", log.StandardLogger()))
router.Head("/status", r.statusHandler)
router.Get("/status", r.statusHandler)
router.Get("/mirrors", r.legacyMirrorsHandler)
router.Get("/mirrors/{server}.svg", r.mirrorStatusHandler)
router.Get("/mirrors.json", r.mirrorsHandler)
router.Post("/reload", r.reloadHandler)
router.Get("/dl_map", r.dlMapHandler)
router.Get("/geoip", r.geoIPHandler)
router.Get("/metrics", promhttp.Handler().ServeHTTP)
router.NotFound(r.redirectHandler)
if r.config.BindAddress != "" {
log.WithField("bind", r.config.BindAddress).Info("Binding to address")
go http.ListenAndServe(r.config.BindAddress, router)
}
return router
}

View File

@ -1,4 +1,4 @@
package main package redirector
import ( import (
"github.com/jmcvetta/randutil" "github.com/jmcvetta/randutil"
@ -6,26 +6,12 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math" "math"
"net" "net"
"net/http"
"sort" "sort"
"sync" "sync"
"time" "time"
) )
var ( // Server represents a download server
checkClient = &http.Client{
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
checks = []serverCheck{
checkHttp,
checkTLS,
}
)
type Server struct { type Server struct {
Available bool `json:"available"` Available bool `json:"available"`
Host string `json:"host"` Host string `json:"host"`
@ -34,14 +20,15 @@ type Server struct {
Longitude float64 `json:"longitude"` Longitude float64 `json:"longitude"`
Weight int `json:"weight"` Weight int `json:"weight"`
Continent string `json:"continent"` Continent string `json:"continent"`
Protocols ProtocolList `json:"protocols"`
Redirects prometheus.Counter `json:"-"` Redirects prometheus.Counter `json:"-"`
LastChange time.Time `json:"lastChange"` LastChange time.Time `json:"lastChange"`
} }
type serverCheck func(server *Server, logFields log.Fields) (bool, error) type ServerCheck func(server *Server, logFields log.Fields) (bool, error)
// checkStatus runs all status checks against a server // checkStatus runs all status checks against a server
func (server *Server) checkStatus() { func (server *Server) checkStatus(checks []ServerCheck) {
logFields := log.Fields{ logFields := log.Fields{
"host": server.Host, "host": server.Host,
} }
@ -83,19 +70,19 @@ func (server *Server) checkStatus() {
type ServerList []*Server type ServerList []*Server
func (s ServerList) checkLoop() { func (s ServerList) checkLoop(checks []ServerCheck) {
t := time.NewTicker(60 * time.Second) t := time.NewTicker(60 * time.Second)
for { for {
<-t.C <-t.C
s.Check() s.Check(checks)
} }
} }
// Check will request the index from all servers // Check will request the index from all servers
// If a server does not respond in 10 seconds, it is considered offline. // If a server does not respond in 10 seconds, it is considered offline.
// This will wait until all checks are complete. // This will wait until all checks are complete.
func (s ServerList) Check() { func (s ServerList) Check(checks []ServerCheck) {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, server := range s { for _, server := range s {
@ -104,7 +91,7 @@ func (s ServerList) Check() {
go func(server *Server) { go func(server *Server) {
defer wg.Done() defer wg.Done()
server.checkStatus() server.checkStatus(checks)
}(server) }(server)
} }
@ -123,12 +110,12 @@ type DistanceList []ComputedDistance
// Closest will use GeoIP on the IP provided and find the closest servers. // Closest will use GeoIP on the IP provided and find the closest servers.
// When we have a list of x servers closest, we can choose a random or weighted one. // When we have a list of x servers closest, we can choose a random or weighted one.
// Return values are the closest server, the distance, and if an error occurred. // Return values are the closest server, the distance, and if an error occurred.
func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, float64, error) {
choiceInterface, exists := serverCache.Get(ip.String()) choiceInterface, exists := r.serverCache.Get(scheme + "_" + ip.String())
if !exists { if !exists {
var city LocationLookup var city LocationLookup
err := db.Lookup(ip, &city) err := r.db.Lookup(ip, &city)
if err != nil { if err != nil {
return nil, -1, err return nil, -1, err
@ -137,7 +124,7 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
c := make(DistanceList, len(s)) c := make(DistanceList, len(s))
for i, server := range s { for i, server := range s {
if !server.Available { if !server.Available || !server.Protocols.Contains(scheme) {
continue continue
} }
@ -154,9 +141,9 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
return c[i].Distance < c[j].Distance return c[i].Distance < c[j].Distance
}) })
choiceCount := topChoices choiceCount := r.config.TopChoices
if len(c) < topChoices { if len(c) < r.config.TopChoices {
choiceCount = len(c) choiceCount = len(c)
} }
@ -175,7 +162,7 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
choiceInterface = choices choiceInterface = choices
serverCache.Add(ip.String(), choiceInterface) r.serverCache.Add(scheme+"_"+ip.String(), choiceInterface)
} }
choice, err := randutil.WeightedChoice(choiceInterface.([]randutil.Choice)) choice, err := randutil.WeightedChoice(choiceInterface.([]randutil.Choice))
@ -188,9 +175,9 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
if !dist.Server.Available { if !dist.Server.Available {
// Choose a new server and refresh cache // Choose a new server and refresh cache
serverCache.Remove(ip.String()) r.serverCache.Remove(scheme + "_" + ip.String())
return s.Closest(ip) return s.Closest(r, scheme, ip)
} }
return dist.Server, dist.Distance, nil return dist.Server, dist.Distance, nil
@ -202,6 +189,7 @@ func hsin(theta float64) float64 {
} }
// Distance function returns the distance (in meters) between two points of // Distance function returns the distance (in meters) between two points of
//
// a given longitude and latitude relatively accurately (using a spherical // a given longitude and latitude relatively accurately (using a spherical
// approximation of the Earth) through the Haversine Distance Formula for // approximation of the Earth) through the Haversine Distance Formula for
// great arc distance on a sphere with accuracy for small distances // great arc distance on a sphere with accuracy for small distances

View File

@ -1,10 +1,10 @@
package main package redirector
import "math/rand" import "math/rand"
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string { func RandomSequence(n int) string {
b := make([]rune, n) b := make([]rune, n)
for i := range b { for i := range b {
b[i] = letters[rand.Intn(len(letters))] b[i] = letters[rand.Intn(len(letters))]

46
util/certificates.go Normal file
View File

@ -0,0 +1,46 @@
package util
import (
"crypto/x509"
"github.com/gwatts/rootcerts/certparse"
log "github.com/sirupsen/logrus"
"net/http"
)
const (
defaultDownloadURL = "https://github.com/mozilla/gecko-dev/blob/master/security/nss/lib/ckfw/builtins/certdata.txt?raw=true"
)
func LoadCACerts() (*x509.CertPool, error) {
res, err := http.Get(defaultDownloadURL)
if err != nil {
return nil, err
}
defer res.Body.Close()
certs, err := certparse.ReadTrustedCerts(res.Body)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
var count int
for _, cert := range certs {
if cert.Trust&certparse.ServerTrustedDelegator == 0 {
continue
}
count++
pool.AddCert(cert.Cert)
}
log.WithField("certs", count).Info("Loaded root cas")
return pool, nil
}