Somewhat big rewrite, adds host registration and other improvements to structure
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/tag Build is passing Details

This commit is contained in:
Tyler 2021-12-22 23:32:59 -05:00
parent cc2988f6a5
commit f9f3cec5a2
17 changed files with 913 additions and 252 deletions

View File

@ -6,6 +6,15 @@ A simple, easy to use ngrok alternative (self hosted!)
The server and client can also be easily embedded into your applications, see the 'server' and 'client' directories.
Features
--------
* HTTP and HTTPS handling
* Public key authentication
* Authorized key whitelists
* Registration of reserved hosts
* Forwarding system that allows easy additions of other protocols (Coming soon: TCP and TCP+TLS, with host-based TLS support)
Example usage
-------------
@ -81,4 +90,13 @@ services:
environment:
- GOGROK_DOMAINS=gogrok.ccatss.dev
- GOGROK_AUTHORIZED_KEY_FILE=/config/authorized_keys
```
```
Host Registration
-----------------
Gogrok lets you register your own custom hosts that are attached to your public key.
On the server, make sure to run the server with the flag `--store=PATH_TO_DB.db`
Use `gogrok register` and `gogrok unregister` to manage registered hosts to your client key.

View File

@ -1,197 +1,181 @@
package client
import (
"bufio"
"crypto/tls"
log "github.com/sirupsen/logrus"
"errors"
"gogrok.ccatss.dev/common"
"golang.org/x/crypto/ssh"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
)
var (
ErrUnsupportedBackend = errors.New("unsupported backend type")
)
// New creates a new client with the specified server and backend
func New(server, backend string, signer ssh.Signer) *Client {
func New(server string, signer ssh.Signer) *Client {
return &Client{
server: server,
signer: signer,
}
}
// Client is a remote tunnel client
type Client struct {
conn *ssh.Client
server string
signer ssh.Signer
}
// Open opens a connection to the server
// Note: This is called automatically on client operations.
func (c *Client) Open() error {
if c.conn != nil {
return nil
}
config := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
ssh.PublicKeys(c.signer),
},
}
conn, err := ssh.Dial("tcp", c.server, config)
if err != nil {
return err
}
c.conn = conn
return nil
}
func (c *Client) Close() error {
if c.conn == nil {
return nil
}
return c.conn.Close()
}
// Start connects to the server over TCP and starts the tunnel
func (c *Client) Start(backend, requestedHost string) (string, error) {
if err := c.Open(); err != nil {
return "", err
}
backendUrl, err := url.Parse(backend)
if err != nil {
return nil
return "", nil
}
if backendUrl.Scheme == "" {
backendUrl.Scheme = "http"
}
host, port, _ := net.SplitHostPort(backendUrl.Host)
if backendUrl.Scheme == "http" || backendUrl.Scheme == "https" {
proxy := NewHTTPProxy(backendUrl)
if port == "" {
port = "80"
return c.StartHTTPForwarding(proxy, requestedHost)
}
dialHost := net.JoinHostPort(host, port)
return "", ErrUnsupportedBackend
}
return &Client{
config: config,
server: server,
backendUrl: backendUrl,
dialHost: dialHost,
tlsConfig: &tls.Config{ServerName: host, InsecureSkipVerify: true},
// Register a host as reserved with the server
func (c *Client) Register(host string) error {
if err := c.Open(); err != nil {
return err
}
}
// Client is a remote tunnel client
type Client struct {
config *ssh.ClientConfig
tlsConfig *tls.Config
payload := ssh.Marshal(common.HostRegisterRequest{
Host: host,
})
server string
backendUrl *url.URL
dialHost string
}
// Start connects to the server over TCP and starts the tunnel
func (c *Client) Start() error {
log.WithFields(log.Fields{
"server": c.server,
}).Debug("Dialing server")
conn, err := ssh.Dial("tcp", c.server, c.config)
success, replyData, err := c.conn.SendRequest(common.HttpRegisterHost, true, payload)
if err != nil {
return err
}
if c.backendUrl.Scheme == "http" || c.backendUrl.Scheme == "https" {
go c.handleHttpRequests(conn)
if !success {
return errors.New(string(replyData))
}
var res common.HostRegisterSuccess
if err = ssh.Unmarshal(replyData, &res); err != nil {
return err
}
return nil
}
func (c *Client) handleHttpRequests(conn *ssh.Client) {
payload := ssh.Marshal(common.RemoteForwardRequest{
RequestedHost: "",
})
log.Debug("Requesting http-forward...")
success, replyData, err := conn.SendRequest("http-forward", true, payload)
if !success || err != nil {
log.WithFields(log.Fields{
"error": err,
"message": string(replyData),
}).Fatalln("Unable to start forwarding")
// Unregister a reserved host with the server
func (c *Client) Unregister(host string) error {
if err := c.Open(); err != nil {
return err
}
log.WithFields(log.Fields{
"success": success,
}).Debug("Received successful response from http-forward request")
payload := ssh.Marshal(common.HostRegisterRequest{
Host: host,
})
success, replyData, err := c.conn.SendRequest(common.HttpUnregisterHost, true, payload)
if err != nil {
return err
}
if !success {
return errors.New(string(replyData))
}
var res common.HostRegisterSuccess
if err = ssh.Unmarshal(replyData, &res); err != nil {
return err
}
return nil
}
// StartHTTPForwarding starts a basic http proxy/forwarding service
func (c *Client) StartHTTPForwarding(proxy *HTTPProxy, requestedHost string) (string, error) {
payload := ssh.Marshal(common.RemoteForwardRequest{
RequestedHost: requestedHost,
})
success, replyData, err := c.conn.SendRequest(common.HttpForward, true, payload)
if err != nil {
return "", err
}
if !success {
return "", errors.New(string(replyData))
}
var response common.RemoteForwardSuccess
if err := ssh.Unmarshal(replyData, &response); err != nil {
log.WithError(err).Fatalln("Unable to unmarshal data")
return "", err
}
defer func() {
ch := c.conn.HandleChannelOpen(common.ForwardedHTTPChannelType)
go func() {
proxy.acceptConnections(ch)
payload := ssh.Marshal(common.RemoteForwardCancelRequest{Host: response.Host})
conn.SendRequest("cancel-http-forward", false, payload)
c.conn.SendRequest(common.CancelHttpForward, false, payload)
}()
log.WithField("host", response.Host).Info("Bound host")
ch := conn.HandleChannelOpen(common.ForwardedHTTPChannelType)
for {
newCh := <-ch
if newCh == nil {
break
}
ch, r, err := newCh.Accept()
if err != nil {
log.WithError(err).Warning("Error accepting channel")
continue
}
go ssh.DiscardRequests(r)
go c.proxyRequest(ch)
}
}
// proxyRequest handles a request from the ssh channel and forwards it to the local http server
func (c *Client) proxyRequest(rw io.ReadWriteCloser) {
tcpConn, err := net.Dial("tcp", c.dialHost)
if err != nil {
rw.Close()
return
}
if c.backendUrl.Scheme == "https" || c.backendUrl.Scheme == "wss" {
// Wrap with TLS
tcpConn = tls.Client(tcpConn, c.tlsConfig)
}
defer rw.Close()
defer tcpConn.Close()
bufferedCh := bufio.NewReader(rw)
tp := textproto.NewReader(bufferedCh)
var s string
if s, err = tp.ReadLine(); err != nil {
return
}
// Write the first response line as-is
tcpConn.Write([]byte(s + "\r\n"))
// Read headers and proxy each to the output
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
// Modify and return our headers
headers := http.Header(mimeHeader)
headers.Set("Host", c.backendUrl.Host)
headers.Write(tcpConn)
// End headers
tcpConn.Write([]byte("\r\n"))
contentLength, err := strconv.ParseInt(headers.Get("Content-Length"), 10, 64)
if err == nil && contentLength > 0 {
// Copy request to the tcpConn
_, err := io.Copy(tcpConn, io.LimitReader(bufferedCh, int64(contentLength)))
if err != nil {
log.WithError(err).Warning("Connection error on body read, closing")
return
}
}
// Copy the response back to the tunnel server
io.Copy(rw, tcpConn)
return response.Host, nil
}

124
client/http.go Normal file
View File

@ -0,0 +1,124 @@
package client
import (
"bufio"
"crypto/tls"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
)
type Proxy interface {
Handle(rw io.ReadWriteCloser)
}
// HTTPProxy is a proxy implementation to pass http requests.
type HTTPProxy struct {
dialHost string
backendUrl *url.URL
tlsConfig *tls.Config
}
// NewHTTPProxy parses the backend url and creates a new proxy for it
func NewHTTPProxy(backendUrl *url.URL) *HTTPProxy {
host, port, _ := net.SplitHostPort(backendUrl.Host)
if port == "" {
port = "80"
}
tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: true}
dialHost := net.JoinHostPort(host, port)
return &HTTPProxy{
dialHost: dialHost,
backendUrl: backendUrl,
tlsConfig: tlsConfig,
}
}
func (p *HTTPProxy) acceptConnections(ch <-chan ssh.NewChannel) {
for {
newCh := <-ch
if newCh == nil {
break
}
ch, r, err := newCh.Accept()
if err != nil {
log.WithError(err).Warning("Error accepting channel")
continue
}
go ssh.DiscardRequests(r)
go p.Handle(ch)
}
}
// Handle a request from the ssh channel and forwards it to the local http server
func (p *HTTPProxy) Handle(rw io.ReadWriteCloser) {
tcpConn, err := net.Dial("tcp", p.dialHost)
if err != nil {
rw.Close()
return
}
if p.backendUrl.Scheme == "https" || p.backendUrl.Scheme == "wss" {
// Wrap with TLS
tcpConn = tls.Client(tcpConn, p.tlsConfig)
}
defer rw.Close()
defer tcpConn.Close()
bufferedCh := bufio.NewReader(rw)
tp := textproto.NewReader(bufferedCh)
var s string
if s, err = tp.ReadLine(); err != nil {
return
}
// Write the first response line as-is
tcpConn.Write([]byte(s + "\r\n"))
// Read headers and proxy each to the output
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
// Modify and return our headers
headers := http.Header(mimeHeader)
headers.Set("Host", p.backendUrl.Host)
headers.Write(tcpConn)
// End headers
tcpConn.Write([]byte("\r\n"))
contentLength, err := strconv.ParseInt(headers.Get("Content-Length"), 10, 64)
if err == nil && contentLength > 0 {
// Copy request to the tcpConn
_, err := io.Copy(tcpConn, io.LimitReader(bufferedCh, int64(contentLength)))
if err != nil {
log.WithError(err).Warning("Connection error on body read, closing")
return
}
}
// Copy the response back to the tunnel server
io.Copy(rw, tcpConn)
}

View File

@ -2,6 +2,7 @@ package cmd
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/afero"
"github.com/spf13/cobra"
@ -21,11 +22,40 @@ var (
func init() {
clientCmd.Flags().String("server", "localhost:2222", "Gogrok Server Address")
clientCmd.Flags().String("key", "", "Client key file")
clientCmd.Flags().String("passphrase", "", "Client key passphrase")
clientCmd.Flags().String("host", "", "Requested host to register")
rootCmd.AddCommand(clientCmd)
}
func clientPreRun(cmd *cobra.Command, args []string) {
viper.SetDefault("gogrok.server", "localhost:2222")
setValueFromFlag(cmd.Flags(), "server", "gogrok.server", false)
setValueFromFlag(cmd.Flags(), "key", "gogrok.clientKey", false)
setValueFromFlag(cmd.Flags(), "passphrase", "gogrok.clientKeyPassphrase", false)
}
func loadClientKey() ssh.Signer {
clientKey := viper.GetString("gogrok.clientKey")
if clientKey == "" {
clientKey = path.Join(viper.GetString("gogrok.storageDir"), "client.key")
}
key, err := common.LoadOrGenerateKey(afero.NewOsFs(), clientKey, viper.GetString("gogrok.clientKeyPassphrase"))
if err != nil {
log.WithError(err).Fatalln("Unable to load client key")
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
log.WithError(err).Fatalln("Unable to create signer from client key")
}
return signer
}
var clientCmd = &cobra.Command{
Use: "client",
Short: "Start the gogrok client",
@ -35,39 +65,25 @@ var clientCmd = &cobra.Command{
}
return nil
},
PreRun: clientPreRun,
Run: func(cmd *cobra.Command, args []string) {
viper.SetDefault("gogrok.server", "localhost:2222")
setValueFromFlag(cmd.Flags(), "host", "gogrok.clientHost", false)
setValueFromFlag(cmd.Flags(), "server", "gogrok.server", false)
setValueFromFlag(cmd.Flags(), "key", "gogrok.clientKey", false)
setValueFromFlag(cmd.Flags(), "passphrase", "gogrok.clientKeyPassphrase", false)
c := client.New(viper.GetString("gogrok.server"), loadClientKey())
clientKey := viper.GetString("gogrok.clientKey")
if clientKey == "" {
clientKey = path.Join(viper.GetString("gogrok.storageDir"), "client.key")
}
key, err := common.LoadOrGenerateKey(afero.NewOsFs(), clientKey, viper.GetString("gogrok.clientKeyPassphrase"))
host, err := c.Start(args[0], viper.GetString("gogrok.clientHost"))
if err != nil {
log.WithError(err).Fatalln("Unable to load client key")
fmt.Fprintln(os.Stderr, "Unable to start server: "+err.Error())
os.Exit(1)
}
signer, err := ssh.NewSignerFromKey(key)
cmd.Println("Successfully bound host and started proxy")
log.WithField("host", host).Info("Successfully bound host and started proxy")
if err != nil {
log.WithError(err).Fatalln("Unable to create signer from client key")
}
// Default command is client
c := client.New(viper.GetString("gogrok.server"), args[0], signer)
err = c.Start()
if err != nil {
log.WithError(err).Fatalln("Unable to connect to server")
}
cmd.Println("Endpoints:")
cmd.Printf("http://%s\n", host)
cmd.Printf("https://%s\n", host)
sig := make(chan os.Signal)

44
cmd/register.go Normal file
View File

@ -0,0 +1,44 @@
package cmd
import (
"errors"
"fmt"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"gogrok.ccatss.dev/client"
"os"
)
var (
ErrNoHost = errors.New("no host specified")
)
func init() {
registerCmd.Flags().String("server", "localhost:2222", "Gogrok Server Address")
rootCmd.AddCommand(registerCmd)
}
var registerCmd = &cobra.Command{
Use: "register",
Short: "Register a host with a gogrok server",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return ErrNoHost
}
return nil
},
PreRun: clientPreRun,
Run: func(cmd *cobra.Command, args []string) {
// Default command is client
c := client.New(viper.GetString("gogrok.server"), loadClientKey())
err := c.Register(args[0])
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to register: "+err.Error())
os.Exit(1)
}
cmd.Println("Successfully registered host " + args[0])
},
}

View File

@ -30,6 +30,8 @@ func Execute() {
func init() {
cobra.OnInitialize(initConfig)
rootCmd.Flags().String("key", "", "Server/Client key file")
rootCmd.Flags().String("passphrase", "", "Server/Client key passphrase")
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.gogrok.yaml)")
rootCmd.PersistentFlags().Bool("viper", true, "use Viper for configuration")
viper.BindPFlag("useViper", rootCmd.PersistentFlags().Lookup("viper"))

View File

@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
"gogrok.ccatss.dev/common"
"gogrok.ccatss.dev/server"
"gogrok.ccatss.dev/server/store"
gossh "golang.org/x/crypto/ssh"
"math/rand"
"path"
@ -22,6 +23,7 @@ func init() {
serveCmd.Flags().String("http", ":8080", "HTTP Server Bind Address")
serveCmd.Flags().String("keys", "", "Authorized keys file to control access")
serveCmd.Flags().StringSlice("domains", nil, "Domains to use for ")
serveCmd.Flags().String("store", "", "Store file to use when allowing host registration")
rootCmd.AddCommand(serveCmd)
}
@ -38,6 +40,7 @@ var serveCmd = &cobra.Command{
setValueFromFlag(cmd.Flags(), "http", "gogrok.httpAddress", false)
setValueFromFlag(cmd.Flags(), "keys", "gogrok.authorizedKeyFile", false)
setValueFromFlag(cmd.Flags(), "domains", "gogrok.domains", false)
setValueFromFlag(cmd.Flags(), "store", "gogrok.store", false)
key, err := common.LoadOrGenerateKey(baseFs, path.Join(viper.GetString("gogrok.storageDir"), "server.key"), "")
@ -70,16 +73,56 @@ var serveCmd = &cobra.Command{
}
opts = append(opts, server.WithAuthorizedKeys(authorizedKeys))
log.WithField("keyFile", authorizedKeysFile).Info("Authorizing public keys on connection")
}
handlerOpts := make([]server.HandlerOption, 0)
if domains := viper.GetStringSlice("gogrok.domains"); domains != nil {
generator := func() string {
return server.RandomAnimal() + "." + domains[rand.Intn(len(domains))]
}
handler := server.NewHttpHandler(server.WithProvider(generator))
validator := server.ValidateMulti(server.DenyPrefixIn(server.Animals()), server.SuffixIn(domains))
opts = append(opts, server.WithForwardHandler(handler))
handlerOpts = append(handlerOpts, server.WithProvider(generator), server.WithValidator(validator))
log.WithField("domains", domains).Info("Registered domains for random use")
}
if storeUri := viper.GetString("gogrok.store"); storeUri != "" {
driver := "bolt"
if idx := strings.Index(storeUri, "://"); idx != -1 {
driver = storeUri[0:idx]
storeUri = storeUri[idx+3:]
}
var s store.Store
switch driver {
case "bolt":
fallthrough
default:
log.WithField("path", storeUri).Info("Using bolt store")
s, err = store.NewBoltStore(storeUri)
}
if err != nil {
log.WithError(err).Fatalln("Unable to create data store")
return
}
log.WithField("driver", driver).Info("Host store set, registration enabled")
handlerOpts = append(handlerOpts, server.WithStore(s))
}
if len(handlerOpts) > 0 {
handler := server.NewHttpHandler(handlerOpts...)
opts = append(opts, server.WithForwardHandler("http", handler))
}
s, err := server.New(opts...)

39
cmd/unregister.go Normal file
View File

@ -0,0 +1,39 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"gogrok.ccatss.dev/client"
"os"
)
func init() {
unregisterCmd.Flags().String("server", "localhost:2222", "Gogrok Server Address")
rootCmd.AddCommand(unregisterCmd)
}
var unregisterCmd = &cobra.Command{
Use: "unregister",
Short: "Unregister a host with a gogrok server",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return ErrNoHost
}
return nil
},
PreRun: clientPreRun,
Run: func(cmd *cobra.Command, args []string) {
// Default command is client
c := client.New(viper.GetString("gogrok.server"), loadClientKey())
err := c.Unregister(args[0])
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to unregister: "+err.Error())
os.Exit(1)
}
cmd.Println("Successfully unregistered host " + args[0])
},
}

8
common/constants.go Normal file
View File

@ -0,0 +1,8 @@
package common
const (
HttpForward = "http-forward"
CancelHttpForward = "cancel-http-forward"
HttpRegisterHost = "http-register-host"
HttpUnregisterHost = "http-unregister-host"
)

View File

@ -7,6 +7,7 @@ const (
// RemoteForwardRequest represents a forwarding request
type RemoteForwardRequest struct {
RequestedHost string
Force bool
}
// RemoteForwardSuccess returns when a successful request is processed
@ -25,3 +26,13 @@ type RemoteForwardChannelData struct {
Host string
ClientIP string
}
// HostRegisterRequest is used when registering a host
type HostRegisterRequest struct {
Host string
}
// HostRegisterSuccess is the response from the server for a Claim request
type HostRegisterSuccess struct {
Host string
}

1
go.mod
View File

@ -3,6 +3,7 @@ module gogrok.ccatss.dev
go 1.17
require (
github.com/boltdb/bolt v1.3.1
github.com/gliderlabs/ssh v0.3.3
github.com/pkg/errors v0.9.1
github.com/sirupsen/logrus v1.8.1

2
go.sum
View File

@ -66,6 +66,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4=
github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=

View File

@ -38,7 +38,58 @@ func RandomAnimal() string {
return animals[rand.Intn(len(animals))]
}
// Animals returns the animal slice
func Animals() []string {
return animals
}
// DenyAll is a HostValidator to deny all custom requests
func DenyAll(host string) bool {
return false
}
// DenyPrefixIn denies hosts in slice s
func DenyPrefixIn(s []string) HostValidator {
return func(host string) bool {
idx := strings.Index(host, ".")
if idx != -1 {
host = host[0:idx]
}
for _, val := range s {
if val == host {
return false
}
}
return true
}
}
// SuffixIn checks hosts for a suffix value
// Note: Suffix is automatically prepended with .
func SuffixIn(s []string) HostValidator {
return func(host string) bool {
for _, val := range s {
if strings.HasSuffix(host, "."+val) {
return true
}
}
return false
}
}
// ValidateMulti checks all specified validators before denying hosts
func ValidateMulti(validators ...HostValidator) HostValidator {
return func(host string) bool {
for _, validator := range validators {
if !validator(host) {
return false
}
}
return true
}
}

View File

@ -2,16 +2,20 @@ package server
import (
"bufio"
"bytes"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common"
"gogrok.ccatss.dev/server/store"
gossh "golang.org/x/crypto/ssh"
"io"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"time"
)
// HostProvider is a func to provide a host + subdomain
@ -24,12 +28,19 @@ type HostValidator func(host string) bool
// adding the HandleSSHRequest callback to the server's RequestHandlers under
// tcpip-forward and cancel-tcpip-forward.
type ForwardedHTTPHandler struct {
forwards map[string]*gossh.ServerConn
forwards map[string]*Forward
provider HostProvider
validator HostValidator
store store.Store
sync.RWMutex
}
// Forward contains the forwarded connection
type Forward struct {
Conn *gossh.ServerConn
Key ssh.PublicKey
}
// HandlerOption represents a func used to assign options to a ForwardedHTTPHandler
type HandlerOption func(h *ForwardedHTTPHandler)
@ -47,9 +58,16 @@ func WithValidator(validator HostValidator) HandlerOption {
}
}
// WithStore assigns a host store to use for storage
func WithStore(s store.Store) HandlerOption {
return func(h *ForwardedHTTPHandler) {
h.store = s
}
}
func NewHttpHandler(opts ...HandlerOption) ForwardHandler {
h := &ForwardedHTTPHandler{
forwards: make(map[string]*gossh.ServerConn),
forwards: make(map[string]*Forward),
provider: RandomAnimal,
validator: DenyAll,
}
@ -61,15 +79,26 @@ func NewHttpHandler(opts ...HandlerOption) ForwardHandler {
return h
}
// RequestTypes lets the server know which request types this handler can use
func (h *ForwardedHTTPHandler) RequestTypes() []string {
return []string{
common.HttpForward,
common.CancelHttpForward,
common.HttpRegisterHost,
common.HttpUnregisterHost,
}
}
// ServeHTTP mocks an http server endpoint that uses Request.Host to forward requests
func (h *ForwardedHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.RLock()
sshConn, ok := h.forwards[r.Host]
fw, ok := h.forwards[r.Host]
h.RUnlock()
if !ok {
log.Warning("Unknown host ", r.Host)
http.Error(w, "not found", http.StatusNotFound)
log.Println("Valid hosts:", h.forwards)
return
}
@ -78,7 +107,7 @@ func (h *ForwardedHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
ClientIP: r.RemoteAddr,
})
ch, reqs, err := sshConn.OpenChannel(common.ForwardedHTTPChannelType, payload)
ch, reqs, err := fw.Conn.OpenChannel(common.ForwardedHTTPChannelType, payload)
if err != nil {
log.WithError(err).Warning("Unable to open ssh connection channel")
@ -162,84 +191,230 @@ func parseResponseLine(line string) (httpVersion, responseCode, responseText str
return line[:s1], line[s1+1 : s2], line[s2+1:], true
}
func (h *ForwardedHTTPHandler) checkHostOwnership(host, owner string) bool {
hostModel, err := h.store.Get(host)
if err != nil {
return false
}
return hostModel.Owner == owner
}
// HandleSSHRequest handles incoming ssh requests.
func (h *ForwardedHTTPHandler) HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
log.WithField("type", req.Type).Info("Handling request")
switch req.Type {
case "http-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-forward")
return false, []byte{}
}
case common.HttpForward:
return h.handleForwardRequest(ctx, conn, req)
case common.CancelHttpForward:
return h.handleCancelRequest(ctx, req)
case common.HttpRegisterHost:
return h.handleRegisterRequest(ctx, conn, req)
case common.HttpUnregisterHost:
return h.handleUnregisterRequest(ctx, req)
default:
return false, nil
}
}
host := reqPayload.RequestedHost
func (h *ForwardedHTTPHandler) handleForwardRequest(ctx ssh.Context, conn *gossh.ServerConn, req *gossh.Request) (bool, []byte) {
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for http-forward")
return false, []byte{}
}
if host != "" && h.validator != nil && !h.validator(host) {
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
keyStr := string(bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
host := strings.ToLower(reqPayload.RequestedHost)
if host != "" {
if h.validator != nil && !h.validator(host) {
return false, []byte("invalid host " + host)
}
// Validate host
if host == "" {
host = h.provider()
hostModel, err := h.store.Get(host)
if hostModel == nil || err != nil {
return false, []byte("host not registered")
}
if hostModel.Owner != keyStr {
return false, []byte("host claimed and not owned by current key")
}
h.RLock()
current, exists := h.forwards[host]
h.RUnlock()
if exists && !reqPayload.Force {
return false, []byte("host already in use and force not set")
}
if exists {
// Force old connection to close
current.Conn.Close()
}
hostModel.LastUse = time.Now()
// Save model last use time
h.store.Add(*hostModel)
} else {
host = h.provider()
}
// Validate host
if host == "" {
h.RLock()
for {
host = h.provider()
_, exists := h.forwards[host]
if !exists {
break
}
host = h.provider()
}
h.RUnlock()
h.Lock()
h.forwards[host] = conn
h.Unlock()
log.WithField("host", host).Info("Registered host")
go func() {
<-ctx.Done()
log.WithField("host", host).Info("Removed host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
}()
return true, gossh.Marshal(&common.RemoteForwardSuccess{
Host: host,
})
case "cancel-http-forward":
var reqPayload common.RemoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
return false, []byte{}
}
h.Lock()
delete(h.forwards, reqPayload.Host)
h.Unlock()
return true, nil
case "http-register-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
log.Println("Key:", ctx.Value("publicKey"))
// Claimed forward via SSH Public Key
return true, nil
default:
return false, nil
}
log.WithField("host", host).Info("Registering host")
h.Lock()
h.forwards[host] = &Forward{
Conn: conn,
Key: pubKey,
}
h.Unlock()
log.WithField("host", host).Info("Registered host")
go func() {
<-ctx.Done()
log.WithField("host", host).Info("Removed host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
}()
return true, gossh.Marshal(&common.RemoteForwardSuccess{
Host: host,
})
}
func (h *ForwardedHTTPHandler) handleCancelRequest(ctx ssh.Context, req *gossh.Request) (bool, []byte) {
var reqPayload common.RemoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for cancel-http-forward")
return false, []byte{}
}
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
host := strings.ToLower(reqPayload.Host)
h.RLock()
fw, exists := h.forwards[host]
h.RUnlock()
if !exists {
return false, []byte("host not found")
}
if !bytes.Equal(pubKey.Marshal(), fw.Key.Marshal()) {
return false, []byte("host not owned by key")
}
log.WithField("host", host).Info("Unregistering host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
return true, nil
}
func (h *ForwardedHTTPHandler) handleRegisterRequest(ctx ssh.Context, conn *gossh.ServerConn, req *gossh.Request) (bool, []byte) {
var reqPayload common.HostRegisterRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
keyStr := string(bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
host := strings.ToLower(reqPayload.Host)
if host == "" || h.validator != nil && !h.validator(host) {
log.WithField("host", host).Warning("Host failed validation")
return false, []byte("invalid host " + host)
}
if h.store.Has(host) {
log.WithField("host", host).Warning("Host is already taken")
return false, []byte("host is already taken")
}
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
log.WithField("ip", ip).WithField("host", host).Info("Registering host")
err := h.store.Add(store.Host{
Host: host,
Owner: keyStr,
IP: ip,
Created: time.Now(),
LastUse: time.Now(),
})
if err != nil {
return false, []byte(err.Error())
}
return true, gossh.Marshal(common.HostRegisterSuccess{
Host: host,
})
}
func (h *ForwardedHTTPHandler) handleUnregisterRequest(ctx ssh.Context, req *gossh.Request) (bool, []byte) {
var reqPayload common.HostRegisterRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
keyStr := string(bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
host := strings.ToLower(reqPayload.Host)
if host == "" || h.validator != nil && !h.validator(host) {
return false, []byte("invalid host " + host)
}
hostModel, err := h.store.Get(host)
if hostModel == nil || err != nil {
return false, []byte(err.Error())
}
if hostModel.Owner != keyStr {
return false, []byte("this host is not owned by you")
}
h.store.Remove(host)
return true, gossh.Marshal(common.HostRegisterSuccess{
Host: host,
})
}

View File

@ -15,12 +15,13 @@ import (
// ForwardHandler is an interface defining the handler type for forwarding
type ForwardHandler interface {
HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte)
RequestTypes() []string
}
// Server is a struct containing our ssh server, forwarding handler, and other attributes
type Server struct {
sshServer *ssh.Server
forwardHandler ForwardHandler
sshServer *ssh.Server
forwardHandlers map[string]ForwardHandler
sshBindAddress string
hostSigners []ssh.Signer
@ -33,9 +34,9 @@ type Option func(s *Server)
// WithForwardHandler lets custom forwarding handlers be registered.
// This will support multiple handlers eventually, for HTTP, TCP, etc.
func WithForwardHandler(f ForwardHandler) Option {
func WithForwardHandler(protocol string, handler ForwardHandler) Option {
return func(s *Server) {
s.forwardHandler = f
s.forwardHandlers[protocol] = handler
}
}
@ -76,14 +77,18 @@ func WithAuthorizedKeys(authorizedKeys []string) Option {
// New creates a new Server instance with a range of options.
func New(options ...Option) (*Server, error) {
s := &Server{}
s := &Server{
forwardHandlers: make(map[string]ForwardHandler),
}
for _, opt := range options {
opt(s)
}
if s.forwardHandler == nil {
s.forwardHandler = NewHttpHandler(WithProvider(RandomAnimal))
if len(s.forwardHandlers) == 0 {
httpHandler := NewHttpHandler(WithProvider(RandomAnimal))
s.forwardHandlers["http"] = httpHandler
}
if s.hostSigners == nil || len(s.hostSigners) < 1 {
@ -106,9 +111,10 @@ func New(options ...Option) (*Server, error) {
// TODO: Add TCP handler using the same idea, potentially support multiple forwardHandlers
if _, ok := s.forwardHandler.(http.Handler); ok {
requestHandlers["http-forward"] = s.forwardHandler.HandleSSHRequest
requestHandlers["cancel-http-forward"] = s.forwardHandler.HandleSSHRequest
for _, handler := range s.forwardHandlers {
for _, requestType := range handler.RequestTypes() {
requestHandlers[requestType] = handler.HandleSSHRequest
}
}
s.sshServer = &ssh.Server{
@ -170,22 +176,36 @@ func (s *Server) Start() error {
// StartHTTP is a convenience method to start a basic http server.
// This uses s.forwardHandler if http.Handler is implemented to serve requests.
func (s *Server) StartHTTP(bind string) error {
if h, ok := s.forwardHandler.(http.Handler); ok {
httpServer := &http.Server{
Addr: bind,
Handler: h,
}
httpHandler := s.forwardHandlers["http"]
return httpServer.ListenAndServe()
if httpHandler == nil {
return errors.New("http handler not registered")
}
return errors.New("forwarding handler doesn't support http")
if _, ok := httpHandler.(http.Handler); !ok {
return errors.New("http handler cannot handle http requests")
}
httpServer := &http.Server{
Addr: bind,
Handler: httpHandler.(http.Handler),
}
return httpServer.ListenAndServe()
}
// ServeHTTP is a passthrough to forwardHandler's ServeHTTP
// This can be used to use your own http server implementation, or for TLS/etc
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h, ok := s.forwardHandler.(http.Handler); ok {
h.ServeHTTP(w, r)
httpHandler := s.forwardHandlers["http"]
if httpHandler == nil {
return
}
if _, ok := httpHandler.(http.Handler); !ok {
return
}
httpHandler.(http.Handler).ServeHTTP(w, r)
}

96
server/store/boltdb.go Normal file
View File

@ -0,0 +1,96 @@
package store
import (
"encoding/json"
"github.com/boltdb/bolt"
)
type BoltStore struct {
db *bolt.DB
}
// NewBoltStore creates a new boltdb backed Store instance
func NewBoltStore(path string) (Store, error) {
db, err := bolt.Open(path, 0644, bolt.DefaultOptions)
if err != nil {
return nil, err
}
err = db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte("hosts"))
return err
})
if err != nil {
return nil, err
}
return &BoltStore{
db: db,
}, nil
}
// Has checks if the host exists in the hosts bucket
func (b *BoltStore) Has(host string) bool {
var exists bool
b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
if b.Get([]byte(host)) != nil {
exists = true
}
return nil
})
return exists
}
// Get retrieves and deserializes a host from the hosts bucket
func (b *BoltStore) Get(key string) (*Host, error) {
var host Host
err := b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
data := b.Get([]byte(key))
if data == nil {
return ErrNoHost
}
return json.Unmarshal(data, &host)
})
if err != nil {
return nil, err
}
return &host, nil
}
// Add updates the hosts bucket and puts a json-serialized version of Host
func (b *BoltStore) Add(host Host) error {
data, err := json.Marshal(host)
if err != nil {
return err
}
return b.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
return b.Put([]byte(host.Host), data)
})
}
// Remove updates the hosts bucket and deletes the key
func (b *BoltStore) Remove(key string) error {
return b.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
return b.Delete([]byte(key))
})
}

27
server/store/store.go Normal file
View File

@ -0,0 +1,27 @@
package store
import (
"errors"
"time"
)
var (
ErrNoHost = errors.New("host not found")
)
// Store represents an interface to retrieve and store hosts
type Store interface {
Has(key string) bool
Get(key string) (*Host, error)
Add(host Host) error
Remove(key string) error
}
// Host represents a claimed host
type Host struct {
Host string `json:"host"`
Owner string `json:"owner"`
IP string `json:"ip"`
Created time.Time `json:"created"`
LastUse time.Time `json:"lastUse"`
}