Compare commits

...

5 Commits

Author SHA1 Message Date
Tyler
1bb0440b9d Make key and passphrase global to all commands as intended
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing
2022-01-31 19:55:23 -05:00
Tyler
56df578d6e Minor change to handler type checking to avoid second cast
All checks were successful
continuous-integration/drone/push Build is passing
2021-12-23 00:15:54 -05:00
Tyler
f9f3cec5a2 Somewhat big rewrite, adds host registration and other improvements to structure
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing
2021-12-22 23:32:59 -05:00
Tyler
cc2988f6a5 Fix entrypoint, update license and readme
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing
2021-12-20 22:15:36 -05:00
Tyler
79e9d8c5f5 Enable env variables, move client configuration to match server config
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing
2021-12-20 21:48:33 -05:00
19 changed files with 1065 additions and 299 deletions

View File

@ -1,7 +1,4 @@
ISC License: Copyright (c) 2021 Tyler Stuyfzand <admin@meow.tf>
Copyright (c) 2004-2010 by Internet Systems Consortium, Inc. ("ISC")
Copyright (c) 1995-2003 by Internet Software Consortium
Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies.

101
README.md
View File

@ -1,3 +1,102 @@
# gogrok # gogrok
A simple, easy to use ngrok alternative (self hosted!) [![Build Status](https://drone.meow.tf/api/badges/gogrok/gogrok/status.svg)](https://drone.meow.tf/gogrok/gogrok)
A simple, easy to use ngrok alternative (self hosted!)
The server and client can also be easily embedded into your applications, see the 'server' and 'client' directories.
Features
--------
* HTTP and HTTPS handling
* Public key authentication
* Authorized key whitelists
* Registration of reserved hosts
* Forwarding system that allows easy additions of other protocols (Coming soon: TCP and TCP+TLS, with host-based TLS support)
Example usage
-------------
By default, the first time you run gogrok it'll generate both a server and a client certificate. These will be stored in ~/.gogrok, but can be overridden with the `gogrok.storageDir` option (or GOGROK_STORAGE_DIR environment variable)
Server:
`gogrok serve`
Client:
`gogrok --server=localhost:2222 http://localhost:3000`
Server
------
```
$ ./gogrok serve --help
Start the gogrok server
Usage:
gogrok serve [flags]
Flags:
--bind string SSH Server Bind Address (default ":2222")
--domains strings Domains to use for
-h, --help help for serve
--http string HTTP Server Bind Address (default ":8080")
--keys string Authorized keys file to control access
Global Flags:
--config string config file (default is $HOME/.gogrok.yaml)
--viper use Viper for configuration (default true)
```
Client
------
```
$ ./gogrok client --help
Start the gogrok client
Usage:
gogrok client [flags]
Flags:
-h, --help help for client
--key string Client key file
--passphrase string Client key passphrase
--server string Gogrok Server Address (default "localhost:2222")
Global Flags:
--config string config file (default is $HOME/.gogrok.yaml)
--viper use Viper for configuration (default true)
```
Docker
------
Example docker compose file. Caddy is suggested as a frontend using dns via cloudflare and DNS-01 for wildcards.
```yaml
version: '3.7'
services:
gogrok:
image: tystuyfzand/gogrok:latest
ports:
- 2222:2222
- 8080:8080
volumes:
- /docker/gogrok/config:/config
environment:
- GOGROK_DOMAINS=gogrok.ccatss.dev
- GOGROK_AUTHORIZED_KEY_FILE=/config/authorized_keys
```
Host Registration
-----------------
Gogrok lets you register your own custom hosts that are attached to your public key.
On the server, make sure to run the server with the flag `--store=PATH_TO_DB.db`
Use `gogrok register` and `gogrok unregister` to manage registered hosts to your client key.

View File

@ -1,197 +1,181 @@
package client package client
import ( import (
"bufio" "errors"
"crypto/tls"
log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common" "gogrok.ccatss.dev/common"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"io"
"net" "net"
"net/http"
"net/textproto"
"net/url" "net/url"
"strconv" )
var (
ErrUnsupportedBackend = errors.New("unsupported backend type")
) )
// New creates a new client with the specified server and backend // New creates a new client with the specified server and backend
func New(server, backend string, signer ssh.Signer) *Client { func New(server string, signer ssh.Signer) *Client {
return &Client{
server: server,
signer: signer,
}
}
// Client is a remote tunnel client
type Client struct {
conn *ssh.Client
server string
signer ssh.Signer
}
// Open opens a connection to the server
// Note: This is called automatically on client operations.
func (c *Client) Open() error {
if c.conn != nil {
return nil
}
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil return nil
}, },
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer), ssh.PublicKeys(c.signer),
}, },
} }
conn, err := ssh.Dial("tcp", c.server, config)
if err != nil {
return err
}
c.conn = conn
return nil
}
func (c *Client) Close() error {
if c.conn == nil {
return nil
}
return c.conn.Close()
}
// Start connects to the server over TCP and starts the tunnel
func (c *Client) Start(backend, requestedHost string) (string, error) {
if err := c.Open(); err != nil {
return "", err
}
backendUrl, err := url.Parse(backend) backendUrl, err := url.Parse(backend)
if err != nil { if err != nil {
return nil return "", nil
} }
if backendUrl.Scheme == "" { if backendUrl.Scheme == "" {
backendUrl.Scheme = "http" backendUrl.Scheme = "http"
} }
host, port, _ := net.SplitHostPort(backendUrl.Host) if backendUrl.Scheme == "http" || backendUrl.Scheme == "https" {
proxy := NewHTTPProxy(backendUrl)
if port == "" { return c.StartHTTPForwarding(proxy, requestedHost)
port = "80"
} }
dialHost := net.JoinHostPort(host, port) return "", ErrUnsupportedBackend
}
return &Client{ // Register a host as reserved with the server
config: config, func (c *Client) Register(host string) error {
server: server, if err := c.Open(); err != nil {
backendUrl: backendUrl, return err
dialHost: dialHost,
tlsConfig: &tls.Config{ServerName: host, InsecureSkipVerify: true},
} }
}
// Client is a remote tunnel client payload := ssh.Marshal(common.HostRegisterRequest{
type Client struct { Host: host,
config *ssh.ClientConfig })
tlsConfig *tls.Config
server string success, replyData, err := c.conn.SendRequest(common.HttpRegisterHost, true, payload)
backendUrl *url.URL
dialHost string
}
// Start connects to the server over TCP and starts the tunnel
func (c *Client) Start() error {
log.WithFields(log.Fields{
"server": c.server,
}).Debug("Dialing server")
conn, err := ssh.Dial("tcp", c.server, c.config)
if err != nil { if err != nil {
return err return err
} }
if c.backendUrl.Scheme == "http" || c.backendUrl.Scheme == "https" { if !success {
go c.handleHttpRequests(conn) return errors.New(string(replyData))
}
var res common.HostRegisterSuccess
if err = ssh.Unmarshal(replyData, &res); err != nil {
return err
} }
return nil return nil
} }
func (c *Client) handleHttpRequests(conn *ssh.Client) { // Unregister a reserved host with the server
payload := ssh.Marshal(common.RemoteForwardRequest{ func (c *Client) Unregister(host string) error {
RequestedHost: "", if err := c.Open(); err != nil {
}) return err
log.Debug("Requesting http-forward...")
success, replyData, err := conn.SendRequest("http-forward", true, payload)
if !success || err != nil {
log.WithFields(log.Fields{
"error": err,
"message": string(replyData),
}).Fatalln("Unable to start forwarding")
} }
log.WithFields(log.Fields{ payload := ssh.Marshal(common.HostRegisterRequest{
"success": success, Host: host,
}).Debug("Received successful response from http-forward request") })
success, replyData, err := c.conn.SendRequest(common.HttpUnregisterHost, true, payload)
if err != nil {
return err
}
if !success {
return errors.New(string(replyData))
}
var res common.HostRegisterSuccess
if err = ssh.Unmarshal(replyData, &res); err != nil {
return err
}
return nil
}
// StartHTTPForwarding starts a basic http proxy/forwarding service
func (c *Client) StartHTTPForwarding(proxy *HTTPProxy, requestedHost string) (string, error) {
payload := ssh.Marshal(common.RemoteForwardRequest{
RequestedHost: requestedHost,
})
success, replyData, err := c.conn.SendRequest(common.HttpForward, true, payload)
if err != nil {
return "", err
}
if !success {
return "", errors.New(string(replyData))
}
var response common.RemoteForwardSuccess var response common.RemoteForwardSuccess
if err := ssh.Unmarshal(replyData, &response); err != nil { if err := ssh.Unmarshal(replyData, &response); err != nil {
log.WithError(err).Fatalln("Unable to unmarshal data") return "", err
} }
defer func() { ch := c.conn.HandleChannelOpen(common.ForwardedHTTPChannelType)
go func() {
proxy.acceptConnections(ch)
payload := ssh.Marshal(common.RemoteForwardCancelRequest{Host: response.Host}) payload := ssh.Marshal(common.RemoteForwardCancelRequest{Host: response.Host})
conn.SendRequest("cancel-http-forward", false, payload) c.conn.SendRequest(common.CancelHttpForward, false, payload)
}() }()
log.WithField("host", response.Host).Info("Bound host") return response.Host, nil
ch := conn.HandleChannelOpen(common.ForwardedHTTPChannelType)
for {
newCh := <-ch
if newCh == nil {
break
}
ch, r, err := newCh.Accept()
if err != nil {
log.WithError(err).Warning("Error accepting channel")
continue
}
go ssh.DiscardRequests(r)
go c.proxyRequest(ch)
}
}
// proxyRequest handles a request from the ssh channel and forwards it to the local http server
func (c *Client) proxyRequest(rw io.ReadWriteCloser) {
tcpConn, err := net.Dial("tcp", c.dialHost)
if err != nil {
rw.Close()
return
}
if c.backendUrl.Scheme == "https" || c.backendUrl.Scheme == "wss" {
// Wrap with TLS
tcpConn = tls.Client(tcpConn, c.tlsConfig)
}
defer rw.Close()
defer tcpConn.Close()
bufferedCh := bufio.NewReader(rw)
tp := textproto.NewReader(bufferedCh)
var s string
if s, err = tp.ReadLine(); err != nil {
return
}
// Write the first response line as-is
tcpConn.Write([]byte(s + "\r\n"))
// Read headers and proxy each to the output
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
// Modify and return our headers
headers := http.Header(mimeHeader)
headers.Set("Host", c.backendUrl.Host)
headers.Write(tcpConn)
// End headers
tcpConn.Write([]byte("\r\n"))
contentLength, err := strconv.ParseInt(headers.Get("Content-Length"), 10, 64)
if err == nil && contentLength > 0 {
// Copy request to the tcpConn
_, err := io.Copy(tcpConn, io.LimitReader(bufferedCh, int64(contentLength)))
if err != nil {
log.WithError(err).Warning("Connection error on body read, closing")
return
}
}
// Copy the response back to the tunnel server
io.Copy(rw, tcpConn)
} }

124
client/http.go Normal file
View File

@ -0,0 +1,124 @@
package client
import (
"bufio"
"crypto/tls"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
)
type Proxy interface {
Handle(rw io.ReadWriteCloser)
}
// HTTPProxy is a proxy implementation to pass http requests.
type HTTPProxy struct {
dialHost string
backendUrl *url.URL
tlsConfig *tls.Config
}
// NewHTTPProxy parses the backend url and creates a new proxy for it
func NewHTTPProxy(backendUrl *url.URL) *HTTPProxy {
host, port, _ := net.SplitHostPort(backendUrl.Host)
if port == "" {
port = "80"
}
tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: true}
dialHost := net.JoinHostPort(host, port)
return &HTTPProxy{
dialHost: dialHost,
backendUrl: backendUrl,
tlsConfig: tlsConfig,
}
}
func (p *HTTPProxy) acceptConnections(ch <-chan ssh.NewChannel) {
for {
newCh := <-ch
if newCh == nil {
break
}
ch, r, err := newCh.Accept()
if err != nil {
log.WithError(err).Warning("Error accepting channel")
continue
}
go ssh.DiscardRequests(r)
go p.Handle(ch)
}
}
// Handle a request from the ssh channel and forwards it to the local http server
func (p *HTTPProxy) Handle(rw io.ReadWriteCloser) {
tcpConn, err := net.Dial("tcp", p.dialHost)
if err != nil {
rw.Close()
return
}
if p.backendUrl.Scheme == "https" || p.backendUrl.Scheme == "wss" {
// Wrap with TLS
tcpConn = tls.Client(tcpConn, p.tlsConfig)
}
defer rw.Close()
defer tcpConn.Close()
bufferedCh := bufio.NewReader(rw)
tp := textproto.NewReader(bufferedCh)
var s string
if s, err = tp.ReadLine(); err != nil {
return
}
// Write the first response line as-is
tcpConn.Write([]byte(s + "\r\n"))
// Read headers and proxy each to the output
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
// Modify and return our headers
headers := http.Header(mimeHeader)
headers.Set("Host", p.backendUrl.Host)
headers.Write(tcpConn)
// End headers
tcpConn.Write([]byte("\r\n"))
contentLength, err := strconv.ParseInt(headers.Get("Content-Length"), 10, 64)
if err == nil && contentLength > 0 {
// Copy request to the tcpConn
_, err := io.Copy(tcpConn, io.LimitReader(bufferedCh, int64(contentLength)))
if err != nil {
log.WithError(err).Warning("Connection error on body read, closing")
return
}
}
// Copy the response back to the tunnel server
io.Copy(rw, tcpConn)
}

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"errors" "errors"
"fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -16,20 +17,45 @@ import (
) )
var ( var (
serverHost string
clientKey string
clientKeyPassphrase string
ErrNoEndpoint = errors.New("no http(s) endpoint provided") ErrNoEndpoint = errors.New("no http(s) endpoint provided")
) )
func init() { func init() {
clientCmd.PersistentFlags().StringVar(&serverHost, "server", "localhost:2222", "Gogrok Server Address") clientCmd.Flags().String("server", "localhost:2222", "Gogrok Server Address")
clientCmd.PersistentFlags().StringVar(&clientKey, "key", "", "Client key file") clientCmd.Flags().String("host", "", "Requested host to register")
clientCmd.PersistentFlags().StringVar(&clientKeyPassphrase, "passphrase", "", "Client key passphrase")
rootCmd.AddCommand(clientCmd) rootCmd.AddCommand(clientCmd)
} }
func clientPreRun(cmd *cobra.Command, args []string) {
viper.SetDefault("gogrok.server", "localhost:2222")
setValueFromFlag(cmd.Flags(), "server", "gogrok.server", false)
setValueFromFlag(cmd.Flags(), "key", "gogrok.clientKey", false)
setValueFromFlag(cmd.Flags(), "passphrase", "gogrok.clientKeyPassphrase", false)
}
func loadClientKey() ssh.Signer {
clientKey := viper.GetString("gogrok.clientKey")
if clientKey == "" {
clientKey = path.Join(viper.GetString("gogrok.storageDir"), "client.key")
}
key, err := common.LoadOrGenerateKey(afero.NewOsFs(), clientKey, viper.GetString("gogrok.clientKeyPassphrase"))
if err != nil {
log.WithError(err).Fatalln("Unable to load client key")
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
log.WithError(err).Fatalln("Unable to create signer from client key")
}
return signer
}
var clientCmd = &cobra.Command{ var clientCmd = &cobra.Command{
Use: "client", Use: "client",
Short: "Start the gogrok client", Short: "Start the gogrok client",
@ -39,31 +65,25 @@ var clientCmd = &cobra.Command{
} }
return nil return nil
}, },
PreRun: clientPreRun,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
if clientKey == "" { setValueFromFlag(cmd.Flags(), "host", "gogrok.clientHost", false)
clientKey = path.Join(viper.GetString("gogrok.storageDir"), "client.key")
}
key, err := common.LoadOrGenerateKey(afero.NewOsFs(), clientKey, clientKeyPassphrase) c := client.New(viper.GetString("gogrok.server"), loadClientKey())
host, err := c.Start(args[0], viper.GetString("gogrok.clientHost"))
if err != nil { if err != nil {
log.WithError(err).Fatalln("Unable to load client key") fmt.Fprintln(os.Stderr, "Unable to start server: "+err.Error())
os.Exit(1)
} }
signer, err := ssh.NewSignerFromKey(key) cmd.Println("Successfully bound host and started proxy")
log.WithField("host", host).Info("Successfully bound host and started proxy")
if err != nil { cmd.Println("Endpoints:")
log.WithError(err).Fatalln("Unable to create signer from client key") cmd.Printf("http://%s\n", host)
} cmd.Printf("https://%s\n", host)
// Default command is client
c := client.New(serverHost, args[0], signer)
err = c.Start()
if err != nil {
log.WithError(err).Fatalln("Unable to connect to server")
}
sig := make(chan os.Signal) sig := make(chan os.Signal)

44
cmd/register.go Normal file
View File

@ -0,0 +1,44 @@
package cmd
import (
"errors"
"fmt"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"gogrok.ccatss.dev/client"
"os"
)
var (
ErrNoHost = errors.New("no host specified")
)
func init() {
registerCmd.Flags().String("server", "localhost:2222", "Gogrok Server Address")
rootCmd.AddCommand(registerCmd)
}
var registerCmd = &cobra.Command{
Use: "register",
Short: "Register a host with a gogrok server",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return ErrNoHost
}
return nil
},
PreRun: clientPreRun,
Run: func(cmd *cobra.Command, args []string) {
// Default command is client
c := client.New(viper.GetString("gogrok.server"), loadClientKey())
err := c.Register(args[0])
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to register: "+err.Error())
os.Exit(1)
}
cmd.Println("Successfully registered host " + args[0])
},
}

View File

@ -1,67 +1,84 @@
package cmd package cmd
import ( import (
"fmt" "fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"os" "os"
"path" "path"
) )
var ( var (
cfgFile string cfgFile string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "gogrok", Use: "gogrok",
Short: "Gogrok is a very simple and easy to use reverse tunnel server", Short: "Gogrok is a very simple and easy to use reverse tunnel server",
Long: `A simple and easy to use remote tunnel server`, Long: `A simple and easy to use remote tunnel server`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
clientCmd.Run(cmd, args) clientCmd.Run(cmd, args)
}, },
} }
) )
func Execute() { func Execute() {
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) os.Exit(1)
} }
} }
func init() { func init() {
cobra.OnInitialize(initConfig) cobra.OnInitialize(initConfig)
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.gogrok.yaml)") rootCmd.PersistentFlags().String("key", "", "Server/Client key file")
rootCmd.PersistentFlags().Bool("viper", true, "use Viper for configuration") rootCmd.PersistentFlags().String("passphrase", "", "Server/Client key passphrase")
viper.BindPFlag("useViper", rootCmd.PersistentFlags().Lookup("viper")) rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.gogrok.yaml)")
rootCmd.PersistentFlags().Bool("viper", true, "use Viper for configuration")
viper.BindPFlag("useViper", rootCmd.PersistentFlags().Lookup("viper"))
} }
func initConfig() { func initConfig() {
if cfgFile != "" { home, err := os.UserHomeDir()
// Use config file from the flag. cobra.CheckErr(err)
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := os.UserHomeDir()
cobra.CheckErr(err)
viper.SetDefault("gogrok.storageDir", path.Join(home, ".gogrok")) if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
// Search config in home directory with name ".cobra" (without extension). // Search config in home directory with name ".cobra" (without extension).
viper.AddConfigPath(home) viper.AddConfigPath(home)
viper.AddConfigPath(path.Join(home, ".gogrok")) viper.AddConfigPath(path.Join(home, ".gogrok"))
viper.SetConfigType("yaml") viper.SetConfigType("yaml")
viper.SetConfigName(".gogrok") viper.SetConfigName(".gogrok")
} }
viper.AutomaticEnv() viper.SetDefault("gogrok.storageDir", path.Join(home, ".gogrok"))
if err := viper.ReadInConfig(); err == nil { // Generic binds
fmt.Println("Using config file:", viper.ConfigFileUsed()) viper.BindEnv("gogrok.storageDir", "GOGROK_STORAGE_DIR")
}
storageDir := viper.GetString("gogrok.storageDir") // Server binds
viper.BindEnv("gogrok.sshAddress", "GOGROK_SSH_ADDRESS")
viper.BindEnv("gogrok.httpAddress", "GOGROK_HTTP_ADDRESS")
viper.BindEnv("gogrok.authorizedKeyFile", "GOGROK_AUTHORIZED_KEY_FILE")
viper.BindEnv("gogrok.domains", "GOGROK_DOMAINS")
if _, err := os.Stat(storageDir); os.IsNotExist(err) { // Client binds
os.MkdirAll(storageDir, 0755) viper.BindEnv("gogrok.clientKey", "GOGROK_CLIENT_KEY")
} viper.BindEnv("gogrok.clientKeyPassphrase", "GOGROK_CLIENT_KEY_PASS")
} viper.BindEnv("gogrok.server", "GOGROK_SERVER")
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
}
storageDir := viper.GetString("gogrok.storageDir")
if _, err := os.Stat(storageDir); os.IsNotExist(err) {
os.MkdirAll(storageDir, 0755)
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
"gogrok.ccatss.dev/common" "gogrok.ccatss.dev/common"
"gogrok.ccatss.dev/server" "gogrok.ccatss.dev/server"
"gogrok.ccatss.dev/server/store"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"math/rand" "math/rand"
"path" "path"
@ -22,6 +23,7 @@ func init() {
serveCmd.Flags().String("http", ":8080", "HTTP Server Bind Address") serveCmd.Flags().String("http", ":8080", "HTTP Server Bind Address")
serveCmd.Flags().String("keys", "", "Authorized keys file to control access") serveCmd.Flags().String("keys", "", "Authorized keys file to control access")
serveCmd.Flags().StringSlice("domains", nil, "Domains to use for ") serveCmd.Flags().StringSlice("domains", nil, "Domains to use for ")
serveCmd.Flags().String("store", "", "Store file to use when allowing host registration")
rootCmd.AddCommand(serveCmd) rootCmd.AddCommand(serveCmd)
} }
@ -38,6 +40,7 @@ var serveCmd = &cobra.Command{
setValueFromFlag(cmd.Flags(), "http", "gogrok.httpAddress", false) setValueFromFlag(cmd.Flags(), "http", "gogrok.httpAddress", false)
setValueFromFlag(cmd.Flags(), "keys", "gogrok.authorizedKeyFile", false) setValueFromFlag(cmd.Flags(), "keys", "gogrok.authorizedKeyFile", false)
setValueFromFlag(cmd.Flags(), "domains", "gogrok.domains", false) setValueFromFlag(cmd.Flags(), "domains", "gogrok.domains", false)
setValueFromFlag(cmd.Flags(), "store", "gogrok.store", false)
key, err := common.LoadOrGenerateKey(baseFs, path.Join(viper.GetString("gogrok.storageDir"), "server.key"), "") key, err := common.LoadOrGenerateKey(baseFs, path.Join(viper.GetString("gogrok.storageDir"), "server.key"), "")
@ -70,16 +73,56 @@ var serveCmd = &cobra.Command{
} }
opts = append(opts, server.WithAuthorizedKeys(authorizedKeys)) opts = append(opts, server.WithAuthorizedKeys(authorizedKeys))
log.WithField("keyFile", authorizedKeysFile).Info("Authorizing public keys on connection")
} }
handlerOpts := make([]server.HandlerOption, 0)
if domains := viper.GetStringSlice("gogrok.domains"); domains != nil { if domains := viper.GetStringSlice("gogrok.domains"); domains != nil {
generator := func() string { generator := func() string {
return server.RandomAnimal() + "." + domains[rand.Intn(len(domains))] return server.RandomAnimal() + "." + domains[rand.Intn(len(domains))]
} }
handler := server.NewHttpHandler(server.WithProvider(generator)) validator := server.ValidateMulti(server.DenyPrefixIn(server.Animals()), server.SuffixIn(domains))
opts = append(opts, server.WithForwardHandler(handler)) handlerOpts = append(handlerOpts, server.WithProvider(generator), server.WithValidator(validator))
log.WithField("domains", domains).Info("Registered domains for random use")
}
if storeUri := viper.GetString("gogrok.store"); storeUri != "" {
driver := "bolt"
if idx := strings.Index(storeUri, "://"); idx != -1 {
driver = storeUri[0:idx]
storeUri = storeUri[idx+3:]
}
var s store.Store
switch driver {
case "bolt":
fallthrough
default:
log.WithField("path", storeUri).Info("Using bolt store")
s, err = store.NewBoltStore(storeUri)
}
if err != nil {
log.WithError(err).Fatalln("Unable to create data store")
return
}
log.WithField("driver", driver).Info("Host store set, registration enabled")
handlerOpts = append(handlerOpts, server.WithStore(s))
}
if len(handlerOpts) > 0 {
handler := server.NewHttpHandler(handlerOpts...)
opts = append(opts, server.WithForwardHandler("http", handler))
} }
s, err := server.New(opts...) s, err := server.New(opts...)

39
cmd/unregister.go Normal file
View File

@ -0,0 +1,39 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"gogrok.ccatss.dev/client"
"os"
)
func init() {
unregisterCmd.Flags().String("server", "localhost:2222", "Gogrok Server Address")
rootCmd.AddCommand(unregisterCmd)
}
var unregisterCmd = &cobra.Command{
Use: "unregister",
Short: "Unregister a host with a gogrok server",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return ErrNoHost
}
return nil
},
PreRun: clientPreRun,
Run: func(cmd *cobra.Command, args []string) {
// Default command is client
c := client.New(viper.GetString("gogrok.server"), loadClientKey())
err := c.Unregister(args[0])
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to unregister: "+err.Error())
os.Exit(1)
}
cmd.Println("Successfully unregistered host " + args[0])
},
}

8
common/constants.go Normal file
View File

@ -0,0 +1,8 @@
package common
const (
HttpForward = "http-forward"
CancelHttpForward = "cancel-http-forward"
HttpRegisterHost = "http-register-host"
HttpUnregisterHost = "http-unregister-host"
)

View File

@ -7,6 +7,7 @@ const (
// RemoteForwardRequest represents a forwarding request // RemoteForwardRequest represents a forwarding request
type RemoteForwardRequest struct { type RemoteForwardRequest struct {
RequestedHost string RequestedHost string
Force bool
} }
// RemoteForwardSuccess returns when a successful request is processed // RemoteForwardSuccess returns when a successful request is processed
@ -25,3 +26,13 @@ type RemoteForwardChannelData struct {
Host string Host string
ClientIP string ClientIP string
} }
// HostRegisterRequest is used when registering a host
type HostRegisterRequest struct {
Host string
}
// HostRegisterSuccess is the response from the server for a Claim request
type HostRegisterSuccess struct {
Host string
}

1
go.mod
View File

@ -3,6 +3,7 @@ module gogrok.ccatss.dev
go 1.17 go 1.17
require ( require (
github.com/boltdb/bolt v1.3.1
github.com/gliderlabs/ssh v0.3.3 github.com/gliderlabs/ssh v0.3.3
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1

2
go.sum
View File

@ -66,6 +66,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4=
github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=

View File

@ -1,3 +1,7 @@
#!/bin/sh #!/bin/sh
if [ -z "$GOGROK_ARGS" ]; then
GOGROK_ARGS="serve"
fi
/usr/bin/gogrok $GOGROK_ARGS /usr/bin/gogrok $GOGROK_ARGS

View File

@ -38,7 +38,58 @@ func RandomAnimal() string {
return animals[rand.Intn(len(animals))] return animals[rand.Intn(len(animals))]
} }
// Animals returns the animal slice
func Animals() []string {
return animals
}
// DenyAll is a HostValidator to deny all custom requests // DenyAll is a HostValidator to deny all custom requests
func DenyAll(host string) bool { func DenyAll(host string) bool {
return false return false
} }
// DenyPrefixIn denies hosts in slice s
func DenyPrefixIn(s []string) HostValidator {
return func(host string) bool {
idx := strings.Index(host, ".")
if idx != -1 {
host = host[0:idx]
}
for _, val := range s {
if val == host {
return false
}
}
return true
}
}
// SuffixIn checks hosts for a suffix value
// Note: Suffix is automatically prepended with .
func SuffixIn(s []string) HostValidator {
return func(host string) bool {
for _, val := range s {
if strings.HasSuffix(host, "."+val) {
return true
}
}
return false
}
}
// ValidateMulti checks all specified validators before denying hosts
func ValidateMulti(validators ...HostValidator) HostValidator {
return func(host string) bool {
for _, validator := range validators {
if !validator(host) {
return false
}
}
return true
}
}

View File

@ -2,16 +2,20 @@ package server
import ( import (
"bufio" "bufio"
"bytes"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common" "gogrok.ccatss.dev/common"
"gogrok.ccatss.dev/server/store"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"io" "io"
"net"
"net/http" "net/http"
"net/textproto" "net/textproto"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
) )
// HostProvider is a func to provide a host + subdomain // HostProvider is a func to provide a host + subdomain
@ -24,12 +28,19 @@ type HostValidator func(host string) bool
// adding the HandleSSHRequest callback to the server's RequestHandlers under // adding the HandleSSHRequest callback to the server's RequestHandlers under
// tcpip-forward and cancel-tcpip-forward. // tcpip-forward and cancel-tcpip-forward.
type ForwardedHTTPHandler struct { type ForwardedHTTPHandler struct {
forwards map[string]*gossh.ServerConn forwards map[string]*Forward
provider HostProvider provider HostProvider
validator HostValidator validator HostValidator
store store.Store
sync.RWMutex sync.RWMutex
} }
// Forward contains the forwarded connection
type Forward struct {
Conn *gossh.ServerConn
Key ssh.PublicKey
}
// HandlerOption represents a func used to assign options to a ForwardedHTTPHandler // HandlerOption represents a func used to assign options to a ForwardedHTTPHandler
type HandlerOption func(h *ForwardedHTTPHandler) type HandlerOption func(h *ForwardedHTTPHandler)
@ -47,9 +58,16 @@ func WithValidator(validator HostValidator) HandlerOption {
} }
} }
// WithStore assigns a host store to use for storage
func WithStore(s store.Store) HandlerOption {
return func(h *ForwardedHTTPHandler) {
h.store = s
}
}
func NewHttpHandler(opts ...HandlerOption) ForwardHandler { func NewHttpHandler(opts ...HandlerOption) ForwardHandler {
h := &ForwardedHTTPHandler{ h := &ForwardedHTTPHandler{
forwards: make(map[string]*gossh.ServerConn), forwards: make(map[string]*Forward),
provider: RandomAnimal, provider: RandomAnimal,
validator: DenyAll, validator: DenyAll,
} }
@ -61,15 +79,26 @@ func NewHttpHandler(opts ...HandlerOption) ForwardHandler {
return h return h
} }
// RequestTypes lets the server know which request types this handler can use
func (h *ForwardedHTTPHandler) RequestTypes() []string {
return []string{
common.HttpForward,
common.CancelHttpForward,
common.HttpRegisterHost,
common.HttpUnregisterHost,
}
}
// ServeHTTP mocks an http server endpoint that uses Request.Host to forward requests // ServeHTTP mocks an http server endpoint that uses Request.Host to forward requests
func (h *ForwardedHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ForwardedHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.RLock() h.RLock()
sshConn, ok := h.forwards[r.Host] fw, ok := h.forwards[r.Host]
h.RUnlock() h.RUnlock()
if !ok { if !ok {
log.Warning("Unknown host ", r.Host) log.Warning("Unknown host ", r.Host)
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
log.Println("Valid hosts:", h.forwards)
return return
} }
@ -78,7 +107,7 @@ func (h *ForwardedHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
ClientIP: r.RemoteAddr, ClientIP: r.RemoteAddr,
}) })
ch, reqs, err := sshConn.OpenChannel(common.ForwardedHTTPChannelType, payload) ch, reqs, err := fw.Conn.OpenChannel(common.ForwardedHTTPChannelType, payload)
if err != nil { if err != nil {
log.WithError(err).Warning("Unable to open ssh connection channel") log.WithError(err).Warning("Unable to open ssh connection channel")
@ -162,84 +191,230 @@ func parseResponseLine(line string) (httpVersion, responseCode, responseText str
return line[:s1], line[s1+1 : s2], line[s2+1:], true return line[:s1], line[s1+1 : s2], line[s2+1:], true
} }
func (h *ForwardedHTTPHandler) checkHostOwnership(host, owner string) bool {
hostModel, err := h.store.Get(host)
if err != nil {
return false
}
return hostModel.Owner == owner
}
// HandleSSHRequest handles incoming ssh requests. // HandleSSHRequest handles incoming ssh requests.
func (h *ForwardedHTTPHandler) HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) { func (h *ForwardedHTTPHandler) HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
log.WithField("type", req.Type).Info("Handling request")
switch req.Type { switch req.Type {
case "http-forward": case common.HttpForward:
var reqPayload common.RemoteForwardRequest return h.handleForwardRequest(ctx, conn, req)
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { case common.CancelHttpForward:
// TODO: log parse failure return h.handleCancelRequest(ctx, req)
log.WithError(err).Warning("Error parsing payload for http-forward") case common.HttpRegisterHost:
return false, []byte{} return h.handleRegisterRequest(ctx, conn, req)
} case common.HttpUnregisterHost:
return h.handleUnregisterRequest(ctx, req)
default:
return false, nil
}
}
host := reqPayload.RequestedHost func (h *ForwardedHTTPHandler) handleForwardRequest(ctx ssh.Context, conn *gossh.ServerConn, req *gossh.Request) (bool, []byte) {
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for http-forward")
return false, []byte{}
}
if host != "" && h.validator != nil && !h.validator(host) { pubKey := ctx.Value("publicKey").(ssh.PublicKey)
keyStr := string(bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
host := strings.ToLower(reqPayload.RequestedHost)
if host != "" {
if h.validator != nil && !h.validator(host) {
return false, []byte("invalid host " + host) return false, []byte("invalid host " + host)
} }
// Validate host hostModel, err := h.store.Get(host)
if host == "" {
host = h.provider() if hostModel == nil || err != nil {
return false, []byte("host not registered")
}
if hostModel.Owner != keyStr {
return false, []byte("host claimed and not owned by current key")
} }
h.RLock()
current, exists := h.forwards[host]
h.RUnlock()
if exists && !reqPayload.Force {
return false, []byte("host already in use and force not set")
}
if exists {
// Force old connection to close
current.Conn.Close()
}
hostModel.LastUse = time.Now()
// Save model last use time
h.store.Add(*hostModel)
} else {
host = h.provider()
}
// Validate host
if host == "" {
h.RLock() h.RLock()
for { for {
host = h.provider()
_, exists := h.forwards[host] _, exists := h.forwards[host]
if !exists { if !exists {
break break
} }
host = h.provider()
} }
h.RUnlock() h.RUnlock()
h.Lock()
h.forwards[host] = conn
h.Unlock()
log.WithField("host", host).Info("Registered host")
go func() {
<-ctx.Done()
log.WithField("host", host).Info("Removed host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
}()
return true, gossh.Marshal(&common.RemoteForwardSuccess{
Host: host,
})
case "cancel-http-forward":
var reqPayload common.RemoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
return false, []byte{}
}
h.Lock()
delete(h.forwards, reqPayload.Host)
h.Unlock()
return true, nil
case "http-register-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
log.Println("Key:", ctx.Value("publicKey"))
// Claimed forward via SSH Public Key
return true, nil
default:
return false, nil
} }
log.WithField("host", host).Info("Registering host")
h.Lock()
h.forwards[host] = &Forward{
Conn: conn,
Key: pubKey,
}
h.Unlock()
log.WithField("host", host).Info("Registered host")
go func() {
<-ctx.Done()
log.WithField("host", host).Info("Removed host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
}()
return true, gossh.Marshal(&common.RemoteForwardSuccess{
Host: host,
})
}
func (h *ForwardedHTTPHandler) handleCancelRequest(ctx ssh.Context, req *gossh.Request) (bool, []byte) {
var reqPayload common.RemoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for cancel-http-forward")
return false, []byte{}
}
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
host := strings.ToLower(reqPayload.Host)
h.RLock()
fw, exists := h.forwards[host]
h.RUnlock()
if !exists {
return false, []byte("host not found")
}
if !bytes.Equal(pubKey.Marshal(), fw.Key.Marshal()) {
return false, []byte("host not owned by key")
}
log.WithField("host", host).Info("Unregistering host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
return true, nil
}
func (h *ForwardedHTTPHandler) handleRegisterRequest(ctx ssh.Context, conn *gossh.ServerConn, req *gossh.Request) (bool, []byte) {
var reqPayload common.HostRegisterRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
keyStr := string(bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
host := strings.ToLower(reqPayload.Host)
if host == "" || h.validator != nil && !h.validator(host) {
log.WithField("host", host).Warning("Host failed validation")
return false, []byte("invalid host " + host)
}
if h.store.Has(host) {
log.WithField("host", host).Warning("Host is already taken")
return false, []byte("host is already taken")
}
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
log.WithField("ip", ip).WithField("host", host).Info("Registering host")
err := h.store.Add(store.Host{
Host: host,
Owner: keyStr,
IP: ip,
Created: time.Now(),
LastUse: time.Now(),
})
if err != nil {
return false, []byte(err.Error())
}
return true, gossh.Marshal(common.HostRegisterSuccess{
Host: host,
})
}
func (h *ForwardedHTTPHandler) handleUnregisterRequest(ctx ssh.Context, req *gossh.Request) (bool, []byte) {
var reqPayload common.HostRegisterRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
pubKey := ctx.Value("publicKey").(ssh.PublicKey)
keyStr := string(bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
host := strings.ToLower(reqPayload.Host)
if host == "" || h.validator != nil && !h.validator(host) {
return false, []byte("invalid host " + host)
}
hostModel, err := h.store.Get(host)
if hostModel == nil || err != nil {
return false, []byte(err.Error())
}
if hostModel.Owner != keyStr {
return false, []byte("this host is not owned by you")
}
h.store.Remove(host)
return true, gossh.Marshal(common.HostRegisterSuccess{
Host: host,
})
} }

View File

@ -15,12 +15,13 @@ import (
// ForwardHandler is an interface defining the handler type for forwarding // ForwardHandler is an interface defining the handler type for forwarding
type ForwardHandler interface { type ForwardHandler interface {
HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte)
RequestTypes() []string
} }
// Server is a struct containing our ssh server, forwarding handler, and other attributes // Server is a struct containing our ssh server, forwarding handler, and other attributes
type Server struct { type Server struct {
sshServer *ssh.Server sshServer *ssh.Server
forwardHandler ForwardHandler forwardHandlers map[string]ForwardHandler
sshBindAddress string sshBindAddress string
hostSigners []ssh.Signer hostSigners []ssh.Signer
@ -33,9 +34,9 @@ type Option func(s *Server)
// WithForwardHandler lets custom forwarding handlers be registered. // WithForwardHandler lets custom forwarding handlers be registered.
// This will support multiple handlers eventually, for HTTP, TCP, etc. // This will support multiple handlers eventually, for HTTP, TCP, etc.
func WithForwardHandler(f ForwardHandler) Option { func WithForwardHandler(protocol string, handler ForwardHandler) Option {
return func(s *Server) { return func(s *Server) {
s.forwardHandler = f s.forwardHandlers[protocol] = handler
} }
} }
@ -76,14 +77,18 @@ func WithAuthorizedKeys(authorizedKeys []string) Option {
// New creates a new Server instance with a range of options. // New creates a new Server instance with a range of options.
func New(options ...Option) (*Server, error) { func New(options ...Option) (*Server, error) {
s := &Server{} s := &Server{
forwardHandlers: make(map[string]ForwardHandler),
}
for _, opt := range options { for _, opt := range options {
opt(s) opt(s)
} }
if s.forwardHandler == nil { if len(s.forwardHandlers) == 0 {
s.forwardHandler = NewHttpHandler(WithProvider(RandomAnimal)) httpHandler := NewHttpHandler(WithProvider(RandomAnimal))
s.forwardHandlers["http"] = httpHandler
} }
if s.hostSigners == nil || len(s.hostSigners) < 1 { if s.hostSigners == nil || len(s.hostSigners) < 1 {
@ -106,9 +111,10 @@ func New(options ...Option) (*Server, error) {
// TODO: Add TCP handler using the same idea, potentially support multiple forwardHandlers // TODO: Add TCP handler using the same idea, potentially support multiple forwardHandlers
if _, ok := s.forwardHandler.(http.Handler); ok { for _, handler := range s.forwardHandlers {
requestHandlers["http-forward"] = s.forwardHandler.HandleSSHRequest for _, requestType := range handler.RequestTypes() {
requestHandlers["cancel-http-forward"] = s.forwardHandler.HandleSSHRequest requestHandlers[requestType] = handler.HandleSSHRequest
}
} }
s.sshServer = &ssh.Server{ s.sshServer = &ssh.Server{
@ -170,22 +176,40 @@ func (s *Server) Start() error {
// StartHTTP is a convenience method to start a basic http server. // StartHTTP is a convenience method to start a basic http server.
// This uses s.forwardHandler if http.Handler is implemented to serve requests. // This uses s.forwardHandler if http.Handler is implemented to serve requests.
func (s *Server) StartHTTP(bind string) error { func (s *Server) StartHTTP(bind string) error {
if h, ok := s.forwardHandler.(http.Handler); ok { handler := s.forwardHandlers["http"]
httpServer := &http.Server{
Addr: bind,
Handler: h,
}
return httpServer.ListenAndServe() if handler == nil {
return errors.New("http handler not registered")
} }
return errors.New("forwarding handler doesn't support http") httpHandler, ok := handler.(http.Handler)
if !ok {
return errors.New("http handler cannot handle http requests")
}
httpServer := &http.Server{
Addr: bind,
Handler: httpHandler,
}
return httpServer.ListenAndServe()
} }
// ServeHTTP is a passthrough to forwardHandler's ServeHTTP // ServeHTTP is a passthrough to forwardHandler's ServeHTTP
// This can be used to use your own http server implementation, or for TLS/etc // This can be used to use your own http server implementation, or for TLS/etc
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h, ok := s.forwardHandler.(http.Handler); ok { handler := s.forwardHandlers["http"]
h.ServeHTTP(w, r)
if handler == nil {
return
} }
httpHandler, ok := handler.(http.Handler)
if !ok {
return
}
httpHandler.ServeHTTP(w, r)
} }

96
server/store/boltdb.go Normal file
View File

@ -0,0 +1,96 @@
package store
import (
"encoding/json"
"github.com/boltdb/bolt"
)
type BoltStore struct {
db *bolt.DB
}
// NewBoltStore creates a new boltdb backed Store instance
func NewBoltStore(path string) (Store, error) {
db, err := bolt.Open(path, 0644, bolt.DefaultOptions)
if err != nil {
return nil, err
}
err = db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte("hosts"))
return err
})
if err != nil {
return nil, err
}
return &BoltStore{
db: db,
}, nil
}
// Has checks if the host exists in the hosts bucket
func (b *BoltStore) Has(host string) bool {
var exists bool
b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
if b.Get([]byte(host)) != nil {
exists = true
}
return nil
})
return exists
}
// Get retrieves and deserializes a host from the hosts bucket
func (b *BoltStore) Get(key string) (*Host, error) {
var host Host
err := b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
data := b.Get([]byte(key))
if data == nil {
return ErrNoHost
}
return json.Unmarshal(data, &host)
})
if err != nil {
return nil, err
}
return &host, nil
}
// Add updates the hosts bucket and puts a json-serialized version of Host
func (b *BoltStore) Add(host Host) error {
data, err := json.Marshal(host)
if err != nil {
return err
}
return b.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
return b.Put([]byte(host.Host), data)
})
}
// Remove updates the hosts bucket and deletes the key
func (b *BoltStore) Remove(key string) error {
return b.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("hosts"))
return b.Delete([]byte(key))
})
}

27
server/store/store.go Normal file
View File

@ -0,0 +1,27 @@
package store
import (
"errors"
"time"
)
var (
ErrNoHost = errors.New("host not found")
)
// Store represents an interface to retrieve and store hosts
type Store interface {
Has(key string) bool
Get(key string) (*Host, error)
Add(host Host) error
Remove(key string) error
}
// Host represents a claimed host
type Host struct {
Host string `json:"host"`
Owner string `json:"owner"`
IP string `json:"ip"`
Created time.Time `json:"created"`
LastUse time.Time `json:"lastUse"`
}