A bit of cleanup and documentation
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Tyler 2022-01-09 23:53:23 -05:00
parent c98b04f9c1
commit e6d3782450
4 changed files with 257 additions and 251 deletions

67
http.go
View File

@ -1,51 +1,52 @@
package main package main
import ( import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
) )
func statusRequest(w http.ResponseWriter, r *http.Request) { func statusRequest(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func redirectRequest(w http.ResponseWriter, r *http.Request) { func redirectRequest(w http.ResponseWriter, r *http.Request) {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr) ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
if ip.IsPrivate() { // TODO: This is temporary to allow testing on private addresses.
ip = net.ParseIP("1.1.1.1") if ip.IsPrivate() {
} ip = net.ParseIP("1.1.1.1")
}
server, distance, err := settings.Servers.Closest(ip) server, distance, err := servers.Closest(ip)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
scheme := r.URL.Scheme scheme := r.URL.Scheme
if scheme == "" { if scheme == "" {
scheme = "https" scheme = "https"
} }
u := &url.URL{ u := &url.URL{
Scheme: scheme, Scheme: scheme,
Host: server.Host, Host: server.Host,
Path: path.Join(server.Path, r.URL.Path), Path: path.Join(server.Path, r.URL.Path),
} }
w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance))
w.Header().Set("Location", u.String()) w.Header().Set("Location", u.String())
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }

165
main.go
View File

@ -1,117 +1,114 @@
package main package main
import ( import (
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/viper" "github.com/spf13/viper"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
) )
var ( var (
db *maxminddb.Reader db *maxminddb.Reader
settings = &Settings{} servers ServerList
) )
// City represents a MaxmindDB city
type City struct { type City struct {
Location struct { Location struct {
Latitude float64 `maxminddb:"latitude"` Latitude float64 `maxminddb:"latitude"`
Longitude float64 `maxminddb:"longitude"` Longitude float64 `maxminddb:"longitude"`
} `maxminddb:"location"` } `maxminddb:"location"`
}
type Settings struct {
Servers ServerList
} }
func main() { func main() {
viper.SetConfigName("dlrouter") // name of config file (without extension) viper.SetConfigName("dlrouter") // name of config file (without extension)
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
viper.AddConfigPath(".") // optionally look for config in the working directory viper.AddConfigPath(".") // optionally look for config in the working directory
err := viper.ReadInConfig() // Find and read the config file err := viper.ReadInConfig() // Find and read the config file
if err != nil { // Handle errors reading the config file if err != nil { // Handle errors reading the config file
log.WithError(err).Fatalln("Unable to load config file") log.WithError(err).Fatalln("Unable to load config file")
} }
db, err = maxminddb.Open(viper.GetString("geodb")) db, err = maxminddb.Open(viper.GetString("geodb"))
if err != nil { if err != nil {
log.WithError(err).Fatalln("Unable to open database") log.WithError(err).Fatalln("Unable to open database")
} }
servers := viper.GetStringSlice("servers") serverList := viper.GetStringSlice("servers")
for _, server := range servers { for _, server := range serverList {
var prefix string var prefix string
if !strings.HasPrefix(server, "http") { if !strings.HasPrefix(server, "http") {
prefix = "https://" prefix = "https://"
} }
u, err := url.Parse(prefix + server) u, err := url.Parse(prefix + server)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"error": err, "error": err,
"server": server, "server": server,
}).Warning("Server is invalid") }).Warning("Server is invalid")
continue continue
} }
ips, err := net.LookupIP(u.Host) ips, err := net.LookupIP(u.Host)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"error": err, "error": err,
"server": server, "server": server,
}).Warning("Could not resolve address") }).Warning("Could not resolve address")
continue continue
} }
var city City var city City
err = db.Lookup(ips[0], &city) err = db.Lookup(ips[0], &city)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"error": err, "error": err,
"server": server, "server": server,
}).Warning("Could not geolocate address") }).Warning("Could not geolocate address")
continue continue
} }
settings.Servers = append(settings.Servers, &Server{ servers = append(servers, &Server{
Host: u.Host, Host: u.Host,
Path: u.Path, Path: u.Path,
Latitude: city.Location.Latitude, Latitude: city.Location.Latitude,
Longitude: city.Location.Longitude, Longitude: city.Location.Longitude,
}) })
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"server": u.Host, "server": u.Host,
"path": u.Path, "path": u.Path,
"latitude": city.Location.Latitude, "latitude": city.Location.Latitude,
"longitude": city.Location.Longitude, "longitude": city.Location.Longitude,
}).Info("Added server") }).Info("Added server")
} }
log.Info("Servers added, checking statuses") log.Info("Servers added, checking statuses")
// Force initial check before running // Force initial check before running
settings.Servers.Check() servers.Check()
// Start check loop // Start check loop
go settings.Servers.checkLoop() go servers.checkLoop()
log.Info("Starting") log.Info("Starting")
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/status", RealIPMiddleware(statusRequest)) mux.HandleFunc("/status", RealIPMiddleware(statusRequest))
mux.HandleFunc("/", RealIPMiddleware(redirectRequest)) mux.HandleFunc("/", RealIPMiddleware(redirectRequest))
http.ListenAndServe(":8080", mux) http.ListenAndServe(":8080", mux)
} }

View File

@ -1,86 +1,87 @@
package main package main
import ( import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
) )
var ( var (
xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
xRealIP = http.CanonicalHeaderKey("X-Real-IP") xRealIP = http.CanonicalHeaderKey("X-Real-IP")
forwardLimit = 5 forwardLimit = 5
) )
// RealIPMiddleware is an implementation of reverse proxy checks.
// It uses the remote address to find the originating IP, as well as protocol
func RealIPMiddleware(f http.HandlerFunc) http.HandlerFunc { func RealIPMiddleware(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Treat unix socket as 127.0.0.1 // Treat unix socket as 127.0.0.1
if r.RemoteAddr == "@" { if r.RemoteAddr == "@" {
r.RemoteAddr = "127.0.0.1:0" r.RemoteAddr = "127.0.0.1:0"
} }
host, _, err := net.SplitHostPort(r.RemoteAddr) host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
f.ServeHTTP(w, r) f.ServeHTTP(w, r)
return return
} }
if net.ParseIP(host).IsPrivate() { if net.ParseIP(host).IsPrivate() {
f.ServeHTTP(w, r) f.ServeHTTP(w, r)
return return
} }
if rip := realIP(r); len(rip) > 0 { if rip := realIP(r); len(rip) > 0 {
r.RemoteAddr = net.JoinHostPort(rip, "0") r.RemoteAddr = net.JoinHostPort(rip, "0")
} }
if rproto := realProto(r); len(rproto) > 0 { if rproto := realProto(r); len(rproto) > 0 {
r.URL.Scheme = rproto r.URL.Scheme = rproto
} }
f.ServeHTTP(w, r) f.ServeHTTP(w, r)
} }
} }
func realIP(r *http.Request) string { func realIP(r *http.Request) string {
var ip string var ip string
if xrip := r.Header.Get(xRealIP); xrip != "" { if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" { } else if xff := r.Header.Get(xForwardedFor); xff != "" {
p := 0 p := 0
for i := forwardLimit; i > 0; i-- { for i := forwardLimit; i > 0; i-- {
if p > 0 { if p > 0 {
xff = xff[:p-2] xff = xff[:p-2]
} }
p = strings.LastIndex(xff, ", ") p = strings.LastIndex(xff, ", ")
if p < 0 { if p < 0 {
p = 0 p = 0
break break
} else { } else {
p += 2 p += 2
} }
} }
ip = xff[p:] ip = xff[p:]
} }
return ip return ip
} }
func realProto(r *http.Request) string { func realProto(r *http.Request) string {
proto := "http" proto := "http"
if r.TLS != nil { if r.TLS != nil {
proto = "https" proto = "https"
} }
if xproto := r.Header.Get(xForwardedProto); xproto != "" { if xproto := r.Header.Get(xForwardedProto); xproto != "" {
proto = xproto proto = xproto
} }
return proto return proto
} }

View File

@ -1,107 +1,114 @@
package main package main
import ( import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math" "math"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
var ( var (
checkClient = &http.Client{ checkClient = &http.Client{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
} }
) )
type Server struct { type Server struct {
Available bool Available bool
Host string Host string
Path string Path string
Latitude float64 Latitude float64
Longitude float64 Longitude float64
} }
type ServerList []*Server type ServerList []*Server
func (s ServerList) checkLoop() { func (s ServerList) checkLoop() {
t := time.NewTicker(60 * time.Second) t := time.NewTicker(60 * time.Second)
for { for {
<- t.C <-t.C
s.Check() s.Check()
} }
} }
// Check will request the index from all servers
// If a server does not respond in 10 seconds, it is considered offline.
// This will wait until all checks are complete.
func (s ServerList) Check() { func (s ServerList) Check() {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, server := range s { for _, server := range s {
wg.Add(1) wg.Add(1)
go func(server *Server) { go func(server *Server) {
req, err := http.NewRequest(http.MethodGet, "https://" + server.Host + "/" + strings.TrimLeft(server.Path, "/"), nil) req, err := http.NewRequest(http.MethodGet, "https://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil)
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go " + runtime.Version() + ")") req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
if err != nil { if err != nil {
return // This should never happen.
} log.WithError(err).Warning("Invalid request! This should not happen, please check config.")
return
}
res, err := checkClient.Do(req) res, err := checkClient.Do(req)
if err != nil { if err != nil {
log.WithField("server", server.Host).Info("Server went offline") log.WithField("server", server.Host).Info("Server went offline")
server.Available = false server.Available = false
return return
} }
if (res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound) && if (res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound) &&
!server.Available { !server.Available {
server.Available = true server.Available = true
log.WithField("server", server.Host).Info("Server is online") log.WithField("server", server.Host).Info("Server is online")
} }
wg.Done() wg.Done()
}(server) }(server)
} }
wg.Wait() wg.Wait()
} }
// Closest will use GeoIP on the IP provided and find the closest server.
// Return values are the closest server, the distance, and if an error occurred.
func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
var city City var city City
err := db.Lookup(ip, &city) err := db.Lookup(ip, &city)
if err != nil { if err != nil {
return nil, -1, err return nil, -1, err
} }
var closest *Server var closest *Server
var closestDistance float64 = -1 var closestDistance float64 = -1
for _, server := range s { for _, server := range s {
if !server.Available { if !server.Available {
continue continue
} }
distance := Distance(city.Location.Latitude, city.Location.Longitude, server.Latitude, server.Longitude) distance := Distance(city.Location.Latitude, city.Location.Longitude, server.Latitude, server.Longitude)
if closestDistance == -1 || distance < closestDistance { if closestDistance == -1 || distance < closestDistance {
closestDistance = distance closestDistance = distance
closest = server closest = server
} }
} }
return closest, closestDistance, nil return closest, closestDistance, nil
} }
// haversin(θ) function // haversin(θ) function
func hsin(theta float64) float64 { func hsin(theta float64) float64 {
return math.Pow(math.Sin(theta/2), 2) return math.Pow(math.Sin(theta/2), 2)
} }
// Distance function returns the distance (in meters) between two points of // Distance function returns the distance (in meters) between two points of
@ -114,18 +121,18 @@ func hsin(theta float64) float64 {
// distance returned is METERS!!!!!! // distance returned is METERS!!!!!!
// http://en.wikipedia.org/wiki/Haversine_formula // http://en.wikipedia.org/wiki/Haversine_formula
func Distance(lat1, lon1, lat2, lon2 float64) float64 { func Distance(lat1, lon1, lat2, lon2 float64) float64 {
// convert to radians // convert to radians
// must cast radius as float to multiply later // must cast radius as float to multiply later
var la1, lo1, la2, lo2, r float64 var la1, lo1, la2, lo2, r float64
la1 = lat1 * math.Pi / 180 la1 = lat1 * math.Pi / 180
lo1 = lon1 * math.Pi / 180 lo1 = lon1 * math.Pi / 180
la2 = lat2 * math.Pi / 180 la2 = lat2 * math.Pi / 180
lo2 = lon2 * math.Pi / 180 lo2 = lon2 * math.Pi / 180
r = 6378100 // Earth radius in METERS r = 6378100 // Earth radius in METERS
// calculate // calculate
h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1) h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1)
return 2 * r * math.Asin(math.Sqrt(h)) return 2 * r * math.Asin(math.Sqrt(h))
} }