Resolve issue with returning a lower number of servers, add auth to reload
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/tag Build is passing Details

This commit is contained in:
Tyler 2022-04-02 14:35:46 -04:00
parent e5434a9a7b
commit 9caa391601
4 changed files with 73 additions and 3 deletions

15
http.go
View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/jmcvetta/randutil" "github.com/jmcvetta/randutil"
"github.com/spf13/viper"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -117,6 +118,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) {
} }
func reloadHandler(w http.ResponseWriter, r *http.Request) { func reloadHandler(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
token = token[strings.Index(token, " ")+1:]
if token != viper.GetString("reloadToken") {
w.WriteHeader(http.StatusUnauthorized)
return
}
reloadConfig() reloadConfig()
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)

View File

@ -79,14 +79,20 @@ type ServerConfig struct {
var ( var (
configFlag = flag.String("config", "", "configuration file path") configFlag = flag.String("config", "", "configuration file path")
flagDebug = flag.Bool("debug", false, "Enable debug logging")
) )
func main() { func main() {
flag.Parse() flag.Parse()
if *flagDebug {
log.SetLevel(log.DebugLevel)
}
viper.SetDefault("bind", ":8080") viper.SetDefault("bind", ":8080")
viper.SetDefault("cacheSize", 1024) viper.SetDefault("cacheSize", 1024)
viper.SetDefault("topChoices", 3) viper.SetDefault("topChoices", 3)
viper.SetDefault("reloadKey", randSeq(32))
viper.SetConfigName("dlrouter") // name of config file (without extension) viper.SetConfigName("dlrouter") // name of config file (without extension)
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name

View File

@ -7,6 +7,7 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
@ -33,7 +34,7 @@ type Server struct {
} }
func (server *Server) checkStatus() { func (server *Server) checkStatus() {
req, err := http.NewRequest(http.MethodGet, "https://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil) req, err := http.NewRequest(http.MethodGet, "http://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil)
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")") req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
@ -72,6 +73,35 @@ func (server *Server) checkStatus() {
} }
if res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound || res.StatusCode == http.StatusNotFound { if res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound || res.StatusCode == http.StatusNotFound {
if res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound {
location := res.Header.Get("Location")
responseFields["url"] = location
log.WithFields(responseFields).Debug("Server responded with redirect")
newUrl, err := url.Parse(location)
if err != nil {
if server.Available {
log.WithFields(responseFields).Warning("Server returned invalid url")
server.Available = false
server.LastChange = time.Now()
}
return
}
if newUrl.Scheme == "https" {
if server.Available {
responseFields["url"] = location
log.WithFields(responseFields).Warning("Server returned https url for http request")
server.Available = false
server.LastChange = time.Now()
}
return
}
}
if !server.Available { if !server.Available {
server.Available = true server.Available = true
server.LastChange = time.Now() server.LastChange = time.Now()
@ -161,9 +191,15 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
return c[i].Distance < c[j].Distance return c[i].Distance < c[j].Distance
}) })
choices := make([]randutil.Choice, topChoices) choiceCount := topChoices
for i, item := range c[0:topChoices] { if len(c) < topChoices {
choiceCount = len(c)
}
choices := make([]randutil.Choice, choiceCount)
for i, item := range c[0:choiceCount] {
if item.Server == nil { if item.Server == nil {
continue continue
} }

13
util.go Normal file
View File

@ -0,0 +1,13 @@
package main
import "math/rand"
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}