From e77b9fdbd19f64e2ce71e5f01ac9e2e296dde4cc Mon Sep 17 00:00:00 2001 From: Dmitriy Goraschenko Date: Wed, 19 Dec 2018 11:49:00 +0400 Subject: [PATCH] add websocket handler --- https.go | 38 +++++++++++++++++++++++++++++++++++++- proxy.go | 15 +++++++++++---- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/https.go b/https.go index 12de7511..2f3d07c1 100644 --- a/https.go +++ b/https.go @@ -2,6 +2,7 @@ package goproxy import ( "bufio" + "bytes" "crypto/tls" "errors" "io" @@ -184,8 +185,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Warnf("Cannot handshake client %v %v", r.Host, err) return } - defer rawClientTls.Close() + clientTlsReader := bufio.NewReader(rawClientTls) + for !isEof(clientTlsReader) { req, err := http.ReadRequest(clientTlsReader) var ctx = &ProxyCtx{Req: req, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy, UserData: ctx.UserData} @@ -209,6 +211,13 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request req, resp := proxy.filterRequest(req, ctx) if resp == nil { + if req.Header.Get("Upgrade") != "" { + proxy.WebSocketHandler.ServeHTTP(dumbResponseWriter{rawClientTls}, req) + return + } else { + defer rawClientTls.Close() + } + if err != nil { ctx.Warnf("Illegal URL %s", "https://"+r.Host+req.URL.Path) return @@ -239,6 +248,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request resp.Header.Del("Content-Length") resp.Header.Set("Transfer-Encoding", "chunked") // Force connection close otherwise chrome will keep CONNECT tunnel open forever + resp.Header.Set("Connection", "close") if err := resp.Header.Write(rawClientTls); err != nil { ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err) @@ -248,7 +258,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err) return } + chunked := newChunkedWriter(rawClientTls) + if _, err := io.Copy(chunked, resp.Body); err != nil { ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err) return @@ -277,6 +289,30 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request } } +type dumbResponseWriter struct { + net.Conn +} + +func (dumb dumbResponseWriter) Header() http.Header { + // panic("Header() should not be called on this ResponseWriter") + return make(http.Header) +} + +func (dumb dumbResponseWriter) Write(buf []byte) (int, error) { + if bytes.Equal(buf, []byte("HTTP/1.0 200 OK\r\n\r\n")) { + return len(buf), nil // throw away the HTTP OK response from the faux CONNECT request + } + return dumb.Conn.Write(buf) +} + +func (dumb dumbResponseWriter) WriteHeader(code int) { + // panic("WriteHeader() should not be called on this ResponseWriter") +} + +func (dumb dumbResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return dumb, bufio.NewReadWriter(bufio.NewReader(dumb), bufio.NewWriter(dumb)), nil +} + func httpError(w io.WriteCloser, ctx *ProxyCtx, err error) { if _, err := io.WriteString(w, "HTTP/1.1 502 Bad Gateway\r\n\r\n"); err != nil { ctx.Warnf("Error responding to client: %s", err) diff --git a/proxy.go b/proxy.go index 2509f6d1..afbcb569 100644 --- a/proxy.go +++ b/proxy.go @@ -22,10 +22,11 @@ type ProxyHttpServer struct { Verbose bool Logger Logger NonproxyHandler http.Handler - reqHandlers []ReqHandler - respHandlers []RespHandler - httpsHandlers []HttpsHandler - Tr *http.Transport + WebSocketHandler http.Handler + reqHandlers []ReqHandler + respHandlers []RespHandler + httpsHandlers []HttpsHandler + Tr *http.Transport // ConnectDial will be used to create TCP connections for CONNECT requests // if nil Tr.Dial will be used ConnectDial func(network string, addr string) (net.Conn, error) @@ -110,9 +111,15 @@ func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) proxy.NonproxyHandler.ServeHTTP(w, r) return } + r, resp := proxy.filterRequest(r, ctx) if resp == nil { + if r.Header.Get("Upgrade") != "" { + proxy.WebSocketHandler.ServeHTTP(w, r) + return + } + removeProxyHeaders(ctx, r) resp, err = ctx.RoundTrip(r) if err != nil {