diff --git a/go.mod b/go.mod index 1f7aa38..904b22b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module meow.tf/armbian-router go 1.17 require ( + github.com/go-chi/chi/v5 v5.0.7 github.com/oschwald/maxminddb-golang v1.8.0 github.com/prometheus/client_golang v1.11.0 github.com/sirupsen/logrus v1.8.1 diff --git a/go.sum b/go.sum index f4b908f..fe84c51 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= +github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= +github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= diff --git a/http.go b/http.go index ae0de0e..d506f7b 100644 --- a/http.go +++ b/http.go @@ -44,12 +44,6 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { return } - scheme := r.URL.Scheme - - if scheme == "" { - scheme = "https" - } - redirectPath := path.Join(server.Path, r.URL.Path) if dlMap != nil { @@ -60,7 +54,7 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { } u := &url.URL{ - Scheme: scheme, + Scheme: realProto(r), Host: server.Host, Path: redirectPath, } diff --git a/main.go b/main.go index adfbcc1..3592b31 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/oschwald/maxminddb-golang" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -92,16 +94,19 @@ func main() { log.Info("Starting") - mux := http.NewServeMux() + r := chi.NewRouter() - mux.HandleFunc("/status", statusHandler) - mux.HandleFunc("/mirrors", mirrorsHandler) - mux.HandleFunc("/reload", reloadHandler) - mux.HandleFunc("/dl_map", dlMapHandler) - mux.Handle("/metrics", promhttp.Handler()) - mux.HandleFunc("/", RealIPMiddleware(redirectHandler)) + r.Use(RealIPMiddleware) + r.Use(middleware.Logger) - http.ListenAndServe(viper.GetString("bind"), mux) + r.HandleFunc("/status", statusHandler) + r.HandleFunc("/mirrors", mirrorsHandler) + r.HandleFunc("/reload", reloadHandler) + r.HandleFunc("/dl_map", dlMapHandler) + r.Handle("/metrics", promhttp.Handler()) + r.HandleFunc("/", redirectHandler) + + http.ListenAndServe(viper.GetString("bind"), r) } var metricReplacer = strings.NewReplacer(".", "_", "-", "_") diff --git a/middleware.go b/middleware.go index 06e4b63..3412882 100644 --- a/middleware.go +++ b/middleware.go @@ -15,8 +15,8 @@ var ( // RealIPMiddleware is an implementation of reverse proxy checks. // It uses the remote address to find the originating IP, as well as protocol -func RealIPMiddleware(f http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func RealIPMiddleware(f http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Treat unix socket as 127.0.0.1 if r.RemoteAddr == "@" { r.RemoteAddr = "127.0.0.1:0" @@ -29,7 +29,7 @@ func RealIPMiddleware(f http.HandlerFunc) http.HandlerFunc { return } - if net.ParseIP(host).IsPrivate() { + if !net.ParseIP(host).IsPrivate() { f.ServeHTTP(w, r) return } @@ -43,7 +43,7 @@ func RealIPMiddleware(f http.HandlerFunc) http.HandlerFunc { } f.ServeHTTP(w, r) - } + }) } func realIP(r *http.Request) string {