diff --git a/config.go b/config.go index c2c9ed7..30cbbea 100644 --- a/config.go +++ b/config.go @@ -46,6 +46,17 @@ func reloadConfig() { // Reload server list reloadServers() + // Create mirror map + mirrors := make(map[string][]*Server) + + for _, server := range servers { + mirrors[server.Continent] = append(mirrors[server.Continent], server) + } + + mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...) + + mirrorMap = mirrors + // Check top choices size if topChoices > len(servers) { topChoices = len(servers) diff --git a/http.go b/http.go index 89e01e1..b9db356 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package main import ( "encoding/json" "fmt" + "github.com/jmcvetta/randutil" "net" "net/http" "net/url" @@ -19,21 +20,19 @@ func statusHandler(w http.ResponseWriter, r *http.Request) { func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - mirrors := make(map[string][]string) + mirrorOutput := make(map[string][]string) - for _, server := range servers { - u := &url.URL{ - Scheme: r.URL.Scheme, - Host: server.Host, - Path: server.Path, + for region, mirrors := range mirrorMap { + list := make([]string, len(mirrors)) + + for i, mirror := range mirrors { + list[i] = r.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/") } - mirrors[server.Continent] = append(mirrors[server.Continent], u.String()) + mirrorOutput[region] = list } - mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...) - - json.NewEncoder(w).Encode(mirrors) + json.NewEncoder(w).Encode(mirrorOutput) } func mirrorsHandler(w http.ResponseWriter, r *http.Request) { @@ -61,11 +60,47 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { ip = net.ParseIP(overrideIP) } - server, distance, err := servers.Closest(ip) + var server *Server + var distance float64 - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + if strings.HasPrefix(r.URL.Path, "/region") { + parts := strings.Split(r.URL.Path, "/") + + // region = parts[2] + if mirrors, ok := mirrorMap[parts[2]]; ok { + choices := make([]randutil.Choice, len(mirrors)) + + for i, item := range mirrors { + if !item.Available { + continue + } + + choices[i] = randutil.Choice{ + Weight: item.Weight, + Item: item, + } + } + + choice, err := randutil.WeightedChoice(choices) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + server = choice.Item.(*Server) + + r.URL.Path = strings.Join(parts[3:], "/") + } + } + + if server == nil { + server, distance, err = servers.Closest(ip) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } scheme := r.URL.Scheme @@ -102,7 +137,10 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { server.Redirects.Inc() redirectsServed.Inc() - w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) + if distance > 0 { + w.Header().Set("X-Geo-Distance", fmt.Sprintf("%f", distance)) + } + w.Header().Set("Location", u.String()) w.WriteHeader(http.StatusFound) } diff --git a/main.go b/main.go index 7d589fa..7dad8f2 100644 --- a/main.go +++ b/main.go @@ -18,8 +18,9 @@ import ( ) var ( - db *maxminddb.Reader - servers ServerList + db *maxminddb.Reader + servers ServerList + mirrorMap map[string][]*Server dlMap map[string]string @@ -110,6 +111,7 @@ func main() { r.Use(RealIPMiddleware) r.Use(logger.Logger("router", log.StandardLogger())) + r.Head("/status", statusHandler) r.Get("/status", statusHandler) r.Get("/mirrors", legacyMirrorsHandler) r.Get("/mirrors.json", mirrorsHandler) diff --git a/servers.go b/servers.go index 22ef708..5bca3aa 100644 --- a/servers.go +++ b/servers.go @@ -160,6 +160,10 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { choices := make([]randutil.Choice, topChoices) for i, item := range c[0:topChoices] { + if item.Server == nil { + continue + } + choices[i] = randutil.Choice{ Weight: item.Server.Weight, Item: item,