Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a websocket handler #350

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion https.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package goproxy

import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"io"
Expand Down Expand Up @@ -184,8 +185,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
ctx.Warnf("Cannot handshake client %v %v", r.Host, err)
return
}
defer rawClientTls.Close()

clientTlsReader := bufio.NewReader(rawClientTls)

for !isEof(clientTlsReader) {
req, err := http.ReadRequest(clientTlsReader)
var ctx = &ProxyCtx{Req: req, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy, UserData: ctx.UserData}
Expand All @@ -209,6 +211,13 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request

req, resp := proxy.filterRequest(req, ctx)
if resp == nil {
if req.Header.Get("Upgrade") != "" {
proxy.WebSocketHandler.ServeHTTP(dumbResponseWriter{rawClientTls}, req)
return
} else {
defer rawClientTls.Close()
}

if err != nil {
ctx.Warnf("Illegal URL %s", "https://"+r.Host+req.URL.Path)
return
Expand Down Expand Up @@ -239,6 +248,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
resp.Header.Del("Content-Length")
resp.Header.Set("Transfer-Encoding", "chunked")
// Force connection close otherwise chrome will keep CONNECT tunnel open forever

resp.Header.Set("Connection", "close")
if err := resp.Header.Write(rawClientTls); err != nil {
ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err)
Expand All @@ -248,7 +258,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err)
return
}

chunked := newChunkedWriter(rawClientTls)

if _, err := io.Copy(chunked, resp.Body); err != nil {
ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err)
return
Expand Down Expand Up @@ -277,6 +289,30 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
}
}

type dumbResponseWriter struct {
net.Conn
}

func (dumb dumbResponseWriter) Header() http.Header {
// panic("Header() should not be called on this ResponseWriter")
return make(http.Header)
}

func (dumb dumbResponseWriter) Write(buf []byte) (int, error) {
if bytes.Equal(buf, []byte("HTTP/1.0 200 OK\r\n\r\n")) {
return len(buf), nil // throw away the HTTP OK response from the faux CONNECT request
}
return dumb.Conn.Write(buf)
}

func (dumb dumbResponseWriter) WriteHeader(code int) {
// panic("WriteHeader() should not be called on this ResponseWriter")
}

func (dumb dumbResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return dumb, bufio.NewReadWriter(bufio.NewReader(dumb), bufio.NewWriter(dumb)), nil
}

func httpError(w io.WriteCloser, ctx *ProxyCtx, err error) {
if _, err := io.WriteString(w, "HTTP/1.1 502 Bad Gateway\r\n\r\n"); err != nil {
ctx.Warnf("Error responding to client: %s", err)
Expand Down
15 changes: 11 additions & 4 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ type ProxyHttpServer struct {
Verbose bool
Logger Logger
NonproxyHandler http.Handler
reqHandlers []ReqHandler
respHandlers []RespHandler
httpsHandlers []HttpsHandler
Tr *http.Transport
WebSocketHandler http.Handler
reqHandlers []ReqHandler
respHandlers []RespHandler
httpsHandlers []HttpsHandler
Tr *http.Transport
// ConnectDial will be used to create TCP connections for CONNECT requests
// if nil Tr.Dial will be used
ConnectDial func(network string, addr string) (net.Conn, error)
Expand Down Expand Up @@ -110,9 +111,15 @@ func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request)
proxy.NonproxyHandler.ServeHTTP(w, r)
return
}

r, resp := proxy.filterRequest(r, ctx)

if resp == nil {
if r.Header.Get("Upgrade") != "" {
proxy.WebSocketHandler.ServeHTTP(w, r)
return
}

removeProxyHeaders(ctx, r)
resp, err = ctx.RoundTrip(r)
if err != nil {
Expand Down