Massive refactoring, struct cleanup, supporting more features
continuous-integration/drone/push Build is failing
Details
continuous-integration/drone/push Build is failing
Details
Features: - Protocol lists (http, https), managed by http responses - Working TLS Checks - Root certificate parsing for TLS checks - Moving configuration into a Config struct, no more direct viper access
This commit is contained in:
parent
3e7782e5ec
commit
e7236b13de
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
96
check.go
96
check.go
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -10,6 +10,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,8 +19,14 @@ var (
|
||||||
ErrCertExpired = errors.New("certificate is expired")
|
ErrCertExpired = errors.New("certificate is expired")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (r *Redirector) checkHttp(scheme string) ServerCheck {
|
||||||
|
return func(server *Server, logFields log.Fields) (bool, error) {
|
||||||
|
return r.checkHttpScheme(server, scheme, logFields)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// checkHttp checks a URL for validity, and checks redirects
|
// checkHttp checks a URL for validity, and checks redirects
|
||||||
func checkHttp(server *Server, logFields log.Fields) (bool, error) {
|
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
Host: server.Host,
|
Host: server.Host,
|
||||||
|
@ -48,13 +55,20 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) {
|
||||||
|
|
||||||
logFields["url"] = location
|
logFields["url"] = location
|
||||||
|
|
||||||
// Check that we don't redirect to https from a http url
|
switch u.Scheme {
|
||||||
if u.Scheme == "http" {
|
case "http":
|
||||||
res, err := checkRedirect(location)
|
res, err := r.checkRedirect(u.Scheme, location)
|
||||||
|
|
||||||
if !res || err != nil {
|
if !res || err != nil {
|
||||||
return res, err
|
// If we don't support http, we remove it from supported protocols
|
||||||
|
server.Protocols = server.Protocols.Remove("http")
|
||||||
|
} else {
|
||||||
|
// Otherwise, we verify https support
|
||||||
|
r.checkProtocol(server, "https")
|
||||||
}
|
}
|
||||||
|
case "https":
|
||||||
|
// We don't want to allow downgrading, so this is an error.
|
||||||
|
return r.checkRedirect(u.Scheme, location)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,8 +80,20 @@ func checkHttp(server *Server, logFields log.Fields) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Redirector) checkProtocol(server *Server, scheme string) {
|
||||||
|
res, err := r.checkHttpScheme(server, scheme, log.Fields{})
|
||||||
|
|
||||||
|
if !res || err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !server.Protocols.Contains(scheme) {
|
||||||
|
server.Protocols = server.Protocols.Append(scheme)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// checkRedirect parses a location header response and checks the scheme
|
// checkRedirect parses a location header response and checks the scheme
|
||||||
func checkRedirect(locationHeader string) (bool, error) {
|
func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
|
||||||
newUrl, err := url.Parse(locationHeader)
|
newUrl, err := url.Parse(locationHeader)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -76,20 +102,41 @@ func checkRedirect(locationHeader string) (bool, error) {
|
||||||
|
|
||||||
if newUrl.Scheme == "https" {
|
if newUrl.Scheme == "https" {
|
||||||
return false, ErrHttpsRedirect
|
return false, ErrHttpsRedirect
|
||||||
|
} else if originatingScheme == "https" && newUrl.Scheme == "https" {
|
||||||
|
return false, ErrHttpsRedirect
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
|
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
|
||||||
func checkTLS(server *Server, logFields log.Fields) (bool, error) {
|
func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) {
|
||||||
host, port, err := net.SplitHostPort(server.Host)
|
var host, port string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if strings.Contains(server.Host, ":") {
|
||||||
|
host, port, err = net.SplitHostPort(server.Host)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
host = server.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"server": server.Host,
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
}).Info("Checking TLS server")
|
||||||
|
|
||||||
if port == "" {
|
if port == "" {
|
||||||
port = "443"
|
port = "443"
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := tls.Dial("tcp", host+":"+port, checkTLSConfig)
|
conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{
|
||||||
|
RootCAs: r.config.RootCAs,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -107,19 +154,39 @@ func checkTLS(server *Server, logFields log.Fields) (bool, error) {
|
||||||
|
|
||||||
state := conn.ConnectionState()
|
state := conn.ConnectionState()
|
||||||
|
|
||||||
|
peerPool := x509.NewCertPool()
|
||||||
|
|
||||||
|
for _, intermediate := range state.PeerCertificates {
|
||||||
|
if !intermediate.IsCA {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerPool.AddCert(intermediate)
|
||||||
|
}
|
||||||
|
|
||||||
opts := x509.VerifyOptions{
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: r.config.RootCAs,
|
||||||
|
Intermediates: peerPool,
|
||||||
CurrentTime: time.Now(),
|
CurrentTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, cert := range state.PeerCertificates {
|
// We want only the leaf certificate, as this will verify up the chain for us.
|
||||||
|
cert := state.PeerCertificates[0]
|
||||||
|
|
||||||
if _, err := cert.Verify(opts); err != nil {
|
if _, err := cert.Verify(opts); err != nil {
|
||||||
logFields["peerCert"] = cert.Subject.String()
|
logFields["peerCert"] = cert.Subject.String()
|
||||||
|
|
||||||
|
if authErr, ok := err.(x509.UnknownAuthorityError); ok {
|
||||||
|
logFields["authCert"] = authErr.Cert.Subject.String()
|
||||||
|
logFields["ca"] = authErr.Cert.Issuer
|
||||||
|
}
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
|
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
|
||||||
|
logFields["peerCert"] = cert.Subject.String()
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range state.VerifiedChains {
|
for _, chain := range state.VerifiedChains {
|
||||||
for _, cert := range chain {
|
for _, cert := range chain {
|
||||||
|
@ -130,5 +197,10 @@ func checkTLS(server *Server, logFields log.Fields) (bool, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If https is valid, append it
|
||||||
|
if !server.Protocols.Contains("https") {
|
||||||
|
server.Protocols = server.Protocols.Append("https")
|
||||||
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -58,11 +58,15 @@ var _ = Describe("Check suite", func() {
|
||||||
httpServer *httptest.Server
|
httpServer *httptest.Server
|
||||||
server *Server
|
server *Server
|
||||||
handler http.HandlerFunc
|
handler http.HandlerFunc
|
||||||
|
r *Redirector
|
||||||
)
|
)
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
}))
|
}))
|
||||||
|
r = New(&Config{
|
||||||
|
RootCAs: x509.NewCertPool(),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
httpServer.Close()
|
httpServer.Close()
|
||||||
|
@ -89,7 +93,7 @@ var _ = Describe("Check suite", func() {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := checkHttp(server, log.Fields{})
|
res, err := r.checkHttpScheme(server, "http", log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeTrue())
|
Expect(res).To(BeTrue())
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
|
@ -100,7 +104,7 @@ var _ = Describe("Check suite", func() {
|
||||||
w.WriteHeader(http.StatusMovedPermanently)
|
w.WriteHeader(http.StatusMovedPermanently)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := checkHttp(server, log.Fields{})
|
res, err := r.checkHttpScheme(server, "http", log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
Expect(err).To(Equal(ErrHttpsRedirect))
|
Expect(err).To(Equal(ErrHttpsRedirect))
|
||||||
|
@ -141,7 +145,7 @@ var _ = Describe("Check suite", func() {
|
||||||
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
||||||
})
|
})
|
||||||
It("Should fail due to invalid ca", func() {
|
It("Should fail due to invalid ca", func() {
|
||||||
res, err := checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
Expect(err).ToNot(BeNil())
|
Expect(err).ToNot(BeNil())
|
||||||
|
@ -151,20 +155,15 @@ var _ = Describe("Check suite", func() {
|
||||||
|
|
||||||
pool.AddCert(x509Cert)
|
pool.AddCert(x509Cert)
|
||||||
|
|
||||||
checkTLSConfig = &tls.Config{RootCAs: pool}
|
r.config.RootCAs = pool
|
||||||
|
|
||||||
res, err := checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
Expect(err).ToNot(BeNil())
|
Expect(err).ToNot(BeNil())
|
||||||
|
|
||||||
checkTLSConfig = nil
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
Context("Expiration tests", func() {
|
Context("Expiration tests", func() {
|
||||||
AfterEach(func() {
|
|
||||||
checkTLSConfig = nil
|
|
||||||
})
|
|
||||||
It("Should fail due to not yet valid certificate", func() {
|
It("Should fail due to not yet valid certificate", func() {
|
||||||
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
|
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
|
||||||
|
|
||||||
|
@ -173,10 +172,10 @@ var _ = Describe("Check suite", func() {
|
||||||
|
|
||||||
pool.AddCert(x509Cert)
|
pool.AddCert(x509Cert)
|
||||||
|
|
||||||
checkTLSConfig = &tls.Config{RootCAs: pool}
|
r.config.RootCAs = pool
|
||||||
|
|
||||||
// Check TLS
|
// Check TLS
|
||||||
res, err := checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
Expect(err).ToNot(BeNil())
|
Expect(err).ToNot(BeNil())
|
||||||
|
@ -189,10 +188,10 @@ var _ = Describe("Check suite", func() {
|
||||||
|
|
||||||
pool.AddCert(x509Cert)
|
pool.AddCert(x509Cert)
|
||||||
|
|
||||||
checkTLSConfig = &tls.Config{RootCAs: pool}
|
r.config.RootCAs = pool
|
||||||
|
|
||||||
// Check TLS
|
// Check TLS
|
||||||
res, err := checkTLS(server, log.Fields{})
|
res, err := r.checkTLS(server, log.Fields{})
|
||||||
|
|
||||||
Expect(res).To(BeFalse())
|
Expect(res).To(BeFalse())
|
||||||
Expect(err).ToNot(BeNil())
|
Expect(err).ToNot(BeNil())
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"github.com/armbian/redirector"
|
||||||
|
"github.com/armbian/redirector/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
configFlag = flag.String("config", "", "configuration file path")
|
||||||
|
flagDebug = flag.Bool("debug", false, "Enable debug logging")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *flagDebug {
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
viper.SetDefault("bind", ":8080")
|
||||||
|
viper.SetDefault("cacheSize", 1024)
|
||||||
|
viper.SetDefault("topChoices", 3)
|
||||||
|
viper.SetDefault("reloadKey", redirector.RandomSequence(32))
|
||||||
|
|
||||||
|
viper.SetConfigName("dlrouter") // name of config file (without extension)
|
||||||
|
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
|
||||||
|
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
|
||||||
|
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
|
||||||
|
viper.AddConfigPath(".") // optionally look for config in the working directory
|
||||||
|
|
||||||
|
if *configFlag != "" {
|
||||||
|
viper.SetConfigFile(*configFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &redirector.Config{}
|
||||||
|
|
||||||
|
loadConfig := func(fatal bool) {
|
||||||
|
log.Info("Reading configuration")
|
||||||
|
|
||||||
|
// Bind reload to reading in the viper config, then deserializing
|
||||||
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
|
log.WithError(err).Error("Unable to unmarshal configuration")
|
||||||
|
|
||||||
|
if fatal {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Unmarshalling configuration")
|
||||||
|
|
||||||
|
if err := viper.Unmarshal(config); err != nil {
|
||||||
|
log.WithError(err).Error("Unable to unmarshal configuration")
|
||||||
|
|
||||||
|
if fatal {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Updating root certificates")
|
||||||
|
|
||||||
|
certs, err := util.LoadCACerts()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Unable to load certificates")
|
||||||
|
|
||||||
|
if fatal {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.RootCAs = certs
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ReloadFunc = func() {
|
||||||
|
loadConfig(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
loadConfig(true)
|
||||||
|
|
||||||
|
redir := redirector.New(config)
|
||||||
|
|
||||||
|
// Because we have a bind address, we can start it without the return value.
|
||||||
|
redir.Start()
|
||||||
|
|
||||||
|
log.Info("Ready")
|
||||||
|
|
||||||
|
c := make(chan os.Signal)
|
||||||
|
|
||||||
|
signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP)
|
||||||
|
|
||||||
|
for {
|
||||||
|
sig := <-c
|
||||||
|
|
||||||
|
if sig != syscall.SIGHUP {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
loadConfig(false)
|
||||||
|
|
||||||
|
err := redir.ReloadConfig()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warning("Did not reload configuration due to error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
151
config.go
151
config.go
|
@ -1,109 +1,149 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
lru "github.com/hashicorp/golang-lru"
|
lru "github.com/hashicorp/golang-lru"
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/viper"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
func reloadConfig() error {
|
type Config struct {
|
||||||
log.Info("Loading configuration...")
|
BindAddress string `mapstructure:"bind"`
|
||||||
|
GeoDBPath string `mapstructure:"geodb"`
|
||||||
err := viper.ReadInConfig() // Find and read the config file
|
MapFile string `mapstructure:"dl_map"`
|
||||||
|
CacheSize int `mapstructure:"cacheSize"`
|
||||||
if err != nil { // Handle errors reading the config file
|
TopChoices int `mapstructure:"topChoices"`
|
||||||
return errors.Wrap(err, "Unable to read configuration")
|
ReloadToken string `mapstructure:"reloadToken"`
|
||||||
|
ServerList []ServerConfig `mapstructure:"servers"`
|
||||||
|
ReloadFunc func()
|
||||||
|
RootCAs *x509.CertPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// db will never be reloaded.
|
type ProtocolList []string
|
||||||
if db == nil {
|
|
||||||
|
func (p ProtocolList) Contains(value string) bool {
|
||||||
|
for _, val := range p {
|
||||||
|
if value == val {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p ProtocolList) Append(value string) ProtocolList {
|
||||||
|
return append(p, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p ProtocolList) Remove(value string) ProtocolList {
|
||||||
|
index := -1
|
||||||
|
|
||||||
|
for i, val := range p {
|
||||||
|
if value == val {
|
||||||
|
index = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if index == -1 {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
p[index] = p[len(p)-1]
|
||||||
|
return p[:len(p)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Redirector) ReloadConfig() error {
|
||||||
|
log.Info("Loading configuration...")
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
// Load maxmind database
|
// Load maxmind database
|
||||||
db, err = maxminddb.Open(viper.GetString("geodb"))
|
if r.db != nil {
|
||||||
|
err = r.db.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Unable to close database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// db can be hot-reloaded if the file changed
|
||||||
|
r.db, err = maxminddb.Open(r.config.GeoDBPath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Unable to open database")
|
return errors.Wrap(err, "Unable to open database")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh server cache if size changed
|
// Refresh server cache if size changed
|
||||||
if serverCache == nil {
|
if r.serverCache == nil {
|
||||||
serverCache, err = lru.New(viper.GetInt("cacheSize"))
|
r.serverCache, err = lru.New(r.config.CacheSize)
|
||||||
} else {
|
} else {
|
||||||
serverCache.Resize(viper.GetInt("cacheSize"))
|
r.serverCache.Resize(r.config.CacheSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purge the cache to ensure we don't have any invalid servers in it
|
// Purge the cache to ensure we don't have any invalid servers in it
|
||||||
serverCache.Purge()
|
r.serverCache.Purge()
|
||||||
|
|
||||||
// Set top choice count
|
|
||||||
topChoices = viper.GetInt("topChoices")
|
|
||||||
|
|
||||||
// Reload map file
|
// Reload map file
|
||||||
if err := reloadMap(); err != nil {
|
if err := r.reloadMap(); err != nil {
|
||||||
return errors.Wrap(err, "Unable to load map file")
|
return errors.Wrap(err, "Unable to load map file")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload server list
|
// Reload server list
|
||||||
if err := reloadServers(); err != nil {
|
if err := r.reloadServers(); err != nil {
|
||||||
return errors.Wrap(err, "Unable to load servers")
|
return errors.Wrap(err, "Unable to load servers")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create mirror map
|
// Create mirror map
|
||||||
mirrors := make(map[string][]*Server)
|
mirrors := make(map[string][]*Server)
|
||||||
|
|
||||||
for _, server := range servers {
|
for _, server := range r.servers {
|
||||||
mirrors[server.Continent] = append(mirrors[server.Continent], server)
|
mirrors[server.Continent] = append(mirrors[server.Continent], server)
|
||||||
}
|
}
|
||||||
|
|
||||||
mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...)
|
mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...)
|
||||||
|
|
||||||
regionMap = mirrors
|
r.regionMap = mirrors
|
||||||
|
|
||||||
hosts := make(map[string]*Server)
|
hosts := make(map[string]*Server)
|
||||||
|
|
||||||
for _, server := range servers {
|
for _, server := range r.servers {
|
||||||
hosts[server.Host] = server
|
hosts[server.Host] = server
|
||||||
}
|
}
|
||||||
|
|
||||||
hostMap = hosts
|
r.hostMap = hosts
|
||||||
|
|
||||||
// Check top choices size
|
// Check top choices size
|
||||||
if topChoices > len(servers) {
|
if r.config.TopChoices > len(r.servers) {
|
||||||
topChoices = len(servers)
|
r.config.TopChoices = len(r.servers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force check
|
// Force check
|
||||||
go servers.Check()
|
go r.servers.Check(r.checks)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadServers() error {
|
func (r *Redirector) reloadServers() error {
|
||||||
var serverList []ServerConfig
|
log.WithField("count", len(r.config.ServerList)).Info("Loading servers")
|
||||||
|
|
||||||
if err := viper.UnmarshalKey("servers", &serverList); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
existing := make(map[string]int)
|
existing := make(map[string]int)
|
||||||
|
|
||||||
for i, server := range servers {
|
for i, server := range r.servers {
|
||||||
existing[server.Host] = i
|
existing[server.Host] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts := make(map[string]bool)
|
hosts := make(map[string]bool)
|
||||||
|
|
||||||
for _, server := range serverList {
|
for _, server := range r.config.ServerList {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
var prefix string
|
var prefix string
|
||||||
|
@ -133,19 +173,19 @@ func reloadServers() error {
|
||||||
go func(i int, server ServerConfig, u *url.URL) {
|
go func(i int, server ServerConfig, u *url.URL) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
s := addServer(server, u)
|
s := r.addServer(server, u)
|
||||||
|
|
||||||
if _, ok := existing[u.Host]; ok {
|
if _, ok := existing[u.Host]; ok {
|
||||||
s.Redirects = servers[i].Redirects
|
s.Redirects = r.servers[i].Redirects
|
||||||
|
|
||||||
servers[i] = s
|
r.servers[i] = s
|
||||||
} else {
|
} else {
|
||||||
s.Redirects = promauto.NewCounter(prometheus.CounterOpts{
|
s.Redirects = promauto.NewCounter(prometheus.CounterOpts{
|
||||||
Name: "armbian_router_redirects_" + metricReplacer.Replace(u.Host),
|
Name: "armbian_router_redirects_" + metricReplacer.Replace(u.Host),
|
||||||
Help: "The number of redirects for server " + u.Host,
|
Help: "The number of redirects for server " + u.Host,
|
||||||
})
|
})
|
||||||
|
|
||||||
servers = append(servers, s)
|
r.servers = append(r.servers, s)
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"server": u.Host,
|
"server": u.Host,
|
||||||
|
@ -160,16 +200,16 @@ func reloadServers() error {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Remove servers that no longer exist in the config
|
// Remove servers that no longer exist in the config
|
||||||
for i := len(servers) - 1; i >= 0; i-- {
|
for i := len(r.servers) - 1; i >= 0; i-- {
|
||||||
if _, exists := hosts[servers[i].Host]; exists {
|
if _, exists := hosts[r.servers[i].Host]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"server": servers[i].Host,
|
"server": r.servers[i].Host,
|
||||||
}).Info("Removed server")
|
}).Info("Removed server")
|
||||||
|
|
||||||
servers = append(servers[:i], servers[i+1:]...)
|
r.servers = append(r.servers[:i], r.servers[i+1:]...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -179,7 +219,7 @@ var metricReplacer = strings.NewReplacer(".", "_", "-", "_")
|
||||||
|
|
||||||
// addServer takes ServerConfig and constructs a server.
|
// addServer takes ServerConfig and constructs a server.
|
||||||
// This will create duplicate servers, but it will overwrite existing ones when changed.
|
// This will create duplicate servers, but it will overwrite existing ones when changed.
|
||||||
func addServer(server ServerConfig, u *url.URL) *Server {
|
func (r *Redirector) addServer(server ServerConfig, u *url.URL) *Server {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
Available: true,
|
Available: true,
|
||||||
Host: u.Host,
|
Host: u.Host,
|
||||||
|
@ -188,6 +228,15 @@ func addServer(server ServerConfig, u *url.URL) *Server {
|
||||||
Longitude: server.Longitude,
|
Longitude: server.Longitude,
|
||||||
Continent: server.Continent,
|
Continent: server.Continent,
|
||||||
Weight: server.Weight,
|
Weight: server.Weight,
|
||||||
|
Protocols: ProtocolList{"http", "https"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(server.Protocols) > 0 {
|
||||||
|
for _, proto := range server.Protocols {
|
||||||
|
if !s.Protocols.Contains(proto) {
|
||||||
|
s.Protocols = s.Protocols.Append(proto)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Defaults to 10 to allow servers to be set lower for lower priority
|
// Defaults to 10 to allow servers to be set lower for lower priority
|
||||||
|
@ -206,7 +255,7 @@ func addServer(server ServerConfig, u *url.URL) *Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
var city City
|
var city City
|
||||||
err = db.Lookup(ips[0], &city)
|
err = r.db.Lookup(ips[0], &city)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
|
@ -229,8 +278,8 @@ func addServer(server ServerConfig, u *url.URL) *Server {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadMap() error {
|
func (r *Redirector) reloadMap() error {
|
||||||
mapFile := viper.GetString("dl_map")
|
mapFile := r.config.MapFile
|
||||||
|
|
||||||
if mapFile == "" {
|
if mapFile == "" {
|
||||||
return nil
|
return nil
|
||||||
|
@ -244,7 +293,7 @@ func reloadMap() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dlMap = newMap
|
r.dlMap = newMap
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,10 @@ servers:
|
||||||
- server: mirrors.bfsu.edu.cn/armbian/
|
- server: mirrors.bfsu.edu.cn/armbian/
|
||||||
- server: mirrors.dotsrc.org/armbian-apt/
|
- server: mirrors.dotsrc.org/armbian-apt/
|
||||||
weight: 15
|
weight: 15
|
||||||
|
protocols:
|
||||||
|
- http
|
||||||
|
- https
|
||||||
|
- rsync
|
||||||
- server: mirrors.netix.net/armbian/apt/
|
- server: mirrors.netix.net/armbian/apt/
|
||||||
- server: mirrors.nju.edu.cn/armbian/
|
- server: mirrors.nju.edu.cn/armbian/
|
||||||
- server: mirrors.sustech.edu.cn/armbian/
|
- server: mirrors.sustech.edu.cn/armbian/
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -1,10 +1,11 @@
|
||||||
module meow.tf/armbian-router
|
module github.com/armbian/redirector
|
||||||
|
|
||||||
go 1.17
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/chi-middleware/logrus-logger v0.2.0
|
github.com/chi-middleware/logrus-logger v0.2.0
|
||||||
github.com/go-chi/chi/v5 v5.0.7
|
github.com/go-chi/chi/v5 v5.0.7
|
||||||
|
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d
|
||||||
github.com/hashicorp/golang-lru v0.5.4
|
github.com/hashicorp/golang-lru v0.5.4
|
||||||
github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff
|
github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff
|
||||||
github.com/onsi/ginkgo/v2 v2.1.4
|
github.com/onsi/ginkgo/v2 v2.1.4
|
||||||
|
|
5
go.sum
5
go.sum
|
@ -128,7 +128,6 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
|
||||||
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
|
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
|
||||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
|
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
|
@ -198,7 +197,6 @@ github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLe
|
||||||
github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
|
@ -210,6 +208,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
|
||||||
github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
|
github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
|
||||||
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
|
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
|
||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||||
|
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d h1:Kp5G1kHMb2fAD9OiqWDXro4qLB8bQ2NusoorYya4Lbo=
|
||||||
|
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g=
|
||||||
github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0=
|
github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0=
|
||||||
github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms=
|
github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms=
|
||||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
|
@ -680,7 +680,6 @@ golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20=
|
|
||||||
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|
67
http.go
67
http.go
|
@ -1,10 +1,9 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jmcvetta/randutil"
|
"github.com/jmcvetta/randutil"
|
||||||
"github.com/spf13/viper"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -14,10 +13,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// statusHandler is a simple handler that will always return 200 OK with a body of "OK"
|
// statusHandler is a simple handler that will always return 200 OK with a body of "OK"
|
||||||
func statusHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) statusHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
if r.Method != http.MethodHead {
|
if req.Method != http.MethodHead {
|
||||||
w.Write([]byte("OK"))
|
w.Write([]byte("OK"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,8 +24,8 @@ func statusHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// redirectHandler is the default "not found" handler which handles redirects
|
// redirectHandler is the default "not found" handler which handles redirects
|
||||||
// if the environment variable OVERRIDE_IP is set, it will use that ip address
|
// if the environment variable OVERRIDE_IP is set, it will use that ip address
|
||||||
// this is useful for local testing when you're on the local network
|
// this is useful for local testing when you're on the local network
|
||||||
func redirectHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
ipStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
@ -50,11 +49,11 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// If the path has a prefix of region/NA, it will use specific regions instead
|
// If the path has a prefix of region/NA, it will use specific regions instead
|
||||||
// of the default geographical distance
|
// of the default geographical distance
|
||||||
if strings.HasPrefix(r.URL.Path, "/region") {
|
if strings.HasPrefix(req.URL.Path, "/region") {
|
||||||
parts := strings.Split(r.URL.Path, "/")
|
parts := strings.Split(req.URL.Path, "/")
|
||||||
|
|
||||||
// region = parts[2]
|
// region = parts[2]
|
||||||
if mirrors, ok := regionMap[parts[2]]; ok {
|
if mirrors, ok := r.regionMap[parts[2]]; ok {
|
||||||
choices := make([]randutil.Choice, len(mirrors))
|
choices := make([]randutil.Choice, len(mirrors))
|
||||||
|
|
||||||
for i, item := range mirrors {
|
for i, item := range mirrors {
|
||||||
|
@ -77,13 +76,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
server = choice.Item.(*Server)
|
server = choice.Item.(*Server)
|
||||||
|
|
||||||
r.URL.Path = strings.Join(parts[3:], "/")
|
req.URL.Path = strings.Join(parts[3:], "/")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we don't have a scheme, we'll use http by default
|
||||||
|
scheme := req.URL.Scheme
|
||||||
|
|
||||||
|
if scheme == "" {
|
||||||
|
scheme = "http"
|
||||||
|
}
|
||||||
|
|
||||||
// If none of the above exceptions are matched, we use the geographical distance based on IP
|
// If none of the above exceptions are matched, we use the geographical distance based on IP
|
||||||
if server == nil {
|
if server == nil {
|
||||||
server, distance, err = servers.Closest(ip)
|
server, distance, err = r.servers.Closest(r, scheme, ip)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
@ -91,27 +97,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we don't have a scheme, we'll use https by default
|
|
||||||
scheme := r.URL.Scheme
|
|
||||||
|
|
||||||
if scheme == "" {
|
|
||||||
scheme = "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
// redirectPath is a combination of server path (which can be something like /armbian)
|
// redirectPath is a combination of server path (which can be something like /armbian)
|
||||||
// and the URL path.
|
// and the URL path.
|
||||||
// Example: /armbian + /some/path = /armbian/some/path
|
// Example: /armbian + /some/path = /armbian/some/path
|
||||||
redirectPath := path.Join(server.Path, r.URL.Path)
|
redirectPath := path.Join(server.Path, req.URL.Path)
|
||||||
|
|
||||||
// If we have a dlMap, we map the url to a final path instead
|
// If we have a dlMap, we map the url to a final path instead
|
||||||
if dlMap != nil {
|
if r.dlMap != nil {
|
||||||
if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists {
|
if newPath, exists := r.dlMap[strings.TrimLeft(req.URL.Path, "/")]; exists {
|
||||||
downloadsMapped.Inc()
|
downloadsMapped.Inc()
|
||||||
redirectPath = path.Join(server.Path, newPath)
|
redirectPath = path.Join(server.Path, newPath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(r.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") {
|
if strings.HasSuffix(req.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") {
|
||||||
redirectPath += "/"
|
redirectPath += "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,15 +135,13 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// reloadHandler is an http handler which lets us reload the server configuration
|
// reloadHandler is an http handler which lets us reload the server configuration
|
||||||
// It is only enabled when the reloadToken is set in the configuration
|
// It is only enabled when the reloadToken is set in the configuration
|
||||||
func reloadHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) reloadHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
expectedToken := viper.GetString("reloadToken")
|
if r.config.ReloadToken == "" {
|
||||||
|
|
||||||
if expectedToken == "" {
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token := r.Header.Get("Authorization")
|
token := req.Header.Get("Authorization")
|
||||||
|
|
||||||
if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") {
|
if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
@ -153,12 +150,12 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
token = token[strings.Index(token, " ")+1:]
|
token = token[strings.Index(token, " ")+1:]
|
||||||
|
|
||||||
if token != expectedToken {
|
if token != r.config.ReloadToken {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := reloadConfig(); err != nil {
|
if err := r.ReloadConfig(); err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
w.Write([]byte(err.Error()))
|
w.Write([]byte(err.Error()))
|
||||||
return
|
return
|
||||||
|
@ -168,19 +165,19 @@ func reloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("OK"))
|
w.Write([]byte("OK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func dlMapHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) dlMapHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
if dlMap == nil {
|
if r.dlMap == nil {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(dlMap)
|
json.NewEncoder(w).Encode(r.dlMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
func geoIPHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) geoIPHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
ipStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
@ -190,7 +187,7 @@ func geoIPHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
ip := net.ParseIP(ipStr)
|
ip := net.ParseIP(ipStr)
|
||||||
|
|
||||||
var city City
|
var city City
|
||||||
err = db.Lookup(ip, &city)
|
err = r.db.Lookup(ip, &city)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
|
154
main.go
154
main.go
|
@ -1,154 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"github.com/chi-middleware/logrus-logger"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
lru "github.com/hashicorp/golang-lru"
|
|
||||||
"github.com/oschwald/maxminddb-golang"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
db *maxminddb.Reader
|
|
||||||
servers ServerList
|
|
||||||
regionMap map[string][]*Server
|
|
||||||
hostMap map[string]*Server
|
|
||||||
dlMap map[string]string
|
|
||||||
topChoices int
|
|
||||||
|
|
||||||
redirectsServed = promauto.NewCounter(prometheus.CounterOpts{
|
|
||||||
Name: "armbian_router_redirects",
|
|
||||||
Help: "The total number of processed redirects",
|
|
||||||
})
|
|
||||||
|
|
||||||
downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{
|
|
||||||
Name: "armbian_router_download_maps",
|
|
||||||
Help: "The total number of mapped download paths",
|
|
||||||
})
|
|
||||||
|
|
||||||
serverCache *lru.Cache
|
|
||||||
)
|
|
||||||
|
|
||||||
type LocationLookup struct {
|
|
||||||
Location struct {
|
|
||||||
Latitude float64 `maxminddb:"latitude"`
|
|
||||||
Longitude float64 `maxminddb:"longitude"`
|
|
||||||
} `maxminddb:"location"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// City represents a MaxmindDB city
|
|
||||||
type City struct {
|
|
||||||
Continent struct {
|
|
||||||
Code string `maxminddb:"code" json:"code"`
|
|
||||||
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
|
|
||||||
Names map[string]string `maxminddb:"names" json:"names"`
|
|
||||||
} `maxminddb:"continent" json:"continent"`
|
|
||||||
Country struct {
|
|
||||||
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
|
|
||||||
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
|
|
||||||
Names map[string]string `maxminddb:"names" json:"names"`
|
|
||||||
} `maxminddb:"country" json:"country"`
|
|
||||||
Location struct {
|
|
||||||
AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'`
|
|
||||||
Latitude float64 `maxminddb:"latitude" json:"latitude"`
|
|
||||||
Longitude float64 `maxminddb:"longitude" json:"longitude"`
|
|
||||||
} `maxminddb:"location"`
|
|
||||||
RegisteredCountry struct {
|
|
||||||
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
|
|
||||||
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
|
|
||||||
Names map[string]string `maxminddb:"names" json:"names"`
|
|
||||||
} `maxminddb:"registered_country" json:"registered_country"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerConfig struct {
|
|
||||||
Server string `mapstructure:"server" yaml:"server"`
|
|
||||||
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
|
|
||||||
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
|
|
||||||
Continent string `mapstructure:"continent"`
|
|
||||||
Weight int `mapstructure:"weight" yaml:"weight"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
configFlag = flag.String("config", "", "configuration file path")
|
|
||||||
flagDebug = flag.Bool("debug", false, "Enable debug logging")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if *flagDebug {
|
|
||||||
log.SetLevel(log.DebugLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
viper.SetDefault("bind", ":8080")
|
|
||||||
viper.SetDefault("cacheSize", 1024)
|
|
||||||
viper.SetDefault("topChoices", 3)
|
|
||||||
viper.SetDefault("reloadKey", randSeq(32))
|
|
||||||
|
|
||||||
viper.SetConfigName("dlrouter") // name of config file (without extension)
|
|
||||||
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
|
|
||||||
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
|
|
||||||
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
|
|
||||||
viper.AddConfigPath(".") // optionally look for config in the working directory
|
|
||||||
|
|
||||||
if *configFlag != "" {
|
|
||||||
viper.SetConfigFile(*configFlag)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := reloadConfig(); err != nil {
|
|
||||||
log.WithError(err).Fatalln("Unable to load configuration")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start check loop
|
|
||||||
go servers.checkLoop()
|
|
||||||
|
|
||||||
log.Info("Starting")
|
|
||||||
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
r.Use(RealIPMiddleware)
|
|
||||||
r.Use(logger.Logger("router", log.StandardLogger()))
|
|
||||||
|
|
||||||
r.Head("/status", statusHandler)
|
|
||||||
r.Get("/status", statusHandler)
|
|
||||||
r.Get("/mirrors", legacyMirrorsHandler)
|
|
||||||
r.Get("/mirrors/{server}.svg", mirrorStatusHandler)
|
|
||||||
r.Get("/mirrors.json", mirrorsHandler)
|
|
||||||
r.Post("/reload", reloadHandler)
|
|
||||||
r.Get("/dl_map", dlMapHandler)
|
|
||||||
r.Get("/geoip", geoIPHandler)
|
|
||||||
r.Get("/metrics", promhttp.Handler().ServeHTTP)
|
|
||||||
|
|
||||||
r.NotFound(redirectHandler)
|
|
||||||
|
|
||||||
go http.ListenAndServe(viper.GetString("bind"), r)
|
|
||||||
|
|
||||||
log.Info("Ready")
|
|
||||||
|
|
||||||
c := make(chan os.Signal)
|
|
||||||
|
|
||||||
signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP)
|
|
||||||
|
|
||||||
for {
|
|
||||||
sig := <-c
|
|
||||||
|
|
||||||
if sig != syscall.SIGHUP {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
err := reloadConfig()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warning("Did not reload configuration due to error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
2
map.go
2
map.go
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
20
mirrors.go
20
mirrors.go
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
@ -11,16 +11,16 @@ import (
|
||||||
|
|
||||||
// legacyMirrorsHandler will list the mirrors by region in the legacy format
|
// legacyMirrorsHandler will list the mirrors by region in the legacy format
|
||||||
// it is preferred to use mirrors.json, but this handler is here for build support
|
// it is preferred to use mirrors.json, but this handler is here for build support
|
||||||
func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) legacyMirrorsHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
mirrorOutput := make(map[string][]string)
|
mirrorOutput := make(map[string][]string)
|
||||||
|
|
||||||
for region, mirrors := range regionMap {
|
for region, mirrors := range r.regionMap {
|
||||||
list := make([]string, len(mirrors))
|
list := make([]string, len(mirrors))
|
||||||
|
|
||||||
for i, mirror := range mirrors {
|
for i, mirror := range mirrors {
|
||||||
list[i] = r.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/")
|
list[i] = req.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
mirrorOutput[region] = list
|
mirrorOutput[region] = list
|
||||||
|
@ -30,9 +30,9 @@ func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirrorsHandler is a simple handler that will return the list of servers
|
// mirrorsHandler is a simple handler that will return the list of servers
|
||||||
func mirrorsHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) mirrorsHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(servers)
|
json.NewEncoder(w).Encode(r.servers)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -48,8 +48,8 @@ var (
|
||||||
|
|
||||||
// mirrorStatusHandler is a fancy svg-returning handler.
|
// mirrorStatusHandler is a fancy svg-returning handler.
|
||||||
// it is used to display mirror statuses on a config repo of sorts
|
// it is used to display mirror statuses on a config repo of sorts
|
||||||
func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
|
func (r *Redirector) mirrorStatusHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
serverHost := chi.URLParam(r, "server")
|
serverHost := chi.URLParam(req, "server")
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8")
|
w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8")
|
||||||
w.Header().Set("Cache-Control", "max-age=120")
|
w.Header().Set("Cache-Control", "max-age=120")
|
||||||
|
@ -61,7 +61,7 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
serverHost = strings.Replace(serverHost, "_", ".", -1)
|
serverHost = strings.Replace(serverHost, "_", ".", -1)
|
||||||
|
|
||||||
server, ok := hostMap[serverHost]
|
server, ok := r.hostMap[serverHost]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown)))
|
w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown)))
|
||||||
|
@ -77,7 +77,7 @@ func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
w.Header().Set("ETag", "\""+key+"\"")
|
w.Header().Set("ETag", "\""+key+"\"")
|
||||||
|
|
||||||
if match := r.Header.Get("If-None-Match"); match != "" {
|
if match := req.Header.Get("If-None-Match"); match != "" {
|
||||||
if strings.Trim(match, "\"") == key {
|
if strings.Trim(match, "\"") == key {
|
||||||
w.WriteHeader(http.StatusNotModified)
|
w.WriteHeader(http.StatusNotModified)
|
||||||
return
|
return
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
package redirector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/armbian/redirector/middleware"
|
||||||
|
"github.com/chi-middleware/logrus-logger"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
lru "github.com/hashicorp/golang-lru"
|
||||||
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
redirectsServed = promauto.NewCounter(prometheus.CounterOpts{
|
||||||
|
Name: "armbian_router_redirects",
|
||||||
|
Help: "The total number of processed redirects",
|
||||||
|
})
|
||||||
|
|
||||||
|
downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{
|
||||||
|
Name: "armbian_router_download_maps",
|
||||||
|
Help: "The total number of mapped download paths",
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
type Redirector struct {
|
||||||
|
config *Config
|
||||||
|
db *maxminddb.Reader
|
||||||
|
servers ServerList
|
||||||
|
regionMap map[string][]*Server
|
||||||
|
hostMap map[string]*Server
|
||||||
|
dlMap map[string]string
|
||||||
|
topChoices int
|
||||||
|
serverCache *lru.Cache
|
||||||
|
checks []ServerCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocationLookup struct {
|
||||||
|
Location struct {
|
||||||
|
Latitude float64 `maxminddb:"latitude"`
|
||||||
|
Longitude float64 `maxminddb:"longitude"`
|
||||||
|
} `maxminddb:"location"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// City represents a MaxmindDB city
|
||||||
|
type City struct {
|
||||||
|
Continent struct {
|
||||||
|
Code string `maxminddb:"code" json:"code"`
|
||||||
|
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
|
||||||
|
Names map[string]string `maxminddb:"names" json:"names"`
|
||||||
|
} `maxminddb:"continent" json:"continent"`
|
||||||
|
Country struct {
|
||||||
|
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
|
||||||
|
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
|
||||||
|
Names map[string]string `maxminddb:"names" json:"names"`
|
||||||
|
} `maxminddb:"country" json:"country"`
|
||||||
|
Location struct {
|
||||||
|
AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'`
|
||||||
|
Latitude float64 `maxminddb:"latitude" json:"latitude"`
|
||||||
|
Longitude float64 `maxminddb:"longitude" json:"longitude"`
|
||||||
|
} `maxminddb:"location"`
|
||||||
|
RegisteredCountry struct {
|
||||||
|
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
|
||||||
|
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
|
||||||
|
Names map[string]string `maxminddb:"names" json:"names"`
|
||||||
|
} `maxminddb:"registered_country" json:"registered_country"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
Server string `mapstructure:"server" yaml:"server"`
|
||||||
|
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
|
||||||
|
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
|
||||||
|
Continent string `mapstructure:"continent"`
|
||||||
|
Weight int `mapstructure:"weight" yaml:"weight"`
|
||||||
|
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new instance of Redirector
|
||||||
|
func New(config *Config) *Redirector {
|
||||||
|
r := &Redirector{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.checks = []ServerCheck{
|
||||||
|
r.checkHttp("http"),
|
||||||
|
r.checkTLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Redirector) Start() http.Handler {
|
||||||
|
if err := r.ReloadConfig(); err != nil {
|
||||||
|
log.WithError(err).Fatalln("Unable to load configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Starting check loop")
|
||||||
|
|
||||||
|
// Start check loop
|
||||||
|
go r.servers.checkLoop(r.checks)
|
||||||
|
|
||||||
|
log.Info("Setting up routes")
|
||||||
|
|
||||||
|
router := chi.NewRouter()
|
||||||
|
|
||||||
|
router.Use(middleware.RealIPMiddleware)
|
||||||
|
router.Use(logger.Logger("router", log.StandardLogger()))
|
||||||
|
|
||||||
|
router.Head("/status", r.statusHandler)
|
||||||
|
router.Get("/status", r.statusHandler)
|
||||||
|
router.Get("/mirrors", r.legacyMirrorsHandler)
|
||||||
|
router.Get("/mirrors/{server}.svg", r.mirrorStatusHandler)
|
||||||
|
router.Get("/mirrors.json", r.mirrorsHandler)
|
||||||
|
router.Post("/reload", r.reloadHandler)
|
||||||
|
router.Get("/dl_map", r.dlMapHandler)
|
||||||
|
router.Get("/geoip", r.geoIPHandler)
|
||||||
|
router.Get("/metrics", promhttp.Handler().ServeHTTP)
|
||||||
|
|
||||||
|
router.NotFound(r.redirectHandler)
|
||||||
|
|
||||||
|
if r.config.BindAddress != "" {
|
||||||
|
log.WithField("bind", r.config.BindAddress).Info("Binding to address")
|
||||||
|
|
||||||
|
go http.ListenAndServe(r.config.BindAddress, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
42
servers.go
42
servers.go
|
@ -1,7 +1,6 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"github.com/jmcvetta/randutil"
|
"github.com/jmcvetta/randutil"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -20,13 +19,6 @@ var (
|
||||||
return http.ErrUseLastResponse
|
return http.ErrUseLastResponse
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
checkTLSConfig *tls.Config = nil
|
|
||||||
|
|
||||||
checks = []serverCheck{
|
|
||||||
checkHttp,
|
|
||||||
checkTLS,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server represents a download server
|
// Server represents a download server
|
||||||
|
@ -38,14 +30,15 @@ type Server struct {
|
||||||
Longitude float64 `json:"longitude"`
|
Longitude float64 `json:"longitude"`
|
||||||
Weight int `json:"weight"`
|
Weight int `json:"weight"`
|
||||||
Continent string `json:"continent"`
|
Continent string `json:"continent"`
|
||||||
|
Protocols ProtocolList `json:"protocols"`
|
||||||
Redirects prometheus.Counter `json:"-"`
|
Redirects prometheus.Counter `json:"-"`
|
||||||
LastChange time.Time `json:"lastChange"`
|
LastChange time.Time `json:"lastChange"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverCheck func(server *Server, logFields log.Fields) (bool, error)
|
type ServerCheck func(server *Server, logFields log.Fields) (bool, error)
|
||||||
|
|
||||||
// checkStatus runs all status checks against a server
|
// checkStatus runs all status checks against a server
|
||||||
func (server *Server) checkStatus() {
|
func (server *Server) checkStatus(checks []ServerCheck) {
|
||||||
logFields := log.Fields{
|
logFields := log.Fields{
|
||||||
"host": server.Host,
|
"host": server.Host,
|
||||||
}
|
}
|
||||||
|
@ -87,19 +80,19 @@ func (server *Server) checkStatus() {
|
||||||
|
|
||||||
type ServerList []*Server
|
type ServerList []*Server
|
||||||
|
|
||||||
func (s ServerList) checkLoop() {
|
func (s ServerList) checkLoop(checks []ServerCheck) {
|
||||||
t := time.NewTicker(60 * time.Second)
|
t := time.NewTicker(60 * time.Second)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
<-t.C
|
<-t.C
|
||||||
s.Check()
|
s.Check(checks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check will request the index from all servers
|
// Check will request the index from all servers
|
||||||
// If a server does not respond in 10 seconds, it is considered offline.
|
// If a server does not respond in 10 seconds, it is considered offline.
|
||||||
// This will wait until all checks are complete.
|
// This will wait until all checks are complete.
|
||||||
func (s ServerList) Check() {
|
func (s ServerList) Check(checks []ServerCheck) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for _, server := range s {
|
for _, server := range s {
|
||||||
|
@ -108,7 +101,7 @@ func (s ServerList) Check() {
|
||||||
go func(server *Server) {
|
go func(server *Server) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
server.checkStatus()
|
server.checkStatus(checks)
|
||||||
}(server)
|
}(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,12 +120,12 @@ type DistanceList []ComputedDistance
|
||||||
// Closest will use GeoIP on the IP provided and find the closest servers.
|
// Closest will use GeoIP on the IP provided and find the closest servers.
|
||||||
// When we have a list of x servers closest, we can choose a random or weighted one.
|
// When we have a list of x servers closest, we can choose a random or weighted one.
|
||||||
// Return values are the closest server, the distance, and if an error occurred.
|
// Return values are the closest server, the distance, and if an error occurred.
|
||||||
func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
|
func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, float64, error) {
|
||||||
choiceInterface, exists := serverCache.Get(ip.String())
|
choiceInterface, exists := r.serverCache.Get(scheme + "_" + ip.String())
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
var city LocationLookup
|
var city LocationLookup
|
||||||
err := db.Lookup(ip, &city)
|
err := r.db.Lookup(ip, &city)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, -1, err
|
return nil, -1, err
|
||||||
|
@ -141,7 +134,7 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
|
||||||
c := make(DistanceList, len(s))
|
c := make(DistanceList, len(s))
|
||||||
|
|
||||||
for i, server := range s {
|
for i, server := range s {
|
||||||
if !server.Available {
|
if !server.Available || !server.Protocols.Contains(scheme) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,9 +151,9 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
|
||||||
return c[i].Distance < c[j].Distance
|
return c[i].Distance < c[j].Distance
|
||||||
})
|
})
|
||||||
|
|
||||||
choiceCount := topChoices
|
choiceCount := r.config.TopChoices
|
||||||
|
|
||||||
if len(c) < topChoices {
|
if len(c) < r.config.TopChoices {
|
||||||
choiceCount = len(c)
|
choiceCount = len(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,7 +172,7 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
|
||||||
|
|
||||||
choiceInterface = choices
|
choiceInterface = choices
|
||||||
|
|
||||||
serverCache.Add(ip.String(), choiceInterface)
|
r.serverCache.Add(scheme+"_"+ip.String(), choiceInterface)
|
||||||
}
|
}
|
||||||
|
|
||||||
choice, err := randutil.WeightedChoice(choiceInterface.([]randutil.Choice))
|
choice, err := randutil.WeightedChoice(choiceInterface.([]randutil.Choice))
|
||||||
|
@ -192,9 +185,9 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
|
||||||
|
|
||||||
if !dist.Server.Available {
|
if !dist.Server.Available {
|
||||||
// Choose a new server and refresh cache
|
// Choose a new server and refresh cache
|
||||||
serverCache.Remove(ip.String())
|
r.serverCache.Remove(scheme + "_" + ip.String())
|
||||||
|
|
||||||
return s.Closest(ip)
|
return s.Closest(r, scheme, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dist.Server, dist.Distance, nil
|
return dist.Server, dist.Distance, nil
|
||||||
|
@ -206,6 +199,7 @@ func hsin(theta float64) float64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Distance function returns the distance (in meters) between two points of
|
// Distance function returns the distance (in meters) between two points of
|
||||||
|
//
|
||||||
// a given longitude and latitude relatively accurately (using a spherical
|
// a given longitude and latitude relatively accurately (using a spherical
|
||||||
// approximation of the Earth) through the Haversine Distance Formula for
|
// approximation of the Earth) through the Haversine Distance Formula for
|
||||||
// great arc distance on a sphere with accuracy for small distances
|
// great arc distance on a sphere with accuracy for small distances
|
||||||
|
|
4
util.go
4
util.go
|
@ -1,10 +1,10 @@
|
||||||
package main
|
package redirector
|
||||||
|
|
||||||
import "math/rand"
|
import "math/rand"
|
||||||
|
|
||||||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
|
||||||
func randSeq(n int) string {
|
func RandomSequence(n int) string {
|
||||||
b := make([]rune, n)
|
b := make([]rune, n)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
b[i] = letters[rand.Intn(len(letters))]
|
b[i] = letters[rand.Intn(len(letters))]
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"github.com/gwatts/rootcerts/certparse"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultDownloadURL = "https://github.com/mozilla/gecko-dev/blob/master/security/nss/lib/ckfw/builtins/certdata.txt?raw=true"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadCACerts() (*x509.CertPool, error) {
|
||||||
|
res, err := http.Get(defaultDownloadURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
certs, err := certparse.ReadTrustedCerts(res.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
|
||||||
|
var count int
|
||||||
|
|
||||||
|
for _, cert := range certs {
|
||||||
|
if cert.Trust&certparse.ServerTrustedDelegator == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
count++
|
||||||
|
|
||||||
|
pool.AddCert(cert.Cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithField("certs", count).Info("Loaded root cas")
|
||||||
|
|
||||||
|
return pool, nil
|
||||||
|
}
|
Loading…
Reference in New Issue