Clean up proxying to use direct protocol instead of reading with http.ReadRequest
continuous-integration/drone/push Build is passing Details

Prevents mangling and changing of headers
This commit is contained in:
Tyler 2021-12-20 19:44:41 -05:00
parent 04a7914b1a
commit 21d76077f3
6 changed files with 370 additions and 296 deletions

View File

@ -1,10 +1,10 @@
kind: pipeline
name: default
type: docker
steps:
- name: build
image: tystuyfzand/goc:latest
group: build
volumes:
- name: build
path: /build

View File

@ -1,154 +1,195 @@
package client
import (
"bufio"
"crypto/tls"
log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common"
"golang.org/x/crypto/ssh"
"io"
"net"
"net/http"
"net/url"
"bufio"
"crypto/tls"
log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common"
"golang.org/x/crypto/ssh"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
)
// New creates a new client with the specified server and backend
func New(server, backend string, signer ssh.Signer) *Client {
config := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
}
config := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
}
backendUrl, err := url.Parse(backend)
backendUrl, err := url.Parse(backend)
if err != nil {
return nil
}
if err != nil {
return nil
}
host, port, _ := net.SplitHostPort(backendUrl.Host)
if backendUrl.Scheme == "" {
backendUrl.Scheme = "http"
}
if port == "" {
port = "80"
}
host, port, _ := net.SplitHostPort(backendUrl.Host)
dialHost := net.JoinHostPort(host, port)
if port == "" {
port = "80"
}
return &Client{
config: config,
server: server,
backendUrl: backendUrl,
dialHost: dialHost,
tlsConfig: &tls.Config{ServerName: host, InsecureSkipVerify: true},
}
dialHost := net.JoinHostPort(host, port)
return &Client{
config: config,
server: server,
backendUrl: backendUrl,
dialHost: dialHost,
tlsConfig: &tls.Config{ServerName: host, InsecureSkipVerify: true},
}
}
// Client is a remote tunnel client
type Client struct {
config *ssh.ClientConfig
tlsConfig *tls.Config
config *ssh.ClientConfig
tlsConfig *tls.Config
server string
backendUrl *url.URL
dialHost string
server string
backendUrl *url.URL
dialHost string
}
// Start connects to the server over TCP and starts the tunnel
func (c *Client) Start() error {
log.WithFields(log.Fields{
"server": c.server,
}).Info("Dialing server")
log.WithFields(log.Fields{
"server": c.server,
}).Info("Dialing server")
conn, err := ssh.Dial("tcp", c.server, c.config)
conn, err := ssh.Dial("tcp", c.server, c.config)
if err != nil {
return err
}
if err != nil {
return err
}
payload := ssh.Marshal(common.RemoteForwardRequest{
RequestedHost: "",
})
if c.backendUrl.Scheme == "http" || c.backendUrl.Scheme == "https" {
go c.handleHttpRequests(conn)
}
success, replyData, err := conn.SendRequest("http-forward", true, payload)
if !success || err != nil {
log.WithError(err).Fatalln("Unable to start forward request")
}
log.WithFields(log.Fields{
"success": success,
}).Info("Got response")
var response common.RemoteForwardSuccess
if err := ssh.Unmarshal(replyData, &response); err != nil {
log.WithError(err).Fatalln("Unable to unmarshal data")
}
defer func() {
payload := ssh.Marshal(common.RemoteForwardCancelRequest{Host: response.Host})
conn.SendRequest("cancel-http-forward", false, payload)
}()
log.WithField("host", response.Host).Info("Bound host")
ch := conn.HandleChannelOpen(common.ForwardedHTTPChannelType)
for {
newCh := <-ch
if newCh == nil {
break
}
ch, r, err := newCh.Accept()
if err != nil {
log.WithError(err).Warning("Error accepting channel")
continue
}
go ssh.DiscardRequests(r)
go c.proxyRequest(ch)
}
return nil
return nil
}
func (c *Client) handleHttpRequests(conn *ssh.Client) {
payload := ssh.Marshal(common.RemoteForwardRequest{
RequestedHost: "",
})
success, replyData, err := conn.SendRequest("http-forward", true, payload)
if !success || err != nil {
log.WithFields(log.Fields{
"error": err,
"message": string(replyData),
}).Fatalln("Unable to start forwarding")
}
log.WithFields(log.Fields{
"success": success,
}).Info("Got response")
var response common.RemoteForwardSuccess
if err := ssh.Unmarshal(replyData, &response); err != nil {
log.WithError(err).Fatalln("Unable to unmarshal data")
}
defer func() {
payload := ssh.Marshal(common.RemoteForwardCancelRequest{Host: response.Host})
conn.SendRequest("cancel-http-forward", false, payload)
}()
log.WithField("host", response.Host).Info("Bound host")
ch := conn.HandleChannelOpen(common.ForwardedHTTPChannelType)
for {
newCh := <-ch
if newCh == nil {
break
}
ch, r, err := newCh.Accept()
if err != nil {
log.WithError(err).Warning("Error accepting channel")
continue
}
go ssh.DiscardRequests(r)
go c.proxyRequest(ch)
}
}
// proxyRequest handles a request from the ssh channel and forwards it to the local http server
func (c *Client) proxyRequest(rw io.ReadWriteCloser) {
tcpConn, err := net.Dial("tcp", c.dialHost)
tcpConn, err := net.Dial("tcp", c.dialHost)
if err != nil {
rw.Close()
return
}
if err != nil {
rw.Close()
return
}
if c.backendUrl.Scheme == "https" || c.backendUrl.Scheme == "wss" {
// Wrap with TLS
tcpConn = tls.Client(tcpConn, c.tlsConfig)
}
if c.backendUrl.Scheme == "https" || c.backendUrl.Scheme == "wss" {
// Wrap with TLS
tcpConn = tls.Client(tcpConn, c.tlsConfig)
}
defer rw.Close()
defer tcpConn.Close()
defer rw.Close()
defer tcpConn.Close()
bufferedCh := bufio.NewReader(rw)
bufferedCh := bufio.NewReader(rw)
req, err := http.ReadRequest(bufferedCh)
tp := textproto.NewReader(bufferedCh)
if err != nil {
log.WithError(err).Warning("Unable to read request header from ch")
return
}
var s string
if s, err = tp.ReadLine(); err != nil {
return
}
// TODO: By parsing the request, we're overriding some fields. Perhaps we want to read only the header and then write the body manually?
// Write the first response line as-is
tcpConn.Write([]byte(s + "\r\n"))
// Override host
req.URL.Scheme = c.backendUrl.Scheme
req.URL.Host = c.backendUrl.Host
req.Host = c.backendUrl.Host
// Read headers and proxy each to the output
mimeHeader, err := tp.ReadMIMEHeader()
go req.Write(tcpConn)
if err != nil {
return
}
io.Copy(rw, tcpConn)
}
// Modify and return our headers
headers := http.Header(mimeHeader)
headers.Set("Host", c.backendUrl.Host)
headers.Write(tcpConn)
// End headers
tcpConn.Write([]byte("\r\n"))
contentLength, err := strconv.ParseInt(headers.Get("Content-Length"), 10, 64)
if err == nil && contentLength > 0 {
// Copy request to the tcpConn
_, err := io.Copy(tcpConn, io.LimitReader(bufferedCh, int64(contentLength)))
if err != nil {
log.WithError(err).Warning("Connection error on body read, closing")
return
}
}
// Copy the response back to the tunnel server
io.Copy(rw, tcpConn)
}

View File

@ -11,6 +11,8 @@ import (
"io/ioutil"
)
// LoadOrGenerateKey generates an RSA Private Key and saves it to a file
// passphrase can be used to (insecurely) encrypt it, though it's deprecated
func LoadOrGenerateKey(fs afero.Fs, file, passphrase string) (*rsa.PrivateKey, error) {
if file == "" {
return GenRSA(4096)
@ -31,10 +33,22 @@ func LoadOrGenerateKey(fs afero.Fs, file, passphrase string) (*rsa.PrivateKey, e
return nil, err
}
err = pem.Encode(f, &pem.Block{
defer f.Close()
pemBlock := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
}
if passphrase != "" {
pemBlock, err = x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", pemBlock.Bytes, []byte(passphrase), x509.PEMCipherAES128)
if err != nil {
return nil, err
}
}
err = pem.Encode(f, pemBlock)
if err != nil {
return nil, err

View File

@ -4,18 +4,23 @@ const (
ForwardedHTTPChannelType = "forwarded-http"
)
// RemoteForwardRequest represents a forwarding request
type RemoteForwardRequest struct {
RequestedHost string
}
// RemoteForwardSuccess returns when a successful request is processed
// Host represents the assigned remote host
type RemoteForwardSuccess struct {
Host string
}
// RemoteForwardCancelRequest represents a forwarding cancel request
type RemoteForwardCancelRequest struct {
Host string
}
// RemoteForwardChannelData is sent when opening a channel to say which host/client ip is accessed
type RemoteForwardChannelData struct {
Host string
ClientIP string

View File

@ -1,37 +1,44 @@
package server
import (
"bufio"
"bytes"
_ "embed"
"math/rand"
"time"
"bufio"
"bytes"
_ "embed"
"math/rand"
"strings"
"time"
)
var (
//go:embed animals.txt
animalBytes []byte
animals []string
//go:embed animals.txt
animalBytes []byte
animals []string
)
func init() {
animals = make([]string, 0)
animals = make([]string, 0)
s := bufio.NewScanner(bytes.NewReader(animalBytes))
s := bufio.NewScanner(bytes.NewReader(animalBytes))
for s.Scan() {
animals = append(animals, s.Text())
}
for s.Scan() {
line := strings.TrimSpace(s.Text())
rand.Seed(time.Now().UTC().UnixNano())
if line[0] == '#' {
continue
}
animals = append(animals, line)
}
rand.Seed(time.Now().UTC().UnixNano())
}
// RandomAnimal is a basic HostProvider using animal names
func RandomAnimal() string {
return animals[rand.Intn(len(animals))]
return animals[rand.Intn(len(animals))]
}
// DenyAll is a HostValidator to deny all custom requests
func DenyAll(host string) bool {
return false
return false
}

View File

@ -1,17 +1,17 @@
package server
import (
"bufio"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common"
gossh "golang.org/x/crypto/ssh"
"io"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"bufio"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"gogrok.ccatss.dev/common"
gossh "golang.org/x/crypto/ssh"
"io"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
)
// HostProvider is a func to provide a host + subdomain
@ -24,215 +24,222 @@ type HostValidator func(host string) bool
// adding the HandleSSHRequest callback to the server's RequestHandlers under
// tcpip-forward and cancel-tcpip-forward.
type ForwardedHTTPHandler struct {
forwards map[string]*gossh.ServerConn
provider HostProvider
validator HostValidator
sync.RWMutex
forwards map[string]*gossh.ServerConn
provider HostProvider
validator HostValidator
sync.RWMutex
}
// HandlerOption represents a func used to assign options to a ForwardedHTTPHandler
type HandlerOption func(h *ForwardedHTTPHandler)
// WithProvider sets a default domain provider
func WithProvider(provider HostProvider) HandlerOption {
return func(h *ForwardedHTTPHandler) {
h.provider = provider
}
return func(h *ForwardedHTTPHandler) {
h.provider = provider
}
}
// WithValidator sets a host validator to use for validation of custom hosts
func WithValidator(validator HostValidator) HandlerOption {
return func(h *ForwardedHTTPHandler) {
h.validator = validator
}
return func(h *ForwardedHTTPHandler) {
h.validator = validator
}
}
func NewHttpHandler(opts ...HandlerOption) ForwardHandler {
h := &ForwardedHTTPHandler{
forwards: make(map[string]*gossh.ServerConn),
provider: RandomAnimal,
}
h := &ForwardedHTTPHandler{
forwards: make(map[string]*gossh.ServerConn),
provider: RandomAnimal,
validator: DenyAll,
}
for _, opt := range opts {
opt(h)
}
for _, opt := range opts {
opt(h)
}
return h
return h
}
// ServeHTTP mocks an http server endpoint that uses Request.Host to forward requests
func (h *ForwardedHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.RLock()
sshConn, ok := h.forwards[r.Host]
h.RUnlock()
h.RLock()
sshConn, ok := h.forwards[r.Host]
h.RUnlock()
if !ok {
log.Warning("Unknown host ", r.Host)
http.Error(w, "not found", http.StatusNotFound)
return
}
if !ok {
log.Warning("Unknown host ", r.Host)
http.Error(w, "not found", http.StatusNotFound)
return
}
payload := gossh.Marshal(&common.RemoteForwardChannelData{
Host: r.Host,
ClientIP: r.RemoteAddr,
})
payload := gossh.Marshal(&common.RemoteForwardChannelData{
Host: r.Host,
ClientIP: r.RemoteAddr,
})
ch, reqs, err := sshConn.OpenChannel(common.ForwardedHTTPChannelType, payload)
ch, reqs, err := sshConn.OpenChannel(common.ForwardedHTTPChannelType, payload)
if err != nil {
return
}
if err != nil {
log.WithError(err).Warning("Unable to open ssh connection channel")
return
}
go gossh.DiscardRequests(reqs)
go gossh.DiscardRequests(reqs)
defer ch.Close()
defer ch.Close()
// Ensure we have Connection: close, keep alive isn't supported
r.Header.Set("Connection", "close")
// Ensure we have Connection: close, keep alive isn't supported
r.Header.Set("Connection", "close")
// Write the request to our channel
r.Write(ch)
// Write the request to our channel
r.Write(ch)
// Read the response
bufReader := bufio.NewReader(ch)
// Read the response
bufReader := bufio.NewReader(ch)
tp := textproto.NewReader(bufReader)
tp := textproto.NewReader(bufReader)
var s string
if s, err = tp.ReadLine(); err != nil {
w.WriteHeader(http.StatusBadGateway)
return
}
var s string
if s, err = tp.ReadLine(); err != nil {
w.WriteHeader(http.StatusBadGateway)
return
}
_, responseCodeStr, _, ok := parseResponseLine(s)
_, responseCodeStr, _, ok := parseResponseLine(s)
if !ok {
w.WriteHeader(http.StatusBadGateway)
return
}
if !ok {
w.WriteHeader(http.StatusBadGateway)
log.Warning("Backend returned unexpected response line")
return
}
responseCode, err := strconv.Atoi(responseCodeStr)
responseCode, err := strconv.Atoi(responseCodeStr)
if responseCode < http.StatusContinue || responseCode > http.StatusNetworkAuthenticationRequired {
return
}
if responseCode < http.StatusContinue || responseCode > http.StatusNetworkAuthenticationRequired {
return
}
mimeHeader, err := tp.ReadMIMEHeader()
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
if err != nil {
return
}
for k, v := range mimeHeader {
w.Header()[k] = v
}
for k, v := range mimeHeader {
w.Header()[k] = v
}
// Set our forwarded address
forwardedHeader := r.Header.Get("X-Forwarded-For")
// Set our forwarded address
forwardedHeader := r.Header.Get("X-Forwarded-For")
// TODO: Only trust this from trusted sources.
if forwardedHeader != "" {
forwardedHeader = strings.Join(append([]string{r.RemoteAddr}, strings.Split(",", forwardedHeader)...), ",")
} else {
forwardedHeader = r.RemoteAddr
}
// TODO: Only trust this from trusted sources.
if forwardedHeader != "" {
forwardedHeader = strings.Join(append([]string{r.RemoteAddr}, strings.Split(",", forwardedHeader)...), ",")
} else {
forwardedHeader = r.RemoteAddr
}
w.Header().Set("X-Forwarded-For", forwardedHeader)
w.Header().Set("X-Forwarded-For", forwardedHeader)
if r.TLS != nil {
w.Header().Set("X-Forwarded-Proto", "https")
}
if r.TLS != nil {
w.Header().Set("X-Forwarded-Proto", "https")
}
w.WriteHeader(responseCode)
w.WriteHeader(responseCode)
io.Copy(w, bufReader)
io.Copy(w, bufReader)
}
// parseResponseLine parses "HTTP/1.1 200 OK" into its three parts.
func parseResponseLine(line string) (httpVersion, responseCode, responseText string, ok bool) {
s1 := strings.Index(line, " ")
s2 := strings.Index(line[s1+1:], " ")
if s1 < 0 || s2 < 0 {
return
}
s2 += s1 + 1
return line[:s1], line[s1+1 : s2], line[s2+1:], true
s1 := strings.Index(line, " ")
s2 := strings.Index(line[s1+1:], " ")
if s1 < 0 || s2 < 0 {
return
}
s2 += s1 + 1
return line[:s1], line[s1+1 : s2], line[s2+1:], true
}
// HandleSSHRequest handles incoming ssh requests.
func (h *ForwardedHTTPHandler) HandleSSHRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
switch req.Type {
case "http-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-forward")
return false, []byte{}
}
switch req.Type {
case "http-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-forward")
return false, []byte{}
}
host := reqPayload.RequestedHost
host := reqPayload.RequestedHost
if host != "" && h.validator != nil && !h.validator(host) {
return false, []byte("invalid host " + host)
}
if host != "" && h.validator != nil && !h.validator(host) {
return false, []byte("invalid host " + host)
}
// Validate host
if host == "" {
host = h.provider()
}
// Validate host
if host == "" {
host = h.provider()
}
h.RLock()
for {
_, exists := h.forwards[host]
h.RLock()
for {
_, exists := h.forwards[host]
if !exists {
break
}
if !exists {
break
}
host = h.provider()
}
h.RUnlock()
host = h.provider()
}
h.RUnlock()
h.Lock()
h.forwards[host] = conn
h.Unlock()
log.WithField("host", host).Info("Registered host")
h.Lock()
h.forwards[host] = conn
h.Unlock()
log.WithField("host", host).Info("Registered host")
go func() {
<-ctx.Done()
go func() {
<-ctx.Done()
log.WithField("host", host).Info("Removed host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
}()
log.WithField("host", host).Info("Removed host")
h.Lock()
delete(h.forwards, host)
h.Unlock()
}()
return true, gossh.Marshal(&common.RemoteForwardSuccess{
Host: host,
})
return true, gossh.Marshal(&common.RemoteForwardSuccess{
Host: host,
})
case "cancel-http-forward":
var reqPayload common.RemoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
return false, []byte{}
}
h.Lock()
delete(h.forwards, reqPayload.Host)
h.Unlock()
return true, nil
case "cancel-http-forward":
var reqPayload common.RemoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
return false, []byte{}
}
h.Lock()
delete(h.forwards, reqPayload.Host)
h.Unlock()
return true, nil
case "http-register-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
case "http-register-forward":
var reqPayload common.RemoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
log.WithError(err).Warning("Error parsing payload for http-register-forward")
return false, []byte{}
}
log.Println("Key:", ctx.Value("publicKey"))
log.Println("Key:", ctx.Value("publicKey"))
// Claimed forward via SSH Public Key
return true, nil
default:
return false, nil
}
// Claimed forward via SSH Public Key
return true, nil
default:
return false, nil
}
}