diff --git a/http.go b/http.go index dc84525..3c7b610 100644 --- a/http.go +++ b/http.go @@ -1,51 +1,52 @@ package main import ( - "fmt" - "net" - "net/http" - "net/url" - "path" + "fmt" + "net" + "net/http" + "net/url" + "path" ) func statusRequest(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusOK) } func redirectRequest(w http.ResponseWriter, r *http.Request) { - ipStr, _, err := net.SplitHostPort(r.RemoteAddr) + ipStr, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - ip := net.ParseIP(ipStr) + ip := net.ParseIP(ipStr) - if ip.IsPrivate() { - ip = net.ParseIP("1.1.1.1") - } + // TODO: This is temporary to allow testing on private addresses. + if ip.IsPrivate() { + ip = net.ParseIP("1.1.1.1") + } - server, distance, err := settings.Servers.Closest(ip) + server, distance, err := servers.Closest(ip) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - scheme := r.URL.Scheme + scheme := r.URL.Scheme - if scheme == "" { - scheme = "https" - } + if scheme == "" { + scheme = "https" + } - u := &url.URL{ - Scheme: scheme, - Host: server.Host, - Path: path.Join(server.Path, r.URL.Path), - } + u := &url.URL{ + Scheme: scheme, + Host: server.Host, + Path: path.Join(server.Path, r.URL.Path), + } - w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) - w.Header().Set("Location", u.String()) - w.WriteHeader(http.StatusFound) -} \ No newline at end of file + w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) + w.Header().Set("Location", u.String()) + w.WriteHeader(http.StatusFound) +} diff --git a/main.go b/main.go index cef72dc..8c34f25 100644 --- a/main.go +++ b/main.go @@ -1,117 +1,114 @@ package main import ( - "github.com/oschwald/maxminddb-golang" - log "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "net" - "net/http" - "net/url" - "strings" + "github.com/oschwald/maxminddb-golang" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "net" + "net/http" + "net/url" + "strings" ) var ( - db *maxminddb.Reader - settings = &Settings{} + db *maxminddb.Reader + servers ServerList ) +// City represents a MaxmindDB city type City struct { - Location struct { - Latitude float64 `maxminddb:"latitude"` - Longitude float64 `maxminddb:"longitude"` - } `maxminddb:"location"` -} - -type Settings struct { - Servers ServerList + Location struct { + Latitude float64 `maxminddb:"latitude"` + Longitude float64 `maxminddb:"longitude"` + } `maxminddb:"location"` } func main() { - viper.SetConfigName("dlrouter") // name of config file (without extension) - viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name - viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in - viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths - viper.AddConfigPath(".") // optionally look for config in the working directory - err := viper.ReadInConfig() // Find and read the config file + viper.SetConfigName("dlrouter") // name of config file (without extension) + viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name + viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in + viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths + viper.AddConfigPath(".") // optionally look for config in the working directory + err := viper.ReadInConfig() // Find and read the config file - if err != nil { // Handle errors reading the config file - log.WithError(err).Fatalln("Unable to load config file") - } + if err != nil { // Handle errors reading the config file + log.WithError(err).Fatalln("Unable to load config file") + } - db, err = maxminddb.Open(viper.GetString("geodb")) + db, err = maxminddb.Open(viper.GetString("geodb")) - if err != nil { - log.WithError(err).Fatalln("Unable to open database") - } + if err != nil { + log.WithError(err).Fatalln("Unable to open database") + } - servers := viper.GetStringSlice("servers") + serverList := viper.GetStringSlice("servers") - for _, server := range servers { - var prefix string + for _, server := range serverList { + var prefix string - if !strings.HasPrefix(server, "http") { - prefix = "https://" - } + if !strings.HasPrefix(server, "http") { + prefix = "https://" + } - u, err := url.Parse(prefix + server) + u, err := url.Parse(prefix + server) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "server": server, - }).Warning("Server is invalid") - continue - } + if err != nil { + log.WithFields(log.Fields{ + "error": err, + "server": server, + }).Warning("Server is invalid") + continue + } - ips, err := net.LookupIP(u.Host) + ips, err := net.LookupIP(u.Host) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "server": server, - }).Warning("Could not resolve address") - continue - } + if err != nil { + log.WithFields(log.Fields{ + "error": err, + "server": server, + }).Warning("Could not resolve address") + continue + } - var city City - err = db.Lookup(ips[0], &city) + var city City + err = db.Lookup(ips[0], &city) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "server": server, - }).Warning("Could not geolocate address") - continue - } + if err != nil { + log.WithFields(log.Fields{ + "error": err, + "server": server, + }).Warning("Could not geolocate address") + continue + } - settings.Servers = append(settings.Servers, &Server{ - Host: u.Host, - Path: u.Path, - Latitude: city.Location.Latitude, - Longitude: city.Location.Longitude, - }) + servers = append(servers, &Server{ + Host: u.Host, + Path: u.Path, + Latitude: city.Location.Latitude, + Longitude: city.Location.Longitude, + }) - log.WithFields(log.Fields{ - "server": u.Host, - "path": u.Path, - "latitude": city.Location.Latitude, - "longitude": city.Location.Longitude, - }).Info("Added server") - } + log.WithFields(log.Fields{ + "server": u.Host, + "path": u.Path, + "latitude": city.Location.Latitude, + "longitude": city.Location.Longitude, + }).Info("Added server") + } - log.Info("Servers added, checking statuses") - // Force initial check before running - settings.Servers.Check() + log.Info("Servers added, checking statuses") + // Force initial check before running + servers.Check() - // Start check loop - go settings.Servers.checkLoop() + // Start check loop + go servers.checkLoop() - log.Info("Starting") + log.Info("Starting") - mux := http.NewServeMux() + mux := http.NewServeMux() - mux.HandleFunc("/status", RealIPMiddleware(statusRequest)) - mux.HandleFunc("/", RealIPMiddleware(redirectRequest)) + mux.HandleFunc("/status", RealIPMiddleware(statusRequest)) + mux.HandleFunc("/", RealIPMiddleware(redirectRequest)) - http.ListenAndServe(":8080", mux) + http.ListenAndServe(":8080", mux) } diff --git a/middleware.go b/middleware.go index 76eaafd..06e4b63 100644 --- a/middleware.go +++ b/middleware.go @@ -1,86 +1,87 @@ package main import ( - "net" - "net/http" - "strings" + "net" + "net/http" + "strings" ) var ( - xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") - xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") - xRealIP = http.CanonicalHeaderKey("X-Real-IP") - forwardLimit = 5 + xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") + xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") + xRealIP = http.CanonicalHeaderKey("X-Real-IP") + forwardLimit = 5 ) +// RealIPMiddleware is an implementation of reverse proxy checks. +// It uses the remote address to find the originating IP, as well as protocol func RealIPMiddleware(f http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Treat unix socket as 127.0.0.1 - if r.RemoteAddr == "@" { - r.RemoteAddr = "127.0.0.1:0" - } + return func(w http.ResponseWriter, r *http.Request) { + // Treat unix socket as 127.0.0.1 + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:0" + } - host, _, err := net.SplitHostPort(r.RemoteAddr) + host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - f.ServeHTTP(w, r) - return - } + if err != nil { + f.ServeHTTP(w, r) + return + } - if net.ParseIP(host).IsPrivate() { - f.ServeHTTP(w, r) - return - } + if net.ParseIP(host).IsPrivate() { + f.ServeHTTP(w, r) + return + } - if rip := realIP(r); len(rip) > 0 { - r.RemoteAddr = net.JoinHostPort(rip, "0") - } + if rip := realIP(r); len(rip) > 0 { + r.RemoteAddr = net.JoinHostPort(rip, "0") + } - if rproto := realProto(r); len(rproto) > 0 { - r.URL.Scheme = rproto - } + if rproto := realProto(r); len(rproto) > 0 { + r.URL.Scheme = rproto + } - f.ServeHTTP(w, r) - } + f.ServeHTTP(w, r) + } } func realIP(r *http.Request) string { - var ip string + var ip string - if xrip := r.Header.Get(xRealIP); xrip != "" { - ip = xrip - } else if xff := r.Header.Get(xForwardedFor); xff != "" { - p := 0 - for i := forwardLimit; i > 0; i-- { - if p > 0 { - xff = xff[:p-2] - } - p = strings.LastIndex(xff, ", ") - if p < 0 { - p = 0 - break - } else { - p += 2 - } - } + if xrip := r.Header.Get(xRealIP); xrip != "" { + ip = xrip + } else if xff := r.Header.Get(xForwardedFor); xff != "" { + p := 0 + for i := forwardLimit; i > 0; i-- { + if p > 0 { + xff = xff[:p-2] + } + p = strings.LastIndex(xff, ", ") + if p < 0 { + p = 0 + break + } else { + p += 2 + } + } - ip = xff[p:] - } + ip = xff[p:] + } - return ip + return ip } func realProto(r *http.Request) string { - proto := "http" + proto := "http" - if r.TLS != nil { - proto = "https" - } + if r.TLS != nil { + proto = "https" + } - if xproto := r.Header.Get(xForwardedProto); xproto != "" { - proto = xproto - } + if xproto := r.Header.Get(xForwardedProto); xproto != "" { + proto = xproto + } - return proto + return proto } - diff --git a/servers.go b/servers.go index 94ba791..32f7946 100644 --- a/servers.go +++ b/servers.go @@ -1,107 +1,114 @@ package main import ( - log "github.com/sirupsen/logrus" - "math" - "net" - "net/http" - "runtime" - "strings" - "sync" - "time" + log "github.com/sirupsen/logrus" + "math" + "net" + "net/http" + "runtime" + "strings" + "sync" + "time" ) var ( - checkClient = &http.Client{ - Timeout: 10 * time.Second, - } + checkClient = &http.Client{ + Timeout: 10 * time.Second, + } ) type Server struct { - Available bool - Host string - Path string - Latitude float64 - Longitude float64 + Available bool + Host string + Path string + Latitude float64 + Longitude float64 } type ServerList []*Server func (s ServerList) checkLoop() { - t := time.NewTicker(60 * time.Second) + t := time.NewTicker(60 * time.Second) - for { - <- t.C + for { + <-t.C - s.Check() - } + s.Check() + } } +// Check will request the index from all servers +// If a server does not respond in 10 seconds, it is considered offline. +// This will wait until all checks are complete. func (s ServerList) Check() { - var wg sync.WaitGroup + var wg sync.WaitGroup - for _, server := range s { - wg.Add(1) + for _, server := range s { + wg.Add(1) - go func(server *Server) { - req, err := http.NewRequest(http.MethodGet, "https://" + server.Host + "/" + strings.TrimLeft(server.Path, "/"), nil) + go func(server *Server) { + req, err := http.NewRequest(http.MethodGet, "https://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil) - req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go " + runtime.Version() + ")") + req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")") - if err != nil { - return - } + if err != nil { + // This should never happen. + log.WithError(err).Warning("Invalid request! This should not happen, please check config.") + return + } - res, err := checkClient.Do(req) + res, err := checkClient.Do(req) - if err != nil { - log.WithField("server", server.Host).Info("Server went offline") - server.Available = false - return - } + if err != nil { + log.WithField("server", server.Host).Info("Server went offline") + server.Available = false + return + } - if (res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound) && - !server.Available { - server.Available = true - log.WithField("server", server.Host).Info("Server is online") - } - wg.Done() - }(server) - } + if (res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound) && + !server.Available { + server.Available = true + log.WithField("server", server.Host).Info("Server is online") + } + wg.Done() + }(server) + } - wg.Wait() + wg.Wait() } +// Closest will use GeoIP on the IP provided and find the closest server. +// Return values are the closest server, the distance, and if an error occurred. func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { - var city City - err := db.Lookup(ip, &city) + var city City + err := db.Lookup(ip, &city) - if err != nil { - return nil, -1, err - } + if err != nil { + return nil, -1, err + } - var closest *Server - var closestDistance float64 = -1 + var closest *Server + var closestDistance float64 = -1 - for _, server := range s { - if !server.Available { - continue - } + for _, server := range s { + if !server.Available { + continue + } - distance := Distance(city.Location.Latitude, city.Location.Longitude, server.Latitude, server.Longitude) + distance := Distance(city.Location.Latitude, city.Location.Longitude, server.Latitude, server.Longitude) - if closestDistance == -1 || distance < closestDistance { - closestDistance = distance - closest = server - } - } + if closestDistance == -1 || distance < closestDistance { + closestDistance = distance + closest = server + } + } - return closest, closestDistance, nil + return closest, closestDistance, nil } // haversin(θ) function func hsin(theta float64) float64 { - return math.Pow(math.Sin(theta/2), 2) + return math.Pow(math.Sin(theta/2), 2) } // Distance function returns the distance (in meters) between two points of @@ -114,18 +121,18 @@ func hsin(theta float64) float64 { // distance returned is METERS!!!!!! // http://en.wikipedia.org/wiki/Haversine_formula func Distance(lat1, lon1, lat2, lon2 float64) float64 { - // convert to radians - // must cast radius as float to multiply later - var la1, lo1, la2, lo2, r float64 - la1 = lat1 * math.Pi / 180 - lo1 = lon1 * math.Pi / 180 - la2 = lat2 * math.Pi / 180 - lo2 = lon2 * math.Pi / 180 + // convert to radians + // must cast radius as float to multiply later + var la1, lo1, la2, lo2, r float64 + la1 = lat1 * math.Pi / 180 + lo1 = lon1 * math.Pi / 180 + la2 = lat2 * math.Pi / 180 + lo2 = lon2 * math.Pi / 180 - r = 6378100 // Earth radius in METERS + r = 6378100 // Earth radius in METERS - // calculate - h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1) + // calculate + h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1) - return 2 * r * math.Asin(math.Sqrt(h)) -} \ No newline at end of file + return 2 * r * math.Asin(math.Sqrt(h)) +}